Skip to content

Commit 4793c84

Browse files
cloud-fanmarmbrus
authored andcommitted
[SPARK-5278][SQL] Introduce UnresolvedGetField and complete the check of ambiguous reference to fields
When the `GetField` chain(`a.b.c.d.....`) is interrupted by `GetItem` like `a.b[0].c.d....`, then the check of ambiguous reference to fields is broken. The reason is that: for something like `a.b[0].c.d`, we first parse it to `GetField(GetField(GetItem(Unresolved("a.b"), 0), "c"), "d")`. Then in `LogicalPlan#resolve`, we resolve `"a.b"` and build a `GetField` chain from bottom(the relation). But for the 2 outer `GetFiled`, we have to resolve them in `Analyzer` or do it in `GetField` lazily, check data type of child, search needed field, etc. which is similar to what we have done in `LogicalPlan#resolve`. So in this PR, the fix is just copy the same logic in `LogicalPlan#resolve` to `Analyzer`, which is simple and quick, but I do suggest introduce `UnresolvedGetFiled` like I explained in #2405. Author: Wenchen Fan <[email protected]> Closes #4068 from cloud-fan/simple and squashes the following commits: a6857b5 [Wenchen Fan] fix import order 8411c40 [Wenchen Fan] use UnresolvedGetField
1 parent bc36356 commit 4793c84

File tree

12 files changed

+84
-88
lines changed

12 files changed

