Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ class SqlParser extends AbstractSparkSQLParser {
| expression ~ ("[" ~> expression <~ "]") ^^
{ case base ~ ordinal => GetItem(base, ordinal) }
| (expression <~ ".") ~ ident ^^
{ case base ~ fieldName => GetField(base, fieldName) }
{ case base ~ fieldName => UnresolvedGetField(base, fieldName) }
| cast
| "(" ~> expression <~ ")"
| function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
NewRelationInstances),
Batch("Resolution", fixedPoint,
ResolveReferences ::
ResolveGetField ::
ResolveRelations ::
ResolveSortReferences ::
NewRelationInstances ::
Expand Down Expand Up @@ -165,6 +166,19 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
}
}

/**
* Replaces [[UnresolvedGetField]]s with concrete [[GetField]]
*/
object ResolveGetField extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case q: LogicalPlan if q.childrenResolved =>
q transformExpressionsUp {
case u @ UnresolvedGetField(child, fieldName) if child.resolved =>
GetField(u.child, u.fieldName, resolver)
}
}
}

/**
* In many dialects of SQL is it valid to sort by attributes that are not present in the SELECT
* clause. This rule detects such queries and adds the required attributes to the original
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,15 @@ case class Star(

override def toString = table.map(_ + ".").getOrElse("") + "*"
}

case class UnresolvedGetField(child: Expression, fieldName: String) extends UnaryExpression {
override def dataType = throw new UnresolvedException(this, "dataType")
override def foldable = throw new UnresolvedException(this, "foldable")
override def nullable = throw new UnresolvedException(this, "nullable")
override lazy val resolved = false

override def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")

override def toString = s"$child.$fieldName"
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.sql.{Date, Timestamp}

import scala.language.implicitConversions

import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
Expand Down Expand Up @@ -96,7 +96,7 @@ package object dsl {
def isNotNull = IsNotNull(expr)

def getItem(ordinal: Expression) = GetItem(expr, ordinal)
def getField(fieldName: String) = GetField(expr, fieldName)
def getField(fieldName: String) = UnresolvedGetField(expr, fieldName)

def cast(to: DataType) = Cast(expr, to)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField

import scala.collection.Map

import org.apache.spark.sql.catalyst.types._
Expand Down Expand Up @@ -73,33 +75,38 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression {
/**
* Returns the value of fields in the Struct `child`.
*/
case class GetField(child: Expression, fieldName: String) extends UnaryExpression {
case class GetField(child: Expression, field: StructField, ordinal: Int) extends UnaryExpression {
type EvaluatedType = Any

def dataType = field.dataType
override def nullable = child.nullable || field.nullable
override def foldable = child.foldable

protected def structType = child.dataType match {
case s: StructType => s
case otherType => sys.error(s"GetField is not valid on fields of type $otherType")
}

lazy val field =
structType.fields
.find(_.name == fieldName)
.getOrElse(sys.error(s"No such field $fieldName in ${child.dataType}"))

lazy val ordinal = structType.fields.indexOf(field)

override lazy val resolved = childrenResolved && child.dataType.isInstanceOf[StructType]

override def eval(input: Row): Any = {
val baseValue = child.eval(input).asInstanceOf[Row]
if (baseValue == null) null else baseValue(ordinal)
}

override def toString = s"$child.$fieldName"
override def toString = s"$child.${field.name}"
}

object GetField {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If possible, I think it might be clearer to keep the resolver logic in the Analyzer rule.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going to put this logic into Analyzer rule, but found some tests depend on GetField(child, fieldName), so I have to create this constructor of GetField. And these two are so similar, so I combine them together. Maybe I should fix those tests instead?

def apply(
e: Expression,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we can check to see if the field actually exists in Struct, otherwise resolved = false

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently all GetFields are resolved as I try to resolve them in the ResolveGetField rule. If it can't be resolved(field not exists in Strct etc.), the rule will throw Exception. That's why I removed the resolved field.

fieldName: String,
equality: (String, String) => Boolean = _ == _): GetField = {
val structType = e.dataType match {
case s: StructType => s
case otherType => sys.error(s"GetField is not valid on fields of type $otherType")
}
val field = structType.fields
.find(f => equality(f.name, fieldName))
.getOrElse(sys.error(s"No such field $fieldName in ${e.dataType}"))
val ordinal = structType.fields.indexOf(field)
GetField(e, field, ordinal)
}

def apply(ug: UnresolvedGetField): GetField = GetField(ug.child, ug.fieldName)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ object NullPropagation extends Rule[LogicalPlan] {
case e @ IsNotNull(c) if !c.nullable => Literal(true, BooleanType)
case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType)
case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType)
case e @ GetField(Literal(null, _), _) => Literal(null, e.dataType)
case e @ GetField(Literal(null, _), _, _) => Literal(null, e.dataType)
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@
package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedGetField}
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.types.StructType
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.trees.TreeNode

