Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.{ArrayType, StructField, StructType, IntegerType}

/**
* A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing
Expand Down Expand Up @@ -311,18 +310,25 @@ class Analyzer(catalog: Catalog,
* desired fields are found.
*/
protected def resolveGetField(expr: Expression, fieldName: String): Expression = {
def findField(fields: Array[StructField]): Int = {
val checkField = (f: StructField) => resolver(f.name, fieldName)
val ordinal = fields.indexWhere(checkField)
if (ordinal == -1) {
sys.error(
s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}")
} else if (fields.indexWhere(checkField, ordinal + 1) != -1) {
sys.error(s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}")
} else {
ordinal
}
}
expr.dataType match {
case StructType(fields) =>
val actualField = fields.filter(f => resolver(f.name, fieldName))
if (actualField.length == 0) {
sys.error(
s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}")
} else if (actualField.length == 1) {
val field = actualField(0)
GetField(expr, field, fields.indexOf(field))
} else {
sys.error(s"Ambiguous reference to fields ${actualField.mkString(", ")}")
}
val ordinal = findField(fields)
StructGetField(expr, fields(ordinal), ordinal)
case ArrayType(StructType(fields), containsNull) =>
val ordinal = findField(fields)
ArrayGetField(expr, fields(ordinal), ordinal, containsNull)
case otherType => sys.error(s"GetField is not valid on fields of type $otherType")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,22 +70,48 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression {
}
}


trait GetField extends UnaryExpression {
self: Product =>

type EvaluatedType = Any
override def foldable = child.foldable
override def toString = s"$child.${field.name}"

def field: StructField
}

/**
* Returns the value of fields in the Struct `child`.
*/
case class GetField(child: Expression, field: StructField, ordinal: Int) extends UnaryExpression {
type EvaluatedType = Any
case class StructGetField(child: Expression, field: StructField, ordinal: Int) extends GetField {

def dataType = field.dataType
override def nullable = child.nullable || field.nullable
override def foldable = child.foldable

override def eval(input: Row): Any = {
val baseValue = child.eval(input).asInstanceOf[Row]
if (baseValue == null) null else baseValue(ordinal)
}
}

override def toString = s"$child.${field.name}"
/**
* Returns the array of value of fields in the Array of Struct `child`.
*/
case class ArrayGetField(child: Expression, field: StructField, ordinal: Int, containsNull: Boolean)
extends GetField {

def dataType = ArrayType(field.dataType, containsNull)
override def nullable = child.nullable

override def eval(input: Row): Any = {
val baseValue = child.eval(input).asInstanceOf[Seq[Row]]
if (baseValue == null) null else {
baseValue.map { row =>
if (row == null) null else row(ordinal)
}
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ object NullPropagation extends Rule[LogicalPlan] {
case e @ IsNotNull(c) if !c.nullable => Literal(true, BooleanType)
case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType)
case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType)
case e @ GetField(Literal(null, _), _, _) => Literal(null, e.dataType)
case e @ StructGetField(Literal(null, _), _, _) => Literal(null, e.dataType)
case e @ ArrayGetField(Literal(null, _), _, _, _) => Literal(null, e.dataType)
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
case e @ Count(expr) if !expr.nullable => Count(Literal(1))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,7 @@ class ExpressionEvaluationSuite extends FunSuite {
expr.dataType match {
case StructType(fields) =>
val field = fields.find(_.name == fieldName).get
GetField(expr, field, fields.indexOf(field))
StructGetField(expr, field, fields.indexOf(field))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,21 +342,19 @@ class JsonSuite extends QueryTest {
)
}

ignore("Complex field and type inferring (Ignored)") {
test("GetField operation on complex data type") {
val jsonDF = jsonRDD(complexFieldAndType1)
jsonDF.registerTempTable("jsonTable")

// Right now, "field1" and "field2" are treated as aliases. We should fix it.
checkAnswer(
sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"),
Row(true, "str1")
)

// Right now, the analyzer cannot resolve arrayOfStruct.field1 and arrayOfStruct.field2.
// Getting all values of a specific field from an array of structs.
checkAnswer(
sql("select arrayOfStruct.field1, arrayOfStruct.field2 from jsonTable"),
Row(Seq(true, false), Seq("str1", null))
Row(Seq(true, false, null), Seq("str1", null, null))
)
}

Expand Down