/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.task;

import com.google.common.collect.ImmutableList;
import java.io.IOException;
import java.time.Instant;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchException;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.ActionListenerResponseHandler;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.get.GetResponse;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.support.ThreadedActionListener;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.transport.TransportResponse;
import org.opensearch.core.xcontent.DeprecationHandler;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.core.xcontent.XContentParserUtils;
import org.opensearch.ml.breaker.MLCircuitBreakerService;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.dataset.MLInputDataType;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLPredictionOutput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.settings.MLCommonsSettings;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelAction;
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.Predictable;
import org.opensearch.ml.engine.indices.MLInputDatasetHandler;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.permission.AccessController;
import org.opensearch.ml.stats.ActionName;
import org.opensearch.ml.stats.MLActionLevelStat;
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.ml.stats.otel.counters.MLOperationalMetricsCounter;
import org.opensearch.ml.stats.otel.metrics.OperationalMetric;
import org.opensearch.ml.task.MLTaskDispatcher;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.task.MLTaskRunner;
import org.opensearch.ml.utils.MLExceptionUtils;
import org.opensearch.ml.utils.MLNodeUtils;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.StreamTransportResponseHandler;
import org.opensearch.transport.TransportChannel;
import org.opensearch.transport.TransportException;
import org.opensearch.transport.TransportRequest;
import org.opensearch.transport.TransportRequestOptions;
import org.opensearch.transport.TransportResponseHandler;
import org.opensearch.transport.TransportService;
import org.opensearch.transport.client.Client;
import org.opensearch.transport.stream.StreamTransportResponse;

