package com.github.tjake.jlama.tensor;

import com.github.tjake.jlama.util.Pair;
import com.google.common.base.Preconditions;
import java.util.Arrays;
import java.util.Objects;
import java.util.Optional;

/* loaded from: input_file:com/github/tjake/jlama/tensor/TensorShape.class */
public class TensorShape {
    public static TensorShape one = of(1, 1);
    private final int[] tshape;
    private final long capacity;
    private final Optional<Pair<Integer, Integer>> sparseRange;
    private final boolean isSparse;
    private final int sparseOffset;
    private final int sparseLength;

    public static TensorShape of(int... iArr) {
        if (iArr.length == 1) {
            iArr = new int[]{1, iArr[0]};
        }
        return new TensorShape(iArr, Optional.empty());
    }

    public static TensorShape sparse(int[] iArr, Pair<Integer, Integer> pair) {
        return new TensorShape(iArr, Optional.of(pair));
    }

    private TensorShape(int[] iArr, Optional<Pair<Integer, Integer>> optional) {
        Preconditions.checkArgument(iArr.length > 1, "Shape must have at least two dimensions, even if first is 1 (to represent a vector)");
        this.tshape = iArr;
        this.sparseRange = optional;
        this.isSparse = optional.isPresent();
        this.sparseOffset = ((Integer) optional.map((v0) -> {
            return v0.left();
        }).orElse(0)).intValue();
        this.sparseLength = ((Integer) optional.map((v0) -> {
            return v0.right();
        }).orElse(Integer.valueOf(iArr[iArr.length - 1]))).intValue();
        long j = 1;
        for (int i = 0; i < iArr.length - 1; i++) {
            j *= iArr[i];
        }
        this.capacity = j * this.sparseLength;
    }

    public final boolean isSparse() {
        return this.isSparse;
    }

    public int dims() {
        return this.tshape.length;
    }

    public int dim(int i) {
        Preconditions.checkArgument(i < this.tshape.length);
        return this.tshape[i];
    }

    public final int getOffset(int... iArr) {
        switch (iArr.length) {
            case 1:
                return (this.sparseLength * iArr[0]) - this.sparseOffset;
            case 2:
                return ((this.sparseLength * iArr[0]) + iArr[1]) - this.sparseOffset;
            case 3:
                return ((((this.sparseLength * this.tshape[1]) * iArr[0]) + (this.sparseLength * iArr[1])) + iArr[2]) - this.sparseOffset;
            default:
                int i = 0;
                for (int i2 = 0; i2 < iArr.length - 1; i2++) {
                    int i3 = this.sparseLength;
                    for (int length = this.tshape.length - 2; length > i2; length--) {
                        i3 *= this.tshape[length];
                    }
                    i += iArr[i2] * i3;
                }
                return (i + iArr[iArr.length - 1]) - this.sparseOffset;
        }
    }

    public int sparseLength() {
        return this.sparseLength;
    }

    public int sparseOffset() {
        return this.sparseOffset;
    }

    public int sparseAdjustment(int i) {
        return i - this.sparseOffset;
    }

    public TensorShape scaleLastDim(float f) {
        int[] copyOf = Arrays.copyOf(this.tshape, this.tshape.length);
        copyOf[copyOf.length - 1] = (int) (copyOf[r1] * f);
        return this.isSparse ? sparse(copyOf, Pair.create(Integer.valueOf((int) (this.sparseOffset * f)), Integer.valueOf((int) (this.sparseLength * f)))) : of(copyOf);
    }

    public TensorShape setDimValue(int i, int i2) {
        Preconditions.checkArgument(i < this.tshape.length);
        int[] copyOf = Arrays.copyOf(this.tshape, this.tshape.length);
        copyOf[i] = i2;
        return this.isSparse ? sparse(copyOf, Pair.create(Integer.valueOf(this.sparseOffset), Integer.valueOf(copyOf[copyOf.length - 1]))) : of(copyOf);
    }

    public int first() {
        return this.tshape[0];
    }

    public int last() {
        return this.tshape[this.tshape.length - 1];
    }

    public long size() {
        return this.capacity;
    }

    public TensorShape sparsify(int i, int i2) {
        Preconditions.checkArgument(!this.isSparse, "Cannot sparsify a sparse tensor");
        return new TensorShape(this.tshape, Optional.of(Pair.create(Integer.valueOf(i), Integer.valueOf(i2))));
    }

    public TensorShape slice(int i) {
        Preconditions.checkArgument(i < this.tshape.length, "Too many dimensions specified for tensor");
        return this.tshape.length - i == 1 ? new TensorShape(new int[]{1, this.tshape[this.tshape.length - 1]}, this.sparseRange) : new TensorShape(Arrays.copyOfRange(this.tshape, i, this.tshape.length), this.sparseRange);
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        TensorShape tensorShape = (TensorShape) obj;
        return Arrays.equals(this.tshape, tensorShape.tshape) && Objects.equals(this.sparseRange, tensorShape.sparseRange);
    }

    public int hashCode() {
        return (31 * Objects.hash(this.sparseRange)) + Arrays.hashCode(this.tshape);
    }

    public String toString() {
        String arrays = Arrays.toString(this.tshape);
        long j = this.capacity;
        String.valueOf(this.sparseRange);
        return "TensorShape{tshape=" + arrays + ", capacity=" + j + ", sparseRange=" + arrays + "}";
    }
}
