Skip to content

Commit aec68a8

Browse files
ueshincloud-fan
authored andcommitted
[SPARK-26370][SQL] Fix resolution of higher-order function for the same identifier.
When using a higher-order function with the same variable name as the existing columns in `Filter` or something which uses `Analyzer.resolveExpressionBottomUp` during the resolution, e.g.,: ```scala val df = Seq( (Seq(1, 9, 8, 7), 1, 2), (Seq(5, 9, 7), 2, 2), (Seq.empty, 3, 2), (null, 4, 2) ).toDF("i", "x", "d") checkAnswer(df.filter("exists(i, x -> x % d == 0)"), Seq(Row(Seq(1, 9, 8, 7), 1, 2))) checkAnswer(df.select("x").filter("exists(i, x -> x % d == 0)"), Seq(Row(1))) ``` the following exception happens: ``` java.lang.ClassCastException: org.apache.spark.sql.catalyst.expressions.BoundReference cannot be cast to org.apache.spark.sql.catalyst.expressions.NamedExpression at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:237) at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62) at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49) at scala.collection.TraversableLike.map(TraversableLike.scala:237) at scala.collection.TraversableLike.map$(TraversableLike.scala:230) at scala.collection.AbstractTraversable.map(Traversable.scala:108) at org.apache.spark.sql.catalyst.expressions.HigherOrderFunction.$anonfun$functionsForEval$1(higherOrderFunctions.scala:147) at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:237) at scala.collection.immutable.List.foreach(List.scala:392) at scala.collection.TraversableLike.map(TraversableLike.scala:237) at scala.collection.TraversableLike.map$(TraversableLike.scala:230) at scala.collection.immutable.List.map(List.scala:298) at org.apache.spark.sql.catalyst.expressions.HigherOrderFunction.functionsForEval(higherOrderFunctions.scala:145) at org.apache.spark.sql.catalyst.expressions.HigherOrderFunction.functionsForEval$(higherOrderFunctions.scala:145) at org.apache.spark.sql.catalyst.expressions.ArrayExists.functionsForEval$lzycompute(higherOrderFunctions.scala:369) at org.apache.spark.sql.catalyst.expressions.ArrayExists.functionsForEval(higherOrderFunctions.scala:369) at org.apache.spark.sql.catalyst.expressions.SimpleHigherOrderFunction.functionForEval(higherOrderFunctions.scala:176) at org.apache.spark.sql.catalyst.expressions.SimpleHigherOrderFunction.functionForEval$(higherOrderFunctions.scala:176) at org.apache.spark.sql.catalyst.expressions.ArrayExists.functionForEval(higherOrderFunctions.scala:369) at org.apache.spark.sql.catalyst.expressions.ArrayExists.nullSafeEval(higherOrderFunctions.scala:387) at org.apache.spark.sql.catalyst.expressions.SimpleHigherOrderFunction.eval(higherOrderFunctions.scala:190) at org.apache.spark.sql.catalyst.expressions.SimpleHigherOrderFunction.eval$(higherOrderFunctions.scala:185) at org.apache.spark.sql.catalyst.expressions.ArrayExists.eval(higherOrderFunctions.scala:369) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificPredicate.eval(Unknown Source) at org.apache.spark.sql.execution.FilterExec.$anonfun$doExecute$3(basicPhysicalOperators.scala:216) at org.apache.spark.sql.execution.FilterExec.$anonfun$doExecute$3$adapted(basicPhysicalOperators.scala:215) ... ``` because the `UnresolvedAttribute`s in `LambdaFunction` are unexpectedly resolved by the rule. This pr modified to use a placeholder `UnresolvedNamedLambdaVariable` to prevent unexpected resolution. Added a test and modified some tests. Closes #23320 from ueshin/issues/SPARK-26370/hof_resolution. Authored-by: Takuya UESHIN <[email protected]> Signed-off-by: Wenchen Fan <[email protected]> (cherry picked from commit 3dda58a) Signed-off-by: Wenchen Fan <[email protected]>
1 parent a2c5bea commit aec68a8