+84
-88
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ class SqlParser extends AbstractSparkSQLParser {
372372
| expression ~ ("[" ~> expression <~ "]") ^^
373373
{ case base ~ ordinal => GetItem(base, ordinal) }
374374
| (expression <~ ".") ~ ident ^^
375-
{ case base ~ fieldName => GetField(base, fieldName) }
375+
{ case base ~ fieldName => UnresolvedGetField(base, fieldName) }
376376
| cast
377377
| "(" ~> expression <~ ")"
378378
| function

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ class Analyzer(catalog: Catalog,
285285

286286
case q: LogicalPlan =>
287287
logTrace(s"Attempting to resolve ${q.simpleString}")
288-
q transformExpressions {
288+
q transformExpressionsUp {
289289
case u @ UnresolvedAttribute(name) if resolver(name, VirtualColumn.groupingIdName) &&
290290
q.isInstanceOf[GroupingAnalytics] =>
291291
// Resolve the virtual column GROUPING__ID for the operator GroupingAnalytics
@@ -295,15 +295,8 @@ class Analyzer(catalog: Catalog,
295295
val result = q.resolveChildren(name, resolver).getOrElse(u)
296296
logDebug(s"Resolving $u to $result")
297297
result
298-
299-
// Resolve field names using the resolver.
300-
case f @ GetField(child, fieldName) if !f.resolved && child.resolved =>
301-
child.dataType match {
302-
case StructType(fields) =>
303-
val resolvedFieldName = fields.map(_.name).find(resolver(_, fieldName))
304-
resolvedFieldName.map(n => f.copy(fieldName = n)).getOrElse(f)
305-
case _ => f
306-
}
298+
case UnresolvedGetField(child, fieldName) if child.resolved =>
299+
resolveGetField(child, fieldName)
307300
}
308301
}
309302

@@ -312,6 +305,27 @@ class Analyzer(catalog: Catalog,
312305
*/
313306
protected def containsStar(exprs: Seq[Expression]): Boolean =
314307
exprs.exists(_.collect { case _: Star => true }.nonEmpty)
308+
309+
/**
310+
* Returns the resolved `GetField`, and report error if no desired field or over one
311+
* desired fields are found.
312+
*/
313+
protected def resolveGetField(expr: Expression, fieldName: String): Expression = {
314+
expr.dataType match {
315+
case StructType(fields) =>
316+
val actualField = fields.filter(f => resolver(f.name, fieldName))
317+
if (actualField.length == 0) {
318+
sys.error(
319+
s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}")
320+
} else if (actualField.length == 1) {
321+
val field = actualField(0)
322+
GetField(expr, field, fields.indexOf(field))
323+
} else {
324+
sys.error(s"Ambiguous reference to fields ${actualField.mkString(", ")}")
325+
}
326+
case otherType => sys.error(s"GetField is not valid on fields of type $otherType")
327+
}
328+
}
315329
}
316330

317331
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,15 @@ case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star {
177177
override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = expressions
178178
override def toString = expressions.mkString("ResolvedStar(", ", ", ")")
179179
}
180+
181+
case class UnresolvedGetField(child: Expression, fieldName: String) extends UnaryExpression {
182+
override def dataType = throw new UnresolvedException(this, "dataType")
183+
override def foldable = throw new UnresolvedException(this, "foldable")
184+
override def nullable = throw new UnresolvedException(this, "nullable")
185+
override lazy val resolved = false
186+
187+
override def eval(input: Row = null): EvaluatedType =
188+
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
189+
190+
override def toString = s"$child.$fieldName"
191+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp}
2222
import scala.language.implicitConversions
2323
import scala.reflect.runtime.universe.{TypeTag, typeTag}
2424

25-
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
25+
import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, UnresolvedAttribute}
2626
import org.apache.spark.sql.catalyst.expressions._
2727
import org.apache.spark.sql.catalyst.plans.logical._
2828
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
@@ -101,7 +101,7 @@ package object dsl {
101101
def isNotNull = IsNotNull(expr)
102102

103103
def getItem(ordinal: Expression) = GetItem(expr, ordinal)
104-
def getField(fieldName: String) = GetField(expr, fieldName)
104+
def getField(fieldName: String) = UnresolvedGetField(expr, fieldName)
105105

106106
def cast(to: DataType) = Cast(expr, to)
107107

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -73,39 +73,19 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression {
7373
/**
7474
* Returns the value of fields in the Struct `child`.
7575
*/
76-
case class GetField(child: Expression, fieldName: String) extends UnaryExpression {
76+
case class GetField(child: Expression, field: StructField, ordinal: Int) extends UnaryExpression {
7777
type EvaluatedType = Any
7878

7979
def dataType = field.dataType
8080
override def nullable = child.nullable || field.nullable
8181
override def foldable = child.foldable
8282

83-
protected def structType = child.dataType match {
84-
case s: StructType => s
85-
case otherType => sys.error(s"GetField is not valid on fields of type $otherType")
86-
}
87-
88-
lazy val field =
89-
structType.fields
90-
.find(_.name == fieldName)
91-
.getOrElse(sys.error(s"No such field $fieldName in ${child.dataType}"))
92-
93-
lazy val ordinal = structType.fields.indexOf(field)
94-
95-
override lazy val resolved = childrenResolved && fieldResolved
96-
97-
/** Returns true only if the fieldName is found in the child struct. */
98-
private def fieldResolved = child.dataType match {
99-
case StructType(fields) => fields.map(_.name).contains(fieldName)
100-
case _ => false
101-
}
102-
10383
override def eval(input: Row): Any = {
10484
val baseValue = child.eval(input).asInstanceOf[Row]
10585
if (baseValue == null) null else baseValue(ordinal)
10686
}
10787

108-
override def toString = s"$child.$fieldName"
88+
override def toString = s"$child.${field.name}"
10989
}
11090

11191
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ object NullPropagation extends Rule[LogicalPlan] {
206206
case e @ IsNotNull(c) if !c.nullable => Literal(true, BooleanType)
207207
case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType)
208208
case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType)
209-
case e @ GetField(Literal(null, _), _) => Literal(null, e.dataType)
209+
case e @ GetField(Literal(null, _), _, _) => Literal(null, e.dataType)
210210
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
211211
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
212212
case e @ Count(expr) if !expr.nullable => Count(Literal(1))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.catalyst.plans.logical
1919

2020
import org.apache.spark.Logging
21-
import org.apache.spark.sql.catalyst.analysis.Resolver
21+
import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, Resolver}
2222
import org.apache.spark.sql.catalyst.errors.TreeNodeException
2323
import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.plans.QueryPlan
@@ -160,11 +160,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
160160

161161
// One match, but we also need to extract the requested nested field.
162162
case Seq((a, nestedFields)) =>
163-
val aliased =
164-
Alias(
165-
resolveNesting(nestedFields, a, resolver),
166-
nestedFields.last)() // Preserve the case of the user's field access.
167-
Some(aliased)
163+
Some(Alias(nestedFields.foldLeft(a: Expression)(UnresolvedGetField), nestedFields.last)())
168164

169165
// No matches.
170166
case Seq() =>
@@ -177,31 +173,6 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
177173
this, s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}")
178174
}
179175
}
180-
181-
/**
182-
* Given a list of successive nested field accesses, and a based expression, attempt to resolve
183-
* the actual field lookups on this expression.
184-
*/
185-
private def resolveNesting(
186-
nestedFields: List[String],
187-
expression: Expression,
188-
resolver: Resolver): Expression = {
189-
190-
(nestedFields, expression.dataType) match {
191-
case (Nil, _) => expression
192-
case (requestedField :: rest, StructType(fields)) =>
193-
val actualField = fields.filter(f => resolver(f.name, requestedField))
194-
if (actualField.length == 0) {
195-
sys.error(
196-
s"No such struct field $requestedField in ${fields.map(_.name).mkString(", ")}")
197-
} else if (actualField.length == 1) {
198-
resolveNesting(rest, GetField(expression, actualField(0).name), resolver)
199-
} else {
200-
sys.error(s"Ambiguous reference to fields ${actualField.mkString(", ")}")
201-
}
202-
case (_, dt) => sys.error(s"Can't access nested field in type $dt")
203-
}
204-
}
205176
}
206177

