Skip to content

Commit 0346b18

Browse files
committed
[SPARK-49004][CONNECT] Use separate registry for Column API internal functions
### What changes were proposed in this pull request? This PR introduces a separate FunctionRegistry for functions used by the Column API that should not be exposed in the global function namespace. This internal registry is only used when then the `UnresolvedFunction` has the `isInternal` flag set to `true`. ### Why are the changes needed? We want to create a Column API shared by the Classic and Connect Scala Clients. This requires that we fully decouple the Column API from Catalyst. A part of this work is decoupling function resolution. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #47572 from hvanhovell/SPARK-49004. Authored-by: Herman van Hovell <[email protected]> Signed-off-by: Herman van Hovell <[email protected]>
1 parent 6246b18 commit 0346b18

File tree

8 files changed

+109
-89
lines changed

8 files changed

+109
-89
lines changed

connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,13 @@ object CheckConnectJvmClientCompatibility {
295295
ProblemFilters.exclude[MissingClassProblem](
296296
"org.apache.spark.sql.artifact.util.ArtifactUtils"),
297297
ProblemFilters.exclude[MissingClassProblem](
298-
"org.apache.spark.sql.artifact.util.ArtifactUtils$")) ++
298+
"org.apache.spark.sql.artifact.util.ArtifactUtils$"),
299+
300+
// Datasource V2 partition transforms
301+
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.PartitionTransform"),
302+
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.PartitionTransform$"),
303+
ProblemFilters.exclude[MissingClassProblem](
304+
"org.apache.spark.sql.PartitionTransform$ExtractTransform")) ++
299305
mergeIntoWriterExcludeRules
300306

301307
checkMiMaCompatibility(clientJar, sqlJar, includedRules, excludeRules)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2095,8 +2095,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
20952095
val externalFunctionNameSet = new mutable.HashSet[Seq[String]]()
20962096

