/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.sql.execution.python;

import java.io.Serializable;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.SerializedLambda;
import org.apache.spark.SparkException$;
import org.apache.spark.api.python.PythonEvalType$;
import org.apache.spark.internal.LogEntry$;
import org.apache.spark.internal.LogKey;
import org.apache.spark.internal.LogKeys;
import org.apache.spark.internal.MDC;
import org.apache.spark.sql.catalyst.expressions.AttributeReference;
import org.apache.spark.sql.catalyst.expressions.AttributeReference$;
import org.apache.spark.sql.catalyst.expressions.ExprId;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.ExpressionSet$;
import org.apache.spark.sql.catalyst.expressions.PythonUDF;
import org.apache.spark.sql.catalyst.expressions.PythonUDF$;
import org.apache.spark.sql.catalyst.plans.QueryPlan;
import org.apache.spark.sql.catalyst.plans.logical.ArrowEvalPython;
import org.apache.spark.sql.catalyst.plans.logical.BatchEvalPython;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.catalyst.plans.logical.Project;
import org.apache.spark.sql.catalyst.plans.logical.Subquery;
import org.apache.spark.sql.catalyst.rules.Rule;
import org.apache.spark.sql.catalyst.trees.TreePattern$;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.MapType;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.UserDefinedType;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.PartialFunction;
import scala.Predef$;
import scala.Some;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.ArrayOps$;
import scala.collection.IterableOnce;
import scala.collection.IterableOnceOps;
import scala.collection.IterableOps;
import scala.collection.SeqFactory;
import scala.collection.SeqOps;
import scala.collection.immutable.;
import scala.collection.immutable.List;
import scala.collection.immutable.Nil$;
import scala.collection.immutable.Seq;
import scala.collection.immutable.Set;
import scala.collection.mutable.HashMap;
import scala.collection.mutable.HashMap$;
import scala.package$;
import scala.runtime.BoxesRunTime;
import scala.runtime.LambdaDeserialize;
import scala.runtime.ObjectRef;
import scala.runtime.ScalaRunTime$;

