/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.nlp;

import java.io.IOException;
import java.util.List;
import java.util.stream.Collectors;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;

public class BertRequestBuilder
implements NlpTask.RequestBuilder {
    static final String REQUEST_ID = "request_id";
    static final String TOKENS = "tokens";
    static final String ARG1 = "arg_1";
    static final String ARG2 = "arg_2";
    static final String ARG3 = "arg_3";
    private final BertTokenizer tokenizer;

    public BertRequestBuilder(BertTokenizer tokenizer) {
        this.tokenizer = tokenizer;
    }

    @Override
    public NlpTask.Request buildRequest(List<String> inputs, String requestId) throws IOException {
        if (this.tokenizer.getPadToken().isEmpty()) {
            throw new IllegalStateException("The input tokenizer does not have a [PAD] token in its vocabulary");
        }
        TokenizationResult tokenization = this.tokenizer.buildTokenizationResult(inputs.stream().map(this.tokenizer::tokenize).collect(Collectors.toList()));
        return this.buildRequest(tokenization, requestId);
    }

    @Override
    public NlpTask.Request buildRequest(TokenizationResult tokenization, String requestId) throws IOException {
        if (this.tokenizer.getPadToken().isEmpty()) {
            throw new IllegalStateException("The input tokenizer does not have a [PAD] token in its vocabulary");
        }
        return new NlpTask.Request(tokenization, BertRequestBuilder.jsonRequest(tokenization, this.tokenizer.getPadToken().getAsInt(), requestId));
    }

    static BytesReference jsonRequest(TokenizationResult tokenization, int padToken, String requestId) throws IOException {
        XContentBuilder builder = XContentFactory.jsonBuilder();
        builder.startObject();
        builder.field(REQUEST_ID, requestId);
        NlpTask.RequestBuilder.writePaddedTokens(TOKENS, tokenization, padToken, (tokens, i) -> tokens.getTokenIds()[i], builder);
        NlpTask.RequestBuilder.writePaddedTokens(ARG1, tokenization, padToken, (tokens, i) -> 1, builder);
        int batchSize = tokenization.getTokenizations().size();
        NlpTask.RequestBuilder.writeNonPaddedArguments(ARG2, batchSize, tokenization.getLongestSequenceLength(), i -> 0, builder);
        NlpTask.RequestBuilder.writeNonPaddedArguments(ARG3, batchSize, tokenization.getLongestSequenceLength(), i -> i, builder);
        builder.endObject();
        return BytesReference.bytes((XContentBuilder)builder);
    }
}

