package com.github.tjake.jlama.safetensors.tokenizer;

import com.github.tjake.jlama.safetensors.SafeTensorSupport;
import com.github.tjake.jlama.safetensors.prompt.PromptSupport;
import com.google.common.base.Preconditions;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.ImmutableBiMap;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/github/tjake/jlama/safetensors/tokenizer/BPETokenizer.class */
public abstract class BPETokenizer implements Tokenizer {
    protected static final Logger logger = LoggerFactory.getLogger(BPETokenizer.class);
    protected final TokenizerModel model;
    protected final PromptSupport promptSupport;
    protected final ByteBuffer decodeBuffer = ByteBuffer.allocate(4);
    public static BiMap<Integer, Integer> alteredBytes;

    /* JADX INFO: Access modifiers changed from: protected */
    public BPETokenizer(Path path) {
        Preconditions.checkArgument(path.resolve("tokenizer.json").toFile().exists(), "No tokenizer.json found in " + String.valueOf(path));
        try {
            this.model = SafeTensorSupport.loadTokenizer(path);
            this.promptSupport = new PromptSupport(this.model);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override // com.github.tjake.jlama.safetensors.tokenizer.Tokenizer
    public TokenizerModel getModel() {
        return this.model;
    }

    @Override // com.github.tjake.jlama.safetensors.tokenizer.Tokenizer
    public List<String> tokenize(String str) {
        if (str.isEmpty()) {
            return Collections.emptyList();
        }
        if (this.model.preTokenizer() == null && this.model.addedTokenPattern() == null) {
            Collections.singletonList(str);
        }
        ArrayList arrayList = new ArrayList();
        if (this.model.addedTokenPattern() != null) {
            for (String str2 : TokenizerModel.split(this.model.addedTokenPattern(), str, 0, true)) {
                if (!str2.isEmpty()) {
                    if (this.model.addedTokens().containsKey(str2)) {
                        arrayList.add(str2);
                    } else if (this.model.preTokenizer() != null) {
                        arrayList.addAll(this.model.preTokenizer().pretokenize(str2));
                    } else {
                        arrayList.add(str2);
                    }
                }
            }
        } else if (this.model.preTokenizer() != null) {
            arrayList.addAll(this.model.preTokenizer().pretokenize(str));
        } else {
            arrayList.add(str);
        }
        return arrayList;
    }

    protected String preProcess(String str) {
        return str;
    }

    @Override // com.github.tjake.jlama.safetensors.tokenizer.Tokenizer
    public long[] encode(String str) {
        Long l;
        List<String> list = tokenize(str);
        ArrayList arrayList = new ArrayList();
        for (String str2 : list) {
            if (this.model.addedTokens() == null || !this.model.addedTokens().containsKey(str2)) {
                ArrayList arrayList2 = new ArrayList();
                int[] array = preProcess(str2).codePoints().toArray();
                for (int i = 0; i < array.length; i++) {
                    Long l2 = (Long) this.model.vocabLookup.get(Character.toString(array[i]));
                    if (l2 != null) {
                        arrayList2.add(l2);
                    } else if (this.model.byteFallback) {
                        for (byte b : Character.toString(array[i]).getBytes(StandardCharsets.UTF_8)) {
                            arrayList2.add(Long.valueOf(encodeCharacterAsToken(b)));
                        }
                    } else if (this.model.unkToken != null) {
                        arrayList2.add((Long) this.model.vocabLookup.get(this.model.unkToken));
                    }
                }
                while (true) {
                    long j = -1;
                    long j2 = -1;
                    long j3 = Long.MAX_VALUE;
                    for (int i2 = 0; i2 < arrayList2.size() - 1; i2++) {
                        String decodeInternal = decodeInternal(((Long) arrayList2.get(i2)).longValue());
                        String decodeInternal2 = decodeInternal(((Long) arrayList2.get(i2 + 1)).longValue());
                        String format = String.format("%s %s", decodeInternal, decodeInternal2);
                        String format2 = String.format("%s%s", decodeInternal, decodeInternal2);
                        if (this.model.merges.containsKey(format) && (l = (Long) this.model.vocabLookup.get(format2)) != null) {
                            long longValue = this.model.merges.get(format).longValue();
                            if (longValue < j3) {
                                j = l.longValue();
                                j2 = i2;
                                j3 = longValue;
                            }
                        }
                    }
                    if (j2 == -1) {
                        break;
                    }
                    arrayList2.set((int) j2, Long.valueOf(j));
                    arrayList2.remove(((int) j2) + 1);
                }
                arrayList.addAll(arrayList2);
            } else {
                arrayList.add(this.model.addedTokens().get(str2));
            }
        }
        return arrayList.stream().mapToLong(l3 -> {
            return l3.longValue();
        }).toArray();
    }

    protected String postProcessToken(String str) {
        if (str == null) {
            str = this.model.unkToken;
        }
        return str;
    }

    @Override // com.github.tjake.jlama.safetensors.tokenizer.Tokenizer
    public String decode(long j) {
        return (String) maybeDecodeTokenAsCharacter(j).map(ch -> {
            if (!Character.isUnicodeIdentifierPart(ch.charValue()) && this.decodeBuffer.remaining() >= 4) {
                return Character.toString(ch.charValue());
            }
            this.decodeBuffer.put((byte) ch.charValue());
            if (this.decodeBuffer.remaining() != 0) {
                return "";
            }
            String str = new String(this.decodeBuffer.array());
            this.decodeBuffer.rewind();
            return str;
        }).orElseGet(() -> {
            return postProcessToken((String) this.model.vocabLookup.inverse().get(Long.valueOf(j)));
        });
    }

    protected abstract long encodeCharacterAsToken(byte b);

    protected abstract Optional<Character> maybeDecodeTokenAsCharacter(long j);

    protected String decodeInternal(long j) {
        return (String) maybeDecodeTokenAsCharacter(j).map((v0) -> {
            return v0.toString();
        }).orElseGet(() -> {
            String str = (String) this.model.vocabLookup.inverse().get(Long.valueOf(j));
            if (str == null) {
                str = this.model.unkToken;
            }
            return str;
        });
    }

    protected String postProcess(String str) {
        return str;
    }

    @Override // com.github.tjake.jlama.safetensors.tokenizer.Tokenizer
    public String decode(long[] jArr) {
        return postProcess((String) Arrays.stream(jArr).mapToObj(this::decode).collect(Collectors.joining()));
    }

    @Override // com.github.tjake.jlama.safetensors.tokenizer.Tokenizer
    public Optional<PromptSupport> promptSupport() {
        return this.model.promptTemplates().isPresent() ? Optional.of(this.promptSupport) : Optional.empty();
    }

    static {
        HashBiMap create = HashBiMap.create();
        int i = 0;
        for (int i2 = 0; i2 < 256; i2++) {
            if ((i2 < 33 || i2 > 126) && ((i2 < 161 || i2 > 172) && (i2 < 174 || i2 > 255))) {
                int i3 = i;
                i++;
                create.put(Integer.valueOf(i2), Integer.valueOf(i3 + 256));
            }
        }
        alteredBytes = ImmutableBiMap.copyOf(create);
    }
}
