/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.action;

import java.util.List;
import java.util.Map;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.FailedNodeException;
import org.elasticsearch.action.TaskOperationFailure;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.tasks.BaseTasksRequest;
import org.elasticsearch.action.support.tasks.TransportTasksAction;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;

public class TransportInferTrainedModelDeploymentAction
extends TransportTasksAction<TrainedModelDeploymentTask, InferTrainedModelDeploymentAction.Request, InferTrainedModelDeploymentAction.Response, InferTrainedModelDeploymentAction.Response> {
    private final TrainedModelProvider provider;

    @Inject
    public TransportInferTrainedModelDeploymentAction(ClusterService clusterService, TransportService transportService, ActionFilters actionFilters, TrainedModelProvider provider) {
        super("cluster:monitor/xpack/ml/trained_models/deployment/infer", clusterService, transportService, actionFilters, InferTrainedModelDeploymentAction.Request::new, InferTrainedModelDeploymentAction.Response::new, InferTrainedModelDeploymentAction.Response::new, "same");
        this.provider = provider;
    }

    protected void doExecute(Task task, InferTrainedModelDeploymentAction.Request request, ActionListener<InferTrainedModelDeploymentAction.Response> listener) {
        String deploymentId = request.getDeploymentId();
        TrainedModelAllocation allocation = TrainedModelAllocationMetadata.allocationForModelId(this.clusterService.state(), deploymentId).orElse(null);
        if (allocation == null) {
            this.provider.getTrainedModel(deploymentId, GetTrainedModelsAction.Includes.empty(), (ActionListener<TrainedModelConfig>)ActionListener.wrap(config -> {
                if (config.getModelType() != TrainedModelType.PYTORCH) {
                    listener.onFailure((Exception)((Object)org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper.badRequestException((String)"Only [pytorch] models are supported by _infer, provided model [{}] has type [{}]", (Object[])new Object[]{config.getModelId(), config.getModelType()})));
                    return;
                }
                String message = "Cannot perform requested action because deployment [" + deploymentId + "] is not started";
                listener.onFailure((Exception)((Object)org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper.conflictStatusException((String)message, (Object[])new Object[0])));
            }, arg_0 -> listener.onFailure(arg_0)));
            return;
        }
        String[] randomRunningNode = allocation.getStartedNodes();
        if (randomRunningNode.length == 0) {
            String message = "Cannot perform requested action because deployment [" + deploymentId + "] is not yet running on any node";
            listener.onFailure((Exception)((Object)org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper.conflictStatusException((String)message, (Object[])new Object[0])));
            return;
        }
        int nodeIndex = Randomness.get().nextInt(randomRunningNode.length);
        request.setNodes(new String[]{randomRunningNode[nodeIndex]});
        super.doExecute(task, (BaseTasksRequest)request, listener);
    }

    protected InferTrainedModelDeploymentAction.Response newResponse(InferTrainedModelDeploymentAction.Request request, List<InferTrainedModelDeploymentAction.Response> tasks, List<TaskOperationFailure> taskOperationFailures, List<FailedNodeException> failedNodeExceptions) {
        if (!taskOperationFailures.isEmpty()) {
            throw ExceptionsHelper.convertToElastic((Exception)taskOperationFailures.get(0).getCause());
        }
        if (!failedNodeExceptions.isEmpty()) {
            throw ExceptionsHelper.convertToElastic((Exception)((Exception)failedNodeExceptions.get(0)));
        }
        if (tasks.isEmpty()) {
            throw new ElasticsearchStatusException("[{}] unable to find deployment task for inference please stop and start the deployment or try again momentarily", RestStatus.NOT_FOUND, new Object[]{request.getDeploymentId()});
        }
        return tasks.get(0);
    }

    protected void taskOperation(InferTrainedModelDeploymentAction.Request request, TrainedModelDeploymentTask task, ActionListener<InferTrainedModelDeploymentAction.Response> listener) {
        task.infer((Map)request.getDocs().get(0), request.getUpdate(), request.getInferenceTimeout(), (ActionListener<InferenceResults>)ActionListener.wrap(pyTorchResult -> listener.onResponse((Object)new InferTrainedModelDeploymentAction.Response(pyTorchResult)), arg_0 -> listener.onFailure(arg_0)));
    }
}

