/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.allocation;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedDeque;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.Message;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchPhaseExecutionException;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ClusterStateListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.component.LifecycleListener;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskAwareRequest;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.threadpool.Scheduler;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.MlMetadata;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAllocationStateAction;
import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingState;
import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason;
import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationService;
import org.elasticsearch.xpack.ml.inference.deployment.DeploymentManager;
import org.elasticsearch.xpack.ml.inference.deployment.ModelStats;
import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask;

public class TrainedModelAllocationNodeService
implements ClusterStateListener {
    private static final String NODE_NO_LONGER_REFERENCED = "node no longer referenced in model routing table";
    private static final String ALLOCATION_NO_LONGER_EXISTS = "model allocation no longer exists";
    private static final TimeValue MODEL_LOADING_CHECK_INTERVAL = TimeValue.timeValueSeconds((long)1L);
    private static final Logger logger = LogManager.getLogger(TrainedModelAllocationNodeService.class);
    private final TrainedModelAllocationService trainedModelAllocationService;
    private final DeploymentManager deploymentManager;
    private final TaskManager taskManager;
    private final Map<String, TrainedModelDeploymentTask> modelIdToTask;
    private final ThreadPool threadPool;
    private final Deque<TrainedModelDeploymentTask> loadingModels;
    private final XPackLicenseState licenseState;
    private volatile Scheduler.Cancellable scheduledFuture;
    private volatile boolean stopped;
    private volatile String nodeId;

    public TrainedModelAllocationNodeService(TrainedModelAllocationService trainedModelAllocationService, final ClusterService clusterService, DeploymentManager deploymentManager, TaskManager taskManager, ThreadPool threadPool, XPackLicenseState licenseState) {
        this.trainedModelAllocationService = trainedModelAllocationService;
        this.deploymentManager = deploymentManager;
        this.taskManager = taskManager;
        this.modelIdToTask = new ConcurrentHashMap<String, TrainedModelDeploymentTask>();
        this.loadingModels = new ConcurrentLinkedDeque<TrainedModelDeploymentTask>();
        this.threadPool = threadPool;
        this.licenseState = licenseState;
        clusterService.addLifecycleListener(new LifecycleListener(){

            public void afterStart() {
                TrainedModelAllocationNodeService.this.nodeId = clusterService.localNode().getId();
                TrainedModelAllocationNodeService.this.start();
            }

            public void beforeStop() {
                TrainedModelAllocationNodeService.this.stop();
            }
        });
    }

    TrainedModelAllocationNodeService(TrainedModelAllocationService trainedModelAllocationService, ClusterService clusterService, DeploymentManager deploymentManager, TaskManager taskManager, ThreadPool threadPool, String nodeId, XPackLicenseState licenseState) {
        this.trainedModelAllocationService = trainedModelAllocationService;
        this.deploymentManager = deploymentManager;
        this.taskManager = taskManager;
        this.modelIdToTask = new ConcurrentHashMap<String, TrainedModelDeploymentTask>();
        this.loadingModels = new ConcurrentLinkedDeque<TrainedModelDeploymentTask>();
        this.threadPool = threadPool;
        this.nodeId = nodeId;
        this.licenseState = licenseState;
        clusterService.addLifecycleListener(new LifecycleListener(){

            public void afterStart() {
                TrainedModelAllocationNodeService.this.start();
            }

            public void beforeStop() {
                TrainedModelAllocationNodeService.this.stop();
            }
        });
    }

    void stopDeploymentAsync(TrainedModelDeploymentTask task, String reason, ActionListener<Void> listener) {
        if (this.stopped) {
            return;
        }
        task.stopWithoutNotification(reason);
        this.threadPool.executor("ml_utility").execute(() -> {
            try {
                this.deploymentManager.stopDeployment(task);
                this.taskManager.unregister((Task)task);
                this.modelIdToTask.remove(task.getModelId());
                listener.onResponse(null);
            }
            catch (Exception e) {
                listener.onFailure(e);
            }
        });
    }

    public void start() {
        this.stopped = false;
        this.scheduledFuture = this.threadPool.scheduleWithFixedDelay(this::loadQueuedModels, MODEL_LOADING_CHECK_INTERVAL, "ml_utility");
    }

    public void stop() {
        this.stopped = true;
        Scheduler.Cancellable cancellable = this.scheduledFuture;
        if (cancellable != null) {
            cancellable.cancel();
        }
    }

    void loadQueuedModels() {
        TrainedModelDeploymentTask loadingTask;
        logger.trace("attempting to load all currently queued models");
        ArrayDeque<TrainedModelDeploymentTask> loadingToRetry = new ArrayDeque<TrainedModelDeploymentTask>();
        while ((loadingTask = this.loadingModels.poll()) != null) {
            String modelId = loadingTask.getModelId();
            if (loadingTask.isStopped()) {
                if (!logger.isTraceEnabled()) continue;
                String reason = loadingTask.stoppedReason().orElse("_unknown_");
                logger.trace("[{}] attempted to load stopped task with reason [{}]", (Object)modelId, (Object)reason);
                continue;
            }
            if (this.stopped) {
                return;
            }
            logger.trace(() -> new ParameterizedMessage("[{}] attempting to load model", (Object)modelId));
            PlainActionFuture listener = new PlainActionFuture();
            try {
                this.deploymentManager.startDeployment(loadingTask, (ActionListener<TrainedModelDeploymentTask>)listener);
                TrainedModelDeploymentTask deployedTask = (TrainedModelDeploymentTask)((Object)listener.actionGet());
                this.handleLoadSuccess(deployedTask);
            }
            catch (Exception ex) {
                if (ExceptionsHelper.unwrapCause((Throwable)ex) instanceof ResourceNotFoundException) {
                    this.handleLoadFailure(loadingTask, (Exception)((Object)ExceptionsHelper.missingTrainedModel((String)loadingTask.getModelId())));
                    continue;
                }
                if (ExceptionsHelper.unwrapCause((Throwable)ex) instanceof SearchPhaseExecutionException) {
                    loadingToRetry.add(loadingTask);
                    continue;
                }
                this.handleLoadFailure(loadingTask, ex);
            }
        }
        this.loadingModels.addAll(loadingToRetry);
    }

    public void stopDeploymentAndNotify(TrainedModelDeploymentTask task, String reason) {
        ActionListener notifyDeploymentOfStopped = ActionListener.wrap(stopped -> this.updateStoredState(task.getModelId(), new RoutingStateAndReason(RoutingState.STOPPED, reason), (ActionListener<AcknowledgedResponse>)ActionListener.wrap(s -> {}, failure -> {})), failed -> {
            logger.warn(() -> new ParameterizedMessage("[{}] failed to stop due to error", (Object)task.getModelId()), (Throwable)failed);
            this.updateStoredState(task.getModelId(), new RoutingStateAndReason(RoutingState.STOPPED, reason), (ActionListener<AcknowledgedResponse>)ActionListener.wrap(s -> {}, failure -> {}));
        });
        this.updateStoredState(task.getModelId(), new RoutingStateAndReason(RoutingState.STOPPING, reason), (ActionListener<AcknowledgedResponse>)ActionListener.wrap(success -> this.stopDeploymentAsync(task, "task locally canceled", (ActionListener<Void>)notifyDeploymentOfStopped), e -> {
            if (ExceptionsHelper.unwrapCause((Throwable)e) instanceof ResourceNotFoundException) {
                logger.debug(() -> new ParameterizedMessage("[{}] failed to set routing state to stopping as allocation already removed", (Object)task.getModelId()), (Throwable)e);
            } else {
                logger.warn(() -> new ParameterizedMessage("[{}] failed to set routing state to stopping due to error", (Object)task.getModelId()), (Throwable)e);
            }
            this.stopDeploymentAsync(task, reason, (ActionListener<Void>)notifyDeploymentOfStopped);
        }));
    }

    public void infer(TrainedModelDeploymentTask task, InferenceConfig config, Map<String, Object> doc, TimeValue timeout, ActionListener<InferenceResults> listener) {
        this.deploymentManager.infer(task, config, doc, timeout, listener);
    }

    public Optional<ModelStats> modelStats(TrainedModelDeploymentTask task) {
        return this.deploymentManager.getStats(task);
    }

    private TaskAwareRequest taskAwareRequest(final StartTrainedModelDeploymentAction.TaskParams params) {
        final TrainedModelAllocationNodeService trainedModelAllocationNodeService = this;
        return new TaskAwareRequest(){

            public void setParentTask(TaskId taskId) {
                throw new UnsupportedOperationException("parent task id for model allocation tasks shouldn't change");
            }

            public TaskId getParentTask() {
                return TaskId.EMPTY_TASK_ID;
            }

            public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
                return new TrainedModelDeploymentTask(id, type, action, parentTaskId, headers, params, trainedModelAllocationNodeService, TrainedModelAllocationNodeService.this.licenseState, MachineLearning.ML_PYTORCH_MODEL_INFERENCE_FEATURE);
            }
        };
    }

    public void clusterChanged(ClusterChangedEvent event) {
        if (event.metadataChanged()) {
            boolean isResetMode = MlMetadata.getMlMetadata((ClusterState)event.state()).isResetMode();
            TrainedModelAllocationMetadata modelAllocationMetadata = TrainedModelAllocationMetadata.fromState(event.state());
            String currentNode = event.state().nodes().getLocalNodeId();
            for (TrainedModelAllocation trainedModelAllocation : modelAllocationMetadata.modelAllocations().values()) {
                TrainedModelDeploymentTask task;
                RoutingStateAndReason routingStateAndReason = (RoutingStateAndReason)trainedModelAllocation.getNodeRoutingTable().get(currentNode);
                if (routingStateAndReason != null && routingStateAndReason.getState().isAnyOf(new RoutingState[]{RoutingState.STARTING, RoutingState.STARTED}) && !this.modelIdToTask.containsKey(trainedModelAllocation.getTaskParams().getModelId()) && !isResetMode) {
                    this.prepareModelToLoad(trainedModelAllocation.getTaskParams());
                }
                if (routingStateAndReason != null || (task = this.modelIdToTask.remove(trainedModelAllocation.getTaskParams().getModelId())) == null) continue;
                this.stopDeploymentAsync(task, NODE_NO_LONGER_REFERENCED, (ActionListener<Void>)ActionListener.wrap(r -> logger.trace(() -> new ParameterizedMessage("[{}] stopped deployment", (Object)task.getModelId())), e -> logger.warn(() -> new ParameterizedMessage("[{}] failed to fully stop deployment", (Object)task.getModelId()), (Throwable)e)));
            }
            ArrayList<TrainedModelDeploymentTask> toCancel = new ArrayList<TrainedModelDeploymentTask>();
            for (String modelIds : Sets.difference(this.modelIdToTask.keySet(), modelAllocationMetadata.modelAllocations().keySet())) {
                toCancel.add(this.modelIdToTask.remove(modelIds));
            }
            for (TrainedModelDeploymentTask t : toCancel) {
                this.stopDeploymentAsync(t, ALLOCATION_NO_LONGER_EXISTS, (ActionListener<Void>)ActionListener.wrap(r -> logger.trace(() -> new ParameterizedMessage("[{}] stopped deployment", (Object)t.getModelId())), e -> logger.warn(() -> new ParameterizedMessage("[{}] failed to fully stop deployment", (Object)t.getModelId()), (Throwable)e)));
            }
        }
    }

    TrainedModelDeploymentTask getTask(String modelId) {
        return this.modelIdToTask.get(modelId);
    }

    void prepareModelToLoad(StartTrainedModelDeploymentAction.TaskParams taskParams) {
        logger.debug(() -> new ParameterizedMessage("[{}] preparing to load model with task params: {}", (Object)taskParams.getModelId(), (Object)taskParams));
        TrainedModelDeploymentTask task = (TrainedModelDeploymentTask)this.taskManager.register("trained_model_allocation", "xpack/ml/allocation-" + taskParams.getModelId(), this.taskAwareRequest(taskParams));
        if (this.modelIdToTask.putIfAbsent(taskParams.getModelId(), task) == null) {
            this.loadingModels.add(task);
        } else {
            this.taskManager.unregister((Task)task);
        }
    }

    private void handleLoadSuccess(TrainedModelDeploymentTask task) {
        String modelId = task.getModelId();
        logger.debug(() -> new ParameterizedMessage("[{}] model successfully loaded and ready for inference. Notifying master node", (Object)modelId));
        if (task.isStopped()) {
            logger.debug(() -> new ParameterizedMessage("[{}] model loaded successfully, but stopped before routing table was updated; reason [{}]", (Object)modelId, (Object)task.stoppedReason().orElse("_unknown_")));
            return;
        }
        this.updateStoredState(modelId, new RoutingStateAndReason(RoutingState.STARTED, ""), (ActionListener<AcknowledgedResponse>)ActionListener.wrap(r -> logger.debug(() -> new ParameterizedMessage("[{}] model loaded and accepting routes", (Object)modelId)), e -> {
            if (ExceptionsHelper.unwrapCause((Throwable)e) instanceof ResourceNotFoundException) {
                logger.debug(() -> new ParameterizedMessage("[{}] model loaded but failed to start accepting routes as allocation to this node was removed", (Object)modelId), (Throwable)e);
            }
            logger.warn(() -> new ParameterizedMessage("[{}] model loaded but failed to start accepting routes", (Object)modelId), (Throwable)e);
        }));
    }

    private void updateStoredState(String modelId, RoutingStateAndReason routingStateAndReason, ActionListener<AcknowledgedResponse> listener) {
        if (this.stopped) {
            return;
        }
        this.trainedModelAllocationService.updateModelAllocationState(new UpdateTrainedModelAllocationStateAction.Request(this.nodeId, modelId, routingStateAndReason), (ActionListener<AcknowledgedResponse>)ActionListener.wrap(success -> {
            logger.debug(() -> new ParameterizedMessage("[{}] model is [{}] and master notified", (Object)modelId, (Object)routingStateAndReason.getState()));
            listener.onResponse((Object)AcknowledgedResponse.TRUE);
        }, error -> {
            logger.warn(() -> new ParameterizedMessage("[{}] model is [{}] but failed to notify master", (Object)modelId, (Object)routingStateAndReason.getState()), (Throwable)error);
            listener.onFailure(error);
        }));
    }

    private void handleLoadFailure(TrainedModelDeploymentTask task, Exception ex) {
        logger.error(() -> new ParameterizedMessage("[{}] model failed to load", (Object)task.getModelId()), (Throwable)ex);
        if (task.isStopped()) {
            logger.debug(() -> new ParameterizedMessage("[{}] model failed to load, but is now stopped; reason [{}]", (Object)task.getModelId(), (Object)task.stoppedReason().orElse("_unknown_")));
        }
        Runnable stopTask = () -> this.stopDeploymentAsync(task, "model failed to load; reason [" + ex.getMessage() + "]", (ActionListener<Void>)ActionListener.wrap(r -> {}, e -> {}));
        this.updateStoredState(task.getModelId(), new RoutingStateAndReason(RoutingState.FAILED, ExceptionsHelper.unwrapCause((Throwable)ex).getMessage()), (ActionListener<AcknowledgedResponse>)ActionListener.wrap(r -> stopTask.run(), e -> stopTask.run()));
    }

    public void failAllocation(TrainedModelDeploymentTask task, String reason) {
        this.updateStoredState(task.getModelId(), new RoutingStateAndReason(RoutingState.FAILED, reason), (ActionListener<AcknowledgedResponse>)ActionListener.wrap(r -> logger.debug((Message)new ParameterizedMessage("[{}] Successfully updating allocation state to [{}] with reason [{}]", new Object[]{task.getModelId(), RoutingState.FAILED, reason})), e -> logger.error((Message)new ParameterizedMessage("[{}] Error while updating allocation state to [{}] with reason [{}]", new Object[]{task.getModelId(), RoutingState.FAILED, reason}), (Throwable)e)));
    }
}