20972097
plan.resolveExpressionsWithPruning(_.containsAnyPattern(UNRESOLVED_FUNCTION)) {
2098-
case f @ UnresolvedFunction(nameParts, _, _, _, _, _) =>
2099-
if (ResolveFunctions.lookupBuiltinOrTempFunction(nameParts).isDefined) {
2098+
case f @ UnresolvedFunction(nameParts, _, _, _, _, _, _) =>
2099+
if (ResolveFunctions.lookupBuiltinOrTempFunction(nameParts, Some(f)).isDefined) {
21002100
f
21012101
} else {
21022102
val CatalogAndIdentifier(catalog, ident) = expandIdentifier(nameParts)
@@ -2141,7 +2141,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
21412141
UNRESOLVED_TABLE_VALUED_FUNCTION, UNRESOLVED_TVF_ALIASES), ruleId) {
21422142
// Resolve functions with concrete relations from v2 catalog.
21432143
case u @ UnresolvedFunctionName(nameParts, cmd, requirePersistentFunc, mismatchHint, _) =>
2144-
lookupBuiltinOrTempFunction(nameParts)
2144+
lookupBuiltinOrTempFunction(nameParts, None)
21452145
.orElse(lookupBuiltinOrTempTableFunction(nameParts)).map { info =>
21462146
if (requirePersistentFunc) {
21472147
throw QueryCompilationErrors.expectPersistentFuncError(
@@ -2263,9 +2263,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
22632263
q.transformExpressionsUpWithPruning(
22642264
_.containsAnyPattern(UNRESOLVED_FUNCTION, GENERATOR),
22652265
ruleId) {
2266-
case u @ UnresolvedFunction(nameParts, arguments, _, _, _, _)
2266+
case u @ UnresolvedFunction(nameParts, arguments, _, _, _, _, _)
22672267
if hasLambdaAndResolvedArguments(arguments) => withPosition(u) {
2268-
resolveBuiltinOrTempFunction(nameParts, arguments, Some(u)).map {
2268+
resolveBuiltinOrTempFunction(nameParts, arguments, u).map {
22692269
case func: HigherOrderFunction => func
22702270
case other => other.failAnalysis(
22712271
errorClass = "INVALID_LAMBDA_FUNCTION_CALL.NON_HIGHER_ORDER_FUNCTION",
@@ -2292,8 +2292,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
22922292
}
22932293
}
22942294

2295-
case u @ UnresolvedFunction(nameParts, arguments, _, _, _, _) => withPosition(u) {
2296-
resolveBuiltinOrTempFunction(nameParts, arguments, Some(u)).getOrElse {
2295+
case u @ UnresolvedFunction(nameParts, arguments, _, _, _, _, _) => withPosition(u) {
2296+
resolveBuiltinOrTempFunction(nameParts, arguments, u).getOrElse {
22972297
val CatalogAndIdentifier(catalog, ident) = expandIdentifier(nameParts)
22982298
if (CatalogV2Util.isSessionCatalog(catalog)) {
22992299
resolveV1Function(ident.asFunctionIdentifier, arguments, u)
@@ -2333,8 +2333,12 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
23332333
lambdas.nonEmpty && others.forall(_.resolved)
23342334
}
23352335

2336-
def lookupBuiltinOrTempFunction(name: Seq[String]): Option[ExpressionInfo] = {
2337-
if (name.length == 1) {
2336+
def lookupBuiltinOrTempFunction(
2337+
name: Seq[String],
2338+
u: Option[UnresolvedFunction]): Option[ExpressionInfo] = {
2339+
if (name.size == 1 && u.exists(_.isInternal)) {
2340+
FunctionRegistry.internal.lookupFunction(FunctionIdentifier(name.head))
2341+
} else if (name.size == 1) {
23382342
v1SessionCatalog.lookupBuiltinOrTempFunction(name.head)
23392343
} else {
23402344
None
@@ -2352,14 +2356,17 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
23522356
private def resolveBuiltinOrTempFunction(
23532357
name: Seq[String],
23542358
arguments: Seq[Expression],
2355-
u: Option[UnresolvedFunction]): Option[Expression] = {
2356-
if (name.length == 1) {
2357-
v1SessionCatalog.resolveBuiltinOrTempFunction(name.head, arguments).map { func =>
2358-
if (u.isDefined) validateFunction(func, arguments.length, u.get) else func
2359-
}
2359+
u: UnresolvedFunction): Option[Expression] = {
2360+
val expression = if (name.size == 1 && u.isInternal) {
2361+
Option(FunctionRegistry.internal.lookupFunction(FunctionIdentifier(name.head), arguments))
2362+
} else if (name.size == 1) {
2363+
v1SessionCatalog.resolveBuiltinOrTempFunction(name.head, arguments)
23602364
} else {
23612365
None
23622366
}
2367+
expression.map { func =>
2368+
validateFunction(func, arguments.length, u)
2369+
}
23632370
}
23642371

23652372
private def resolveBuiltinOrTempTableFunction(

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,26 @@ object FunctionRegistry {
883883

884884
val functionSet: Set[FunctionIdentifier] = builtin.listFunction().toSet
885885

886+
/** Registry for internal functions used by Connect and the Column API. */
887+
private[sql] val internal: SimpleFunctionRegistry = new SimpleFunctionRegistry
888+
889+
private def registerInternalExpression[T <: Expression : ClassTag](name: String): Unit = {
890+
val (info, builder) = FunctionRegistryBase.build(name, None)
891+
internal.internalRegisterFunction(FunctionIdentifier(name), info, builder)
892+
}
893+
894+
registerInternalExpression[Product]("product")
895+
registerInternalExpression[BloomFilterAggregate]("bloom_filter_agg")
896+
registerInternalExpression[Years]("years")
897+
registerInternalExpression[Months]("months")
898+
registerInternalExpression[Days]("days")
899+
registerInternalExpression[Hours]("hours")
900+
registerInternalExpression[UnwrapUDT]("unwrap_udt")
901+
registerInternalExpression[DistributedSequenceID]("distributed_sequence_id")
902+
registerInternalExpression[PandasSkewness]("pandas_skew")
903+
registerInternalExpression[PandasKurtosis]("pandas_kurt")
904+
registerInternalExpression[NullIndex]("null_index")
905+
886906
private def makeExprInfoForVirtualOperator(name: String, usage: String): ExpressionInfo = {
887907
new ExpressionInfo(
888908
null,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,8 @@ case class UnresolvedFunction(
342342
isDistinct: Boolean,
343343
filter: Option[Expression] = None,
344344
ignoreNulls: Boolean = false,
345-
orderingWithinGroup: Seq[SortOrder] = Seq.empty)
345+
orderingWithinGroup: Seq[SortOrder] = Seq.empty,
346+
isInternal: Boolean = false)
346347
extends Expression with Unevaluable {
347348
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
348349

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala

Lines changed: 16 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@ import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, Task
4747
import org.apache.spark.sql.{withOrigin, Column, Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, SparkSession}
4848
import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro}
4949
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier, QueryPlanningTracker}
50-
import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
50+
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
5151
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder, RowEncoder}
5252
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder
5353
import org.apache.spark.sql.catalyst.expressions._
54-
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate}
54+
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
5555
import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils}
5656
import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
5757
import org.apache.spark.sql.catalyst.plans.logical
@@ -1614,14 +1614,23 @@ class SparkConnectPlanner(
16141614
fun: proto.Expression.UnresolvedFunction): Expression = {
16151615
if (fun.getIsUserDefinedFunction) {
16161616
UnresolvedFunction(
1617-
parser.parseFunctionIdentifier(fun.getFunctionName),
1617+
parser.parseMultipartIdentifier(fun.getFunctionName),
16181618
fun.getArgumentsList.asScala.map(transformExpression).toSeq,
16191619
isDistinct = fun.getIsDistinct)
16201620
} else {
1621+
// Spark Connect historically used the global namespace to lookup a couple of internal
1622+
// functions (e.g. product, collect_top_k, unwrap_udt, ...). In Spark 4 we moved these
1623+
// functions to a dedicated namespace, however in order to stay backwards compatible we still
1624+
// need to allow connect to use the global namespace. Here we check if a function is
1625+
// registered in the internal function registry, and we reroute the lookup to the internal
1626+
// registry.
1627+
val name = fun.getFunctionName
1628+
val internal = FunctionRegistry.internal.functionExists(FunctionIdentifier(name))
16211629
UnresolvedFunction(
1622-
FunctionIdentifier(fun.getFunctionName),
1630+
name :: Nil,
16231631
fun.getArgumentsList.asScala.map(transformExpression).toSeq,
1624-
isDistinct = fun.getIsDistinct)
1632+
isDistinct = fun.getIsDistinct,
1633+
isInternal = internal)
16251634
}
16261635
}
16271636

@@ -1832,18 +1841,6 @@ class SparkConnectPlanner(
18321841
private def transformUnregisteredFunction(
18331842
fun: proto.Expression.UnresolvedFunction): Option[Expression] = {
18341843
fun.getFunctionName match {
1835-
case "product" if fun.getArgumentsCount == 1 =>
1836-
Some(
1837-
aggregate
1838-
.Product(transformExpression(fun.getArgumentsList.asScala.head))
1839-
.toAggregateExpression())
1840-
1841-
case "bloom_filter_agg" if fun.getArgumentsCount == 3 =>
1842-
// [col, expectedNumItems: Long, numBits: Long]
1843-
val children = fun.getArgumentsList.asScala.map(transformExpression)
1844-
Some(
1845-
new BloomFilterAggregate(children(0), children(1), children(2))
1846-
.toAggregateExpression())
18471844

18481845
case "timestampdiff" if fun.getArgumentsCount == 3 =>
18491846
val children = fun.getArgumentsList.asScala.map(transformExpression)
@@ -1864,21 +1861,6 @@ class SparkConnectPlanner(
18641861
throw InvalidPlanInput(s"numBuckets should be a literal integer, but got $other")
18651862
}
18661863

1867-
case "years" if fun.getArgumentsCount == 1 =>
1868-
Some(Years(transformExpression(fun.getArguments(0))))
1869-
1870-
case "months" if fun.getArgumentsCount == 1 =>
1871-
Some(Months(transformExpression(fun.getArguments(0))))
1872-
1873-
case "days" if fun.getArgumentsCount == 1 =>
1874-
Some(Days(transformExpression(fun.getArguments(0))))
1875-
1876-
case "hours" if fun.getArgumentsCount == 1 =>
1877-
Some(Hours(transformExpression(fun.getArguments(0))))
1878-
1879-
case "unwrap_udt" if fun.getArgumentsCount == 1 =>
1880-
Some(UnwrapUDT(transformExpression(fun.getArguments(0))))
1881-
18821864
// Avro-specific functions
18831865
case "from_avro" if Seq(2, 3).contains(fun.getArgumentsCount) =>
18841866
val children = fun.getArgumentsList.asScala.map(transformExpression)
@@ -1898,9 +1880,6 @@ class SparkConnectPlanner(
18981880
Some(CatalystDataToAvro(children.head, jsonFormatSchema))
18991881

19001882
// PS(Pandas API on Spark)-specific functions
1901-
case "distributed_sequence_id" if fun.getArgumentsCount == 0 =>
1902-
Some(DistributedSequenceID())
1903-
19041883
case "pandas_product" if fun.getArgumentsCount == 2 =>
19051884
val children = fun.getArgumentsList.asScala.map(transformExpression)
19061885
val dropna = extractBoolean(children(1), "dropna")
@@ -1911,14 +1890,6 @@ class SparkConnectPlanner(
19111890
val ddof = extractInteger(children(1), "ddof")
19121891
Some(aggregate.PandasStddev(children(0), ddof).toAggregateExpression(false))
19131892

1914-
case "pandas_skew" if fun.getArgumentsCount == 1 =>
1915-
val children = fun.getArgumentsList.asScala.map(transformExpression)
1916-
Some(aggregate.PandasSkewness(children(0)).toAggregateExpression(false))
1917-
1918-
case "pandas_kurt" if fun.getArgumentsCount == 1 =>
1919-
val children = fun.getArgumentsList.asScala.map(transformExpression)
1920-
Some(aggregate.PandasKurtosis(children(0)).toAggregateExpression(false))
1921-
19221893
case "pandas_var" if fun.getArgumentsCount == 2 =>
19231894
val children = fun.getArgumentsList.asScala.map(transformExpression)
19241895
val ddof = extractInteger(children(1), "ddof")
@@ -1938,11 +1909,7 @@ class SparkConnectPlanner(
19381909
val children = fun.getArgumentsList.asScala.map(transformExpression)
19391910
val alpha = extractDouble(children(1), "alpha")
19401911
val ignoreNA = extractBoolean(children(2), "ignoreNA")
1941-
Some(EWM(children(0), alpha, ignoreNA))
1942-
1943-
case "null_index" if fun.getArgumentsCount == 1 =>
1944-
val children = fun.getArgumentsList.asScala.map(transformExpression)
1945-
Some(NullIndex(children(0)))
1912+
Some(new EWM(children(0), alpha, ignoreNA))
19461913

19471914
// ML-specific functions
19481915
case "vector_to_array" if fun.getArgumentsCount == 2 =>
@@ -2044,7 +2011,7 @@ class SparkConnectPlanner(
20442011
@scala.annotation.tailrec
20452012
private def extractMapData(expr: Expression, field: String): Map[String, String] = expr match {
20462013
case map: CreateMap => ExprUtils.convertToMapData(map)
2047-
case UnresolvedFunction(Seq("map"), args, _, _, _, _) =>
2014+
case UnresolvedFunction(Seq("map"), args, _, _, _, _, _) =>
20482015
extractMapData(CreateMap(args), field)
20492016
case other => throw InvalidPlanInput(s"$field should be created by map, but got $other")
20502017
}

sql/core/src/main/scala/org/apache/spark/sql/Column.scala

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,21 +61,27 @@ private[sql] object Column {
6161
}
6262

6363
private[sql] def fn(name: String, inputs: Column*): Column = {
64-
fn(name, isDistinct = false, ignoreNulls = false, inputs: _*)
64+
fn(name, isDistinct = false, inputs: _*)
6565
}
6666

6767
private[sql] def fn(name: String, isDistinct: Boolean, inputs: Column*): Column = {
68-
fn(name, isDistinct = isDistinct, ignoreNulls = false, inputs: _*)
68+
fn(name, isDistinct = isDistinct, isInternal = false, inputs)
6969
}
7070

71-
private[sql] def fn(
71+
private[sql] def internalFn(name: String, inputs: Column*): Column = {
72+
fn(name, isDistinct = false, isInternal = true, inputs)
73+
}
74+
75+
private def fn(
7276
name: String,
7377
isDistinct: Boolean,
74-
ignoreNulls: Boolean,
75-
inputs: Column*): Column = withOrigin {
76-
Column {
77-
UnresolvedFunction(Seq(name), inputs.map(_.expr), isDistinct, ignoreNulls = ignoreNulls)
78-
}
78+
isInternal: Boolean,
79+
inputs: Seq[Column]): Column = withOrigin {
80+
Column(UnresolvedFunction(
81+
name :: Nil,
82+
inputs.map(_.expr),
83+
isDistinct = isDistinct,
84+
isInternal = isInternal))
7985
}
8086
}
8187

sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ import scala.collection.mutable
2121
import scala.jdk.CollectionConverters._
2222

2323
import org.apache.spark.annotation.Experimental
24-
import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException, UnresolvedIdentifier, UnresolvedRelation}
25-
import org.apache.spark.sql.catalyst.expressions.{Attribute, Bucket, Days, Hours, Literal, Months, Years}
24+
import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException, UnresolvedFunction, UnresolvedIdentifier, UnresolvedRelation}
25+
import org.apache.spark.sql.catalyst.expressions.{Attribute, Bucket, Expression, Literal}
2626
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, LogicalPlan, OptionList, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, UnresolvedTableSpec}
2727
import org.apache.spark.sql.connector.expressions.{ClusterByTransform, FieldReference, LogicalExpressions, NamedReference, Transform}
2828
import org.apache.spark.sql.errors.QueryCompilationErrors
@@ -89,13 +89,13 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
8989
def ref(name: String): NamedReference = LogicalExpressions.parseReference(name)
9090

9191
val asTransforms = (column +: columns).map(_.expr).map {
92-
case Years(attr: Attribute) =>
92+
case PartitionTransform.YEARS(Seq(attr: Attribute)) =>
9393
LogicalExpressions.years(ref(attr.name))
94-
case Months(attr: Attribute) =>
94+
case PartitionTransform.MONTHS(Seq(attr: Attribute)) =>
9595
LogicalExpressions.months(ref(attr.name))
96-
case Days(attr: Attribute) =>
96+
case PartitionTransform.DAYS(Seq(attr: Attribute)) =>
9797
LogicalExpressions.days(ref(attr.name))
98-
case Hours(attr: Attribute) =>
98+
case PartitionTransform.HOURS(Seq(attr: Attribute)) =>
9999
LogicalExpressions.hours(ref(attr.name))
100100
case Bucket(Literal(numBuckets: Int, IntegerType), attr: Attribute) =>
101101
LogicalExpressions.bucket(numBuckets, Array(ref(attr.name)))
@@ -235,6 +235,22 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T])
235235
}
236236
}
237237

238+
private object PartitionTransform {
239+
class ExtractTransform(name: String) {
240+
private val NAMES = Seq(name)
241+
242+
def unapply(e: Expression): Option[Seq[Expression]] = e match {
243+
case UnresolvedFunction(NAMES, children, false, None, false, Nil, true) => Option(children)
244+
case _ => None
245+
}
246+
}
247+
248+
val HOURS = new ExtractTransform("hours")
249+
val DAYS = new ExtractTransform("days")
250+
val MONTHS = new ExtractTransform("months")
251+
val YEARS = new ExtractTransform("years")
252+
}
253+
238254
/**
239255
* Configuration methods common to create/replace operations and insert/overwrite operations.
240256
* @tparam R builder type to return

0 commit comments

Comments
 (0)