Skip to content

Commit 08a228a

Browse files
committed
support dot notation on array of struct
1 parent 1390e56 commit 08a228a

File tree

5 files changed

+53
-22
lines changed

5 files changed

+53
-22
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException
2222
import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.plans.logical._
2424
import org.apache.spark.sql.catalyst.rules._
25-
import org.apache.spark.sql.types.StructType
26-
import org.apache.spark.sql.types.IntegerType
25+
import org.apache.spark.sql.types.{ArrayType, StructField, StructType, IntegerType}
2726

2827
/**
2928
* A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing
@@ -311,18 +310,25 @@ class Analyzer(catalog: Catalog,
311310
* desired fields are found.
312311
*/
313312
protected def resolveGetField(expr: Expression, fieldName: String): Expression = {
313+
def findField(fields: Array[StructField]): Int = {
314+
val checkField = (f: StructField) => resolver(f.name, fieldName)
315+
val ordinal = fields.indexWhere(checkField)
316+
if (ordinal == -1) {
317+
sys.error(
318+
s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}")
319+
} else if (fields.indexWhere(checkField, ordinal + 1) != -1) {
320+
sys.error(s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}")
321+
} else {
322+
ordinal
323+
}
324+
}
314325
expr.dataType match {
315326
case StructType(fields) =>
316-
val actualField = fields.filter(f => resolver(f.name, fieldName))
317-
if (actualField.length == 0) {
318-
sys.error(
319-
s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}")
320-
} else if (actualField.length == 1) {
321-
val field = actualField(0)
322-
GetField(expr, field, fields.indexOf(field))
323-
} else {
324-
sys.error(s"Ambiguous reference to fields ${actualField.mkString(", ")}")
325-
}
327+
val ordinal = findField(fields)
328+
StructGetField(expr, fields(ordinal), ordinal)
329+
case ArrayType(StructType(fields), containsNull) =>
330+
val ordinal = findField(fields)
331+
ArrayGetField(expr, fields(ordinal), ordinal, containsNull)
326332
case otherType => sys.error(s"GetField is not valid on fields of type $otherType")
327333
}
328334
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,22 +70,48 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression {
7070
}
7171
}
7272

73+
74+
trait GetField extends UnaryExpression {
75+
self: Product =>
76+
77+
type EvaluatedType = Any
78+
override def foldable = child.foldable
79+
override def toString = s"$child.${field.name}"
80+
81+
def field: StructField
82+
}
83+
7384
/**
7485
* Returns the value of fields in the Struct `child`.
7586
*/
76-
case class GetField(child: Expression, field: StructField, ordinal: Int) extends UnaryExpression {
77-
type EvaluatedType = Any
87+
case class StructGetField(child: Expression, field: StructField, ordinal: Int) extends GetField {
7888

7989
def dataType = field.dataType
8090
override def nullable = child.nullable || field.nullable
81-
override def foldable = child.foldable
8291

8392
override def eval(input: Row): Any = {
8493
val baseValue = child.eval(input).asInstanceOf[Row]
8594
if (baseValue == null) null else baseValue(ordinal)
8695
}
96+
}
8797

88-
override def toString = s"$child.${field.name}"
98+
/**
99+
* Returns the array of value of fields in the Array of Struct `child`.
100+
*/
101+
case class ArrayGetField(child: Expression, field: StructField, ordinal: Int, containsNull: Boolean)
102+
extends GetField {
103+
104+
def dataType = ArrayType(field.dataType, containsNull)
105+
override def nullable = child.nullable
106+
107+
override def eval(input: Row): Any = {
108+
val baseValue = child.eval(input).asInstanceOf[Seq[Row]]
109+
if (baseValue == null) null else {
110+
baseValue.map { row =>
111+
if (row == null) null else row(ordinal)
112+
}
113+
}
114+
}
89115
}
90116

91117
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,8 @@ object NullPropagation extends Rule[LogicalPlan] {
206206
case e @ IsNotNull(c) if !c.nullable => Literal(true, BooleanType)
207207
case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType)
208208
case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType)
209-
case e @ GetField(Literal(null, _), _, _) => Literal(null, e.dataType)
209+
case e @ StructGetField(Literal(null, _), _, _) => Literal(null, e.dataType)
210+
case e @ ArrayGetField(Literal(null, _), _, _, _) => Literal(null, e.dataType)
210211
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
211212
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
212213
case e @ Count(expr) if !expr.nullable => Count(Literal(1))

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -851,7 +851,7 @@ class ExpressionEvaluationSuite extends FunSuite {
851851
expr.dataType match {
852852
case StructType(fields) =>
853853
val field = fields.find(_.name == fieldName).get
854-
GetField(expr, field, fields.indexOf(field))
854+
StructGetField(expr, field, fields.indexOf(field))
855855
}
856856
}
857857

sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -342,21 +342,19 @@ class JsonSuite extends QueryTest {
342342
)
343343
}
344344

345-
ignore("Complex field and type inferring (Ignored)") {
345+
test("GetField operation on complex data type") {
346346
val jsonDF = jsonRDD(complexFieldAndType1)
347347
jsonDF.registerTempTable("jsonTable")
348348

349-
// Right now, "field1" and "field2" are treated as aliases. We should fix it.
350349
checkAnswer(
351350
sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"),
352351
Row(true, "str1")
353352
)
354353

355-
// Right now, the analyzer cannot resolve arrayOfStruct.field1 and arrayOfStruct.field2.
356354
// Getting all values of a specific field from an array of structs.
357355
checkAnswer(
358356
sql("select arrayOfStruct.field1, arrayOfStruct.field2 from jsonTable"),
359-
Row(Seq(true, false), Seq("str1", null))
357+
Row(Seq(true, false, null), Seq("str1", null, null))
360358
)
361359
}
362360

0 commit comments

Comments
 (0)