File tree

6 files changed

+62
-12
lines changed

6 files changed

+62
-12
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,13 +148,14 @@ case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] {
148148
val lambdaMap = l.arguments.map(v => canonicalizer(v.name) -> v).toMap
149149
l.mapChildren(resolve(_, parentLambdaMap ++ lambdaMap))
150150

151-
case u @ UnresolvedAttribute(name +: nestedFields) =>
151+
case u @ UnresolvedNamedLambdaVariable(name +: nestedFields) =>
152152
parentLambdaMap.get(canonicalizer(name)) match {
153153
case Some(lambda) =>
154154
nestedFields.foldLeft(lambda: Expression) { (expr, fieldName) =>
155155
ExtractValue(expr, Literal(fieldName), conf.resolver)
156156
}
157-
case None => u
157+
case None =>
158+
UnresolvedAttribute(u.nameParts)
158159
}
159160

160161
case _ =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,34 @@ import java.util.concurrent.atomic.AtomicReference
2222
import scala.collection.mutable
2323

2424
import org.apache.spark.sql.catalyst.InternalRow
25-
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute}
25+
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedException}
2626
import org.apache.spark.sql.catalyst.expressions.codegen._
2727
import org.apache.spark.sql.catalyst.util._
2828
import org.apache.spark.sql.types._
2929
import org.apache.spark.unsafe.array.ByteArrayMethods
3030

