diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 2c73a80f64eb..96f5431f766c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -343,6 +343,9 @@ class SqlParser extends StandardTokenParsers with PackratParsers { elem("decimal", _.isInstanceOf[lexical.FloatLit]) ^^ (_.chars) protected lazy val baseExpression: PackratParser[Expression] = + expression ~ "[" ~ expression ~ "]" ~ expression ^^ { + case base ~ _ ~ ordinal ~ _ ~ field => GetField(GetItem(base, ordinal), field.toString) + } | expression ~ "[" ~ expression <~ "]" ^^ { case base ~ _ ~ ordinal => GetItem(base, ordinal) } | @@ -373,8 +376,13 @@ class SqlLexical(val keywords: Seq[String]) extends StdLexical { ) override lazy val token: Parser[Token] = ( - identChar ~ rep( identChar | digit ) ^^ - { case first ~ rest => processIdent(first :: rest mkString "") } + identChar ~ rep( identChar | digit ) ^^ + { + case first ~ rest => first match { + case '.' => StringLit(rest mkString "") + case _ => processIdent(first :: rest mkString "") + } + } | rep1(digit) ~ opt('.' ~> rep(digit)) ^^ { case i ~ None => NumericLit(i mkString "") case i ~ Some(d) => FloatLit(i.mkString("") + "." + d.mkString("")) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index c1154eb81c31..823e407bdde0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -101,3 +101,41 @@ case class GetField(child: Expression, fieldName: String) extends UnaryExpressio override def toString = s"$child.$fieldName" } + +/** + * Returns an array containing the value of fieldName + * for each element in the input array of type struct + */ +case class GetArrayField(child: Expression, fieldName: String) extends UnaryExpression { + type EvaluatedType = Any + + def dataType = field.dataType + override def nullable = child.nullable || field.nullable + override def foldable = child.foldable + + protected def arrayType = child.dataType match { + case ArrayType(s: StructType, _) => s + case otherType => sys.error(s"GetArrayField is not valid on fields of type $otherType") + } + + lazy val field = if (arrayType.isInstanceOf[StructType]) { + arrayType.fields + .find(_.name == fieldName) + .getOrElse(sys.error(s"No such field $fieldName in ${child.dataType}")) + } else null + + + lazy val ordinal = arrayType.fields.indexOf(field) + + override lazy val resolved = childrenResolved && child.dataType.isInstanceOf[ArrayType] + + override def eval(input: Row): Any = { + val value : Seq[Row] = child.eval(input).asInstanceOf[Seq[Row]] + val v = value.map{ t => + if (t == null) null else t(ordinal) + } + v + } + + override def toString = s"$child.$fieldName" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 278569f0cb14..d08c293e910d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.types.StructType +import org.apache.spark.sql.catalyst.types.{ArrayType, StructType} import org.apache.spark.sql.catalyst.trees abstract class LogicalPlan extends QueryPlan[LogicalPlan] { @@ -108,6 +108,10 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { a.dataType match { case StructType(fields) => Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)()) + case ArrayType(fields, _) => nestedFields.length match { + case 1 => Some(Alias(GetArrayField(a, nestedFields.head), nestedFields.last)()) + case _ => None // can't resolve arrayOfStruct.field1._ + } case _ => None // Don't know how to resolve these field references } case Seq() => None // No matches. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 58b1e23891a3..c6ca2c5e7cc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -292,24 +292,29 @@ class JsonSuite extends QueryTest { sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from jsonTable"), (5, null) :: Nil ) - } - ignore("Complex field and type inferring (Ignored)") { - val jsonSchemaRDD = jsonRDD(complexFieldAndType) - jsonSchemaRDD.registerTempTable("jsonTable") + checkAnswer( + sql("select arrayOfStruct.field1, arrayOfStruct.field2 from jsonTable"), + (Seq(true, false, null), Seq("str1", null, null)) :: Nil + ) - // Right now, "field1" and "field2" are treated as aliases. We should fix it. checkAnswer( sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"), (true, "str1") :: Nil ) - // Right now, the analyzer cannot resolve arrayOfStruct.field1 and arrayOfStruct.field2. - // Getting all values of a specific field from an array of structs. + } + + ignore("Complex field and type inferring (Ignored)") { + val jsonSchemaRDD = jsonRDD(complexFieldAndType) + jsonSchemaRDD.registerTempTable("jsonTable") + + // still need add filter??? I am not sure whether this function is necessary. quite complex checkAnswer( - sql("select arrayOfStruct.field1, arrayOfStruct.field2 from jsonTable"), - (Seq(true, false), Seq("str1", null)) :: Nil + sql("select arrayOfStruct.field1 from jsonTable where arrayOfStruct.field1 = true"), + (Seq(true)) :: Nil ) + } test("Type conflict in primitive field values") {