/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

public class LangNetLayer
implements ToXContentObject,
Writeable,
Accountable {
    public static final ParseField NAME = new ParseField("lang_net_layer", new String[0]);
    private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(LangNetLayer.class);
    public static final ParseField NUM_ROWS = new ParseField("num_rows", new String[0]);
    public static final ParseField NUM_COLS = new ParseField("num_cols", new String[0]);
    public static final ParseField WEIGHTS = new ParseField("weights", new String[0]);
    public static final ParseField BIAS = new ParseField("bias", new String[0]);
    public static final ConstructingObjectParser<LangNetLayer, Void> STRICT_PARSER = LangNetLayer.createParser(false);
    public static final ConstructingObjectParser<LangNetLayer, Void> LENIENT_PARSER = LangNetLayer.createParser(true);
    private final double[] weights;
    private final int weightRows;
    private final int weightCols;
    private final double[] bias;

    private static ConstructingObjectParser<LangNetLayer, Void> createParser(boolean lenient) {
        ConstructingObjectParser<LangNetLayer, Void> parser = new ConstructingObjectParser<LangNetLayer, Void>(NAME.getPreferredName(), lenient, a -> new LangNetLayer((List)a[0], (int)((Integer)a[1]), (int)((Integer)a[2]), (List)a[3]));
        parser.declareDoubleArray(ConstructingObjectParser.constructorArg(), WEIGHTS);
        parser.declareInt(ConstructingObjectParser.constructorArg(), NUM_COLS);
        parser.declareInt(ConstructingObjectParser.constructorArg(), NUM_ROWS);
        parser.declareDoubleArray(ConstructingObjectParser.constructorArg(), BIAS);
        return parser;
    }

    private LangNetLayer(List<Double> weights, int numCols, int numRows, List<Double> bias) {
        this(weights.stream().mapToDouble(Double::doubleValue).toArray(), numCols, numRows, bias.stream().mapToDouble(Double::doubleValue).toArray());
    }

    LangNetLayer(double[] weights, int numCols, int numRows, double[] bias) {
        this.weights = weights;
        this.weightCols = numCols;
        this.weightRows = numRows;
        this.bias = bias;
        if (weights.length != numCols * numRows) {
            throw ExceptionsHelper.badRequestException("malformed network layer. Total vector size [{}] does not equal [{}] x [{}].", weights.length, numCols, numRows);
        }
    }

    LangNetLayer(StreamInput in) throws IOException {
        this.weights = in.readDoubleArray();
        this.bias = in.readDoubleArray();
        this.weightRows = in.readInt();
        this.weightCols = in.readInt();
    }

    double[] productPlusBias(boolean applyRelu, double[] x) {
        double[] y = Arrays.copyOf(this.bias, this.bias.length);
        for (int i = 0; i < x.length; ++i) {
            int j;
            double scale = x[i];
            if (applyRelu) {
                if (!(scale > 0.0)) continue;
                for (j = 0; j < y.length; ++j) {
                    int n = j;
                    y[n] = y[n] + this.weights[LangNetLayer.rowMajorIndex(i, this.weightCols, j)] * scale;
                }
                continue;
            }
            for (j = 0; j < y.length; ++j) {
                int n = j;
                y[n] = y[n] + this.weights[LangNetLayer.rowMajorIndex(i, this.weightCols, j)] * scale;
            }
        }
        return y;
    }

    private static int rowMajorIndex(int row, int colDim, int col) {
        return row * colDim + col;
    }

    double[] getWeights() {
        return this.weights;
    }

    int getWeightRows() {
        return this.weightRows;
    }

    int getWeightCols() {
        return this.weightCols;
    }

    double[] getBias() {
        return this.bias;
    }

    @Override
    public long ramBytesUsed() {
        long size = SHALLOW_SIZE;
        size += RamUsageEstimator.sizeOf(this.weights);
        return size += RamUsageEstimator.sizeOf(this.bias);
    }

    @Override
    public void writeTo(StreamOutput out) throws IOException {
        out.writeDoubleArray(this.weights);
        out.writeDoubleArray(this.bias);
        out.writeInt(this.weightRows);
        out.writeInt(this.weightCols);
    }

    @Override
    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject();
        builder.field(NUM_COLS.getPreferredName(), this.weightCols);
        builder.field(NUM_ROWS.getPreferredName(), this.weightRows);
        builder.field(WEIGHTS.getPreferredName(), this.weights);
        builder.field(BIAS.getPreferredName(), this.bias);
        builder.endObject();
        return builder;
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        LangNetLayer that = (LangNetLayer)o;
        return Arrays.equals(this.weights, that.weights) && Arrays.equals(this.bias, that.bias) && Objects.equals(this.weightCols, that.weightCols) && Objects.equals(this.weightRows, that.weightRows);
    }

    public int hashCode() {
        return Objects.hash(Arrays.hashCode(this.weights), Arrays.hashCode(this.bias), this.weightCols, this.weightRows);
    }
}