/**
* Estimates of various statistics. The default estimation logic simply lazily multiplies the
Expand Down Expand Up @@ -160,11 +159,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {

// One match, but we also need to extract the requested nested field.
case Seq((a, nestedFields)) =>
val aliased =
Alias(
resolveNesting(nestedFields, a, resolver),
nestedFields.last)() // Preserve the case of the user's field access.
Some(aliased)
Some(Alias(nestedFields.foldLeft(a: Expression)(UnresolvedGetField), nestedFields.last)())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For something like a.b[0].c.d, the origin logic here only works for a and b. but not c and d. So I just simplified the logic here and let the ResolveGetField rule to do its job.


// No matches.
case Seq() =>
Expand All @@ -177,32 +172,6 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
this, s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}")
}
}

/**
* Given a list of successive nested field accesses, and a based expression, attempt to resolve
* the actual field lookups on this expression.
*/
private def resolveNesting(
nestedFields: List[String],
expression: Expression,
resolver: Resolver): Expression = {

(nestedFields, expression.dataType) match {
case (Nil, _) => expression
case (requestedField :: rest, StructType(fields)) =>
val actualField = fields.filter(f => resolver(f.name, requestedField))
actualField match {
case Seq() =>
sys.error(
s"No such struct field $requestedField in ${fields.map(_.name).mkString(", ")}")
case Seq(singleMatch) =>
resolveNesting(rest, GetField(expression, singleMatch.name), resolver)
case multipleMatches =>
sys.error(s"Ambiguous reference to fields ${multipleMatches.mkString(", ")}")
}
case (_, dt) => sys.error(s"Can't access nested field in type $dt")
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ class ExpressionEvaluationSuite extends FunSuite {

checkEvaluation('c.map(typeMap).at(3).getItem("aa"), "bb", row)
checkEvaluation('c.array(typeArray.elementType).at(4).getItem(1), "bb", row)
checkEvaluation('c.struct(typeS).at(2).getField("a"), "aa", row)
checkEvaluation(GetField('c.struct(typeS).at(2).getField("a")), "aa", row)
}

test("arithmetic") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators
import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, EliminateAnalysisOperators}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.PlanTest
Expand Down Expand Up @@ -184,7 +184,7 @@ class ConstantFoldingSuite extends PlanTest {

GetItem(Literal(null, ArrayType(IntegerType)), 1) as 'c3,
GetItem(Literal(Seq(1), ArrayType(IntegerType)), Literal(null, IntegerType)) as 'c4,
GetField(
UnresolvedGetField(
Literal(null, StructType(Seq(StructField("a", IntegerType, true)))),
"a") as 'c5,

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -895,7 +895,7 @@ private[hive] object HiveQl {
nodeToExpr(qualifier) match {
case UnresolvedAttribute(qualifierName) =>
UnresolvedAttribute(qualifierName + "." + cleanIdentifier(attr))
case other => GetField(other, attr)
case other => UnresolvedGetField(other, attr)
}

/* Stars (*) */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,16 @@ class HiveQuerySuite extends HiveComparisonTest {
}
}

case class NestedData(a: Seq[NestedData2], B: NestedData2)
case class NestedData2(a: NestedData3, B: NestedData3)
case class NestedData3(a: Int, B: Int)

test("SPARK-3698: case insensitive test for nested data") {
sparkContext.makeRDD(Seq.empty[NestedData]).registerTempTable("nested")
// This should be successfully analyzed
sql("SELECT a[0].A.A from nested").queryExecution.analyzed
}

test("parse HQL set commands") {
// Adapted from its SQL counterpart.
val testKey = "spark.sql.key.usedfortestonly"
Expand Down