Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,14 @@ case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] {
val lambdaMap = l.arguments.map(v => canonicalizer(v.name) -> v).toMap
l.mapChildren(resolve(_, parentLambdaMap ++ lambdaMap))

case u @ UnresolvedAttribute(name +: nestedFields) =>
case u @ UnresolvedNamedLambdaVariable(name +: nestedFields) =>
parentLambdaMap.get(canonicalizer(name)) match {
case Some(lambda) =>
nestedFields.foldLeft(lambda: Expression) { (expr, fieldName) =>
ExtractValue(expr, Literal(fieldName), conf.resolver)
}
case None => u
case None =>
UnresolvedAttribute(u.nameParts)
}

case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,34 @@ import java.util.concurrent.atomic.AtomicReference
import scala.collection.mutable

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedException}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.array.ByteArrayMethods

/**
* A placeholder of lambda variables to prevent unexpected resolution of [[LambdaFunction]].
*/
case class UnresolvedNamedLambdaVariable(nameParts: Seq[String])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does a lambda variable name can have multiple name parts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this implementation, this placeholder/wrapper type isn't just used to represent lambda variables, but also for all UnresolvedAttributes inside of a LambdaFunction, so it needs to hold enough information to fallback to/reconstruct an UnresolvedAttribute the the name turned out not to be a lambda variable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah i see!

extends LeafExpression with NamedExpression with Unevaluable {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we provide something like toString to give better explain output? It looks like this node prints as just unresolvednamedlambdavariable() without the nameParts right now

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added toString and sql.

override def name: String =
nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".")

override def exprId: ExprId = throw new UnresolvedException(this, "exprId")
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override def qualifier: Seq[String] = throw new UnresolvedException(this, "qualifier")
override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute")
override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance")
override lazy val resolved = false

override def toString: String = s"lambda '$name"

override def sql: String = name
}

/**
* A named lambda variable.
*/
Expand Down Expand Up @@ -79,7 +101,7 @@ case class LambdaFunction(

object LambdaFunction {
val identity: LambdaFunction = {
val id = UnresolvedAttribute.quoted("id")
val id = UnresolvedNamedLambdaVariable(Seq("id"))
LambdaFunction(id, Seq(id))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1338,9 +1338,12 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
*/
override def visitLambda(ctx: LambdaContext): Expression = withOrigin(ctx) {
val arguments = ctx.IDENTIFIER().asScala.map { name =>
UnresolvedAttribute.quoted(name.getText)
UnresolvedNamedLambdaVariable(UnresolvedAttribute.quoted(name.getText).nameParts)
}
LambdaFunction(expression(ctx.expression), arguments)
val function = expression(ctx.expression).transformUp {
case a: UnresolvedAttribute => UnresolvedNamedLambdaVariable(a.nameParts)
}
LambdaFunction(function, arguments)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,21 @@ class ResolveLambdaVariablesSuite extends PlanTest {
comparePlans(Analyzer.execute(plan(e1)), plan(e2))
}

private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(s.name))

test("resolution - no op") {
checkExpression(key, key)
}

test("resolution - simple") {
val in = ArrayTransform(values1, LambdaFunction('x.attr + 1, 'x.attr :: Nil))
val in = ArrayTransform(values1, LambdaFunction(lv('x) + 1, lv('x) :: Nil))
val out = ArrayTransform(values1, LambdaFunction(lvInt + 1, lvInt :: Nil))
checkExpression(in, out)
}

test("resolution - nested") {
val in = ArrayTransform(values2, LambdaFunction(
ArrayTransform('x.attr, LambdaFunction('x.attr + 1, 'x.attr :: Nil)), 'x.attr :: Nil))
ArrayTransform(lv('x), LambdaFunction(lv('x) + 1, lv('x) :: Nil)), lv('x) :: Nil))
val out = ArrayTransform(values2, LambdaFunction(
ArrayTransform(lvArray, LambdaFunction(lvInt + 1, lvInt :: Nil)), lvArray :: Nil))
checkExpression(in, out)
Expand All @@ -75,14 +77,14 @@ class ResolveLambdaVariablesSuite extends PlanTest {

test("fail - name collisions") {
val p = plan(ArrayTransform(values1,
LambdaFunction('x.attr + 'X.attr, 'x.attr :: 'X.attr :: Nil)))
LambdaFunction(lv('x) + lv('X), lv('x) :: lv('X) :: Nil)))
val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage
assert(msg.contains("arguments should not have names that are semantically the same"))
}

test("fail - lambda arguments") {
val p = plan(ArrayTransform(values1,
LambdaFunction('x.attr + 'y.attr + 'z.attr, 'x.attr :: 'y.attr :: 'z.attr :: Nil)))
LambdaFunction(lv('x) + lv('y) + lv('z), lv('x) :: lv('y) :: lv('z) :: Nil)))
val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage
assert(msg.contains("does not match the number of arguments expected"))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or}
import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable}
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
Expand Down Expand Up @@ -306,22 +306,24 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testProjection(originalExpr = column, expectedExpr = column)
}

