/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.OptionalInt;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.function.Function;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.nlp.BertRequestBuilder;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BasicTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.WordPieceTokenizer;

public class BertTokenizer
implements NlpTokenizer {
    public static final String UNKNOWN_TOKEN = "[UNK]";
    public static final String SEPARATOR_TOKEN = "[SEP]";
    public static final String PAD_TOKEN = "[PAD]";
    public static final String CLASS_TOKEN = "[CLS]";
    public static final String MASK_TOKEN = "[MASK]";
    public static final int SPECIAL_TOKEN_POSITION = -1;
    public static final int DEFAULT_MAX_INPUT_CHARS_PER_WORD = 100;
    private final Set<String> NEVER_SPLIT = Set.of("[MASK]");
    private final WordPieceTokenizer wordPieceTokenizer;
    private final List<String> originalVocab;
    private final SortedMap<String, Integer> vocab;
    private final boolean doLowerCase;
    private final boolean doTokenizeCjKChars;
    private final boolean doStripAccents;
    private final boolean withSpecialTokens;
    private final Tokenization.Truncate truncate;
    private final Set<String> neverSplit;
    private final int maxSequenceLength;
    private final NlpTask.RequestBuilder requestBuilder;

    protected BertTokenizer(List<String> originalVocab, SortedMap<String, Integer> vocab, boolean doLowerCase, boolean doTokenizeCjKChars, boolean doStripAccents, boolean withSpecialTokens, Tokenization.Truncate truncate, int maxSequenceLength, Function<BertTokenizer, NlpTask.RequestBuilder> requestBuilderFactory, Set<String> neverSplit) {
        this.wordPieceTokenizer = new WordPieceTokenizer(vocab, UNKNOWN_TOKEN, 100);
        this.originalVocab = originalVocab;
        this.vocab = vocab;
        this.doLowerCase = doLowerCase;
        this.doTokenizeCjKChars = doTokenizeCjKChars;
        this.doStripAccents = doStripAccents;
        this.withSpecialTokens = withSpecialTokens;
        this.truncate = truncate;
        this.neverSplit = Sets.union(neverSplit, this.NEVER_SPLIT);
        this.maxSequenceLength = maxSequenceLength;
        this.requestBuilder = requestBuilderFactory.apply(this);
    }

    @Override
    public OptionalInt getPadToken() {
        Integer pad = (Integer)this.vocab.get(PAD_TOKEN);
        if (pad != null) {
            return OptionalInt.of(pad);
        }
        return OptionalInt.empty();
    }

    @Override
    public TokenizationResult buildTokenizationResult(List<TokenizationResult.Tokenization> tokenizations) {
        TokenizationResult tokenizationResult = new TokenizationResult(this.originalVocab);
        for (TokenizationResult.Tokenization tokenization : tokenizations) {
            tokenizationResult.addTokenization(tokenization);
        }
        return tokenizationResult;
    }

    @Override
    public TokenizationResult.Tokenization tokenize(String seq) {
        Tuple<List<WordPieceTokenizer.TokenAndId>, List<Integer>> innerResult = this.innerTokenize(seq);
        List wordPieceTokens = (List)innerResult.v1();
        List tokenPositionMap = (List)innerResult.v2();
        int numTokens = this.withSpecialTokens ? wordPieceTokens.size() + 2 : wordPieceTokens.size();
        boolean isTruncated = false;
        if (numTokens > this.maxSequenceLength) {
            switch (this.truncate) {
                case FIRST: 
                case SECOND: {
                    isTruncated = true;
                    wordPieceTokens = wordPieceTokens.subList(0, this.withSpecialTokens ? this.maxSequenceLength - 2 : this.maxSequenceLength);
                    break;
                }
                case NONE: {
                    throw ExceptionsHelper.badRequestException((String)"Input too large. The tokenized input length [{}] exceeds the maximum sequence length [{}]", (Object[])new Object[]{numTokens, this.maxSequenceLength});
                }
            }
            numTokens = this.maxSequenceLength;
        }
        String[] tokens = new String[numTokens];
        int[] tokenIds = new int[numTokens];
        int[] tokenMap = new int[numTokens];
        if (this.withSpecialTokens) {
            tokens[0] = CLASS_TOKEN;
            tokenIds[0] = (Integer)this.vocab.get(CLASS_TOKEN);
            tokenMap[0] = -1;
        }
        int i = this.withSpecialTokens ? 1 : 0;
        int decrementHandler = this.withSpecialTokens ? 1 : 0;
        for (WordPieceTokenizer.TokenAndId tokenAndId : wordPieceTokens) {
            tokens[i] = tokenAndId.getToken();
            tokenIds[i] = tokenAndId.getId();
            tokenMap[i] = (Integer)tokenPositionMap.get(i - decrementHandler);
            ++i;
        }
        if (this.withSpecialTokens) {
            tokens[i] = SEPARATOR_TOKEN;
            tokenIds[i] = (Integer)this.vocab.get(SEPARATOR_TOKEN);
            tokenMap[i] = -1;
        }
        return new TokenizationResult.Tokenization(seq, isTruncated, tokens, tokenIds, tokenMap);
    }

    @Override
    public TokenizationResult.Tokenization tokenize(String seq1, String seq2) {
        Tuple<List<WordPieceTokenizer.TokenAndId>, List<Integer>> innerResult = this.innerTokenize(seq1);
        List wordPieceTokenSeq1s = (List)innerResult.v1();
        List tokenPositionMapSeq1 = (List)innerResult.v2();
        innerResult = this.innerTokenize(seq2);
        List wordPieceTokenSeq2s = (List)innerResult.v1();
        List tokenPositionMapSeq2 = (List)innerResult.v2();
        if (!this.withSpecialTokens) {
            throw new IllegalArgumentException("Unable to do sequence pair tokenization without special tokens");
        }
        int numTokens = wordPieceTokenSeq1s.size() + wordPieceTokenSeq2s.size() + 3;
        boolean isTruncated = false;
        if (numTokens > this.maxSequenceLength) {
            switch (this.truncate) {
                case FIRST: {
                    isTruncated = true;
                    if (wordPieceTokenSeq2s.size() > this.maxSequenceLength - 3) {
                        throw ExceptionsHelper.badRequestException((String)"Attempting truncation [{}] but input is too large for the second sequence. The tokenized input length [{}] exceeds the maximum sequence length [{}], when taking special tokens into account", (Object[])new Object[]{this.truncate.toString(), wordPieceTokenSeq2s.size(), this.maxSequenceLength - 3});
                    }
                    wordPieceTokenSeq1s = wordPieceTokenSeq1s.subList(0, this.maxSequenceLength - 3 - wordPieceTokenSeq2s.size());
                    break;
                }
                case SECOND: {
                    isTruncated = true;
                    if (wordPieceTokenSeq1s.size() > this.maxSequenceLength - 3) {
                        throw ExceptionsHelper.badRequestException((String)"Attempting truncation [{}] but input is too large for the first sequence. The tokenized input length [{}] exceeds the maximum sequence length [{}], when taking special tokens into account", (Object[])new Object[]{this.truncate.toString(), wordPieceTokenSeq2s.size(), this.maxSequenceLength - 3});
                    }
                    wordPieceTokenSeq2s = wordPieceTokenSeq2s.subList(0, this.maxSequenceLength - 3 - wordPieceTokenSeq1s.size());
                    break;
                }
                case NONE: {
                    throw ExceptionsHelper.badRequestException((String)"Input too large. The tokenized input length [{}] exceeds the maximum sequence length [{}]", (Object[])new Object[]{numTokens, this.maxSequenceLength});
                }
            }
            numTokens = this.maxSequenceLength;
        }
        String[] tokens = new String[numTokens];
        int[] tokenIds = new int[numTokens];
        int[] tokenMap = new int[numTokens];
        tokens[0] = CLASS_TOKEN;
        tokenIds[0] = (Integer)this.vocab.get(CLASS_TOKEN);
        tokenMap[0] = -1;
        int i = 1;
        for (WordPieceTokenizer.TokenAndId tokenAndId : wordPieceTokenSeq1s) {
            tokens[i] = tokenAndId.getToken();
            tokenIds[i] = tokenAndId.getId();
            tokenMap[i] = (Integer)tokenPositionMapSeq1.get(i - 1);
            ++i;
        }
        tokens[i] = SEPARATOR_TOKEN;
        tokenIds[i] = (Integer)this.vocab.get(SEPARATOR_TOKEN);
        tokenMap[i] = -1;
        ++i;
        int j = 0;
        for (WordPieceTokenizer.TokenAndId tokenAndId : wordPieceTokenSeq2s) {
            tokens[i] = tokenAndId.getToken();
            tokenIds[i] = tokenAndId.getId();
            tokenMap[i] = (Integer)tokenPositionMapSeq2.get(j);
            ++i;
            ++j;
        }
        tokens[i] = SEPARATOR_TOKEN;
        tokenIds[i] = (Integer)this.vocab.get(SEPARATOR_TOKEN);
        tokenMap[i] = -1;
        return new TokenizationResult.Tokenization(seq1 + seq2, isTruncated, tokens, tokenIds, tokenMap);
    }

    private Tuple<List<WordPieceTokenizer.TokenAndId>, List<Integer>> innerTokenize(String seq) {
        BasicTokenizer basicTokenizer = new BasicTokenizer(this.doLowerCase, this.doTokenizeCjKChars, this.doStripAccents, this.neverSplit);
        List<String> delineatedTokens = basicTokenizer.tokenize(seq);
        ArrayList<WordPieceTokenizer.TokenAndId> wordPieceTokens = new ArrayList<WordPieceTokenizer.TokenAndId>();
        ArrayList<Integer> tokenPositionMap = new ArrayList<Integer>();
        for (int sourceIndex = 0; sourceIndex < delineatedTokens.size(); ++sourceIndex) {
            String token = delineatedTokens.get(sourceIndex);
            if (this.neverSplit.contains(token)) {
                wordPieceTokens.add(new WordPieceTokenizer.TokenAndId(token, this.vocab.getOrDefault(token, (Integer)this.vocab.get(UNKNOWN_TOKEN))));
                tokenPositionMap.add(sourceIndex);
                continue;
            }
            List<WordPieceTokenizer.TokenAndId> tokens = this.wordPieceTokenizer.tokenize(token);
            for (int tokenCount = 0; tokenCount < tokens.size(); ++tokenCount) {
                tokenPositionMap.add(sourceIndex);
            }
            wordPieceTokens.addAll(tokens);
        }
        return Tuple.tuple(wordPieceTokens, tokenPositionMap);
    }

    @Override
    public NlpTask.RequestBuilder requestBuilder() {
        return this.requestBuilder;
    }

    public int getMaxSequenceLength() {
        return this.maxSequenceLength;
    }

    public static Builder builder(List<String> vocab, Tokenization tokenization) {
        return new Builder(vocab, tokenization);
    }

    public static class Builder {
        protected final List<String> originalVocab;
        protected final SortedMap<String, Integer> vocab;
        protected boolean doLowerCase = false;
        protected boolean doTokenizeCjKChars = true;
        protected boolean withSpecialTokens = true;
        protected Tokenization.Truncate truncate = Tokenization.Truncate.FIRST;
        protected int maxSequenceLength;
        protected Boolean doStripAccents = null;
        protected Set<String> neverSplit;
        protected Function<BertTokenizer, NlpTask.RequestBuilder> requestBuilderFactory = BertRequestBuilder::new;

        protected Builder(List<String> vocab, Tokenization tokenization) {
            this.originalVocab = vocab;
            this.vocab = Builder.buildSortedVocab(vocab);
            this.doLowerCase = tokenization.doLowerCase();
            this.withSpecialTokens = tokenization.withSpecialTokens();
            this.maxSequenceLength = tokenization.maxSequenceLength();
            this.truncate = tokenization.getTruncate();
        }

        private static SortedMap<String, Integer> buildSortedVocab(List<String> vocab) {
            TreeMap<String, Integer> sortedVocab = new TreeMap<String, Integer>();
            for (int i = 0; i < vocab.size(); ++i) {
                sortedVocab.put(vocab.get(i), i);
            }
            return sortedVocab;
        }

        public Builder setDoLowerCase(boolean doLowerCase) {
            this.doLowerCase = doLowerCase;
            return this;
        }

        public Builder setDoTokenizeCjKChars(boolean doTokenizeCjKChars) {
            this.doTokenizeCjKChars = doTokenizeCjKChars;
            return this;
        }

        public Builder setDoStripAccents(Boolean doStripAccents) {
            this.doStripAccents = doStripAccents;
            return this;
        }

        public Builder setNeverSplit(Set<String> neverSplit) {
            this.neverSplit = neverSplit;
            return this;
        }

        public Builder setMaxSequenceLength(int maxSequenceLength) {
            this.maxSequenceLength = maxSequenceLength;
            return this;
        }

        public Builder setWithSpecialTokens(boolean withSpecialTokens) {
            this.withSpecialTokens = withSpecialTokens;
            return this;
        }

        public Builder setRequestBuilderFactory(Function<BertTokenizer, NlpTask.RequestBuilder> requestBuilderFactory) {
            this.requestBuilderFactory = requestBuilderFactory;
            return this;
        }

        public Builder setTruncate(Tokenization.Truncate truncate) {
            this.truncate = truncate;
            return this;
        }

        public BertTokenizer build() {
            if (this.doStripAccents == null) {
                this.doStripAccents = this.doLowerCase;
            }
            if (this.neverSplit == null) {
                this.neverSplit = Collections.emptySet();
            }
            return new BertTokenizer(this.originalVocab, this.vocab, this.doLowerCase, this.doTokenizeCjKChars, this.doStripAccents, this.withSpecialTokens, this.truncate, this.maxSequenceLength, this.requestBuilderFactory, this.neverSplit);
        }
    }
}

