Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ class SqlParser extends AbstractSparkSQLParser {
| expression ~ ("[" ~> expression <~ "]") ^^
{ case base ~ ordinal => GetItem(base, ordinal) }
| (expression <~ ".") ~ ident ^^
{ case base ~ fieldName => GetField(base, fieldName) }
{ case base ~ fieldName => UnresolvedGetField(base, fieldName) }
| cast
| "(" ~> expression <~ ")"
| function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ class Analyzer(catalog: Catalog,

case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString}")
q transformExpressions {
q transformExpressionsUp {
case u @ UnresolvedAttribute(name) if resolver(name, VirtualColumn.groupingIdName) &&
q.isInstanceOf[GroupingAnalytics] =>
// Resolve the virtual column GROUPING__ID for the operator GroupingAnalytics
Expand All @@ -295,15 +295,8 @@ class Analyzer(catalog: Catalog,
val result = q.resolveChildren(name, resolver).getOrElse(u)
logDebug(s"Resolving $u to $result")
result

// Resolve field names using the resolver.
case f @ GetField(child, fieldName) if !f.resolved && child.resolved =>
child.dataType match {
case StructType(fields) =>
val resolvedFieldName = fields.map(_.name).find(resolver(_, fieldName))
resolvedFieldName.map(n => f.copy(fieldName = n)).getOrElse(f)
case _ => f
}
case UnresolvedGetField(child, fieldName) if child.resolved =>
resolveGetField(child, fieldName)
}
}

Expand All @@ -312,6 +305,27 @@ class Analyzer(catalog: Catalog,
*/
protected def containsStar(exprs: Seq[Expression]): Boolean =
exprs.exists(_.collect { case _: Star => true }.nonEmpty)

/**
* Returns the resolved `GetField`, and report error if no desired field or over one
* desired fields are found.
*/
protected def resolveGetField(expr: Expression, fieldName: String): Expression = {
expr.dataType match {
case StructType(fields) =>
val actualField = fields.filter(f => resolver(f.name, fieldName))
if (actualField.length == 0) {
sys.error(
s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}")
} else if (actualField.length == 1) {
val field = actualField(0)
GetField(expr, field, fields.indexOf(field))
} else {
sys.error(s"Ambiguous reference to fields ${actualField.mkString(", ")}")
}
case otherType => sys.error(s"GetField is not valid on fields of type $otherType")
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,15 @@ case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star {
override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = expressions
override def toString = expressions.mkString("ResolvedStar(", ", ", ")")
}

case class UnresolvedGetField(child: Expression, fieldName: String) extends UnaryExpression {
override def dataType = throw new UnresolvedException(this, "dataType")
override def foldable = throw new UnresolvedException(this, "foldable")
override def nullable = throw new UnresolvedException(this, "nullable")
override lazy val resolved = false

override def eval(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")

override def toString = s"$child.$fieldName"
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp}
import scala.language.implicitConversions
import scala.reflect.runtime.universe.{TypeTag, typeTag}

import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
Expand Down Expand Up @@ -101,7 +101,7 @@ package object dsl {
def isNotNull = IsNotNull(expr)

def getItem(ordinal: Expression) = GetItem(expr, ordinal)
def getField(fieldName: String) = GetField(expr, fieldName)
def getField(fieldName: String) = UnresolvedGetField(expr, fieldName)

def cast(to: DataType) = Cast(expr, to)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,39 +73,19 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression {
/**
* Returns the value of fields in the Struct `child`.
*/
case class GetField(child: Expression, fieldName: String) extends UnaryExpression {
case class GetField(child: Expression, field: StructField, ordinal: Int) extends UnaryExpression {
type EvaluatedType = Any

def dataType = field.dataType
override def nullable = child.nullable || field.nullable
override def foldable = child.foldable

protected def structType = child.dataType match {
case s: StructType => s
case otherType => sys.error(s"GetField is not valid on fields of type $otherType")
}

lazy val field =
structType.fields
.find(_.name == fieldName)
.getOrElse(sys.error(s"No such field $fieldName in ${child.dataType}"))

lazy val ordinal = structType.fields.indexOf(field)

override lazy val resolved = childrenResolved && fieldResolved

/** Returns true only if the fieldName is found in the child struct. */
private def fieldResolved = child.dataType match {
case StructType(fields) => fields.map(_.name).contains(fieldName)
case _ => false
}

override def eval(input: Row): Any = {
val baseValue = child.eval(input).asInstanceOf[Row]
if (baseValue == null) null else baseValue(ordinal)
}

override def toString = s"$child.$fieldName"
override def toString = s"$child.${field.name}"
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ object NullPropagation extends Rule[LogicalPlan] {
case e @ IsNotNull(c) if !c.nullable => Literal(true, BooleanType)
case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType)
case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType)
case e @ GetField(Literal(null, _), _) => Literal(null, e.dataType)
case e @ GetField(Literal(null, _), _, _) => Literal(null, e.dataType)
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
case e @ Count(expr) if !expr.nullable => Count(Literal(1))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, Resolver}
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
Expand Down Expand Up @@ -160,11 +160,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {

// One match, but we also need to extract the requested nested field.
case Seq((a, nestedFields)) =>
val aliased =
Alias(
resolveNesting(nestedFields, a, resolver),
nestedFields.last)() // Preserve the case of the user's field access.
Some(aliased)
Some(Alias(nestedFields.foldLeft(a: Expression)(UnresolvedGetField), nestedFields.last)())

// No matches.
case Seq() =>
Expand All @@ -177,31 +173,6 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
this, s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}")
}
}

/**
* Given a list of successive nested field accesses, and a based expression, attempt to resolve
* the actual field lookups on this expression.
*/
private def resolveNesting(
nestedFields: List[String],
expression: Expression,
resolver: Resolver): Expression = {

(nestedFields, expression.dataType) match {
case (Nil, _) => expression
case (requestedField :: rest, StructType(fields)) =>
val actualField = fields.filter(f => resolver(f.name, requestedField))
if (actualField.length == 0) {
sys.error(
s"No such struct field $requestedField in ${fields.map(_.name).mkString(", ")}")
} else if (actualField.length == 1) {
resolveNesting(rest, GetField(expression, actualField(0).name), resolver)
} else {
sys.error(s"Ambiguous reference to fields ${actualField.mkString(", ")}")
}
case (_, dt) => sys.error(s"Can't access nested field in type $dt")
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.scalatest.FunSuite
import org.scalatest.Matchers._

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField
import org.apache.spark.sql.types._


Expand Down Expand Up @@ -846,23 +847,33 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(GetItem(BoundReference(4, typeArray, true),
Literal(null, IntegerType)), null, row)

checkEvaluation(GetField(BoundReference(2, typeS, nullable = true), "a"), "aa", row)
checkEvaluation(GetField(Literal(null, typeS), "a"), null, row)
def quickBuildGetField(expr: Expression, fieldName: String) = {
expr.dataType match {
case StructType(fields) =>
val field = fields.find(_.name == fieldName).get
GetField(expr, field, fields.indexOf(field))
}
}

def quickResolve(u: UnresolvedGetField) = quickBuildGetField(u.child, u.fieldName)

checkEvaluation(quickBuildGetField(BoundReference(2, typeS, nullable = true), "a"), "aa", row)
checkEvaluation(quickBuildGetField(Literal(null, typeS), "a"), null, row)

val typeS_notNullable = StructType(
StructField("a", StringType, nullable = false)
:: StructField("b", StringType, nullable = false) :: Nil
)

assert(GetField(BoundReference(2,typeS, nullable = true), "a").nullable === true)
assert(GetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable === false)
assert(quickBuildGetField(BoundReference(2,typeS, nullable = true), "a").nullable === true)
assert(quickBuildGetField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable === false)

assert(GetField(Literal(null, typeS), "a").nullable === true)
assert(GetField(Literal(null, typeS_notNullable), "a").nullable === true)
assert(quickBuildGetField(Literal(null, typeS), "a").nullable === true)
assert(quickBuildGetField(Literal(null, typeS_notNullable), "a").nullable === true)

checkEvaluation('c.map(typeMap).at(3).getItem("aa"), "bb", row)
checkEvaluation('c.array(typeArray.elementType).at(4).getItem(1), "bb", row)
checkEvaluation('c.struct(typeS).at(2).getField("a"), "aa", row)
checkEvaluation(quickResolve('c.struct(typeS).at(2).getField("a")), "aa", row)
}

test("arithmetic") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators
import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, EliminateAnalysisOperators}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.PlanTest
Expand Down Expand Up @@ -184,7 +184,7 @@ class ConstantFoldingSuite extends PlanTest {

GetItem(Literal(null, ArrayType(IntegerType)), 1) as 'c3,
GetItem(Literal(Seq(1), ArrayType(IntegerType)), Literal(null, IntegerType)) as 'c4,
GetField(
UnresolvedGetField(
Literal(null, StructType(Seq(StructField("a", IntegerType, true)))),
"a") as 'c5,

Expand Down
3 changes: 2 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import scala.language.implicitConversions
import org.apache.spark.sql.Dsl.lit
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Subquery, Project, LogicalPlan}
import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField
import org.apache.spark.sql.types._


Expand Down Expand Up @@ -505,7 +506,7 @@ trait Column extends DataFrame {
/**
* An expression that gets a field by name in a [[StructField]].
*/
def getField(fieldName: String): Column = exprToColumn(GetField(expr, fieldName))
def getField(fieldName: String): Column = exprToColumn(UnresolvedGetField(expr, fieldName))

/**
* An expression that returns a substring.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1044,7 +1044,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
nodeToExpr(qualifier) match {
case UnresolvedAttribute(qualifierName) =>
UnresolvedAttribute(qualifierName + "." + cleanIdentifier(attr))
case other => GetField(other, attr)
case other => UnresolvedGetField(other, attr)
}

/* Stars (*) */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@

package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive.{sparkContext, sql}
import org.apache.spark.sql.hive.test.TestHive.{sparkContext, jsonRDD, sql}
import org.apache.spark.sql.hive.test.TestHive.implicits._

case class Nested(a: Int, B: Int)
Expand All @@ -29,16 +28,24 @@ case class Data(a: Int, B: Int, n: Nested, nestedArray: Seq[Nested])
*/
class HiveResolutionSuite extends HiveComparisonTest {

case class NestedData(a: Seq[NestedData2], B: NestedData2)
case class NestedData2(a: NestedData3, B: NestedData3)
case class NestedData3(a: Int, B: Int)

test("SPARK-3698: case insensitive test for nested data") {
sparkContext.makeRDD(Seq.empty[NestedData]).registerTempTable("nested")
jsonRDD(sparkContext.makeRDD(
"""{"a": [{"a": {"a": 1}}]}""" :: Nil)).registerTempTable("nested")
// This should be successfully analyzed
sql("SELECT a[0].A.A from nested").queryExecution.analyzed
}

test("SPARK-5278: check ambiguous reference to fields") {
jsonRDD(sparkContext.makeRDD(
"""{"a": [{"b": 1, "B": 2}]}""" :: Nil)).registerTempTable("nested")

// there are 2 filed matching field name "b", we should report Ambiguous reference error
val exception = intercept[RuntimeException] {
sql("SELECT a[0].b from nested").queryExecution.analyzed
}
assert(exception.getMessage.contains("Ambiguous reference to fields"))
}

createQueryTest("table.attr",
"SELECT src.key FROM src ORDER BY key LIMIT 1")

Expand Down Expand Up @@ -68,7 +75,7 @@ class HiveResolutionSuite extends HiveComparisonTest {

test("case insensitivity with scala reflection") {
// Test resolution with Scala Reflection
TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil)
sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil)
.registerTempTable("caseSensitivityTest")

val query = sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest")
Expand All @@ -79,14 +86,14 @@ class HiveResolutionSuite extends HiveComparisonTest {

ignore("case insensitivity with scala reflection joins") {
// Test resolution with Scala Reflection
TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil)
sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil)
.registerTempTable("caseSensitivityTest")

sql("SELECT * FROM casesensitivitytest a JOIN casesensitivitytest b ON a.a = b.a").collect()
}

test("nested repeated resolution") {
TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil)
sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil)
.registerTempTable("nestedRepeatedTest")
assert(sql("SELECT nestedArray[0].a FROM nestedRepeatedTest").collect().head(0) === 1)
}
Expand Down