private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(s.name))

test("replace nulls in lambda function of ArrayFilter") {
testHigherOrderFunc('a, ArrayFilter, Seq('e))
testHigherOrderFunc('a, ArrayFilter, Seq(lv('e)))
}

test("replace nulls in lambda function of ArrayExists") {
testHigherOrderFunc('a, ArrayExists, Seq('e))
testHigherOrderFunc('a, ArrayExists, Seq(lv('e)))
}

test("replace nulls in lambda function of MapFilter") {
testHigherOrderFunc('m, MapFilter, Seq('k, 'v))
testHigherOrderFunc('m, MapFilter, Seq(lv('k), lv('v)))
}

test("inability to replace nulls in arbitrary higher-order function") {
val lambdaFunc = LambdaFunction(
function = If('e > 0, Literal(null, BooleanType), TrueLiteral),
arguments = Seq[NamedExpression]('e))
function = If(lv('e) > 0, Literal(null, BooleanType), TrueLiteral),
arguments = Seq[NamedExpression](lv('e)))
val column = ArrayTransform('a, lambdaFunc)
testProjection(originalExpr = column, expectedExpr = column)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,11 @@ class ExpressionParserSuite extends PlanTest {
intercept("foo(a x)", "extraneous input 'x'")
}

private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(s.name))

test("lambda functions") {
assertEqual("x -> x + 1", LambdaFunction('x + 1, Seq('x.attr)))
assertEqual("(x, y) -> x + y", LambdaFunction('x + 'y, Seq('x.attr, 'y.attr)))
assertEqual("x -> x + 1", LambdaFunction(lv('x) + 1, Seq(lv('x))))
assertEqual("(x, y) -> x + y", LambdaFunction(lv('x) + lv('y), Seq(lv('x), lv('y))))
}

test("window function expressions") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ FROM various_maps
struct<>
-- !query 5 output
org.apache.spark.sql.AnalysisException
cannot resolve 'map_zip_with(various_maps.`decimal_map1`, various_maps.`decimal_map2`, lambdafunction(named_struct(NamePlaceholder(), `k`, NamePlaceholder(), `v1`, NamePlaceholder(), `v2`), `k`, `v1`, `v2`))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,0), decimal(36,35)].; line 1 pos 7
cannot resolve 'map_zip_with(various_maps.`decimal_map1`, various_maps.`decimal_map2`, lambdafunction(named_struct(NamePlaceholder(), k, NamePlaceholder(), v1, NamePlaceholder(), v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,0), decimal(36,35)].; line 1 pos 7


-- !query 6
Expand Down Expand Up @@ -113,7 +113,7 @@ FROM various_maps
struct<>
-- !query 8 output
org.apache.spark.sql.AnalysisException
cannot resolve 'map_zip_with(various_maps.`decimal_map2`, various_maps.`int_map`, lambdafunction(named_struct(NamePlaceholder(), `k`, NamePlaceholder(), `v1`, NamePlaceholder(), `v2`), `k`, `v1`, `v2`))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,35), int].; line 1 pos 7
cannot resolve 'map_zip_with(various_maps.`decimal_map2`, various_maps.`int_map`, lambdafunction(named_struct(NamePlaceholder(), k, NamePlaceholder(), v1, NamePlaceholder(), v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,35), int].; line 1 pos 7


-- !query 9
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2908,6 +2908,26 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
}
assert(ex.getMessage.contains("Cannot use null as map key"))
}

test("SPARK-26370: Fix resolution of higher-order function for the same identifier") {
val df = Seq(
(Seq(1, 9, 8, 7), 1, 2),
(Seq(5, 9, 7), 2, 2),
(Seq.empty, 3, 2),
(null, 4, 2)
).toDF("i", "x", "d")

checkAnswer(df.selectExpr("x", "exists(i, x -> x % d == 0)"),
Seq(
Row(1, true),
Row(2, false),
Row(3, false),
Row(4, null)))
checkAnswer(df.filter("exists(i, x -> x % d == 0)"),
Seq(Row(Seq(1, 9, 8, 7), 1, 2)))
checkAnswer(df.select("x").filter("exists(i, x -> x % d == 0)"),
Seq(Row(1)))
}
}

object DataFrameFunctionsSuite {
Expand Down