Skip to content

Commit 3305939

Browse files
hvanhovellHyukjinKwon
authored andcommitted
[SPARK-49026][CONNECT] Add ColumnNode to Proto conversion
### What changes were proposed in this pull request? This PR adds a converter that converts ColumnNodes into Connect proto.Expression. ### Why are the changes needed? TBD ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added a test suite ### Was this patch authored or co-authored using generative AI tooling? No Closes #47812 from hvanhovell/SPARK-49026. Authored-by: Herman van Hovell <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent bc7bfbc commit 3305939

File tree

7 files changed

+604
-11
lines changed

7 files changed

+604
-11
lines changed
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.connect
18+
19+
import scala.jdk.CollectionConverters._
20+
21+
import org.apache.spark.SparkException
22+
import org.apache.spark.connect.proto
23+
import org.apache.spark.connect.proto.Expression.SortOrder.NullOrdering.{SORT_NULLS_FIRST, SORT_NULLS_LAST}
24+
import org.apache.spark.connect.proto.Expression.SortOrder.SortDirection.{SORT_DIRECTION_ASCENDING, SORT_DIRECTION_DESCENDING}
25+
import org.apache.spark.connect.proto.Expression.Window.WindowFrame.{FrameBoundary, FrameType}
26+
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
27+
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
28+
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProtoBuilder
29+
import org.apache.spark.sql.expressions.ScalaUserDefinedFunction
30+
import org.apache.spark.sql.internal._
31+
32+
/**
33+
* Converter for [[ColumnNode]] to [[proto.Expression]] conversions.
34+
*/
35+
object ColumnNodeToProtoConverter extends (ColumnNode => proto.Expression) {
36+
override def apply(node: ColumnNode): proto.Expression = {
37+
val builder = proto.Expression.newBuilder()
38+
// TODO(SPARK-49273) support Origin in Connect Scala Client.
39+
node match {
40+
case Literal(value, None, _) =>
41+
builder.setLiteral(toLiteralProtoBuilder(value))
42+
43+
case Literal(value, Some(dataType), _) =>
44+
builder.setLiteral(toLiteralProtoBuilder(value, dataType))
45+
46+
case UnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn, _) =>
47+
val b = builder.getUnresolvedAttributeBuilder
48+
.setUnparsedIdentifier(unparsedIdentifier)
49+
.setIsMetadataColumn(isMetadataColumn)
50+
planId.foreach(b.setPlanId)
51+
52+
case UnresolvedStar(unparsedTarget, planId, _) =>
53+
val b = builder.getUnresolvedStarBuilder
54+
unparsedTarget.foreach(b.setUnparsedTarget)
55+
planId.foreach(b.setPlanId)
56+
57+
case UnresolvedRegex(regex, planId, _) =>
58+
val b = builder.getUnresolvedRegexBuilder
59+
.setColName(regex)
60+
planId.foreach(b.setPlanId)
61+
62+
case UnresolvedFunction(functionName, arguments, isDistinct, isUserDefinedFunction, _, _) =>
63+
// TODO(SPARK-49087) use internal namespace.
64+
builder.getUnresolvedFunctionBuilder
65+
.setFunctionName(functionName)
66+
.setIsUserDefinedFunction(isUserDefinedFunction)
67+
.setIsDistinct(isDistinct)
68+
.addAllArguments(arguments.map(apply).asJava)
69+
70+
case Alias(child, name, metadata, _) =>
71+
val b = builder.getAliasBuilder.setExpr(apply(child))
72+
name.foreach(b.addName)
73+
metadata.foreach(m => b.setMetadata(m.json))
74+
75+
case Cast(child, dataType, evalMode, _) =>
76+
val b = builder.getCastBuilder
77+
.setExpr(apply(child))
78+
.setType(DataTypeProtoConverter.toConnectProtoType(dataType))
79+
evalMode.foreach { mode =>
80+
val convertedMode = mode match {
81+
case Cast.Try => proto.Expression.Cast.EvalMode.EVAL_MODE_TRY
82+
case Cast.Ansi => proto.Expression.Cast.EvalMode.EVAL_MODE_ANSI
83+
case Cast.Legacy => proto.Expression.Cast.EvalMode.EVAL_MODE_LEGACY
84+
}
85+
b.setEvalMode(convertedMode)
86+
}
87+
88+
case SqlExpression(expression, _) =>
89+
builder.getExpressionStringBuilder.setExpression(expression)
90+
91+
case s: SortOrder =>
92+
builder.setSortOrder(convertSortOrder(s))
93+
94+
case Window(windowFunction, windowSpec, _) =>
95+
val b = builder.getWindowBuilder
96+
.setWindowFunction(apply(windowFunction))
97+
.addAllPartitionSpec(windowSpec.partitionColumns.map(apply).asJava)
98+
.addAllOrderSpec(windowSpec.sortColumns.map(convertSortOrder).asJava)
99+
windowSpec.frame.foreach { frame =>
100+
b.getFrameSpecBuilder
101+
.setFrameType(frame.frameType match {
102+
case WindowFrame.Row => FrameType.FRAME_TYPE_ROW
103+
case WindowFrame.Range => FrameType.FRAME_TYPE_RANGE
104+
})
105+
.setLower(convertFrameBoundary(frame.lower))
106+
.setUpper(convertFrameBoundary(frame.upper))
107+
}
108+
109+
case UnresolvedExtractValue(child, extraction, _) =>
110+
builder.getUnresolvedExtractValueBuilder
111+
.setChild(apply(child))
112+
.setExtraction(apply(extraction))
113+
114+
case UpdateFields(structExpression, fieldName, valueExpression, _) =>
115+
val b = builder.getUpdateFieldsBuilder
116+
.setStructExpression(apply(structExpression))
117+
.setFieldName(fieldName)
118+
valueExpression.foreach(v => b.setValueExpression(apply(v)))
119+
120+
case v: UnresolvedNamedLambdaVariable =>
121+
builder.setUnresolvedNamedLambdaVariable(convertNamedLambdaVariable(v))
122+
123+
case LambdaFunction(function, arguments, _) =>
124+
builder.getLambdaFunctionBuilder
125+
.setFunction(apply(function))
126+
.addAllArguments(arguments.map(convertNamedLambdaVariable).asJava)
127+
128+
case InvokeInlineUserDefinedFunction(udf: ScalaUserDefinedFunction, arguments, false, _) =>
129+
val b = builder.getCommonInlineUserDefinedFunctionBuilder
130+
.setScalarScalaUdf(udf.udf)
131+
.setDeterministic(udf.deterministic)
132+
.addAllArguments(arguments.map(apply).asJava)
133+
udf.givenName.foreach(b.setFunctionName)
134+
135+
case CaseWhenOtherwise(branches, otherwise, _) =>
136+
val b = builder.getUnresolvedFunctionBuilder
137+
.setFunctionName("when")
138+
branches.foreach { case (condition, value) =>
139+
b.addArguments(apply(condition))
140+
b.addArguments(apply(value))
141+
}
142+
otherwise.foreach { value =>
143+
b.addArguments(apply(value))
144+
}
145+
146+
case ProtoColumnNode(e, _) =>
147+
return e
148+
149+
case node =>
150+
throw SparkException.internalError("Unsupported ColumnNode: " + node)
151+
}
152+
builder.build()
153+
}
154+
155+
private def convertSortOrder(s: SortOrder): proto.Expression.SortOrder = {
156+
proto.Expression.SortOrder
157+
.newBuilder()
158+
.setChild(apply(s.child))
159+
.setDirection(s.sortDirection match {
160+
case SortOrder.Ascending => SORT_DIRECTION_ASCENDING
161+
case SortOrder.Descending => SORT_DIRECTION_DESCENDING
162+
})
163+
.setNullOrdering(s.nullOrdering match {
164+
case SortOrder.NullsFirst => SORT_NULLS_FIRST
165+
case SortOrder.NullsLast => SORT_NULLS_LAST
166+
})
167+
.build()
168+
}
169+
170+
private def convertFrameBoundary(boundary: WindowFrame.FrameBoundary): FrameBoundary = {
171+
val builder = FrameBoundary.newBuilder()
172+
boundary match {
173+
case WindowFrame.UnboundedPreceding => builder.setUnbounded(true)
174+
case WindowFrame.UnboundedFollowing => builder.setUnbounded(true)
175+
case WindowFrame.CurrentRow => builder.setCurrentRow(true)
176+
case WindowFrame.Value(value) => builder.setValue(apply(value))
177+
}
178+
builder.build()
179+
}
180+
181+
private def convertNamedLambdaVariable(
182+
v: UnresolvedNamedLambdaVariable): proto.Expression.UnresolvedNamedLambdaVariable = {
183+
proto.Expression.UnresolvedNamedLambdaVariable.newBuilder().addNameParts(v.name).build()
184+
}
185+
}
186+
187+
case class ProtoColumnNode(
188+
expr: proto.Expression,
189+
override val origin: Origin = CurrentOrigin.get)
190+
extends ColumnNode {
191+
override def sql: String = expr.toString
192+
}

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import org.apache.spark.sql.Column
2929
import org.apache.spark.sql.catalyst.ScalaReflection
3030
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
3131
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, UdfPacket}
32+
import org.apache.spark.sql.internal.UserDefinedFunctionLike
3233
import org.apache.spark.sql.types.DataType
3334
import org.apache.spark.util.{ClosureCleaner, SparkClassUtils, SparkSerDeUtils}
3435

