/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.vectors.query;

import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
import org.apache.lucene.search.Query;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.vectors.mapper.DenseVectorFieldMapper;

public class KnnVectorQueryBuilder
extends AbstractQueryBuilder<KnnVectorQueryBuilder> {
    public static final String NAME = "knn";
    private final String fieldName;
    private final float[] queryVector;
    private final int numCands;

    public KnnVectorQueryBuilder(String fieldName, float[] queryVector, int numCands) {
        this.fieldName = fieldName;
        this.queryVector = queryVector;
        this.numCands = numCands;
    }

    public KnnVectorQueryBuilder(StreamInput in) throws IOException {
        super(in);
        this.fieldName = in.readString();
        this.numCands = in.readVInt();
        this.queryVector = in.readFloatArray();
    }

    public String getFieldName() {
        return this.fieldName;
    }

    public float[] queryVector() {
        return this.queryVector;
    }

    public int numCands() {
        return this.numCands;
    }

    protected void doWriteTo(StreamOutput out) throws IOException {
        out.writeString(this.fieldName);
        out.writeVInt(this.numCands);
        out.writeFloatArray(this.queryVector);
    }

    protected void doXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject(NAME).field("field", this.fieldName).field("vector", (Object)this.queryVector).field("num_candidates", this.numCands);
        builder.endObject();
    }

    public String getWriteableName() {
        return NAME;
    }

    protected Query doToQuery(SearchExecutionContext context) {
        MappedFieldType fieldType = context.getFieldType(this.fieldName);
        if (fieldType == null) {
            throw new IllegalArgumentException("field [" + this.fieldName + "] does not exist in the mapping");
        }
        if (!(fieldType instanceof DenseVectorFieldMapper.DenseVectorFieldType)) {
            throw new IllegalArgumentException("[knn] queries are only supported on [dense_vector] fields");
        }
        if (context.getNestedParent(fieldType.name()) != null) {
            throw new IllegalArgumentException("[knn] queries are not supported on nested fields");
        }
        DenseVectorFieldMapper.DenseVectorFieldType vectorFieldType = (DenseVectorFieldMapper.DenseVectorFieldType)fieldType;
        return vectorFieldType.createKnnQuery(this.queryVector, this.numCands);
    }

    protected int doHashCode() {
        return Objects.hash(this.fieldName, Arrays.hashCode(this.queryVector), this.numCands);
    }

    protected boolean doEquals(KnnVectorQueryBuilder other) {
        return Objects.equals(this.fieldName, other.fieldName) && Arrays.equals(this.queryVector, other.queryVector) && this.numCands == other.numCands;
    }
}