public final class ExtractPythonUDFs$
extends Rule<LogicalPlan> {
    public static final ExtractPythonUDFs$ MODULE$ = new ExtractPythonUDFs$();

    private boolean hasScalarPythonUDF(Expression e2) {
        return e2.exists((Function1 & Serializable)e -> BoxesRunTime.boxToBoolean((boolean)PythonUDF$.MODULE$.isScalarPythonUDF(e)));
    }

    private boolean canEvaluateInPython(PythonUDF e2) {
        Expression u;
        SeqOps seqOps;
        Seq seq;
        while ((seq = e2.children()) != null && !SeqFactory.UnapplySeqWrapper$.MODULE$.isEmpty$extension(seqOps = package$.MODULE$.Seq().unapplySeq((SeqOps)seq)) && new SeqFactory.UnapplySeqWrapper(SeqFactory.UnapplySeqWrapper$.MODULE$.get$extension(seqOps)) != null && SeqFactory.UnapplySeqWrapper$.MODULE$.lengthCompare$extension(SeqFactory.UnapplySeqWrapper$.MODULE$.get$extension(seqOps), 1) == 0 && (u = (Expression)SeqFactory.UnapplySeqWrapper$.MODULE$.apply$extension(SeqFactory.UnapplySeqWrapper$.MODULE$.get$extension(seqOps), 0)) instanceof PythonUDF) {
            PythonUDF pythonUDF = (PythonUDF)u;
            if (this.correctEvalType(e2) == this.correctEvalType(pythonUDF)) {
                e2 = pythonUDF;
                continue;
            }
            return false;
        }
        return !seq.exists((Function1 & Serializable)e -> BoxesRunTime.boxToBoolean((boolean)ExtractPythonUDFs$.MODULE$.hasScalarPythonUDF(e)));
    }

    private Seq<PythonUDF> collectEvaluableUDFsFromExpressions(Seq<Expression> expressions) {
        ObjectRef firstVisitedScalarUDFEvalType = ObjectRef.create((Object)None$.MODULE$);
        return (Seq)expressions.flatMap((Function1 & Serializable)expr -> this.collectEvaluableUDFs$1((Expression)expr, firstVisitedScalarUDFEvalType));
    }

    public LogicalPlan apply(LogicalPlan plan) {
        Subquery subquery;
        LogicalPlan logicalPlan2 = plan;
        if (logicalPlan2 instanceof Subquery && (subquery = (Subquery)logicalPlan2).correlated()) {
            return plan;
        }
        return plan.transformUpWithPruning((Function1 & Serializable)x$7 -> BoxesRunTime.boxToBoolean((boolean)x$7.containsPattern(TreePattern$.MODULE$.PYTHON_UDF())), plan.transformUpWithPruning$default$2(), (PartialFunction)new Serializable(){
            private static final long serialVersionUID = 0L;

            public final <A1 extends LogicalPlan, B1> B1 applyOrElse(A1 x1, Function1<A1, B1> function1) {
                A1 A1 = x1;
                if (A1 instanceof BatchEvalPython) {
                    BatchEvalPython batchEvalPython = (BatchEvalPython)A1;
                    return (B1)batchEvalPython;
                }
                if (A1 instanceof ArrowEvalPython) {
                    ArrowEvalPython arrowEvalPython = (ArrowEvalPython)A1;
                    return (B1)arrowEvalPython;
                }
                if (A1 != null) {
                    A1 A12 = A1;
                    return (B1)ExtractPythonUDFs$.MODULE$.org$apache$spark$sql$execution$python$ExtractPythonUDFs$$extract(A12);
                }
                return (B1)function1.apply(x1);
            }

            public final boolean isDefinedAt(LogicalPlan x1) {
                LogicalPlan logicalPlan2 = x1;
                if (logicalPlan2 instanceof BatchEvalPython) {
                    return true;
                }
                if (logicalPlan2 instanceof ArrowEvalPython) {
                    return true;
                }
                return logicalPlan2 != null;
            }
        });
    }

    public PythonUDF org$apache$spark$sql$execution$python$ExtractPythonUDFs$$canonicalizeDeterministic(PythonUDF u) {
        if (u.deterministic()) {
            return (PythonUDF)u.canonicalized();
        }
        return u;
    }

    private int correctEvalType(PythonUDF udf) {
        if (udf.evalType() == PythonEvalType$.MODULE$.SQL_ARROW_BATCHED_UDF()) {
            if (this.containsUDT(udf.dataType()) || udf.children().exists((Function1 & Serializable)expr -> BoxesRunTime.boxToBoolean((boolean)ExtractPythonUDFs$.MODULE$.containsUDT(expr.dataType())))) {
                return PythonEvalType$.MODULE$.SQL_BATCHED_UDF();
            }
            return PythonEvalType$.MODULE$.SQL_ARROW_BATCHED_UDF();
        }
        return udf.evalType();
    }

    private boolean containsUDT(DataType dataType) {
        block4: {
            while (true) {
                DataType dataType2;
                if ((dataType2 = dataType) instanceof UserDefinedType) {
                    return true;
                }
                if (dataType2 instanceof ArrayType) {
                    DataType elementType;
                    ArrayType arrayType = (ArrayType)dataType2;
                    dataType = elementType = arrayType.elementType();
                    continue;
                }
                if (dataType2 instanceof StructType) {
                    StructType structType = (StructType)dataType2;
                    StructField[] fields = structType.fields();
                    return ArrayOps$.MODULE$.exists$extension(Predef$.MODULE$.refArrayOps((Object[])fields), (Function1 & Serializable)field -> BoxesRunTime.boxToBoolean((boolean)ExtractPythonUDFs$.MODULE$.containsUDT(field.dataType())));
                }
                if (!(dataType2 instanceof MapType)) break block4;
                MapType mapType = (MapType)dataType2;
                DataType keyType = mapType.keyType();
                DataType valueType = mapType.valueType();
                if (this.containsUDT(keyType)) break;
                dataType = valueType;
            }
            return true;
        }
        return false;
    }

    public LogicalPlan org$apache$spark$sql$execution$python$ExtractPythonUDFs$$extract(LogicalPlan plan) {
        Seq udfs = ExpressionSet$.MODULE$.apply(this.collectEvaluableUDFsFromExpressions((Seq<Expression>)plan.expressions())).filter((Function1 & Serializable)udf -> BoxesRunTime.boxToBoolean((boolean)ExtractPythonUDFs$.$anonfun$extract$4(plan, udf))).toSeq();
        if (udfs.isEmpty()) {
            return plan;
        }
        HashMap attributeMap = (HashMap)HashMap$.MODULE$.apply((Seq)Nil$.MODULE$);
        Seq newChildren = (Seq)plan.children().map((Function1 & Serializable)child -> {
            Seq validUdfs = (Seq)udfs.filter((Function1 & Serializable)udf -> BoxesRunTime.boxToBoolean((boolean)ExtractPythonUDFs$.$anonfun$extract$6(child, udf)));
            if (validUdfs.nonEmpty()) {
                BatchEvalPython batchEvalPython;
                int evalType;
                Predef$.MODULE$.require(validUdfs.forall((Function1 & Serializable)e -> BoxesRunTime.boxToBoolean((boolean)PythonUDF$.MODULE$.isScalarPythonUDF(e))), (Function0 & Serializable)() -> "Can only extract scalar vectorized udf or sql batch udf");
                Seq resultAttrs = (Seq)((IterableOps)validUdfs.zipWithIndex()).map((Function1 & Serializable)x0$1 -> {
                    Tuple2 tuple2 = x0$1;
                    if (tuple2 != null) {
                        PythonUDF u = (PythonUDF)tuple2._1();
                        int i = tuple2._2$mcI$sp();
                        String x$1 = "pythonUDF" + i;
                        DataType x$2 = u.dataType();
                        boolean x$3 = AttributeReference$.MODULE$.apply$default$3();
                        Metadata x$4 = AttributeReference$.MODULE$.apply$default$4();
                        ExprId x$5 = AttributeReference$.MODULE$.apply$default$5(x$1, x$2, x$3, x$4);
                        Seq x$6 = AttributeReference$.MODULE$.apply$default$6(x$1, x$2, x$3, x$4);
                        return new AttributeReference(x$1, x$2, x$3, x$4, x$5, x$6);
                    }
                    throw new MatchError((Object)tuple2);
                });
                Set evalTypes = ((IterableOnceOps)validUdfs.map((Function1 & Serializable)udf -> BoxesRunTime.boxToInteger((int)ExtractPythonUDFs$.MODULE$.correctEvalType(udf)))).toSet();
                if (evalTypes.size() != 1) {
                    throw SparkException$.MODULE$.internalError("Expected udfs have the same evalType but got different evalTypes: " + evalTypes.mkString(","));
                }
                int n = evalType = BoxesRunTime.unboxToInt((Object)evalTypes.head());
                if (PythonEvalType$.MODULE$.SQL_BATCHED_UDF() == n) {
                    if (validUdfs.exists((Function1 & Serializable)x$8 -> BoxesRunTime.boxToBoolean((boolean)ExtractPythonUDFs$.$anonfun$extract$11(x$8)))) {
                        MODULE$.logWarning(LogEntry$.MODULE$.from((Function0 & Serializable)() -> MODULE$.LogStringContext(new StringContext((Seq)ScalaRunTime$.MODULE$.wrapRefArray((Object[])new String[]{"Arrow optimization disabled due to "}))).log((Seq)Nil$.MODULE$).$plus(MODULE$.LogStringContext(new StringContext((Seq)ScalaRunTime$.MODULE$.wrapRefArray((Object[])new String[]{"", ". "}))).log((Seq)ScalaRunTime$.MODULE$.wrapRefArray((Object[])new MDC[]{new MDC((LogKey)LogKeys.REASON$.MODULE$, (Object)"UDT input or return type")}))).$plus(MODULE$.LogStringContext(new StringContext((Seq)ScalaRunTime$.MODULE$.wrapRefArray((Object[])new String[]{"Falling back to non-Arrow-optimized UDF execution."}))).log((Seq)Nil$.MODULE$))));
                    }
                    batchEvalPython = new BatchEvalPython(validUdfs, resultAttrs, child);
                } else if (PythonEvalType$.MODULE$.SQL_SCALAR_PANDAS_UDF() == n ? true : (PythonEvalType$.MODULE$.SQL_SCALAR_PANDAS_ITER_UDF() == n ? true : PythonEvalType$.MODULE$.SQL_ARROW_BATCHED_UDF() == n)) {
                    batchEvalPython = new ArrowEvalPython(validUdfs, resultAttrs, child, evalType);
                } else {
                    throw SparkException$.MODULE$.internalError("Unexpected UDF evalType");
                }
                BatchEvalPython evaluation = batchEvalPython;
                attributeMap.$plus$plus$eq((IterableOnce)((IterableOps)validUdfs.map((Function1 & Serializable)u -> MODULE$.org$apache$spark$sql$execution$python$ExtractPythonUDFs$$canonicalizeDeterministic((PythonUDF)u))).zip((IterableOnce)resultAttrs));
                return (LogicalPlan)evaluation;
            }
            return child;
        });
        ((IterableOnceOps)((IterableOps)udfs.map((Function1 & Serializable)u -> MODULE$.org$apache$spark$sql$execution$python$ExtractPythonUDFs$$canonicalizeDeterministic((PythonUDF)u))).filterNot((Function1 & Serializable)key -> BoxesRunTime.boxToBoolean((boolean)attributeMap.contains((Object)key)))).foreach((Function1 & Serializable)udf -> {
            throw SparkException$.MODULE$.internalError("Invalid PythonUDF " + udf + ", requires attributes from more than one child.");
        });
        LogicalPlan rewritten = (LogicalPlan)((QueryPlan)plan.withNewChildren(newChildren)).transformExpressions((PartialFunction)new Serializable(attributeMap){
            private static final long serialVersionUID = 0L;
            private final HashMap attributeMap$2;

            public final <A1 extends Expression, B1> B1 applyOrElse(A1 x1, Function1<A1, B1> function1) {
                A1 A1 = x1;
                if (A1 instanceof PythonUDF) {
                    PythonUDF pythonUDF = (PythonUDF)A1;
                    return (B1)this.attributeMap$2.getOrElse((Object)ExtractPythonUDFs$.MODULE$.org$apache$spark$sql$execution$python$ExtractPythonUDFs$$canonicalizeDeterministic(pythonUDF), (Function0 & Serializable)() -> pythonUDF);
                }
                return (B1)function1.apply(x1);
            }

            public final boolean isDefinedAt(Expression x1) {
                Expression expression = x1;
                return expression instanceof PythonUDF;
            }
            {
                this.attributeMap$2 = attributeMap$2;
            }

            private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
                return LambdaDeserialize.bootstrap("lambdaDeserialize", new MethodHandle[]{$anonfun$applyOrElse$5(org.apache.spark.sql.catalyst.expressions.PythonUDF )}, serializedLambda);
            }
        });
        LogicalPlan newPlan = this.org$apache$spark$sql$execution$python$ExtractPythonUDFs$$extract(rewritten);
        Seq seq = newPlan.output();
        Seq seq2 = plan.output();
        if (seq == null ? seq2 != null : !seq.equals(seq2)) {
            return new Project(plan.output(), newPlan);
        }
        return newPlan;
    }

    private static final boolean canChainUDF$1(int evalType, ObjectRef firstVisitedScalarUDFEvalType$1) {
        if (evalType == PythonEvalType$.MODULE$.SQL_SCALAR_PANDAS_ITER_UDF()) {
            return false;
        }
        return evalType == BoxesRunTime.unboxToInt((Object)((Option)firstVisitedScalarUDFEvalType$1.elem).get());
    }

    private final Seq collectEvaluableUDFs$1(Expression expr2, ObjectRef firstVisitedScalarUDFEvalType$1) {
        boolean bl = false;
        PythonUDF pythonUDF = null;
        Expression expression = expr2;
        if (expression instanceof PythonUDF) {
            bl = true;
            pythonUDF = (PythonUDF)expression;
            if (PythonUDF$.MODULE$.isScalarPythonUDF((Expression)pythonUDF) && this.canEvaluateInPython(pythonUDF) && ((Option)firstVisitedScalarUDFEvalType$1.elem).isEmpty()) {
                firstVisitedScalarUDFEvalType$1.elem = new Some((Object)BoxesRunTime.boxToInteger((int)this.correctEvalType(pythonUDF)));
                return new .colon.colon((Object)pythonUDF, (List)Nil$.MODULE$);
            }
        }
        if (bl && PythonUDF$.MODULE$.isScalarPythonUDF((Expression)pythonUDF) && this.canEvaluateInPython(pythonUDF) && ExtractPythonUDFs$.canChainUDF$1(this.correctEvalType(pythonUDF), firstVisitedScalarUDFEvalType$1)) {
            return new .colon.colon((Object)pythonUDF, (List)Nil$.MODULE$);
        }
        return (Seq)expression.children().flatMap((Function1 & Serializable)expr -> this.collectEvaluableUDFs$1((Expression)expr, firstVisitedScalarUDFEvalType$1));
    }

    public static final /* synthetic */ boolean $anonfun$extract$4(LogicalPlan plan$1, Expression udf) {
        return udf.references().subsetOf(plan$1.inputSet());
    }

    public static final /* synthetic */ boolean $anonfun$extract$6(LogicalPlan child$1, PythonUDF udf) {
        return udf.references().subsetOf(child$1.outputSet());
    }

    public static final /* synthetic */ boolean $anonfun$extract$11(PythonUDF x$8) {
        return x$8.evalType() != PythonEvalType$.MODULE$.SQL_BATCHED_UDF();
    }

    private ExtractPythonUDFs$() {
    }
}

