package com.github.tjake.jlama.safetensors.tokenizer;

import com.github.tjake.jlama.safetensors.SafeTensorSupport;
import com.google.common.base.Preconditions;
import java.io.IOException;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/* loaded from: input_file:com/github/tjake/jlama/safetensors/tokenizer/WordPieceTokenizer.class */
public class WordPieceTokenizer implements Tokenizer {
    protected final TokenizerModel model;
    protected final PromptSupport promptSupport;
    protected final long sepToken;
    protected final long clsToken;
    protected final long unkToken;
    protected static final String sepString = "[SEP]";
    protected static final String clsString = "[CLS]";
    protected static final String unkString = "[UNK]";

    public WordPieceTokenizer(Path path) {
        Preconditions.checkArgument(path.resolve("tokenizer.json").toFile().exists(), "No tokenizer.json found in " + String.valueOf(path));
        try {
            this.model = SafeTensorSupport.loadTokenizer(path);
            Preconditions.checkArgument(this.model.type == null || this.model.type.equalsIgnoreCase("WordPiece"), "Invalid model type: " + this.model.type);
            this.promptSupport = new PromptSupport(this.model);
            this.sepToken = ((Long) this.model.vocabLookup.get(sepString)).longValue();
            this.clsToken = ((Long) this.model.vocabLookup.get(clsString)).longValue();
            this.unkToken = ((Long) this.model.vocabLookup.get(unkString)).longValue();
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override // com.github.tjake.jlama.safetensors.tokenizer.Tokenizer
    public List<String> tokenize(String str) {
        String[] split = preProcess(str).split("\\s+");
        ArrayList arrayList = new ArrayList();
        arrayList.add(clsString);
        arrayList.addAll((List) Arrays.stream(split).flatMap(this::splitByPunctuation).map(str2 -> {
            return str2.length() > 200 ? this.model.unkToken : str2;
        }).flatMap(str3 -> {
            boolean z = false;
            ArrayList arrayList2 = new ArrayList();
            int i = 0;
            while (true) {
                int i2 = i;
                if (i2 >= str3.length()) {
                    break;
                }
                int length = str3.length();
                String str3 = null;
                while (true) {
                    if (i2 >= length) {
                        break;
                    }
                    String substring = str3.substring(i2, length);
                    if (i2 > 0) {
                        substring = "##" + substring;
                    }
                    if (this.model.vocabLookup.containsKey(substring)) {
                        str3 = substring;
                        break;
                    }
                    length--;
                }
                if (str3 == null) {
                    z = true;
                    break;
                }
                arrayList2.add(str3);
                i = length;
            }
            if (z) {
                arrayList2.add(this.model.unkToken);
            }
            return arrayList2.stream();
        }).collect(Collectors.toList()));
        arrayList.add(sepString);
        return arrayList;
    }

    protected String preProcess(String str) {
        return cleanText(str.toLowerCase().strip());
    }

    static boolean isControl(Integer num) {
        if (num.intValue() == 9 || num.intValue() == 10 || num.intValue() == 13) {
            return false;
        }
        return Character.isISOControl(num.intValue());
    }

    static boolean isPunctuation(Integer num) {
        if (num.intValue() >= 33 && num.intValue() <= 47) {
            return true;
        }
        if (num.intValue() >= 58 && num.intValue() <= 64) {
            return true;
        }
        if (num.intValue() >= 91 && num.intValue() <= 96) {
            return true;
        }
        if (num.intValue() >= 123 && num.intValue() <= 126) {
            return true;
        }
        int type = Character.getType(num.intValue());
        return type >= 20 && type <= 24;
    }

    String cleanText(String str) {
        return (String) str.codePoints().map(i -> {
            if (i == 0 || i == 65533 || isControl(Integer.valueOf(i))) {
                return -1;
            }
            if (Character.isWhitespace(i)) {
                return 32;
            }
            return i;
        }).filter(i2 -> {
            return i2 != -1;
        }).mapToObj(Character::toString).collect(Collectors.joining());
    }

    Stream<String> splitByPunctuation(String str) {
        ArrayList arrayList = new ArrayList();
        int i = 0;
        int i2 = 0;
        while (true) {
            int i3 = i2;
            if (i3 >= str.length()) {
                break;
            }
            int codePointAt = str.codePointAt(i3);
            if (isPunctuation(Integer.valueOf(codePointAt))) {
                if (i3 != i) {
                    arrayList.add(str.substring(i, i3));
                }
                arrayList.add(str.substring(i3, i3 + Character.charCount(codePointAt)));
                i = i3 + Character.charCount(codePointAt);
            }
            i2 = i3 + Character.charCount(codePointAt);
        }
        if (i != str.length()) {
            arrayList.add(str.substring(i));
        }
        return arrayList.stream();
    }

    @Override // com.github.tjake.jlama.safetensors.tokenizer.Tokenizer
    public long[] encode(String str) {
        return tokenize(str).stream().mapToLong(str2 -> {
            return ((Long) this.model.vocabLookup.get(str2)).longValue();
        }).toArray();
    }

    protected String postProcessToken(String str) {
        return str.startsWith("##") ? str.substring(2) : " " + str;
    }

    @Override // com.github.tjake.jlama.safetensors.tokenizer.Tokenizer
    public String decode(long j) {
        return postProcessToken((String) this.model.vocabLookup.inverse().get(Long.valueOf(j)));
    }

    protected String postProcess(String str) {
        return str.strip();
    }

    @Override // com.github.tjake.jlama.safetensors.tokenizer.Tokenizer
    public String decode(long[] jArr) {
        return postProcess((String) Arrays.stream(jArr).mapToObj(this::decode).collect(Collectors.joining()));
    }

    @Override // com.github.tjake.jlama.safetensors.tokenizer.Tokenizer
    public Optional<PromptSupport> promptSupport() {
        return this.promptSupport.hasPromptTemplates() ? Optional.of(this.promptSupport) : Optional.empty();
    }
}
