/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.nodes.exec.stream;

import java.util.Collections;
import java.util.List;
import java.util.Optional;
import javax.annotation.Nullable;
import org.apache.flink.FlinkVersion;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.dag.Transformation;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.PipelineOptions;
import org.apache.flink.configuration.ReadableConfig;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty;
import org.apache.flink.streaming.api.functions.ProcessFunction;
import org.apache.flink.streaming.api.functions.async.AsyncFunction;
import org.apache.flink.streaming.api.operators.ProcessOperator;
import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
import org.apache.flink.streaming.api.operators.StreamOperator;
import org.apache.flink.streaming.api.operators.async.AsyncWaitOperatorFactory;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.catalog.DataTypeFactory;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.data.conversion.DataStructureConverter;
import org.apache.flink.table.data.conversion.DataStructureConverters;
import org.apache.flink.table.functions.AsyncPredictFunction;
import org.apache.flink.table.functions.PredictFunction;
import org.apache.flink.table.functions.UserDefinedFunction;
import org.apache.flink.table.ml.AsyncPredictRuntimeProvider;
import org.apache.flink.table.ml.ModelProvider;
import org.apache.flink.table.ml.PredictRuntimeProvider;
import org.apache.flink.table.planner.calcite.FlinkContext;
import org.apache.flink.table.planner.codegen.CodeGeneratorContext;
import org.apache.flink.table.planner.codegen.FilterCodeGenerator;
import org.apache.flink.table.planner.codegen.LookupJoinCodeGenerator;
import org.apache.flink.table.planner.delegation.PlannerBase;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeBase;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeConfig;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeContext;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeMetadata;
import org.apache.flink.table.planner.plan.nodes.exec.InputProperty;
import org.apache.flink.table.planner.plan.nodes.exec.MultipleTransformationTranslator;
import org.apache.flink.table.planner.plan.nodes.exec.spec.MLPredictSpec;
import org.apache.flink.table.planner.plan.nodes.exec.spec.ModelSpec;
import org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecNode;
import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil;
import org.apache.flink.table.planner.plan.utils.FunctionCallUtils;
import org.apache.flink.table.planner.utils.JavaScalaConversionUtil;
import org.apache.flink.table.runtime.collector.ListenableCollector;
import org.apache.flink.table.runtime.collector.TableFunctionResultFuture;
import org.apache.flink.table.runtime.functions.ml.ModelPredictRuntimeProviderContext;
import org.apache.flink.table.runtime.generated.GeneratedCollector;
import org.apache.flink.table.runtime.generated.GeneratedFunction;
import org.apache.flink.table.runtime.generated.GeneratedResultFuture;
import org.apache.flink.table.runtime.operators.join.lookup.AsyncLookupJoinRunner;
import org.apache.flink.table.runtime.operators.join.lookup.LookupJoinRunner;
import org.apache.flink.table.runtime.typeutils.InternalSerializers;
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;