public class MLPredictTaskRunner
extends MLTaskRunner<MLPredictionTaskRequest, MLTaskResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(MLPredictTaskRunner.class);
    private final ThreadPool threadPool;
    private final ClusterService clusterService;
    private final Client client;
    private final MLInputDatasetHandler mlInputDatasetHandler;
    private final NamedXContentRegistry xContentRegistry;
    private final MLModelManager mlModelManager;
    private final DiscoveryNodeHelper nodeHelper;
    private final MLEngine mlEngine;
    private volatile boolean autoDeploymentEnabled;
    public static final String BUCKET_FIELD = "bucket";
    public static final String REGION_FIELD = "region";

    public MLPredictTaskRunner(ThreadPool threadPool, ClusterService clusterService, Client client, MLTaskManager mlTaskManager, MLStats mlStats, MLInputDatasetHandler mlInputDatasetHandler, MLTaskDispatcher mlTaskDispatcher, MLCircuitBreakerService mlCircuitBreakerService, NamedXContentRegistry xContentRegistry, MLModelManager mlModelManager, DiscoveryNodeHelper nodeHelper, MLEngine mlEngine, Settings settings) {
        super(mlTaskManager, mlStats, nodeHelper, mlTaskDispatcher, mlCircuitBreakerService, clusterService);
        this.threadPool = threadPool;
        this.clusterService = clusterService;
        this.client = client;
        this.mlInputDatasetHandler = mlInputDatasetHandler;
        this.xContentRegistry = xContentRegistry;
        this.mlModelManager = mlModelManager;
        this.nodeHelper = nodeHelper;
        this.mlEngine = mlEngine;
        this.autoDeploymentEnabled = (Boolean)MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE.get(settings);
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE, it -> {
            this.autoDeploymentEnabled = it;
        });
    }

    @Override
    protected String getTransportActionName() {
        return "cluster:admin/opensearch/ml/predict";
    }

    @Override
    protected String getTransportStreamActionName() {
        return "cluster:admin/opensearch/ml/predict/stream";
    }

    @Override
    protected TransportResponseHandler<MLTaskResponse> getResponseHandler(ActionListener<MLTaskResponse> listener) {
        return new ActionListenerResponseHandler(listener, MLTaskResponse::new);
    }

    @Override
    protected TransportResponseHandler<MLTaskResponse> getResponseStreamHandler(MLPredictionTaskRequest request) {
        final TransportChannel channel = request.getStreamingChannel();
        return new StreamTransportResponseHandler<MLTaskResponse>(){

            public void handleStreamResponse(StreamTransportResponse<MLTaskResponse> streamResponse) {
                try {
                    MLTaskResponse response;
                    while ((response = (MLTaskResponse)streamResponse.nextResponse()) != null) {
                        channel.sendResponseBatch((TransportResponse)response);
                    }
                    channel.completeStream();
                    streamResponse.close();
                }
                catch (Exception e) {
                    streamResponse.cancel("Stream error", (Throwable)e);
                }
            }

            public void handleException(TransportException exp) {
                try {
                    channel.sendResponse((Exception)exp);
                }
                catch (IOException e) {
                    throw new RuntimeException(e);
                }
            }

            public String executor() {
                return "same";
            }

            public MLTaskResponse read(StreamInput in) throws IOException {
                return new MLTaskResponse(in);
            }
        };
    }

    @Override
    public void dispatchTask(FunctionName functionName, MLPredictionTaskRequest request, TransportService transportService, ActionListener<MLTaskResponse> listener) {
        RemoteInferenceInputDataSet inputDataset;
        Map dlq;
        String modelId = request.getModelId();
        if (request.getMlInput().getInputDataset() instanceof RemoteInferenceInputDataSet && (dlq = (inputDataset = (RemoteInferenceInputDataSet)request.getMlInput().getInputDataset()).getDlq()) != null) {
            String bucketName = (String)dlq.get(BUCKET_FIELD);
            String region = (String)dlq.get(REGION_FIELD);
            if (bucketName == null || region == null) {
                throw new IllegalArgumentException("DLQ bucketName or region cannot be null");
            }
        }
        try {
            ActionListener actionListener = ActionListener.wrap(node -> {
                if (this.clusterService.localNode().getId().equals(node.getId())) {
                    log.debug("Execute ML predict request {} locally on node {}", (Object)request.getRequestID(), (Object)node.getId());
                    request.setDispatchTask(false);
                    this.checkCBAndExecute(functionName, request, listener);
                } else {
                    log.debug("Execute ML predict request {} remotely on node {}", (Object)request.getRequestID(), (Object)node.getId());
                    request.setDispatchTask(false);
                    if (this.isStreamingRequest(request)) {
                        log.debug("Using streaming transport for request {}", (Object)request.getRequestID());
                        transportService.sendRequest(node, this.getTransportStreamActionName(), (TransportRequest)request, TransportRequestOptions.builder().withType(TransportRequestOptions.Type.STREAM).build(), this.getResponseStreamHandler(request));
                    } else {
                        transportService.sendRequest(node, this.getTransportActionName(), (TransportRequest)request, this.getResponseHandler(listener));
                    }
                }
            }, arg_0 -> listener.onFailure(arg_0));
            String[] workerNodes = this.mlModelManager.getWorkerNodes(modelId, functionName, true);
            String[] targetWorkerNodes = this.mlModelManager.getTargetWorkerNodes(modelId);
            if (this.requiresAutoDeployment(workerNodes, targetWorkerNodes)) {
                if (FunctionName.isAutoDeployEnabled((boolean)this.autoDeploymentEnabled, (FunctionName)functionName)) {
                    try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
                        this.mlModelManager.getModel(modelId, request.getTenantId(), (ActionListener<MLModel>)ActionListener.runBefore((ActionListener)ActionListener.wrap(model -> {
                            MLModel modelBeingAutoDeployed;
                            Boolean isHidden = model.getIsHidden();
                            if (!this.checkModelAutoDeployEnabled((MLModel)model)) {
                                String errorMsg = StringUtils.getErrorMessage((String)"Auto deployment disabled for this model, please deploy model first", (String)modelId, (Boolean)isHidden);
                                log.info(errorMsg);
                                listener.onFailure((Exception)new IllegalArgumentException(errorMsg));
                                return;
                            }
                            String[] planningWorkerNodes = model.getPlanningWorkerNodes();
                            boolean deployToAllNodes = model.isDeployToAllNodes();
                            if (deployToAllNodes) {
                                planningWorkerNodes = null;
                            }
                            if ((modelBeingAutoDeployed = this.mlModelManager.addModelToAutoDeployCache(modelId, (MLModel)model)) == model) {
                                log.info(StringUtils.getErrorMessage((String)"Automatically deploy model", (String)modelId, (Boolean)isHidden));
                                MLDeployModelRequest deployModelRequest = new MLDeployModelRequest(modelId, request.getTenantId(), planningWorkerNodes, false, true, false);
                                this.client.execute((ActionType)MLDeployModelAction.INSTANCE, (ActionRequest)deployModelRequest, ActionListener.wrap(r -> log.info(StringUtils.getErrorMessage((String)"Auto deployment action triggered for the model", (String)modelId, (Boolean)isHidden)), e -> log.info(StringUtils.getErrorMessage((String)"Auto deployment action failed for the given model {}", (String)modelId, (Boolean)isHidden), (Throwable)e)));
                            }
                            if (planningWorkerNodes == null || planningWorkerNodes.length == 0) {
                                planningWorkerNodes = this.nodeHelper.getEligibleNodeIds(functionName);
                            }
                            this.mlTaskDispatcher.dispatchPredictTask(planningWorkerNodes, (ActionListener<DiscoveryNode>)actionListener);
                        }, e -> {
                            log.error("Failed to get model " + modelId, (Throwable)e);
                            listener.onFailure(e);
                        }), () -> ((ThreadContext.StoredContext)context).restore()));
                    }
                    return;
                }
                if (FunctionName.needDeployFirst((FunctionName)functionName)) {
                    listener.onFailure((Exception)new IllegalArgumentException("Model not ready yet. Please deploy the model first."));
                    return;
                }
                workerNodes = this.nodeHelper.getEligibleNodeIds(functionName);
            } else {
                this.mlModelManager.removeAutoDeployModel(modelId);
            }
            this.mlTaskDispatcher.dispatchPredictTask(workerNodes, (ActionListener<DiscoveryNode>)actionListener);
        }
        catch (Exception e2) {
            log.error("Failed to predict model " + modelId, (Throwable)e2);
            listener.onFailure(e2);
        }
    }

    @Override
    protected void executeTask(MLPredictionTaskRequest request, ActionListener<MLTaskResponse> listener) {
        TransportChannel channel = request.getStreamingChannel();
        String tenantId = request.getTenantId();
        MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType();
        Instant now = Instant.now();
        String modelId = request.getModelId();
        FunctionName functionName = request.getMlInput().getFunctionName();
        MLInput mlInput = request.getMlInput();
        ConnectorAction.ActionType actionType = null;
        if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) {
            actionType = ((RemoteInferenceInputDataSet)mlInput.getInputDataset()).getActionType();
        }
        actionType = actionType == null ? ConnectorAction.ActionType.PREDICT : actionType;
        MLTask mlTask = MLTask.builder().taskId(UUID.randomUUID().toString()).modelId(modelId).taskType(actionType.equals((Object)ConnectorAction.ActionType.BATCH_PREDICT) ? MLTaskType.BATCH_PREDICTION : MLTaskType.PREDICTION).inputType(inputDataType).functionName(functionName).state(MLTaskState.CREATED).workerNodes((List)ImmutableList.of((Object)this.clusterService.localNode().getId())).createTime(now).lastUpdateTime(now).async(false).tenantId(tenantId).build();
        if (actionType.equals((Object)ConnectorAction.ActionType.BATCH_PREDICT)) {
            this.mlModelManager.checkMaxBatchJobTask(mlTask, (ActionListener<Boolean>)ActionListener.wrap(exceedLimits -> {
                if (exceedLimits.booleanValue()) {
                    String error = "Exceeded maximum limit for BATCH_PREDICTION tasks. To increase the limit, update the plugins.ml_commons.max_batch_inference_tasks setting.";
                    log.warn(error + " in task " + mlTask.getTaskId());
                    listener.onFailure((Exception)new OpenSearchStatusException(error, RestStatus.TOO_MANY_REQUESTS, new Object[0]));
                } else {
                    this.executePredictionByInputDataType(inputDataType, modelId, mlInput, mlTask, functionName, tenantId, listener, channel);
                }
            }, exception -> {
                log.error("Failed to check the maximum BATCH_PREDICTION Task limits", (Throwable)exception);
                listener.onFailure(exception);
            }));
            return;
        }
        this.executePredictionByInputDataType(inputDataType, modelId, mlInput, mlTask, functionName, tenantId, listener, channel);
    }

    @Override
    protected boolean isStreamingRequest(MLPredictionTaskRequest request) {
        return request.getStreamingChannel() != null;
    }

    private void executePredictionByInputDataType(MLInputDataType inputDataType, String modelId, MLInput mlInput, MLTask mlTask, FunctionName functionName, String tenantId, ActionListener<MLTaskResponse> listener, TransportChannel channel) {
        switch (inputDataType) {
            case SEARCH_QUERY: {
                ActionListener dataFrameActionListener = ActionListener.wrap(dataSet -> {
                    MLInput newInput = mlInput.toBuilder().inputDataset(dataSet).build();
                    this.predict(modelId, tenantId, mlTask, newInput, listener, channel);
                }, e -> {
                    log.error("Failed to generate DataFrame from search query", (Throwable)e);
                    this.handleAsyncMLTaskFailure(mlTask, (Exception)e);
                    listener.onFailure(e);
                });
                this.mlInputDatasetHandler.parseSearchQueryInput(mlInput.getInputDataset(), this.threadedActionListener(functionName, dataFrameActionListener));
                break;
            }
            default: {
                String threadPoolName = this.getPredictThreadPool(functionName);
                this.threadPool.executor(threadPoolName).execute(() -> this.predict(modelId, tenantId, mlTask, mlInput, listener, channel));
            }
        }
    }

    private boolean checkModelAutoDeployEnabled(MLModel mlModel) {
        if (mlModel.getDeploySetting() == null || mlModel.getDeploySetting().getIsAutoDeployEnabled() == null) {
            return true;
        }
        return mlModel.getDeploySetting().getIsAutoDeployEnabled();
    }

    private String getPredictThreadPool(FunctionName functionName) {
        return functionName == FunctionName.REMOTE ? "opensearch_ml_predict_remote" : "opensearch_ml_predict";
    }

    private void predict(String modelId, String tenantId, MLTask mlTask, MLInput mlInput, ActionListener<MLTaskResponse> listener, TransportChannel channel) {
        boolean modelReady;
        ActionListener<MLTaskResponse> internalListener = this.wrappedCleanupListener(listener, mlTask.getTaskId());
        ActionName actionName = this.getActionNameFromInput(mlInput);
        this.mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment();
        this.mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment();
        this.mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), actionName, MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment();
        if (modelId != null) {
            this.mlStats.createModelCounterStatIfAbsent(modelId, actionName, MLActionLevelStat.ML_ACTION_REQUEST_COUNT).increment();
        }
        mlTask.setState(MLTaskState.RUNNING);
        this.mlTaskManager.add(mlTask);
        FunctionName functionName = mlInput.getFunctionName();
        Predictable predictor = this.mlModelManager.getPredictor(modelId);
        boolean bl = modelReady = predictor != null && predictor.isModelReady();
        if (!modelReady && FunctionName.isAutoDeployEnabled((boolean)this.autoDeploymentEnabled, (FunctionName)functionName)) {
            log.info("Auto deploy model {} to local node", (Object)modelId);
            Instant now = Instant.now();
            MLTask mlDeployTask = MLTask.builder().taskId(UUID.randomUUID().toString()).functionName(functionName).async(false).taskType(MLTaskType.DEPLOY_MODEL).createTime(now).lastUpdateTime(now).state(MLTaskState.RUNNING).workerNodes(Arrays.asList(this.clusterService.localNode().getId())).tenantId(tenantId).build();
            this.mlModelManager.deployModel(modelId, tenantId, null, functionName, false, true, mlDeployTask, (ActionListener<String>)ActionListener.wrap(s -> this.runPredict(modelId, tenantId, mlTask, mlInput, functionName, actionName, internalListener, channel), e -> {
                log.error("Failed to auto deploy model {}", (Object)modelId, e);
                internalListener.onFailure(e);
            }));
            return;
        }
        this.runPredict(modelId, tenantId, mlTask, mlInput, functionName, actionName, internalListener, channel);
    }

    private void recordPredictMetrics(String modelId, double durationInMs, MLTaskResponse output, ActionListener<MLTaskResponse> internalListener) {
        this.mlModelManager.getModel(modelId, (ActionListener<MLModel>)ActionListener.wrap(model -> {
            if (model != null) {
                if (model.getConnector() == null && model.getConnectorId() != null) {
                    this.mlModelManager.getConnector(model.getConnectorId(), model.getTenantId(), (ActionListener<Connector>)ActionListener.wrap(connector -> {
                        MLOperationalMetricsCounter.getInstance().incrementCounter(OperationalMetric.MODEL_PREDICT_COUNT, model.getTags(connector));
                        MLOperationalMetricsCounter.getInstance().recordHistogram(OperationalMetric.MODEL_PREDICT_LATENCY, durationInMs, model.getTags(connector));
                        internalListener.onResponse((Object)output);
                    }, e -> {
                        log.error("Failed to get connector for latency metrics", (Throwable)e);
                        internalListener.onResponse((Object)output);
                    }));
                    return;
                }
                MLOperationalMetricsCounter.getInstance().incrementCounter(OperationalMetric.MODEL_PREDICT_COUNT, model.getTags());
                MLOperationalMetricsCounter.getInstance().recordHistogram(OperationalMetric.MODEL_PREDICT_LATENCY, durationInMs, model.getTags());
                internalListener.onResponse((Object)output);
            } else {
                internalListener.onResponse((Object)output);
            }
        }, e -> {
            log.error("Failed to get model for latency metrics", (Throwable)e);
            internalListener.onResponse((Object)output);
        }));
    }

    private void runPredict(String modelId, String tenantId, MLTask mlTask, MLInput mlInput, FunctionName algorithm, ActionName actionName, ActionListener<MLTaskResponse> internalListener, TransportChannel channel) {
        if (modelId != null) {
            Predictable predictor = this.mlModelManager.getPredictor(modelId);
            if (predictor != null) {
                try {
                    if (!predictor.isModelReady()) {
                        throw new IllegalArgumentException("Model not ready: " + modelId);
                    }
                    if (mlInput.getAlgorithm() == FunctionName.REMOTE) {
                        long startTime = System.nanoTime();
                        ActionListener trackPredictDurationListener = ActionListener.wrap(output -> {
                            if (output.getOutput() instanceof ModelTensorOutput) {
                                this.validateOutputSchema(modelId, (ModelTensorOutput)output.getOutput());
                            }
                            if (mlTask.getTaskType().equals((Object)MLTaskType.BATCH_PREDICTION)) {
                                HashMap<String, Map> remoteJob = new HashMap<String, Map>();
                                ModelTensorOutput tensorOutput = (ModelTensorOutput)output.getOutput();
                                if (tensorOutput != null && tensorOutput.getMlModelOutputs() != null && !tensorOutput.getMlModelOutputs().isEmpty()) {
                                    ModelTensors modelOutput = (ModelTensors)tensorOutput.getMlModelOutputs().get(0);
                                    Integer statusCode = modelOutput.getStatusCode();
                                    if (modelOutput.getMlModelTensors() != null && !modelOutput.getMlModelTensors().isEmpty()) {
                                        Map dataAsMap = ((ModelTensor)modelOutput.getMlModelTensors().get(0)).getDataAsMap();
                                        if (dataAsMap != null && statusCode != null && statusCode >= 200 && statusCode < 300) {
                                            remoteJob.putAll(dataAsMap);
                                            remoteJob.put("dlq", ((RemoteInferenceInputDataSet)mlInput.getInputDataset()).getDlq());
                                            mlTask.setRemoteJob(remoteJob);
                                            mlTask.setTaskId(null);
                                            this.mlTaskManager.createMLTask(mlTask, (ActionListener<IndexResponse>)ActionListener.wrap(response -> {
                                                String taskId = response.getId();
                                                mlTask.setTaskId(taskId);
                                                MLPredictionOutput outputBuilder = new MLPredictionOutput(taskId, MLTaskState.CREATED.name(), remoteJob);
                                                this.mlTaskManager.startTaskPollingJob();
                                                MLTaskResponse predictOutput = MLTaskResponse.builder().output((MLOutput)outputBuilder).build();
                                                internalListener.onResponse((Object)predictOutput);
                                            }, e -> {
                                                MLExceptionUtils.logException("Failed to create task for batch predict model", e, log);
                                                internalListener.onFailure(e);
                                            }));
                                        } else {
                                            log.debug("Batch transform job output from remote model did not return the job ID");
                                            internalListener.onFailure((Exception)new ResourceNotFoundException("Unable to create batch transform job", new Object[0]));
                                        }
                                    } else {
                                        log.debug("ML Model Tensors are null or empty.");
                                        internalListener.onFailure((Exception)new ResourceNotFoundException("Unable to create batch transform job", new Object[0]));
                                    }
                                } else {
                                    log.debug("ML Model Outputs are null or empty.");
                                    internalListener.onFailure((Exception)new ResourceNotFoundException("Unable to create batch transform job", new Object[0]));
                                }
                            } else {
                                this.handleAsyncMLTaskComplete(mlTask);
                                this.mlModelManager.trackPredictDuration(modelId, startTime);
                                internalListener.onResponse(output);
                            }
                        }, e -> this.handlePredictFailure(mlTask, internalListener, (Exception)e, false, modelId, actionName));
                        predictor.asyncPredict(mlInput, trackPredictDurationListener, channel);
                    } else {
                        MLOutput output2 = this.mlModelManager.trackPredictDuration(modelId, () -> predictor.predict(mlInput));
                        if (output2 instanceof MLPredictionOutput) {
                            ((MLPredictionOutput)output2).setStatus(MLTaskState.COMPLETED.name());
                        }
                        if (output2 instanceof ModelTensorOutput) {
                            this.validateOutputSchema(modelId, (ModelTensorOutput)output2);
                        }
                        this.handleAsyncMLTaskComplete(mlTask);
                        internalListener.onResponse((Object)new MLTaskResponse(output2));
                    }
                    return;
                }
                catch (Exception e2) {
                    log.error("Failed to predict model " + modelId, (Throwable)e2);
                    this.handlePredictFailure(mlTask, internalListener, e2, false, modelId, actionName);
                    return;
                }
            }
            if (FunctionName.needDeployFirst((FunctionName)algorithm)) {
                throw new IllegalArgumentException("Model not ready to be used: " + modelId);
            }
            try (ThreadContext.StoredContext context = this.threadPool.getThreadContext().stashContext();){
                ActionListener getModelListener = ActionListener.wrap(r -> {
                    if (r == null || !r.isExists()) {
                        internalListener.onFailure((Exception)new ResourceNotFoundException("No model found, please check the modelId.", new Object[0]));
                        return;
                    }
                    try (XContentParser xContentParser = XContentType.JSON.xContent().createParser(this.xContentRegistry, (DeprecationHandler)LoggingDeprecationHandler.INSTANCE, r.getSourceAsString());){
                        MLOutput output;
                        XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_OBJECT, (XContentParser.Token)xContentParser.nextToken(), (XContentParser)xContentParser);
                        GetResponse getResponse = r;
                        String algorithmName = getResponse.getSource().get("algorithm").toString();
                        MLModel mlModel = MLModel.parse((XContentParser)xContentParser, (String)algorithmName);
                        mlModel.setModelId(modelId);
                        User resourceUser = mlModel.getUser();
                        User requestUser = AccessController.getUserContext(this.client);
                        if (!AccessController.checkUserPermissions(requestUser, resourceUser, modelId)) {
                            OpenSearchException e = new OpenSearchException("User: " + requestUser.getName() + " does not have permissions to run predict by model: " + modelId, new Object[0]);
                            this.handlePredictFailure(mlTask, internalListener, (Exception)e, false, modelId, actionName);
                            return;
                        }
                        if (this.mlTaskManager.contains(mlTask.getTaskId())) {
                            this.mlTaskManager.updateTaskStateAsRunning(mlTask.getTaskId(), tenantId, mlTask.isAsync());
                        }
                        if ((output = this.mlEngine.predict((Input)mlInput, mlModel)) instanceof MLPredictionOutput) {
                            ((MLPredictionOutput)output).setStatus(MLTaskState.COMPLETED.name());
                        }
                        if (output instanceof ModelTensorOutput) {
                            this.validateOutputSchema(modelId, (ModelTensorOutput)output);
                        }
                        this.handleAsyncMLTaskComplete(mlTask);
                        MLTaskResponse response = MLTaskResponse.builder().output(output).build();
                        internalListener.onResponse((Object)response);
                    }
                    catch (Exception e) {
                        log.error("Failed to predict model " + modelId, (Throwable)e);
                        internalListener.onFailure(e);
                    }
                }, e -> {
                    log.error("Failed to predict " + String.valueOf(mlInput.getAlgorithm()) + ", modelId: " + mlTask.getModelId(), (Throwable)e);
                    this.handlePredictFailure(mlTask, internalListener, (Exception)e, true, modelId, actionName);
                });
                GetRequest getRequest = new GetRequest(".plugins-ml-model", mlTask.getModelId());
                this.client.get(getRequest, this.threadedActionListener(mlTask.getFunctionName(), ActionListener.runBefore((ActionListener)getModelListener, () -> context.restore())));
            }
            catch (Exception e3) {
                log.error("Failed to get model " + mlTask.getModelId(), (Throwable)e3);
                this.handlePredictFailure(mlTask, internalListener, e3, true, modelId, actionName);
            }
        } else {
            IllegalArgumentException e4 = new IllegalArgumentException("ModelId is invalid");
            log.error("ModelId is invalid", (Throwable)e4);
            this.handlePredictFailure(mlTask, internalListener, e4, false, modelId, actionName);
        }
    }

    private <T> ThreadedActionListener<T> threadedActionListener(FunctionName functionName, ActionListener<T> listener) {
        String threadPoolName = this.getPredictThreadPool(functionName);
        return new ThreadedActionListener(log, this.threadPool, threadPoolName, listener, false);
    }

    private void handlePredictFailure(MLTask mlTask, ActionListener<MLTaskResponse> listener, Exception e, boolean trackFailure, String modelId, ActionName actionName) {
        if (trackFailure) {
            this.mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), actionName, MLActionLevelStat.ML_ACTION_FAILURE_COUNT).increment();
            this.mlStats.createModelCounterStatIfAbsent(modelId, actionName, MLActionLevelStat.ML_ACTION_FAILURE_COUNT);
            this.mlStats.getStat(MLNodeLevelStat.ML_FAILURE_COUNT).increment();
        }
        this.handleAsyncMLTaskFailure(mlTask, e);
        listener.onFailure(e);
    }

    private ActionName getActionNameFromInput(MLInput mlInput) {
        ConnectorAction.ActionType actionType = null;
        if (mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet) {
            actionType = ((RemoteInferenceInputDataSet)mlInput.getInputDataset()).getActionType();
        }
        return actionType == null ? ActionName.PREDICT : ActionName.from(actionType.toString());
    }

    public void validateOutputSchema(String modelId, ModelTensorOutput output) {
        if (this.mlModelManager.getModelInterface(modelId) != null && this.mlModelManager.getModelInterface(modelId).get("output") != null) {
            String outputSchemaString = this.mlModelManager.getModelInterface(modelId).get("output");
            try {
                MLNodeUtils.validateSchema(outputSchemaString, output.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString());
            }
            catch (Exception e) {
                throw new OpenSearchStatusException("Error validating output schema, if you think this is expected, please update your 'output' field in the 'interface' field for this model: " + e.getMessage(), RestStatus.BAD_REQUEST, new Object[0]);
            }
        }
    }

    private boolean requiresAutoDeployment(String[] workerNodes, String[] targetWorkerNodes) {
        return workerNodes == null || workerNodes.length == 0 || targetWorkerNodes != null && workerNodes.length < targetWorkerNodes.length;
    }
}

