package com.github.tjake.jlama.model;

import com.github.tjake.jlama.safetensors.Config;
import com.github.tjake.jlama.tensor.AbstractTensor;
import java.util.List;
import java.util.function.Consumer;

/* loaded from: input_file:com/github/tjake/jlama/model/DistributedContext.class */
public class DistributedContext {
    private final Config c;
    private final int modelShard;
    private final int numModelShards;
    private final int layerShard;
    private final int numLayerShards;
    private final Consumer<List<AbstractTensor>> tensorSync;
    public final int embeddingSegmentStart;
    public final int embeddingSegmentLength;
    public final int embeddingSegmentEnd;
    public final int attentionSegmentStart;
    public final int attentionSegmentLength;
    public final int attentionSegmentEnd;
    public final int hiddenSegmentStart;
    public final int hiddenSegmentLength;
    public final int hiddenSegmentEnd;
    public final int kvSegmentStart;
    public final int kvSegmentLength;
    public final int kvSegmentEnd;
    public final int headStart;
    public final int headEnd;
    public final int groupHeadStart;
    public final int groupHeadEnd;
    public final int numberOfLayers;
    public final int layerStart;
    public final int layerEnd;

    /* loaded from: input_file:com/github/tjake/jlama/model/DistributedContext$Builder.class */
    public static class Builder {
        private Config c;
        private int modelShard = 0;
        private int numModelShards = 1;
        private int layerShard = 0;
        private int numLayerShards = 1;
        private Consumer<List<AbstractTensor>> tensorSync;

        public Builder(Config config) {
            this.c = config;
        }

        public Builder setModelShard(int i) {
            this.modelShard = i;
            return this;
        }

        public Builder setNumModelShards(int i) {
            this.numModelShards = i;
            return this;
        }

        public Builder setLayerShard(int i) {
            this.layerShard = i;
            return this;
        }

        public Builder setNumLayerShards(int i) {
            this.numLayerShards = i;
            return this;
        }

        public Builder setTensorSync(Consumer<List<AbstractTensor>> consumer) {
            this.tensorSync = consumer;
            return this;
        }

        public DistributedContext build() {
            return new DistributedContext(this.c, this.modelShard, this.numModelShards, this.layerShard, this.numLayerShards, this.tensorSync);
        }
    }

    private DistributedContext(Config config, int i, int i2, int i3, int i4, Consumer<List<AbstractTensor>> consumer) {
        this.c = config;
        this.modelShard = i;
        this.numModelShards = i2;
        this.layerShard = i3;
        this.numLayerShards = i4;
        this.tensorSync = consumer;
        this.numberOfLayers = config.numberOfLayers / i4;
        this.layerStart = this.numberOfLayers * i3;
        this.layerEnd = this.layerStart + this.numberOfLayers;
        this.embeddingSegmentLength = config.embeddingLength / i2;
        this.embeddingSegmentStart = this.embeddingSegmentLength * i;
        this.embeddingSegmentEnd = this.embeddingSegmentStart + this.embeddingSegmentLength;
        this.attentionSegmentLength = config.attentionLength / i2;
        this.attentionSegmentStart = this.attentionSegmentLength * i;
        this.attentionSegmentEnd = this.attentionSegmentStart + this.attentionSegmentLength;
        this.hiddenSegmentLength = config.hiddenLength / i2;
        this.hiddenSegmentStart = this.hiddenSegmentLength * i;
        this.hiddenSegmentEnd = this.hiddenSegmentStart + this.hiddenSegmentLength;
        this.kvSegmentStart = this.embeddingSegmentStart / config.headGroupSize;
        this.kvSegmentEnd = this.embeddingSegmentEnd / config.headGroupSize;
        this.kvSegmentLength = this.embeddingSegmentLength / config.headGroupSize;
        this.headStart = this.embeddingSegmentStart / config.headSize;
        this.headEnd = this.embeddingSegmentEnd / config.headSize;
        this.groupHeadStart = this.kvSegmentStart / config.headSize;
        this.groupHeadEnd = this.kvSegmentEnd / config.headSize;
    }

    public boolean hasModelShard() {
        return this.numModelShards > 1;
    }

    public void syncTensors(List<AbstractTensor> list) {
        this.tensorSync.accept(list);
    }

    public int getShardOffsetForLength(int i) {
        return (i / this.numModelShards) * this.modelShard;
    }

    public int getShardLength(int i) {
        return i / this.numModelShards;
    }

    public static Builder builder(Config config) {
        return new Builder(config);
    }
}