207178
/**

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.scalatest.FunSuite
2626
import org.scalatest.Matchers._
2727

2828
import org.apache.spark.sql.catalyst.dsl.expressions._
29+
import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField
2930
import org.apache.spark.sql.types._
3031

3132

@@ -846,23 +847,33 @@ class ExpressionEvaluationSuite extends FunSuite {
846847
checkEvaluation(GetItem(BoundReference(4, typeArray, true),
847848
Literal(null, IntegerType)), null, row)
848849

849-
checkEvaluation(GetField(BoundReference(2, typeS, nullable = true), "a"), "aa", row)
850-
checkEvaluation(GetField(Literal(null, typeS), "a"), null, row)
850+
def quickBuildGetField(expr: Expression, fieldName: String) = {
851+
expr.dataType match {
852+
case StructType(fields) =>
853+
val field = fields.find(_.name == fieldName).get
854+
GetField(expr, field, fields.indexOf(field))
855+
}
856+
}
857+
858+
def quickResolve(u: UnresolvedGetField) = quickBuildGetField(u.child, u.fieldName)
859+
860+
checkEvaluation(quickBuildGetField(BoundReference(2, typeS, nullable = true), "a"), "aa", row)
861+
checkEvaluation(quickBuildGetField(Literal(null, typeS), "a"), null, row)
851862

852863
val typeS_notNullable = StructType(
853864
StructField("a", StringType, nullable = false)
854865
:: StructField("b", StringType, nullable = false) :: Nil
855866
)
856867

857-
assert(GetField(BoundReference(2,typeS, nullable = true), "a").nullable === true)
858-
assert(GetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable === false)
868+
assert(quickBuildGetField(BoundReference(2,typeS, nullable = true), "a").nullable === true)
869+
assert(quickBuildGetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable === false)
859870

860-
assert(GetField(Literal(null, typeS), "a").nullable === true)
861-
assert(GetField(Literal(null, typeS_notNullable), "a").nullable === true)
871+
assert(quickBuildGetField(Literal(null, typeS), "a").nullable === true)
872+
assert(quickBuildGetField(Literal(null, typeS_notNullable), "a").nullable === true)
862873

863874
checkEvaluation('c.map(typeMap).at(3).getItem("aa"), "bb", row)
864875
checkEvaluation('c.array(typeArray.elementType).at(4).getItem(1), "bb", row)
865-
checkEvaluation('c.struct(typeS).at(2).getField("a"), "aa", row)
876+
checkEvaluation(quickResolve('c.struct(typeS).at(2).getField("a")), "aa", row)
866877
}
867878

868879
test("arithmetic") {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.optimizer
1919

20-
import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators
20+
import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, EliminateAnalysisOperators}
2121
import org.apache.spark.sql.catalyst.expressions._
2222
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
2323
import org.apache.spark.sql.catalyst.plans.PlanTest
@@ -184,7 +184,7 @@ class ConstantFoldingSuite extends PlanTest {
184184

185185
GetItem(Literal(null, ArrayType(IntegerType)), 1) as 'c3,
186186
GetItem(Literal(Seq(1), ArrayType(IntegerType)), Literal(null, IntegerType)) as 'c4,
187-
GetField(
187+
UnresolvedGetField(
188188
Literal(null, StructType(Seq(StructField("a", IntegerType, true)))),
189189
"a") as 'c5,
190190

sql/core/src/main/scala/org/apache/spark/sql/Column.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import scala.language.implicitConversions
2323
import org.apache.spark.sql.Dsl.lit
2424
import org.apache.spark.sql.catalyst.expressions._
2525
import org.apache.spark.sql.catalyst.plans.logical.{Subquery, Project, LogicalPlan}
26+
import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField
2627
import org.apache.spark.sql.types._
2728

2829

@@ -505,7 +506,7 @@ trait Column extends DataFrame {
505506
/**
506507
* An expression that gets a field by name in a [[StructField]].
507508
*/
508-
def getField(fieldName: String): Column = exprToColumn(GetField(expr, fieldName))
509+
def getField(fieldName: String): Column = exprToColumn(UnresolvedGetField(expr, fieldName))
509510

510511
/**
511512
* An expression that returns a substring.

0 commit comments

Comments
 (0)