@ExecNodeMetadata(name="stream-exec-ml-predict-table-function", version=1, consumedOptions={"table.exec.async-ml-predict.max-concurrent-operations", "table.exec.async-ml-predict.timeout", "table.exec.async-ml-predict.output-mode"}, producedTransformations={"ml-predict-table-function"}, minPlanVersion=FlinkVersion.v2_1, minStateVersion=FlinkVersion.v2_1)
public class StreamExecMLPredictTableFunction
extends ExecNodeBase<RowData>
implements MultipleTransformationTranslator<RowData>,
StreamExecNode<RowData> {
    public static final String ML_PREDICT_TRANSFORMATION = "ml-predict-table-function";
    public static final String FIELD_NAME_ML_PREDICT_SPEC = "mlPredictSpec";
    public static final String FIELD_NAME_MODEL_SPEC = "modelSpec";
    public static final String FIELD_NAME_ASYNC_OPTIONS = "asyncOptions";
    @JsonProperty(value="mlPredictSpec")
    private final MLPredictSpec mlPredictSpec;
    @JsonProperty(value="modelSpec")
    private final ModelSpec modelSpec;
    @JsonProperty(value="asyncOptions")
    @Nullable
    private final FunctionCallUtils.AsyncOptions asyncOptions;

    public StreamExecMLPredictTableFunction(ReadableConfig persistedConfig, MLPredictSpec mlPredictSpec, ModelSpec modelSpec, @Nullable FunctionCallUtils.AsyncOptions asyncOptions, InputProperty inputProperty, RowType outputType, String description) {
        this(ExecNodeContext.newNodeId(), ExecNodeContext.newContext(StreamExecMLPredictTableFunction.class), persistedConfig, mlPredictSpec, modelSpec, asyncOptions, Collections.singletonList(inputProperty), outputType, description);
    }

    @JsonCreator
    public StreamExecMLPredictTableFunction(@JsonProperty(value="id") int id, @JsonProperty(value="type") ExecNodeContext context, @JsonProperty(value="configuration") ReadableConfig persistedConfig, @JsonProperty(value="mlPredictSpec") MLPredictSpec mlPredictSpec, @JsonProperty(value="modelSpec") ModelSpec modelSpec, @JsonProperty(value="asyncOptions") @Nullable FunctionCallUtils.AsyncOptions asyncOptions, @JsonProperty(value="inputProperties") List<InputProperty> inputProperties, @JsonProperty(value="outputType") RowType outputType, @JsonProperty(value="description") String description) {
        super(id, context, persistedConfig, inputProperties, (LogicalType)outputType, description);
        this.mlPredictSpec = mlPredictSpec;
        this.modelSpec = modelSpec;
        this.asyncOptions = asyncOptions;
    }

    @Override
    protected Transformation<RowData> translateToPlanInternal(PlannerBase planner, ExecNodeConfig config) {
        Transformation<?> inputTransformation = this.getInputEdges().get(0).translateToPlan(planner);
        ModelProvider provider = this.modelSpec.getModelProvider(planner.getFlinkContext());
        boolean async = this.asyncOptions != null;
        UserDefinedFunction predictFunction = this.findModelFunction(provider, async);
        FlinkContext context = planner.getFlinkContext();
        DataTypeFactory dataTypeFactory = context.getCatalogManager().getDataTypeFactory();
        RowType inputType = (RowType)this.getInputEdges().get(0).getOutputType();
        RowType modelOutputType = (RowType)this.modelSpec.getContextResolvedModel().getResolvedModel().getResolvedOutputSchema().toPhysicalRowDataType().getLogicalType();
        return async ? this.createAsyncModelPredict(inputTransformation, config, planner.getFlinkContext().getClassLoader(), dataTypeFactory, inputType, modelOutputType, (RowType)this.getOutputType(), (AsyncPredictFunction)predictFunction) : this.createModelPredict(inputTransformation, config, planner.getFlinkContext().getClassLoader(), dataTypeFactory, inputType, modelOutputType, (RowType)this.getOutputType(), (PredictFunction)predictFunction);
    }

    private Transformation<RowData> createModelPredict(Transformation<RowData> inputTransformation, ExecNodeConfig config, ClassLoader classLoader, DataTypeFactory dataTypeFactory, RowType inputRowType, RowType modelOutputType, RowType resultRowType, PredictFunction predictFunction) {
        GeneratedFunction<FlatMapFunction<RowData, RowData>> generatedFetcher = LookupJoinCodeGenerator.generateSyncLookupFunction(config, classLoader, dataTypeFactory, (LogicalType)inputRowType, (LogicalType)modelOutputType, (LogicalType)resultRowType, this.mlPredictSpec.getFeatures(), predictFunction, "MLPredict", (Boolean)config.get(PipelineOptions.OBJECT_REUSE));
        GeneratedCollector<ListenableCollector<RowData>> generatedCollector = LookupJoinCodeGenerator.generateCollector(new CodeGeneratorContext(config, classLoader), inputRowType, modelOutputType, (RowType)this.getOutputType(), JavaScalaConversionUtil.toScala(Optional.empty()), JavaScalaConversionUtil.toScala(Optional.empty()), true);
        LookupJoinRunner mlPredictRunner = new LookupJoinRunner(generatedFetcher, generatedCollector, (GeneratedFunction)FilterCodeGenerator.generateFilterCondition(config, classLoader, null, (LogicalType)inputRowType), false, modelOutputType.getFieldCount());
        SimpleOperatorFactory operatorFactory = SimpleOperatorFactory.of((StreamOperator)new ProcessOperator((ProcessFunction)mlPredictRunner));
        return ExecNodeUtil.createOneInputTransformation(inputTransformation, this.createTransformationMeta(ML_PREDICT_TRANSFORMATION, config), operatorFactory, InternalTypeInfo.of((LogicalType)this.getOutputType()), inputTransformation.getParallelism(), false);
    }

    private Transformation<RowData> createAsyncModelPredict(Transformation<RowData> inputTransformation, ExecNodeConfig config, ClassLoader classLoader, DataTypeFactory dataTypeFactory, RowType inputRowType, RowType modelOutputType, RowType resultRowType, AsyncPredictFunction asyncPredictFunction) {
        LookupJoinCodeGenerator.GeneratedTableFunctionWithDataType<AsyncFunction<RowData, Object>> generatedFuncWithType = LookupJoinCodeGenerator.generateAsyncLookupFunction(config, classLoader, dataTypeFactory, (LogicalType)inputRowType, (LogicalType)modelOutputType, (LogicalType)resultRowType, this.mlPredictSpec.getFeatures(), asyncPredictFunction, "AsyncMLPredict");
        GeneratedResultFuture<TableFunctionResultFuture<RowData>> generatedResultFuture = LookupJoinCodeGenerator.generateTableAsyncCollector(config, classLoader, "TableFunctionResultFuture", inputRowType, modelOutputType, JavaScalaConversionUtil.toScala(Optional.empty()));
        DataStructureConverter fetcherConverter = DataStructureConverters.getConverter((DataType)generatedFuncWithType.dataType());
        AsyncLookupJoinRunner asyncFunc = new AsyncLookupJoinRunner(generatedFuncWithType.tableFunc(), fetcherConverter, generatedResultFuture, (GeneratedFunction)FilterCodeGenerator.generateFilterCondition(config, classLoader, null, (LogicalType)inputRowType), InternalSerializers.create((RowType)modelOutputType), false, this.asyncOptions.asyncBufferCapacity);
        return ExecNodeUtil.createOneInputTransformation(inputTransformation, this.createTransformationMeta(ML_PREDICT_TRANSFORMATION, config), new AsyncWaitOperatorFactory((AsyncFunction)asyncFunc, this.asyncOptions.asyncTimeout, this.asyncOptions.asyncBufferCapacity, this.asyncOptions.asyncOutputMode), InternalTypeInfo.of((LogicalType)this.getOutputType()), inputTransformation.getParallelism(), false);
    }

    private UserDefinedFunction findModelFunction(ModelProvider provider, boolean async) {
        ModelPredictRuntimeProviderContext context = new ModelPredictRuntimeProviderContext(this.modelSpec.getContextResolvedModel().getResolvedModel(), (ReadableConfig)Configuration.fromMap(this.mlPredictSpec.getRuntimeConfig()));
        if (async) {
            if (provider instanceof AsyncPredictRuntimeProvider) {
                return ((AsyncPredictRuntimeProvider)provider).createAsyncPredictFunction((ModelProvider.Context)context);
            }
        } else if (provider instanceof PredictRuntimeProvider) {
            return ((PredictRuntimeProvider)provider).createPredictFunction((ModelProvider.Context)context);
        }
        throw new TableException("Required " + (async ? "async" : "sync") + " model function by planner, but ModelProvider does not offer a valid model function.");
    }
}

