package com.github.tjake.jlama.model.functions;

import com.github.tjake.jlama.math.VectorMath;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.TensorCache;
import com.github.tjake.jlama.tensor.TensorShape;
import com.google.common.base.Preconditions;

/* loaded from: input_file:com/github/tjake/jlama/model/functions/EmbedInput.class */
public interface EmbedInput {
    AbstractTensor inputTokenToEmbedding(int i, int i2);

    default AbstractTensor batchInputsToEmbeddings(int[] iArr, int i) {
        Preconditions.checkArgument(iArr.length > 0);
        AbstractTensor inputTokenToEmbedding = inputTokenToEmbedding(iArr[0], i);
        if (iArr.length == 1) {
            return inputTokenToEmbedding;
        }
        TensorShape of = TensorShape.of(iArr.length, inputTokenToEmbedding.shape().last());
        if (inputTokenToEmbedding.shape().isSparse()) {
            of = of.sparsify(inputTokenToEmbedding.shape().sparseOffset(), inputTokenToEmbedding.shape().sparseLength());
        }
        AbstractTensor abstractTensor = TensorCache.instance.get(inputTokenToEmbedding.dType(), of);
        abstractTensor.copyFrom(inputTokenToEmbedding, 0, 0, inputTokenToEmbedding.shape().sparseLength());
        inputTokenToEmbedding.close();
        VectorMath.pfor(1, iArr.length, i2 -> {
            AbstractTensor inputTokenToEmbedding2 = inputTokenToEmbedding(iArr[i2], i + i2);
            abstractTensor.copyFrom(inputTokenToEmbedding2, 0, i2 * inputTokenToEmbedding2.shape().sparseLength(), inputTokenToEmbedding2.shape().sparseLength());
            inputTokenToEmbedding2.close();
        });
        return abstractTensor;
    }
}
