From dc316980f1a5547527e0c4f42cfdf5e36e7f7b61 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 1 Sep 2014 19:10:02 +0800 Subject: [PATCH 1/6] SPARK-2096 Correctly parse dot notations --- .../apache/spark/sql/catalyst/SqlParser.scala | 13 ++- .../catalyst/plans/logical/LogicalPlan.scala | 39 +++---- .../org/apache/spark/sql/json/JsonSuite.scala | 13 +++ .../apache/spark/sql/json/TestJsonData.scala | 26 +++++ .../spark/sql/parquet/ParquetQuerySuite.scala | 102 +++++------------- 5 files changed, 90 insertions(+), 103 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index a04b4a938da64..837a19ad7a174 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -357,6 +357,10 @@ class SqlParser extends StandardTokenParsers with PackratParsers { expression ~ "[" ~ expression <~ "]" ^^ { case base ~ _ ~ ordinal => GetItem(base, ordinal) } | + dotExpressionHeader | + expression ~ "." ~ ident ^^ { + case base ~ _ ~ fieldName => GetField(base, fieldName) + } | TRUE ^^^ Literal(true, BooleanType) | FALSE ^^^ Literal(false, BooleanType) | cast | @@ -367,6 +371,11 @@ class SqlParser extends StandardTokenParsers with PackratParsers { "*" ^^^ Star(None) | literal + protected lazy val dotExpressionHeader: Parser[Expression] = + ident ~ "." ~ ident ^^ { + case i1 ~ _ ~ i2 => UnresolvedAttribute(i1 + "." + i2) + } + protected lazy val dataType: Parser[DataType] = STRING ^^^ StringType | TIMESTAMP ^^^ TimestampType } @@ -380,7 +389,7 @@ class SqlLexical(val keywords: Seq[String]) extends StdLexical { delimiters += ( "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", - ",", ";", "%", "{", "}", ":", "[", "]" + ",", ";", "%", "{", "}", ":", "[", "]", "." ) override lazy val token: Parser[Token] = ( @@ -401,7 +410,7 @@ class SqlLexical(val keywords: Seq[String]) extends StdLexical { | failure("illegal character") ) - override def identChar = letter | elem('_') | elem('.') + override def identChar = letter | elem('_') override def whitespace: Parser[Any] = rep( whitespaceChar diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index f81d9111945f5..4f5e2abf2ec74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -88,31 +88,24 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { /** Performs attribute resolution given a name and a sequence of possible attributes. */ protected def resolve(name: String, input: Seq[Attribute]): Option[NamedExpression] = { - val parts = name.split("\\.") - // Collect all attributes that are output by this nodes children where either the first part - // matches the name or where the first part matches the scope and the second part matches the - // name. Return these matches along with any remaining parts, which represent dotted access to - // struct fields. - val options = input.flatMap { option => - // If the first part of the desired name matches a qualifier for this possible match, drop it. - val remainingParts = - if (option.qualifiers.contains(parts.head) && parts.size > 1) parts.drop(1) else parts - if (option.name == remainingParts.head) (option, remainingParts.tail.toList) :: Nil else Nil + def handleResult[A <: NamedExpression](result: Seq[A]) = { + result match { + case Seq(a) => Some(a) + case Seq() => None + case ambiguousReferences => + throw new TreeNodeException( + this, s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}") + } } - options.distinct match { - case Seq((a, Nil)) => Some(a) // One match, no nested fields, use it. - // One match, but we also need to extract the requested nested field. - case Seq((a, nestedFields)) => - a.dataType match { - case StructType(fields) => - Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)()) - case _ => None // Don't know how to resolve these field references - } - case Seq() => None // No matches. - case ambiguousReferences => - throw new TreeNodeException( - this, s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}") + name.split("\\.") match { + case Array(s) => handleResult(input.filter(_.name == s)) + case Array(s1, s2) => + handleResult(input.collect { + case a if a.qualifiers.contains(s1) && a.name == s2 => a + case a if a.name == s1 && a.dataType.isInstanceOf[StructType] => Alias(GetField(a, s2), s2)() + }) + case _ => None } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 05513a127150c..0c5c97c93ed7c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -581,4 +581,17 @@ class JsonSuite extends QueryTest { "this is a simple string.") :: Nil ) } + + test("SPARK-2096 Correctly parse dot notations") { + val jsonSchemaRDD = jsonRDD(complexFieldAndType2) + jsonSchemaRDD.registerTempTable("jsonTable") + checkAnswer( + sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"), + (true, "str1") :: Nil + ) + checkAnswer( + sql("select complexArrayOfStruct[0].field1[1].inner2[0], complexArrayOfStruct[1].field2[0][1] from jsonTable"), + ("str2", 6) :: Nil + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala index a88310b5f1b46..b3f95f08e8044 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala @@ -82,4 +82,30 @@ object TestJsonData { """{"c":[33, 44]}""" :: """{"d":{"field":true}}""" :: """{"e":"str"}""" :: Nil) + + val complexFieldAndType2 = + TestSQLContext.sparkContext.parallelize( + """{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}], + "complexArrayOfStruct": [ + { + "field1": [ + { + "inner1": "str1" + }, + { + "inner2": ["str2", "str22"] + }], + "field2": [[1, 2], [3, 4]] + }, + { + "field1": [ + { + "inner2": ["str3", "str33"] + }, + { + "inner1": "str4" + }], + "field2": [[5, 6], [7, 8]] + }] + }""" :: Nil) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 42923b6a288d9..b0a06cd3ca090 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -17,19 +17,14 @@ package org.apache.spark.sql.parquet +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.mapreduce.Job import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} - import parquet.hadoop.ParquetFileWriter import parquet.hadoop.util.ContextUtil -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.mapreduce.Job - -import org.apache.spark.SparkContext import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{SqlLexical, SqlParser} -import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.types.{BooleanType, IntegerType} +import org.apache.spark.sql.catalyst.types.IntegerType import org.apache.spark.sql.catalyst.util.getTempFilePath import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ @@ -87,11 +82,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA var testRDD: SchemaRDD = null - // TODO: remove this once SqlParser can parse nested select statements - var nestedParserSqlContext: NestedParserSQLContext = null - override def beforeAll() { - nestedParserSqlContext = new NestedParserSQLContext(TestSQLContext.sparkContext) ParquetTestData.writeFile() ParquetTestData.writeFilterFile() ParquetTestData.writeNestedFile1() @@ -718,11 +709,9 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } test("Projection in addressbook") { - val data = nestedParserSqlContext - .parquetFile(ParquetTestData.testNestedDir1.toString) - .toSchemaRDD + val data = parquetFile(ParquetTestData.testNestedDir1.toString).toSchemaRDD data.registerTempTable("data") - val query = nestedParserSqlContext.sql("SELECT owner, contacts[1].name FROM data") + val query = sql("SELECT owner, contacts[1].name FROM data") val tmp = query.collect() assert(tmp.size === 2) assert(tmp(0).size === 2) @@ -733,21 +722,19 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } test("Simple query on nested int data") { - val data = nestedParserSqlContext - .parquetFile(ParquetTestData.testNestedDir2.toString) - .toSchemaRDD + val data = parquetFile(ParquetTestData.testNestedDir2.toString).toSchemaRDD data.registerTempTable("data") - val result1 = nestedParserSqlContext.sql("SELECT entries[0].value FROM data").collect() + val result1 = sql("SELECT entries[0].value FROM data").collect() assert(result1.size === 1) assert(result1(0).size === 1) assert(result1(0)(0) === 2.5) - val result2 = nestedParserSqlContext.sql("SELECT entries[0] FROM data").collect() + val result2 = sql("SELECT entries[0] FROM data").collect() assert(result2.size === 1) val subresult1 = result2(0)(0).asInstanceOf[CatalystConverter.StructScalaType[_]] assert(subresult1.size === 2) assert(subresult1(0) === 2.5) assert(subresult1(1) === false) - val result3 = nestedParserSqlContext.sql("SELECT outerouter FROM data").collect() + val result3 = sql("SELECT outerouter FROM data").collect() val subresult2 = result3(0)(0) .asInstanceOf[CatalystConverter.ArrayScalaType[_]](0) .asInstanceOf[CatalystConverter.ArrayScalaType[_]] @@ -760,19 +747,18 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } test("nested structs") { - val data = nestedParserSqlContext - .parquetFile(ParquetTestData.testNestedDir3.toString) + val data = parquetFile(ParquetTestData.testNestedDir3.toString) .toSchemaRDD data.registerTempTable("data") - val result1 = nestedParserSqlContext.sql("SELECT booleanNumberPairs[0].value[0].truth FROM data").collect() + val result1 = sql("SELECT booleanNumberPairs[0].value[0].truth FROM data").collect() assert(result1.size === 1) assert(result1(0).size === 1) assert(result1(0)(0) === false) - val result2 = nestedParserSqlContext.sql("SELECT booleanNumberPairs[0].value[1].truth FROM data").collect() + val result2 = sql("SELECT booleanNumberPairs[0].value[1].truth FROM data").collect() assert(result2.size === 1) assert(result2(0).size === 1) assert(result2(0)(0) === true) - val result3 = nestedParserSqlContext.sql("SELECT booleanNumberPairs[1].value[0].truth FROM data").collect() + val result3 = sql("SELECT booleanNumberPairs[1].value[0].truth FROM data").collect() assert(result3.size === 1) assert(result3(0).size === 1) assert(result3(0)(0) === false) @@ -796,11 +782,9 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } test("map with struct values") { - val data = nestedParserSqlContext - .parquetFile(ParquetTestData.testNestedDir4.toString) - .toSchemaRDD + val data = parquetFile(ParquetTestData.testNestedDir4.toString).toSchemaRDD data.registerTempTable("mapTable") - val result1 = nestedParserSqlContext.sql("SELECT data2 FROM mapTable").collect() + val result1 = sql("SELECT data2 FROM mapTable").collect() assert(result1.size === 1) val entry1 = result1(0)(0) .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] @@ -814,7 +798,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(entry2 != null) assert(entry2(0) === 49) assert(entry2(1) === null) - val result2 = nestedParserSqlContext.sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM mapTable""").collect() + val result2 = sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM mapTable""").collect() assert(result2.size === 1) assert(result2(0)(0) === 42.toLong) assert(result2(0)(1) === "the answer") @@ -825,15 +809,12 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA // has no effect in this test case val tmpdir = Utils.createTempDir() Utils.deleteRecursively(tmpdir) - val result = nestedParserSqlContext - .parquetFile(ParquetTestData.testNestedDir1.toString) - .toSchemaRDD + val result = parquetFile(ParquetTestData.testNestedDir1.toString).toSchemaRDD result.saveAsParquetFile(tmpdir.toString) - nestedParserSqlContext - .parquetFile(tmpdir.toString) + parquetFile(tmpdir.toString) .toSchemaRDD .registerTempTable("tmpcopy") - val tmpdata = nestedParserSqlContext.sql("SELECT owner, contacts[1].name FROM tmpcopy").collect() + val tmpdata = sql("SELECT owner, contacts[1].name FROM tmpcopy").collect() assert(tmpdata.size === 2) assert(tmpdata(0).size === 2) assert(tmpdata(0)(0) === "Julien Le Dem") @@ -844,20 +825,17 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA } test("Writing out Map and reading it back in") { - val data = nestedParserSqlContext - .parquetFile(ParquetTestData.testNestedDir4.toString) - .toSchemaRDD + val data = parquetFile(ParquetTestData.testNestedDir4.toString).toSchemaRDD val tmpdir = Utils.createTempDir() Utils.deleteRecursively(tmpdir) data.saveAsParquetFile(tmpdir.toString) - nestedParserSqlContext - .parquetFile(tmpdir.toString) + parquetFile(tmpdir.toString) .toSchemaRDD .registerTempTable("tmpmapcopy") - val result1 = nestedParserSqlContext.sql("""SELECT data1["key2"] FROM tmpmapcopy""").collect() + val result1 = sql("""SELECT data1["key2"] FROM tmpmapcopy""").collect() assert(result1.size === 1) assert(result1(0)(0) === 2) - val result2 = nestedParserSqlContext.sql("SELECT data2 FROM tmpmapcopy").collect() + val result2 = sql("SELECT data2 FROM tmpmapcopy").collect() assert(result2.size === 1) val entry1 = result2(0)(0) .asInstanceOf[CatalystConverter.MapScalaType[String, CatalystConverter.StructScalaType[_]]] @@ -871,42 +849,10 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(entry2 != null) assert(entry2(0) === 49) assert(entry2(1) === null) - val result3 = nestedParserSqlContext.sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM tmpmapcopy""").collect() + val result3 = sql("""SELECT data2["seven"].payload1, data2["seven"].payload2 FROM tmpmapcopy""").collect() assert(result3.size === 1) assert(result3(0)(0) === 42.toLong) assert(result3(0)(1) === "the answer") Utils.deleteRecursively(tmpdir) } } - -// TODO: the code below is needed temporarily until the standard parser is able to parse -// nested field expressions correctly -class NestedParserSQLContext(@transient override val sparkContext: SparkContext) extends SQLContext(sparkContext) { - override protected[sql] val parser = new NestedSqlParser() -} - -class NestedSqlLexical(override val keywords: Seq[String]) extends SqlLexical(keywords) { - override def identChar = letter | elem('_') - delimiters += (".") -} - -class NestedSqlParser extends SqlParser { - override val lexical = new NestedSqlLexical(reservedWords) - - override protected lazy val baseExpression: PackratParser[Expression] = - expression ~ "[" ~ expression <~ "]" ^^ { - case base ~ _ ~ ordinal => GetItem(base, ordinal) - } | - expression ~ "." ~ ident ^^ { - case base ~ _ ~ fieldName => GetField(base, fieldName) - } | - TRUE ^^^ Literal(true, BooleanType) | - FALSE ^^^ Literal(false, BooleanType) | - cast | - "(" ~> expression <~ ")" | - function | - "-" ~> literal ^^ UnaryMinus | - ident ^^ UnresolvedAttribute | - "*" ^^^ Star(None) | - literal -} From 95d733f42cb549cfc97dd01908410b30029161c3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 2 Sep 2014 16:05:01 +0800 Subject: [PATCH 2/6] split long line --- .../apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 4f5e2abf2ec74..6bbe27090c97b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -103,7 +103,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { case Array(s1, s2) => handleResult(input.collect { case a if a.qualifiers.contains(s1) && a.name == s2 => a - case a if a.name == s1 && a.dataType.isInstanceOf[StructType] => Alias(GetField(a, s2), s2)() + case a if a.name == s1 && a.dataType.isInstanceOf[StructType] => + Alias(GetField(a, s2), s2)() }) case _ => None } From 16bc4c68bfe006dce134b6f0110bb2c20e036312 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 3 Sep 2014 16:36:53 +0800 Subject: [PATCH 3/6] some enhance --- .../apache/spark/sql/catalyst/SqlParser.scala | 2 +- .../catalyst/plans/logical/LogicalPlan.scala | 17 +++++++++++++---- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 837a19ad7a174..ef5ae85755fa3 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -357,7 +357,6 @@ class SqlParser extends StandardTokenParsers with PackratParsers { expression ~ "[" ~ expression <~ "]" ^^ { case base ~ _ ~ ordinal => GetItem(base, ordinal) } | - dotExpressionHeader | expression ~ "." ~ ident ^^ { case base ~ _ ~ fieldName => GetField(base, fieldName) } | @@ -367,6 +366,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers { "(" ~> expression <~ ")" | function | "-" ~> literal ^^ UnaryMinus | + dotExpressionHeader | ident ^^ UnresolvedAttribute | "*" ^^^ Star(None) | literal diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 6bbe27090c97b..db1392471c5c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -86,10 +86,13 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { def resolve(name: String): Option[NamedExpression] = resolve(name, output) - /** Performs attribute resolution given a name and a sequence of possible attributes. */ + /** + * Performs attribute resolution given a name and a sequence of possible attributes. + * The only possible formats of name are "ident" and "ident.ident". + */ protected def resolve(name: String, input: Seq[Attribute]): Option[NamedExpression] = { def handleResult[A <: NamedExpression](result: Seq[A]) = { - result match { + result.distinct match { case Seq(a) => Some(a) case Seq() => None case ambiguousReferences => @@ -99,13 +102,19 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { } name.split("\\.") match { + // the format of name is "ident", it should match the name of a possible attribute case Array(s) => handleResult(input.filter(_.name == s)) + + // The format of name is "ident.ident", the first part should matches the scope + // and the second part should matches the name. Or the first part matches the + // name and this attribute is struct type. case Array(s1, s2) => handleResult(input.collect { - case a if a.qualifiers.contains(s1) && a.name == s2 => a - case a if a.name == s1 && a.dataType.isInstanceOf[StructType] => + case a if (a.qualifiers.contains(s1) && a.name == s2) => a + case a if (a.name == s1 && a.dataType.isInstanceOf[StructType]) => Alias(GetField(a, s2), s2)() }) + case _ => None } } From a58df403fce8ef9cff956cc02f2504510a5f9341 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 4 Sep 2014 19:02:53 -0700 Subject: [PATCH 4/6] add regression test for doubly nested data --- .../spark/sql/hive/execution/SQLQuerySuite.scala | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 635a9fb0d56cb..82d3ebcfa387d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -24,6 +24,10 @@ import org.apache.spark.sql.execution.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ +case class Nested1(f1: Nested2) +case class Nested2(f2: Nested3) +case class Nested3(f3: Int) + /** * A collection of hive query tests where we generate the answers ourselves instead of depending on * Hive to generate them (in contrast to HiveQuerySuite). Often this is because the query is @@ -47,4 +51,11 @@ class SQLQuerySuite extends QueryTest { GROUP BY key, value ORDER BY value) a""").collect().toSeq) } + + test("double nested data") { + sparkContext.parallelize(Nested1(Nested2(Nested3(1))) :: Nil).registerTempTable("nested") + checkAnswer( + sql("SELECT f1.f2.f3 FROM nested"), + 1) + } } From ee8a72424974e4ec23224b42d5cc1744ab9ddac2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 5 Sep 2014 16:43:46 +0800 Subject: [PATCH 5/6] rollback LogicalPlan, support dot operation on nested array type --- .../apache/spark/sql/catalyst/SqlParser.scala | 8 +-- .../catalyst/expressions/complexTypes.scala | 54 ++++++++++++++++--- .../catalyst/plans/logical/LogicalPlan.scala | 47 +++++++--------- .../org/apache/spark/sql/json/JsonSuite.scala | 29 ++++++++++ .../apache/spark/sql/json/TestJsonData.scala | 27 +++++++++- .../sql/hive/execution/SQLQuerySuite.scala | 6 +-- 6 files changed, 128 insertions(+), 43 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index ef5ae85755fa3..ca69531c69a77 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -357,8 +357,8 @@ class SqlParser extends StandardTokenParsers with PackratParsers { expression ~ "[" ~ expression <~ "]" ^^ { case base ~ _ ~ ordinal => GetItem(base, ordinal) } | - expression ~ "." ~ ident ^^ { - case base ~ _ ~ fieldName => GetField(base, fieldName) + (expression <~ ".") ~ ident ^^ { + case base ~ fieldName => GetField(base, fieldName) } | TRUE ^^^ Literal(true, BooleanType) | FALSE ^^^ Literal(false, BooleanType) | @@ -372,8 +372,8 @@ class SqlParser extends StandardTokenParsers with PackratParsers { literal protected lazy val dotExpressionHeader: Parser[Expression] = - ident ~ "." ~ ident ^^ { - case i1 ~ _ ~ i2 => UnresolvedAttribute(i1 + "." + i2) + (ident <~ ".") ~ ident ~ rep("." ~> ident) ^^ { + case i1 ~ i2 ~ rest => UnresolvedAttribute(i1 + "." + i2 + rest.mkString(".", ".", "")) } protected lazy val dataType: Parser[DataType] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index dafd745ec96c6..c6eb8d8c22e6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -71,20 +71,52 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression { } /** - * Returns the value of fields in the Struct `child`. + * Returns the value of fields in the `child`. + * The type of `child` can be struct, or array of struct, + * or array of array of struct, or array of array ... of struct. */ case class GetField(child: Expression, fieldName: String) extends UnaryExpression { type EvaluatedType = Any - def dataType = field.dataType + lazy val dataType = { + structType + buildDataType(field.dataType) + } + override def nullable = child.nullable || field.nullable override def foldable = child.foldable - protected def structType = child.dataType match { + private var _buildDataType = identity[DataType] _ + private lazy val buildDataType = { + structType + _buildDataType + } + + private var _nestedArrayCount = 0 + private lazy val nestedArrayCount = { + structType + _nestedArrayCount + } + + private def getStructType(t: DataType): StructType = t match { + case ArrayType(elementType, containsNull) => + _buildDataType = {(t: DataType) => ArrayType(t, containsNull)} andThen _buildDataType + _nestedArrayCount += 1 + getStructType(elementType) case s: StructType => s case otherType => sys.error(s"GetField is not valid on fields of type $otherType") } + protected lazy val structType: StructType = { + child match { + case n: GetField => + this._buildDataType = n._buildDataType + this._nestedArrayCount = n._nestedArrayCount + getStructType(n.field.dataType) + case _ => getStructType(child.dataType) + } + } + lazy val field = structType.fields .find(_.name == fieldName) @@ -92,11 +124,21 @@ case class GetField(child: Expression, fieldName: String) extends UnaryExpressio lazy val ordinal = structType.fields.indexOf(field) - override lazy val resolved = childrenResolved && child.dataType.isInstanceOf[StructType] + override lazy val resolved = childrenResolved override def eval(input: Row): Any = { - val baseValue = child.eval(input).asInstanceOf[Row] - if (baseValue == null) null else baseValue(ordinal) + val baseValue = child.eval(input) + evaluateValue(baseValue, nestedArrayCount) + } + + private def evaluateValue(v: Any, count: Int): Any = { + if (v == null) { + null + } else if (count > 0) { + v.asInstanceOf[Seq[_]].map(r => evaluateValue(r, count - 1)) + } else { + v.asInstanceOf[Row](ordinal) + } } override def toString = s"$child.$fieldName" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index db1392471c5c8..bae491f07c13f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -86,36 +86,29 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] { def resolve(name: String): Option[NamedExpression] = resolve(name, output) - /** - * Performs attribute resolution given a name and a sequence of possible attributes. - * The only possible formats of name are "ident" and "ident.ident". - */ + /** Performs attribute resolution given a name and a sequence of possible attributes. */ protected def resolve(name: String, input: Seq[Attribute]): Option[NamedExpression] = { - def handleResult[A <: NamedExpression](result: Seq[A]) = { - result.distinct match { - case Seq(a) => Some(a) - case Seq() => None - case ambiguousReferences => - throw new TreeNodeException( - this, s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}") - } + val parts = name.split("\\.") + // Collect all attributes that are output by this nodes children where either the first part + // matches the name or where the first part matches the scope and the second part matches the + // name. Return these matches along with any remaining parts, which represent dotted access to + // struct fields. + val options = input.flatMap { option => + // If the first part of the desired name matches a qualifier for this possible match, drop it. + val remainingParts = + if (option.qualifiers.contains(parts.head) && parts.size > 1) parts.drop(1) else parts + if (option.name == remainingParts.head) (option, remainingParts.tail.toList) :: Nil else Nil } - name.split("\\.") match { - // the format of name is "ident", it should match the name of a possible attribute - case Array(s) => handleResult(input.filter(_.name == s)) - - // The format of name is "ident.ident", the first part should matches the scope - // and the second part should matches the name. Or the first part matches the - // name and this attribute is struct type. - case Array(s1, s2) => - handleResult(input.collect { - case a if (a.qualifiers.contains(s1) && a.name == s2) => a - case a if (a.name == s1 && a.dataType.isInstanceOf[StructType]) => - Alias(GetField(a, s2), s2)() - }) - - case _ => None + options.distinct match { + case Seq((a, Nil)) => Some(a) // One match, no nested fields, use it. + // One match, but we also need to extract the requested nested field. + case Seq((a, nestedFields)) => + Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)()) + case Seq() => None // No matches. + case ambiguousReferences => + throw new TreeNodeException( + this, s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 0c5c97c93ed7c..9cf3a753269d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -585,6 +585,7 @@ class JsonSuite extends QueryTest { test("SPARK-2096 Correctly parse dot notations") { val jsonSchemaRDD = jsonRDD(complexFieldAndType2) jsonSchemaRDD.registerTempTable("jsonTable") + checkAnswer( sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"), (true, "str1") :: Nil @@ -593,5 +594,33 @@ class JsonSuite extends QueryTest { sql("select complexArrayOfStruct[0].field1[1].inner2[0], complexArrayOfStruct[1].field2[0][1] from jsonTable"), ("str2", 6) :: Nil ) + + checkAnswer( + sql("select arrayOfStruct.field1, arrayOfStruct.field2 from jsonTable"), + (Seq(true, false, null), Seq("str1", null, null)) :: Nil + ) + + checkAnswer( + sql("select complexNestedArray.field, complexNestedArray.field.innerField from jsonTable"), + ( + Seq( + Seq( + Seq("str1", null), + Seq("str2", null) + ), + Seq( + Seq("str3", null), + Seq(null, "str4") + ), + null + ), + + Seq( + Seq("str1", "str2"), + Seq("str3", null), + null + ) + ) :: Nil + ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala index b3f95f08e8044..75fe848592808 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala @@ -106,6 +106,31 @@ object TestJsonData { "inner1": "str4" }], "field2": [[5, 6], [7, 8]] - }] + }], + "complexNestedArray": [ + { + "field": [ + { + "innerField": "str1" + }, + { + "innerField": "str2" + } + ] + }, + { + "field": [ + { + "innerField": "str3" + }, + { + "otherInner": "str4" + } + ] + }, + { + "otherField": "str5" + } + ] }""" :: Nil) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 82d3ebcfa387d..b99caf77bce28 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -17,11 +17,7 @@ package org.apache.spark.sql.hive.execution -import scala.reflect.ClassTag - -import org.apache.spark.sql.{SQLConf, QueryTest} -import org.apache.spark.sql.execution.{BroadcastHashJoin, ShuffledHashJoin} -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.QueryTest import org.apache.spark.sql.hive.test.TestHive._ case class Nested1(f1: Nested2) From e1a88986ddb3cf8147cdb04c35addc974f2acba2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 9 Sep 2014 14:22:49 +0800 Subject: [PATCH 6/6] remove support for arbitrary nested arrays --- .../catalyst/expressions/complexTypes.scala | 54 +++---------------- .../org/apache/spark/sql/json/JsonSuite.scala | 28 ---------- .../apache/spark/sql/json/TestJsonData.scala | 27 +--------- 3 files changed, 7 insertions(+), 102 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index c6eb8d8c22e6c..dafd745ec96c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -71,52 +71,20 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression { } /** - * Returns the value of fields in the `child`. - * The type of `child` can be struct, or array of struct, - * or array of array of struct, or array of array ... of struct. + * Returns the value of fields in the Struct `child`. */ case class GetField(child: Expression, fieldName: String) extends UnaryExpression { type EvaluatedType = Any - lazy val dataType = { - structType - buildDataType(field.dataType) - } - + def dataType = field.dataType override def nullable = child.nullable || field.nullable override def foldable = child.foldable - private var _buildDataType = identity[DataType] _ - private lazy val buildDataType = { - structType - _buildDataType - } - - private var _nestedArrayCount = 0 - private lazy val nestedArrayCount = { - structType - _nestedArrayCount - } - - private def getStructType(t: DataType): StructType = t match { - case ArrayType(elementType, containsNull) => - _buildDataType = {(t: DataType) => ArrayType(t, containsNull)} andThen _buildDataType - _nestedArrayCount += 1 - getStructType(elementType) + protected def structType = child.dataType match { case s: StructType => s case otherType => sys.error(s"GetField is not valid on fields of type $otherType") } - protected lazy val structType: StructType = { - child match { - case n: GetField => - this._buildDataType = n._buildDataType - this._nestedArrayCount = n._nestedArrayCount - getStructType(n.field.dataType) - case _ => getStructType(child.dataType) - } - } - lazy val field = structType.fields .find(_.name == fieldName) @@ -124,21 +92,11 @@ case class GetField(child: Expression, fieldName: String) extends UnaryExpressio lazy val ordinal = structType.fields.indexOf(field) - override lazy val resolved = childrenResolved + override lazy val resolved = childrenResolved && child.dataType.isInstanceOf[StructType] override def eval(input: Row): Any = { - val baseValue = child.eval(input) - evaluateValue(baseValue, nestedArrayCount) - } - - private def evaluateValue(v: Any, count: Int): Any = { - if (v == null) { - null - } else if (count > 0) { - v.asInstanceOf[Seq[_]].map(r => evaluateValue(r, count - 1)) - } else { - v.asInstanceOf[Row](ordinal) - } + val baseValue = child.eval(input).asInstanceOf[Row] + if (baseValue == null) null else baseValue(ordinal) } override def toString = s"$child.$fieldName" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 9cf3a753269d2..301d482d27d86 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -594,33 +594,5 @@ class JsonSuite extends QueryTest { sql("select complexArrayOfStruct[0].field1[1].inner2[0], complexArrayOfStruct[1].field2[0][1] from jsonTable"), ("str2", 6) :: Nil ) - - checkAnswer( - sql("select arrayOfStruct.field1, arrayOfStruct.field2 from jsonTable"), - (Seq(true, false, null), Seq("str1", null, null)) :: Nil - ) - - checkAnswer( - sql("select complexNestedArray.field, complexNestedArray.field.innerField from jsonTable"), - ( - Seq( - Seq( - Seq("str1", null), - Seq("str2", null) - ), - Seq( - Seq("str3", null), - Seq(null, "str4") - ), - null - ), - - Seq( - Seq("str1", "str2"), - Seq("str3", null), - null - ) - ) :: Nil - ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala index 75fe848592808..b3f95f08e8044 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala @@ -106,31 +106,6 @@ object TestJsonData { "inner1": "str4" }], "field2": [[5, 6], [7, 8]] - }], - "complexNestedArray": [ - { - "field": [ - { - "innerField": "str1" - }, - { - "innerField": "str2" - } - ] - }, - { - "field": [ - { - "innerField": "str3" - }, - { - "otherInner": "str4" - } - ] - }, - { - "otherField": "str5" - } - ] + }] }""" :: Nil) }