/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.job;

import java.lang.invoke.CallSite;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.stream.Collectors;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.Strings;
import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
import org.elasticsearch.xpack.core.ml.MlTasks;
import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingState;
import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason;
import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
import org.elasticsearch.xpack.core.ml.utils.MemoryTrackedTaskState;
import org.elasticsearch.xpack.core.ml.utils.MlTaskParams;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
import org.elasticsearch.xpack.ml.job.NodeLoad;
import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
import org.elasticsearch.xpack.ml.utils.NativeMemoryCalculator;

public class NodeLoadDetector {
    private final MlMemoryTracker mlMemoryTracker;

    public NodeLoadDetector(MlMemoryTracker memoryTracker) {
        this.mlMemoryTracker = memoryTracker;
    }

    public MlMemoryTracker getMlMemoryTracker() {
        return this.mlMemoryTracker;
    }

    public NodeLoad detectNodeLoad(ClusterState clusterState, DiscoveryNode node, int dynamicMaxOpenJobs, int maxMachineMemoryPercent, boolean useAutoMachineMemoryCalculation) {
        return this.detectNodeLoad(clusterState, TrainedModelAllocationMetadata.fromState(clusterState), node, dynamicMaxOpenJobs, maxMachineMemoryPercent, useAutoMachineMemoryCalculation);
    }

    public NodeLoad detectNodeLoad(ClusterState clusterState, TrainedModelAllocationMetadata allocationMetadata, DiscoveryNode node, int maxNumberOfOpenJobs, int maxMachineMemoryPercent, boolean useAutoMachineMemoryCalculation) {
        PersistentTasksCustomMetadata persistentTasks = (PersistentTasksCustomMetadata)clusterState.getMetadata().custom("persistent_tasks");
        Map nodeAttributes = node.getAttributes();
        ArrayList<CallSite> errors = new ArrayList<CallSite>();
        OptionalLong maxMlMemory = NativeMemoryCalculator.allowedBytesForMl(node, maxMachineMemoryPercent, useAutoMachineMemoryCalculation);
        if (maxMlMemory.isEmpty()) {
            errors.add((CallSite)((Object)("ml.machine_memory attribute [" + (String)nodeAttributes.get("ml.machine_memory") + "] is not a long")));
        }
        NodeLoad.Builder nodeLoad = NodeLoad.builder(node.getId()).setMaxMemory(maxMlMemory.orElse(-1L)).setMaxJobs(maxNumberOfOpenJobs).setUseMemory(true);
        if (!errors.isEmpty()) {
            return nodeLoad.setError(Strings.collectionToCommaDelimitedString(errors)).build();
        }
        this.updateLoadGivenTasks(nodeLoad, persistentTasks);
        this.updateLoadGivenModelAllocations(nodeLoad, allocationMetadata);
        return nodeLoad.build();
    }

    private void updateLoadGivenTasks(NodeLoad.Builder nodeLoad, PersistentTasksCustomMetadata persistentTasks) {
        if (persistentTasks != null) {
            Collection<PersistentTasksCustomMetadata.PersistentTask<?>> memoryTrackedTasks = NodeLoadDetector.findAllMemoryTrackedTasks(persistentTasks, nodeLoad.getNodeId());
            for (PersistentTasksCustomMetadata.PersistentTask<?> task : memoryTrackedTasks) {
                MemoryTrackedTaskState state = MlTasks.getMemoryTrackedTaskState(task);
                if (state != null && !state.consumesMemory()) continue;
                MlTaskParams taskParams = (MlTaskParams)task.getParams();
                nodeLoad.addTask(task.getTaskName(), taskParams.getMlId(), state.isAllocating(), this.mlMemoryTracker);
            }
            if (nodeLoad.getNumAssignedJobs() > 0L) {
                nodeLoad.incAssignedJobMemory(MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD.getBytes());
            }
        }
    }

    private void updateLoadGivenModelAllocations(NodeLoad.Builder nodeLoad, TrainedModelAllocationMetadata trainedModelAllocationMetadata) {
        if (trainedModelAllocationMetadata != null && !trainedModelAllocationMetadata.modelAllocations().isEmpty()) {
            for (TrainedModelAllocation allocation : trainedModelAllocationMetadata.modelAllocations().values()) {
                if (!Optional.ofNullable((RoutingStateAndReason)allocation.getNodeRoutingTable().get(nodeLoad.getNodeId())).map(RoutingStateAndReason::getState).orElse(RoutingState.STOPPED).consumesMemory()) continue;
                nodeLoad.incNumAssignedJobs();
                nodeLoad.incAssignedJobMemory(allocation.getTaskParams().estimateMemoryUsageBytes());
            }
        }
    }

    private static Collection<PersistentTasksCustomMetadata.PersistentTask<?>> findAllMemoryTrackedTasks(PersistentTasksCustomMetadata persistentTasks, String nodeId) {
        return persistentTasks.tasks().stream().filter(NodeLoadDetector::isMemoryTrackedTask).filter(task -> nodeId.equals(task.getExecutorNode())).collect(Collectors.toList());
    }

    private static boolean isMemoryTrackedTask(PersistentTasksCustomMetadata.PersistentTask<?> task) {
        return "xpack/ml/job".equals(task.getTaskName()) || "xpack/ml/job/snapshot/upgrade".equals(task.getTaskName()) || "xpack/ml/data_frame/analytics".equals(task.getTaskName());
    }
}

