package com.github.tjake.jlama.tensor.operations;

import com.github.tjake.jlama.tensor.AbstractTensor;
import com.google.common.base.Preconditions;

/* loaded from: input_file:com/github/tjake/jlama/tensor/operations/NaiveTensorOperations.class */
public class NaiveTensorOperations implements TensorOperations {
    @Override // com.github.tjake.jlama.tensor.operations.TensorOperations
    public String name() {
        return "Naive Java Operations";
    }

    @Override // com.github.tjake.jlama.tensor.operations.TensorOperations
    public boolean requiresOffHeapTensor() {
        return false;
    }

    @Override // com.github.tjake.jlama.tensor.operations.TensorOperations
    public void accumulate(AbstractTensor abstractTensor, AbstractTensor abstractTensor2, int i, int i2) {
        Preconditions.checkArgument(abstractTensor.dims() == abstractTensor2.dims());
        boolean z = abstractTensor2.shape().first() > 1;
        for (int i3 = 0; i3 < abstractTensor.shape().first(); i3++) {
            AbstractTensor slice = abstractTensor.slice(i3);
            AbstractTensor slice2 = z ? abstractTensor2.slice(i3) : abstractTensor2;
            for (int i4 = i; i4 < i + i2; i4++) {
                slice.set(slice.get(0, i4) + slice2.get(0, i4), 0, i4);
            }
        }
    }

    @Override // com.github.tjake.jlama.tensor.operations.TensorOperations
    public void maccumulate(AbstractTensor abstractTensor, AbstractTensor abstractTensor2, int i, int i2) {
        Preconditions.checkArgument(abstractTensor.size() == abstractTensor2.size() && abstractTensor.dims() == abstractTensor2.dims() && abstractTensor.dims() == 1);
        for (int i3 = i; i3 < i + i2; i3++) {
            abstractTensor.set(abstractTensor.get(i3) * abstractTensor2.get(i3), i3);
        }
    }

    @Override // com.github.tjake.jlama.tensor.operations.TensorOperations
    public float dotProduct(AbstractTensor abstractTensor, AbstractTensor abstractTensor2, int i, int i2, int i3) {
        Preconditions.checkArgument(abstractTensor.dims() == abstractTensor2.dims());
        int i4 = i + i3;
        int i5 = i2 + i3;
        float f = 0.0f;
        while (i < i4 && i2 < i5) {
            f += abstractTensor.get(0, i) * abstractTensor2.get(0, i2);
            i++;
            i2++;
        }
        return f;
    }

    @Override // com.github.tjake.jlama.tensor.operations.TensorOperations
    public void batchDotProduct(AbstractTensor abstractTensor, AbstractTensor abstractTensor2, AbstractTensor abstractTensor3, int i, int i2, int i3, int i4, int i5) {
        Preconditions.checkArgument(abstractTensor2.dims() == 2 && abstractTensor3.dims() == 2 && abstractTensor.dims() == 2);
        int i6 = i4 + i5;
        for (int i7 = 0; i7 < abstractTensor2.shape().first(); i7++) {
            for (int i8 = i4; i8 < i6; i8++) {
                abstractTensor.set(dotProduct(abstractTensor2.slice(i7), abstractTensor3.slice(i8), i, i2, i3), i7, i8);
            }
        }
    }

    @Override // com.github.tjake.jlama.tensor.operations.TensorOperations
    public void saxpy(float f, AbstractTensor abstractTensor, AbstractTensor abstractTensor2, int i, int i2, int i3) {
        Preconditions.checkArgument(abstractTensor.shape().first() == 1 && abstractTensor2.shape().first() == 1);
        int i4 = i;
        for (int i5 = i2; i4 < i + i3 && i5 < i2 + i3; i5++) {
            abstractTensor2.set((f * abstractTensor.get(0, i4)) + abstractTensor2.get(0, i5), 0, i5);
            i4++;
        }
    }

    @Override // com.github.tjake.jlama.tensor.operations.TensorOperations
    public void sxpby(float f, AbstractTensor abstractTensor, AbstractTensor abstractTensor2, int i, int i2, int i3) {
        Preconditions.checkArgument(abstractTensor.shape().first() == 1 && abstractTensor2.shape().first() == 1);
        int i4 = i;
        for (int i5 = i2; i4 < i + i3 && i5 < i2 + i3; i5++) {
            abstractTensor2.set(abstractTensor.get(0, i4) + (f * abstractTensor2.get(0, i5)), 0, i5);
            i4++;
        }
    }

    @Override // com.github.tjake.jlama.tensor.operations.TensorOperations
    public void scale(float f, AbstractTensor abstractTensor, int i, int i2) {
        Preconditions.checkArgument(abstractTensor.shape().first() == 1);
        int i3 = i + i2;
        while (i < i3) {
            abstractTensor.set(abstractTensor.get(0, i) * f, 0, i);
            i++;
        }
    }
}