31+
/**
32+
* A placeholder of lambda variables to prevent unexpected resolution of [[LambdaFunction]].
33+
*/
34+
case class UnresolvedNamedLambdaVariable(nameParts: Seq[String])
35+
extends LeafExpression with NamedExpression with Unevaluable {
36+
37+
override def name: String =
38+
nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".")
39+
40+
override def exprId: ExprId = throw new UnresolvedException(this, "exprId")
41+
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
42+
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
43+
override def qualifier: Seq[String] = throw new UnresolvedException(this, "qualifier")
44+
override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute")
45+
override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance")
46+
override lazy val resolved = false
47+
48+
override def toString: String = s"lambda '$name"
49+
50+
override def sql: String = name
51+
}
52+
3153
/**
3254
* A named lambda variable.
3355
*/
@@ -79,7 +101,7 @@ case class LambdaFunction(
79101

80102
object LambdaFunction {
81103
val identity: LambdaFunction = {
82-
val id = UnresolvedAttribute.quoted("id")
104+
val id = UnresolvedNamedLambdaVariable(Seq("id"))
83105
LambdaFunction(id, Seq(id))
84106
}
85107
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1336,9 +1336,12 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
13361336
*/
13371337
override def visitLambda(ctx: LambdaContext): Expression = withOrigin(ctx) {
13381338
val arguments = ctx.IDENTIFIER().asScala.map { name =>
1339-
UnresolvedAttribute.quoted(name.getText)
1339+
UnresolvedNamedLambdaVariable(UnresolvedAttribute.quoted(name.getText).nameParts)
13401340
}
1341-
LambdaFunction(expression(ctx.expression), arguments)
1341+
val function = expression(ctx.expression).transformUp {
1342+
case a: UnresolvedAttribute => UnresolvedNamedLambdaVariable(a.nameParts)
1343+
}
1344+
LambdaFunction(function, arguments)
13421345
}
13431346

13441347
/**

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,19 +49,21 @@ class ResolveLambdaVariablesSuite extends PlanTest {
4949
comparePlans(Analyzer.execute(plan(e1)), plan(e2))
5050
}
5151

52+
private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(s.name))
53+
5254
test("resolution - no op") {
5355
checkExpression(key, key)
5456
}
5557

5658
test("resolution - simple") {
57-
val in = ArrayTransform(values1, LambdaFunction('x.attr + 1, 'x.attr :: Nil))
59+
val in = ArrayTransform(values1, LambdaFunction(lv('x) + 1, lv('x) :: Nil))
5860
val out = ArrayTransform(values1, LambdaFunction(lvInt + 1, lvInt :: Nil))
5961
checkExpression(in, out)
6062
}
6163

6264
test("resolution - nested") {
6365
val in = ArrayTransform(values2, LambdaFunction(
64-
ArrayTransform('x.attr, LambdaFunction('x.attr + 1, 'x.attr :: Nil)), 'x.attr :: Nil))
66+
ArrayTransform(lv('x), LambdaFunction(lv('x) + 1, lv('x) :: Nil)), lv('x) :: Nil))
6567
val out = ArrayTransform(values2, LambdaFunction(
6668
ArrayTransform(lvArray, LambdaFunction(lvInt + 1, lvInt :: Nil)), lvArray :: Nil))
6769
checkExpression(in, out)
@@ -75,14 +77,14 @@ class ResolveLambdaVariablesSuite extends PlanTest {
7577

7678
test("fail - name collisions") {
7779
val p = plan(ArrayTransform(values1,
78-
LambdaFunction('x.attr + 'X.attr, 'x.attr :: 'X.attr :: Nil)))
80+
LambdaFunction(lv('x) + lv('X), lv('x) :: lv('X) :: Nil)))
7981
val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage
8082
assert(msg.contains("arguments should not have names that are semantically the same"))
8183
}
8284

8385
test("fail - lambda arguments") {
8486
val p = plan(ArrayTransform(values1,
85-
LambdaFunction('x.attr + 'y.attr + 'z.attr, 'x.attr :: 'y.attr :: 'z.attr :: Nil)))
87+
LambdaFunction(lv('x) + lv('y) + lv('z), lv('x) :: lv('y) :: lv('z) :: Nil)))
8688
val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage
8789
assert(msg.contains("does not match the number of arguments expected"))
8890
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,9 +246,11 @@ class ExpressionParserSuite extends PlanTest {
246246
intercept("foo(a x)", "extraneous input 'x'")
247247
}
248248

249+
private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(s.name))
250+
249251
test("lambda functions") {
250-
assertEqual("x -> x + 1", LambdaFunction('x + 1, Seq('x.attr)))
251-
assertEqual("(x, y) -> x + y", LambdaFunction('x + 'y, Seq('x.attr, 'y.attr)))
252+
assertEqual("x -> x + 1", LambdaFunction(lv('x) + 1, Seq(lv('x))))
253+
assertEqual("(x, y) -> x + y", LambdaFunction(lv('x) + lv('y), Seq(lv('x), lv('y))))
252254
}
253255

254256
test("window function expressions") {

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2486,6 +2486,26 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
24862486
}
24872487
assert(ex.getMessage.contains("Cannot use null as map key"))
24882488
}
2489+
2490+
test("SPARK-26370: Fix resolution of higher-order function for the same identifier") {
2491+
val df = Seq(
2492+
(Seq(1, 9, 8, 7), 1, 2),
2493+
(Seq(5, 9, 7), 2, 2),
2494+
(Seq.empty, 3, 2),
2495+
(null, 4, 2)
2496+
).toDF("i", "x", "d")
2497+
2498+
checkAnswer(df.selectExpr("x", "exists(i, x -> x % d == 0)"),
2499+
Seq(
2500+
Row(1, true),
2501+
Row(2, false),
2502+
Row(3, false),
2503+
Row(4, null)))
2504+
checkAnswer(df.filter("exists(i, x -> x % d == 0)"),
2505+
Seq(Row(Seq(1, 9, 8, 7), 1, 2)))
2506+
checkAnswer(df.select("x").filter("exists(i, x -> x % d == 0)"),
2507+
Seq(Row(1)))
2508+
}
24892509
}
24902510

24912511
object DataFrameFunctionsSuite {

0 commit comments

Comments
 (0)