@@ -47,11 +47,11 @@ import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, Task
4747import org .apache .spark .sql .{withOrigin , Column , Dataset , Encoders , ForeachWriter , Observation , RelationalGroupedDataset , SparkSession }
4848import org .apache .spark .sql .avro .{AvroDataToCatalyst , CatalystDataToAvro }
4949import org .apache .spark .sql .catalyst .{expressions , AliasIdentifier , FunctionIdentifier , QueryPlanningTracker }
50- import org .apache .spark .sql .catalyst .analysis .{GlobalTempView , LocalTempView , MultiAlias , NameParameterizedQuery , PosParameterizedQuery , UnresolvedAlias , UnresolvedAttribute , UnresolvedDataFrameStar , UnresolvedDeserializer , UnresolvedExtractValue , UnresolvedFunction , UnresolvedRegex , UnresolvedRelation , UnresolvedStar }
50+ import org .apache .spark .sql .catalyst .analysis .{FunctionRegistry , GlobalTempView , LocalTempView , MultiAlias , NameParameterizedQuery , PosParameterizedQuery , UnresolvedAlias , UnresolvedAttribute , UnresolvedDataFrameStar , UnresolvedDeserializer , UnresolvedExtractValue , UnresolvedFunction , UnresolvedRegex , UnresolvedRelation , UnresolvedStar }
5151import org .apache .spark .sql .catalyst .encoders .{AgnosticEncoder , ExpressionEncoder , RowEncoder }
5252import org .apache .spark .sql .catalyst .encoders .AgnosticEncoders .UnboundRowEncoder
5353import org .apache .spark .sql .catalyst .expressions ._
54- import org .apache .spark .sql .catalyst .expressions .aggregate .{ AggregateExpression , BloomFilterAggregate }
54+ import org .apache .spark .sql .catalyst .expressions .aggregate .AggregateExpression
5555import org .apache .spark .sql .catalyst .parser .{ParseException , ParserUtils }
5656import org .apache .spark .sql .catalyst .plans .{Cross , FullOuter , Inner , JoinType , LeftAnti , LeftOuter , LeftSemi , RightOuter , UsingJoin }
5757import org .apache .spark .sql .catalyst .plans .logical
@@ -1614,14 +1614,23 @@ class SparkConnectPlanner(
16141614 fun : proto.Expression .UnresolvedFunction ): Expression = {
16151615 if (fun.getIsUserDefinedFunction) {
16161616 UnresolvedFunction (
1617- parser.parseFunctionIdentifier (fun.getFunctionName),
1617+ parser.parseMultipartIdentifier (fun.getFunctionName),
16181618 fun.getArgumentsList.asScala.map(transformExpression).toSeq,
16191619 isDistinct = fun.getIsDistinct)
16201620 } else {
1621+ // Spark Connect historically used the global namespace to lookup a couple of internal
1622+ // functions (e.g. product, collect_top_k, unwrap_udt, ...). In Spark 4 we moved these
1623+ // functions to a dedicated namespace, however in order to stay backwards compatible we still
1624+ // need to allow connect to use the global namespace. Here we check if a function is
1625+ // registered in the internal function registry, and we reroute the lookup to the internal
1626+ // registry.
1627+ val name = fun.getFunctionName
1628+ val internal = FunctionRegistry .internal.functionExists(FunctionIdentifier (name))
16211629 UnresolvedFunction (
1622- FunctionIdentifier (fun.getFunctionName) ,
1630+ name :: Nil ,
16231631 fun.getArgumentsList.asScala.map(transformExpression).toSeq,
1624- isDistinct = fun.getIsDistinct)
1632+ isDistinct = fun.getIsDistinct,
1633+ isInternal = internal)
16251634 }
16261635 }
16271636
@@ -1832,18 +1841,6 @@ class SparkConnectPlanner(
18321841 private def transformUnregisteredFunction (
18331842 fun : proto.Expression .UnresolvedFunction ): Option [Expression ] = {
18341843 fun.getFunctionName match {
1835- case " product" if fun.getArgumentsCount == 1 =>
1836- Some (
1837- aggregate
1838- .Product (transformExpression(fun.getArgumentsList.asScala.head))
1839- .toAggregateExpression())
1840-
1841- case " bloom_filter_agg" if fun.getArgumentsCount == 3 =>
1842- // [col, expectedNumItems: Long, numBits: Long]
1843- val children = fun.getArgumentsList.asScala.map(transformExpression)
1844- Some (
1845- new BloomFilterAggregate (children(0 ), children(1 ), children(2 ))
1846- .toAggregateExpression())
18471844
18481845 case " timestampdiff" if fun.getArgumentsCount == 3 =>
18491846 val children = fun.getArgumentsList.asScala.map(transformExpression)
@@ -1864,21 +1861,6 @@ class SparkConnectPlanner(
18641861 throw InvalidPlanInput (s " numBuckets should be a literal integer, but got $other" )
18651862 }
18661863
1867- case " years" if fun.getArgumentsCount == 1 =>
1868- Some (Years (transformExpression(fun.getArguments(0 ))))
1869-
1870- case " months" if fun.getArgumentsCount == 1 =>
1871- Some (Months (transformExpression(fun.getArguments(0 ))))
1872-
1873- case " days" if fun.getArgumentsCount == 1 =>
1874- Some (Days (transformExpression(fun.getArguments(0 ))))
1875-
1876- case " hours" if fun.getArgumentsCount == 1 =>
1877- Some (Hours (transformExpression(fun.getArguments(0 ))))
1878-
1879- case " unwrap_udt" if fun.getArgumentsCount == 1 =>
1880- Some (UnwrapUDT (transformExpression(fun.getArguments(0 ))))
1881-
18821864 // Avro-specific functions
18831865 case " from_avro" if Seq (2 , 3 ).contains(fun.getArgumentsCount) =>
18841866 val children = fun.getArgumentsList.asScala.map(transformExpression)
@@ -1898,9 +1880,6 @@ class SparkConnectPlanner(
18981880 Some (CatalystDataToAvro (children.head, jsonFormatSchema))
18991881
19001882 // PS(Pandas API on Spark)-specific functions
1901- case " distributed_sequence_id" if fun.getArgumentsCount == 0 =>
1902- Some (DistributedSequenceID ())
1903-
19041883 case " pandas_product" if fun.getArgumentsCount == 2 =>
19051884 val children = fun.getArgumentsList.asScala.map(transformExpression)
19061885 val dropna = extractBoolean(children(1 ), " dropna" )
@@ -1911,14 +1890,6 @@ class SparkConnectPlanner(
19111890 val ddof = extractInteger(children(1 ), " ddof" )
19121891 Some (aggregate.PandasStddev (children(0 ), ddof).toAggregateExpression(false ))
19131892
1914- case " pandas_skew" if fun.getArgumentsCount == 1 =>
1915- val children = fun.getArgumentsList.asScala.map(transformExpression)
1916- Some (aggregate.PandasSkewness (children(0 )).toAggregateExpression(false ))
1917-
1918- case " pandas_kurt" if fun.getArgumentsCount == 1 =>
1919- val children = fun.getArgumentsList.asScala.map(transformExpression)
1920- Some (aggregate.PandasKurtosis (children(0 )).toAggregateExpression(false ))
1921-
19221893 case " pandas_var" if fun.getArgumentsCount == 2 =>
19231894 val children = fun.getArgumentsList.asScala.map(transformExpression)
19241895 val ddof = extractInteger(children(1 ), " ddof" )
@@ -1938,11 +1909,7 @@ class SparkConnectPlanner(
19381909 val children = fun.getArgumentsList.asScala.map(transformExpression)
19391910 val alpha = extractDouble(children(1 ), " alpha" )
19401911 val ignoreNA = extractBoolean(children(2 ), " ignoreNA" )
1941- Some (EWM (children(0 ), alpha, ignoreNA))
1942-
1943- case " null_index" if fun.getArgumentsCount == 1 =>
1944- val children = fun.getArgumentsList.asScala.map(transformExpression)
1945- Some (NullIndex (children(0 )))
1912+ Some (new EWM (children(0 ), alpha, ignoreNA))
19461913
19471914 // ML-specific functions
19481915 case " vector_to_array" if fun.getArgumentsCount == 2 =>
@@ -2044,7 +2011,7 @@ class SparkConnectPlanner(
20442011 @ scala.annotation.tailrec
20452012 private def extractMapData (expr : Expression , field : String ): Map [String , String ] = expr match {
20462013 case map : CreateMap => ExprUtils .convertToMapData(map)
2047- case UnresolvedFunction (Seq (" map" ), args, _, _, _, _) =>
2014+ case UnresolvedFunction (Seq (" map" ), args, _, _, _, _, _ ) =>
20482015 extractMapData(CreateMap (args), field)
20492016 case other => throw InvalidPlanInput (s " $field should be created by map, but got $other" )
20502017 }
0 commit comments