@@ -101,13 +102,14 @@ case class ScalaUserDefinedFunction private[sql] (
101102
serializedUdfPacket: Array[Byte],
102103
inputTypes: Seq[proto.DataType],
103104
outputType: proto.DataType,
104-
name: Option[String],
105+
givenName: Option[String],
105106
override val nullable: Boolean,
106107
override val deterministic: Boolean,
107108
aggregate: Boolean)
108-
extends UserDefinedFunction {
109+
extends UserDefinedFunction
110+
with UserDefinedFunctionLike {
109111

110-
private[expressions] lazy val udf = {
112+
private[sql] lazy val udf = {
111113
val scalaUdfBuilder = proto.ScalarScalaUDF
112114
.newBuilder()
113115
.setPayload(ByteString.copyFrom(serializedUdfPacket))
@@ -128,10 +130,10 @@ case class ScalaUserDefinedFunction private[sql] (
128130
.setScalarScalaUdf(udf)
129131
.addAllArguments(exprs.map(_.expr).asJava)
130132

131-
name.foreach(udfBuilder.setFunctionName)
133+
givenName.foreach(udfBuilder.setFunctionName)
132134
}
133135

134-
override def withName(name: String): ScalaUserDefinedFunction = copy(name = Option(name))
136+
override def withName(name: String): ScalaUserDefinedFunction = copy(givenName = Option(name))
135137

136138
override def asNonNullable(): ScalaUserDefinedFunction = copy(nullable = false)
137139

@@ -143,9 +145,11 @@ case class ScalaUserDefinedFunction private[sql] (
143145
.setDeterministic(deterministic)
144146
.setScalarScalaUdf(udf)
145147

146-
name.foreach(builder.setFunctionName)
148+
givenName.foreach(builder.setFunctionName)
147149
builder.build()
148150
}
151+
152+
override def name: String = givenName.getOrElse("UDF")
149153
}
150154

151155
object ScalaUserDefinedFunction {
@@ -195,7 +199,7 @@ object ScalaUserDefinedFunction {
195199
serializedUdfPacket = udfPacketBytes,
196200
inputTypes = inputEncoders.map(_.dataType).map(DataTypeProtoConverter.toConnectProtoType),
197201
outputType = DataTypeProtoConverter.toConnectProtoType(outputEncoder.dataType),
198-
name = None,
202+
givenName = None,
199203
nullable = true,
200204
deterministic = true,
201205
aggregate = aggregate)

connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UDFClassLoadingE2ESuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class UDFClassLoadingE2ESuite extends ConnectFunSuite with RemoteSparkSession {
4343
serializedUdfPacket = udfByteArray,
4444
inputTypes = Seq(ProtoDataTypes.IntegerType),
4545
outputType = ProtoDataTypes.IntegerType,
46-
name = Some("dummyUdf"),
46+
givenName = Some("dummyUdf"),
4747
nullable = true,
4848
deterministic = true,
4949
aggregate = false)

0 commit comments

Comments
 (0)