From 3b2b448c640eae5b50deb69346409581e8448af3 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 15 Feb 2016 09:58:46 -0800 Subject: [PATCH 1/9] structStarExpansion --- .../spark/sql/catalyst/analysis/Analyzer.scala | 16 +++++++++++++++- .../org/apache/spark/sql/DatasetSuite.scala | 8 ++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 2 ++ 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 26c3d286b19f..7154b49362a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -362,12 +362,26 @@ class Analyzer( exprs.flatMap { case s: Star => s.expand(child, resolver) case e => - e.transformDown { + e.transformUp { + // ResolveFunctions can handle the case when the number of variables is not valid case f1: UnresolvedFunction if containsStar(f1.children) => f1.copy(children = f1.children.flatMap { case s: Star => s.expand(child, resolver) case o => o :: Nil }) + case c: CreateStruct if containsStar(c.children) => + c.copy(children = c.children.flatMap { + case s: Star => s.expand(child, resolver) + case o => o :: Nil + }) + case c: CreateStructUnsafe if containsStar(c.children) => + c.copy(children = c.children.flatMap { + case s: Star => s.expand(child, resolver) + case o => o :: Nil + }) + // count(*) has been replaced by count(1) + case f2: ExpectsInputTypes if containsStar(f2.children) => + failAnalysis(s"Invalid usage of '*' in function '${f2.prettyName}'") } :: Nil } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index f9ba60770022..737af539da21 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -528,6 +528,14 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(e.getMessage.contains("cannot resolve 'c' given input columns: [a, b]"), e.getMessage) } + test("verify star in functions fail with a good error") { + val ds = Seq(("a", 1, "c"), ("b", 2, "d")).map(a => (a._1, a._3)) + val e = intercept[AnalysisException] { + ds.toDF().groupBy($"_1").agg(sum($"*") as "sumOccurances") + } + assert(e.getMessage.contains("Invalid usage of '*' in function 'sum'"), e.getMessage) + } + test("runtime nullability check") { val schema = StructType(Seq( StructField("f", StructType(Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index f665a1c87bd7..18e88a70f15c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1834,6 +1834,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { """.stripMargin).select($"r.*"), Row(3, 2) :: Nil) + assert(structDf.groupBy($"a").agg(min(struct($"record.*"))).first() == Row(3, Row(3, 1))) + // With GROUP BY checkAnswer(sql( """ From ac71f3913148f97c25eac6957aacf64015532583 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 18 Feb 2016 18:45:02 -0800 Subject: [PATCH 2/9] add a new rule for resolving star. --- .../sql/catalyst/analysis/Analyzer.scala | 152 +++++++++--------- .../spark/sql/DataFrameComplexTypeSuite.scala | 2 + .../org/apache/spark/sql/DatasetSuite.scala | 2 +- 3 files changed, 75 insertions(+), 81 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7154b49362a2..9d64e849602b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -72,6 +72,7 @@ class Analyzer( EliminateUnions), Batch("Resolution", fixedPoint, ResolveRelations :: + ResolveStar :: ResolveReferences :: ResolveGroupingAnalytics :: ResolvePivot :: @@ -350,42 +351,83 @@ class Analyzer( } /** - * Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from - * a logical plan node's children. + * Expand [[UnresolvedStar]] or [[ResolvedStar]] to the matching attributes in child's output. */ - object ResolveReferences extends Rule[LogicalPlan] { + object ResolveStar extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p: LogicalPlan if !p.childrenResolved => p + + // If the projection list contains Stars, expand it. + case p: Project if containsStar(p.projectList) => + val expanded = p.projectList.flatMap { + case s: Star => s.expand(p.child, resolver) + case ua @ UnresolvedAlias(_: UnresolvedFunction | _: CreateArray | _: CreateStruct, _) => + UnresolvedAlias(child = expandStarExpression(ua.child, p.child)) :: Nil + case a @ Alias(_: UnresolvedFunction | _: CreateArray | _: CreateStruct, _) => + Alias(child = expandStarExpression(a.child, p.child), a.name)( + isGenerated = a.isGenerated) :: Nil + case o => o :: Nil + } + Project(projectList = expanded, p.child) + // If the aggregate function argument contains Stars, expand it. + case a: Aggregate if containsStar(a.aggregateExpressions) => + val expanded = a.aggregateExpressions.flatMap { + case s: Star => s.expand(a.child, resolver) + case o if containsStar(o :: Nil) => expandStarExpression(o, a.child) :: Nil + case o => o :: Nil + }.map(_.asInstanceOf[NamedExpression]) + a.copy(aggregateExpressions = expanded) + // If the script transformation input contains Stars, expand it. + case t: ScriptTransformation if containsStar(t.input) => + t.copy( + input = t.input.flatMap { + case s: Star => s.expand(t.child, resolver) + case o => o :: Nil + } + ) + case g: Generate if containsStar(g.generator.children) => + failAnalysis("Cannot explode *, explode can only be applied on a specific column.") + } + + /** + * Returns true if `exprs` contains a [[Star]]. + */ + def containsStar(exprs: Seq[Expression]): Boolean = + exprs.exists(_.collect { case _: Star => true }.nonEmpty) + /** - * Foreach expression, expands the matching attribute.*'s in `child`'s input for the subtree - * rooted at each expression. + * Expands the matching attribute.*'s in `child`'s output. */ - def expandStarExpressions(exprs: Seq[Expression], child: LogicalPlan): Seq[Expression] = { - exprs.flatMap { - case s: Star => s.expand(child, resolver) - case e => - e.transformUp { - // ResolveFunctions can handle the case when the number of variables is not valid - case f1: UnresolvedFunction if containsStar(f1.children) => - f1.copy(children = f1.children.flatMap { - case s: Star => s.expand(child, resolver) - case o => o :: Nil - }) - case c: CreateStruct if containsStar(c.children) => - c.copy(children = c.children.flatMap { - case s: Star => s.expand(child, resolver) - case o => o :: Nil - }) - case c: CreateStructUnsafe if containsStar(c.children) => - c.copy(children = c.children.flatMap { - case s: Star => s.expand(child, resolver) - case o => o :: Nil - }) - // count(*) has been replaced by count(1) - case f2: ExpectsInputTypes if containsStar(f2.children) => - failAnalysis(s"Invalid usage of '*' in function '${f2.prettyName}'") - } :: Nil + def expandStarExpression(expr: Expression, child: LogicalPlan): Expression = { + expr.transformUp { + case f1: UnresolvedFunction if containsStar(f1.children) => + f1.copy(children = f1.children.flatMap { + case s: Star => s.expand(child, resolver) + case o => o :: Nil + }) + case c: CreateStruct if containsStar(c.children) => + c.copy(children = c.children.flatMap { + case s: Star => s.expand(child, resolver) + case o => o :: Nil + }) + case c: CreateArray if containsStar(c.children) => + c.copy(children = c.children.flatMap { + case s: Star => s.expand(child, resolver) + case o => o :: Nil + }) + // count(*) has been replaced by count(1) + case o if containsStar(o.children) => + failAnalysis(s"Invalid usage of '*' in expression '${o.prettyName}'") } } + } + /** + * Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from + * a logical plan node's children. + */ + object ResolveReferences extends Rule[LogicalPlan] { /** * Generate a new logical plan for the right child with different expression IDs * for all conflicting attributes. @@ -446,48 +488,6 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p: LogicalPlan if !p.childrenResolved => p - // If the projection list contains Stars, expand it. - case p @ Project(projectList, child) if containsStar(projectList) => - Project( - projectList.flatMap { - case s: Star => s.expand(child, resolver) - case UnresolvedAlias(f @ UnresolvedFunction(_, args, _), _) if containsStar(args) => - val newChildren = expandStarExpressions(args, child) - UnresolvedAlias(child = f.copy(children = newChildren)) :: Nil - case a @ Alias(f @ UnresolvedFunction(_, args, _), name) if containsStar(args) => - val newChildren = expandStarExpressions(args, child) - Alias(child = f.copy(children = newChildren), name)( - isGenerated = a.isGenerated) :: Nil - case UnresolvedAlias(c @ CreateArray(args), _) if containsStar(args) => - val expandedArgs = args.flatMap { - case s: Star => s.expand(child, resolver) - case o => o :: Nil - } - UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil - case UnresolvedAlias(c @ CreateStruct(args), _) if containsStar(args) => - val expandedArgs = args.flatMap { - case s: Star => s.expand(child, resolver) - case o => o :: Nil - } - UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil - case o => o :: Nil - }, - child) - - case t: ScriptTransformation if containsStar(t.input) => - t.copy( - input = t.input.flatMap { - case s: Star => s.expand(t.child, resolver) - case o => o :: Nil - } - ) - - // If the aggregate function argument contains Stars, expand it. - case a: Aggregate if containsStar(a.aggregateExpressions) => - val expanded = expandStarExpressions(a.aggregateExpressions, a.child) - .map(_.asInstanceOf[NamedExpression]) - a.copy(aggregateExpressions = expanded) - // To resolve duplicate expression IDs for Join and Intersect case j @ Join(left, right, _, _) if !j.duplicateResolved => j.copy(right = dedupRight(left, right)) @@ -575,12 +575,6 @@ class Analyzer( def findAliases(projectList: Seq[NamedExpression]): AttributeSet = { AttributeSet(projectList.collect { case a: Alias => a.toAttribute }) } - - /** - * Returns true if `exprs` contains a [[Star]]. - */ - def containsStar(exprs: Seq[Expression]): Boolean = - exprs.exists(_.collect { case _: Star => true }.nonEmpty) } private def resolveExpression(expr: Expression, plan: LogicalPlan, throws: Boolean = false) = { @@ -847,8 +841,6 @@ class Analyzer( */ object ResolveGenerate extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case g: Generate if ResolveReferences.containsStar(g.generator.children) => - failAnalysis("Cannot explode *, explode can only be applied on a specific column.") case p: Generate if !p.child.resolved || !p.generator.resolved => p case g: Generate if !g.resolved => g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index b76fc73b7fa0..5debc6260460 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -30,6 +30,7 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext { val f = udf((a: String) => a) val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") df.select(struct($"a").as("s")).select(f($"s.a")).collect() + df.select(struct($"*").as("s")).select(f($"s.a")).collect() } test("UDF on named_struct") { @@ -42,6 +43,7 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext { val f = udf((a: String) => a) val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") df.select(array($"a").as("s")).select(f(expr("s[0]"))).collect() + df.select(array($"*").as("s")).select(f(expr("s[0]"))).collect() } test("SPARK-12477 accessing null element in array field") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 737af539da21..0efbab60e47a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -533,7 +533,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val e = intercept[AnalysisException] { ds.toDF().groupBy($"_1").agg(sum($"*") as "sumOccurances") } - assert(e.getMessage.contains("Invalid usage of '*' in function 'sum'"), e.getMessage) + assert(e.getMessage.contains("Invalid usage of '*' in expression 'sum'"), e.getMessage) } test("runtime nullability check") { From 8d809bc7d2541e71b8e1a1f145c2c42ee5528f92 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 18 Feb 2016 23:30:42 -0800 Subject: [PATCH 3/9] address comments. --- .../catalyst/analysis/AnalysisErrorSuite.scala | 5 +++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 17 +++++++++++++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index e0cec09742eb..f8ea612b87da 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -183,6 +183,11 @@ class AnalysisErrorSuite extends AnalysisTest { .orderBy('havingCondition.asc), "cannot resolve" :: "havingCondition" :: Nil) + errorTest( + "unresolved star expansion in max", + testRelation2.groupBy('a)(sum(UnresolvedStar(None))), + "Invalid usage of '*'" :: "in expression 'sum'" :: Nil) + errorTest( "bad casts", testRelation.select(Literal(1).cast(BinaryType).as('badCast)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 18e88a70f15c..260acd950872 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1834,8 +1834,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { """.stripMargin).select($"r.*"), Row(3, 2) :: Nil) - assert(structDf.groupBy($"a").agg(min(struct($"record.*"))).first() == Row(3, Row(3, 1))) - // With GROUP BY checkAnswer(sql( """ @@ -1940,6 +1938,21 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + test("Star Expansion - CreateStruct and CreateArray") { + val structDf = testData2.select("a", "b").as("record") + // CreateStruct and CreateArray in aggregateExpressions + assert(structDf.groupBy($"a").agg(min(struct($"record.*"))).first() == Row(3, Row(3, 1))) + assert(structDf.groupBy($"a").agg(min(array($"record.*"))).first() == Row(3, Seq(3, 1))) + + // CreateStruct and CreateArray in project list (unresolved alias) + assert(structDf.select(struct($"record.*")).first() == Row(Row(1, 1))) + assert(structDf.select(array($"record.*")).first().getAs[Seq[Int]](0) === Array(1, 1)) + + // CreateStruct and CreateArray in project list (alias) + assert(structDf.select(struct($"record.*").as("a")).first() == Row(Row(1, 1))) + assert(structDf.select(array($"record.*").as("a")).first().getAs[Seq[Int]](0) === Array(1, 1)) + } + test("Common subexpression elimination") { // TODO: support subexpression elimination in whole stage codegen withSQLConf("spark.sql.codegen.wholeStage" -> "false") { From 2c72edf662b037b0dba845f81e95dadfc35bf648 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 18 Feb 2016 23:31:12 -0800 Subject: [PATCH 4/9] address comments. --- .../test/scala/org/apache/spark/sql/DatasetSuite.scala | 8 -------- 1 file changed, 8 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 0efbab60e47a..f9ba60770022 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -528,14 +528,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(e.getMessage.contains("cannot resolve 'c' given input columns: [a, b]"), e.getMessage) } - test("verify star in functions fail with a good error") { - val ds = Seq(("a", 1, "c"), ("b", 2, "d")).map(a => (a._1, a._3)) - val e = intercept[AnalysisException] { - ds.toDF().groupBy($"_1").agg(sum($"*") as "sumOccurances") - } - assert(e.getMessage.contains("Invalid usage of '*' in expression 'sum'"), e.getMessage) - } - test("runtime nullability check") { val schema = StructType(Seq( StructField("f", StructType(Seq( From 6b2d60996831fd216b4821e62ed9bea5a3892ab5 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 22 Feb 2016 20:13:50 -0800 Subject: [PATCH 5/9] address comments. --- .../org/apache/spark/sql/DataFrameComplexTypeSuite.scala | 2 -- .../src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 4 ++-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index 5debc6260460..b76fc73b7fa0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -30,7 +30,6 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext { val f = udf((a: String) => a) val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") df.select(struct($"a").as("s")).select(f($"s.a")).collect() - df.select(struct($"*").as("s")).select(f($"s.a")).collect() } test("UDF on named_struct") { @@ -43,7 +42,6 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext { val f = udf((a: String) => a) val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") df.select(array($"a").as("s")).select(f(expr("s[0]"))).collect() - df.select(array($"*").as("s")).select(f(expr("s[0]"))).collect() } test("SPARK-12477 accessing null element in array field") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 260acd950872..c0859d3491f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1946,11 +1946,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { // CreateStruct and CreateArray in project list (unresolved alias) assert(structDf.select(struct($"record.*")).first() == Row(Row(1, 1))) - assert(structDf.select(array($"record.*")).first().getAs[Seq[Int]](0) === Array(1, 1)) + assert(structDf.select(array($"record.*")).first().getAs[Seq[Int]](0) === Seq(1, 1)) // CreateStruct and CreateArray in project list (alias) assert(structDf.select(struct($"record.*").as("a")).first() == Row(Row(1, 1))) - assert(structDf.select(array($"record.*").as("a")).first().getAs[Seq[Int]](0) === Array(1, 1)) + assert(structDf.select(array($"record.*").as("a")).first().getAs[Seq[Int]](0) === Seq(1, 1)) } test("Common subexpression elimination") { From e060deaaf09d122966f090bf3b86895636418664 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 23 Feb 2016 23:01:06 -0800 Subject: [PATCH 6/9] address comments. --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 42ab64a5f7b9..6a38d8d3989f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -406,7 +406,7 @@ class Analyzer( } ) case g: Generate if containsStar(g.generator.children) => - failAnalysis("Cannot explode *, explode can only be applied on a specific column.") + failAnalysis("Invalid usage of '*' in explode/json_tuple/UDTF") } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 4930c485da83..54871920022b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -178,8 +178,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq }.queryExecution.assertAnalyzed() } - assert(e.getMessage.contains( - "Cannot explode *, explode can only be applied on a specific column.")) + assert(e.getMessage.contains("Invalid usage of '*' in explode/json_tuple/UDTF")) df.explode('prefix, 'csv) { case Row(prefix: String, csv: String) => csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq From ba3fe7ce3d42e93bfde7ca4e3f893d84cfa82604 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 20 Mar 2016 22:03:14 -0700 Subject: [PATCH 7/9] added test cases. --- .../org/apache/spark/sql/DataFrameSuite.scala | 22 ------------------- .../org/apache/spark/sql/SQLQuerySuite.scala | 22 +++++++++++++++++++ .../sql/hive/execution/SQLQuerySuite.scala | 8 +++++++ 3 files changed, 30 insertions(+), 22 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index b854726795f5..f4f7e990fa78 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -164,28 +164,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { ) } - test("SPARK-8930: explode should fail with a meaningful message if it takes a star") { - val df = Seq(("1", "1,2"), ("2", "4"), ("3", "7,8,9")).toDF("prefix", "csv") - val e = intercept[AnalysisException] { - df.explode($"*") { case Row(prefix: String, csv: String) => - csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq - }.queryExecution.assertAnalyzed() - } - assert(e.getMessage.contains("Invalid usage of '*' in explode/json_tuple/UDTF")) - - df.explode('prefix, 'csv) { case Row(prefix: String, csv: String) => - csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq - }.queryExecution.assertAnalyzed() - } - - test("explode alias and star") { - val df = Seq((Array("a"), 1)).toDF("a", "b") - - checkAnswer( - df.select(explode($"a").as("a"), $"*"), - Row("a", Seq("a"), 1) :: Nil) - } - test("sort after generate with join=true") { val df = Seq((Array("a"), 1)).toDF("a", "b") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 781637b1169f..c34a36019c1f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1949,6 +1949,28 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { assert(structDf.select(array($"record.*").as("a")).first().getAs[Seq[Int]](0) === Seq(1, 1)) } + test("Star Expansion - explode should fail with a meaningful message if it takes a star") { + val df = Seq(("1", "1,2"), ("2", "4"), ("3", "7,8,9")).toDF("prefix", "csv") + val e = intercept[AnalysisException] { + df.explode($"*") { case Row(prefix: String, csv: String) => + csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq + }.queryExecution.assertAnalyzed() + } + assert(e.getMessage.contains("Invalid usage of '*' in explode/json_tuple/UDTF")) + + df.explode('prefix, 'csv) { case Row(prefix: String, csv: String) => + csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq + }.queryExecution.assertAnalyzed() + } + + test("Star Expansion - explode alias and star") { + val df = Seq((Array("a"), 1)).toDF("a", "b") + + checkAnswer( + df.select(explode($"a").as("a"), $"*"), + Row("a", Seq("a"), 1) :: Nil) + } + test("Common subexpression elimination") { // TODO: support subexpression elimination in whole stage codegen withSQLConf("spark.sql.codegen.wholeStage" -> "false") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index b42f00e90f31..f59db4d43b65 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -737,6 +737,14 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { .queryExecution.analyzed } + test("Star Expansion - script transform") { + val data = (1 to 100000).map { i => (i, i, i) } + data.toDF("d1", "d2", "d3").registerTempTable("script_trans") + assert(100000 === + sql("SELECT TRANSFORM (*) USING 'cat' FROM script_trans") + .queryExecution.toRdd.count()) + } + test("test script transform for stdout") { val data = (1 to 100000).map { i => (i, i, i) } data.toDF("d1", "d2", "d3").registerTempTable("script_trans") From 0fce0752fb74b4eb49931c36fdd6e43fc2ec04f2 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 20 Mar 2016 23:52:38 -0700 Subject: [PATCH 8/9] address comments. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 4 ++-- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 13 ++++++++++--- .../spark/sql/hive/execution/SQLQuerySuite.scala | 10 +++------- 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c668a5e31d76..9497f852f756 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -384,8 +384,8 @@ class Analyzer( case ua @ UnresolvedAlias(_: UnresolvedFunction | _: CreateArray | _: CreateStruct, _) => UnresolvedAlias(child = expandStarExpression(ua.child, p.child)) :: Nil case a @ Alias(_: UnresolvedFunction | _: CreateArray | _: CreateStruct, _) => - Alias(child = expandStarExpression(a.child, p.child), a.name)( - isGenerated = a.isGenerated) :: Nil + a.withNewChildren(expandStarExpression(a.child, p.child) :: Nil) + .asInstanceOf[Alias] :: Nil case o => o :: Nil } Project(projectList = expanded, p.child) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index c34a36019c1f..54cd22cc4dce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1958,9 +1958,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } assert(e.getMessage.contains("Invalid usage of '*' in explode/json_tuple/UDTF")) - df.explode('prefix, 'csv) { case Row(prefix: String, csv: String) => - csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq - }.queryExecution.assertAnalyzed() + checkAnswer( + df.explode('prefix, 'csv) { case Row(prefix: String, csv: String) => + csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq + }, + Row("1", "1,2", "1:1") :: + Row("1", "1,2", "1:2") :: + Row("2", "4", "2:4") :: + Row("3", "7,8,9", "3:7") :: + Row("3", "7,8,9", "3:8") :: + Row("3", "7,8,9", "3:9") :: Nil) } test("Star Expansion - explode alias and star") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index f59db4d43b65..ada2d5209621 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -740,25 +740,21 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("Star Expansion - script transform") { val data = (1 to 100000).map { i => (i, i, i) } data.toDF("d1", "d2", "d3").registerTempTable("script_trans") - assert(100000 === - sql("SELECT TRANSFORM (*) USING 'cat' FROM script_trans") - .queryExecution.toRdd.count()) + assert(100000 === sql("SELECT TRANSFORM (*) USING 'cat' FROM script_trans").count()) } test("test script transform for stdout") { val data = (1 to 100000).map { i => (i, i, i) } data.toDF("d1", "d2", "d3").registerTempTable("script_trans") assert(100000 === - sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat' AS (a,b,c) FROM script_trans") - .queryExecution.toRdd.count()) + sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat' AS (a,b,c) FROM script_trans").count()) } test("test script transform for stderr") { val data = (1 to 100000).map { i => (i, i, i) } data.toDF("d1", "d2", "d3").registerTempTable("script_trans") assert(0 === - sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat 1>&2' AS (a,b,c) FROM script_trans") - .queryExecution.toRdd.count()) + sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat 1>&2' AS (a,b,c) FROM script_trans").count()) } test("test script transform data type") { From 50abeec3ea7fd4da83ac89ed90fc478d493d3dba Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 21 Mar 2016 09:50:52 -0700 Subject: [PATCH 9/9] address comments. --- .../org/apache/spark/sql/DataFrameSuite.scala | 44 +++++++++++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 44 ------------------- 2 files changed, 44 insertions(+), 44 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f4f7e990fa78..477d66f87a64 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -164,6 +164,50 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { ) } + test("Star Expansion - CreateStruct and CreateArray") { + val structDf = testData2.select("a", "b").as("record") + // CreateStruct and CreateArray in aggregateExpressions + assert(structDf.groupBy($"a").agg(min(struct($"record.*"))).first() == Row(3, Row(3, 1))) + assert(structDf.groupBy($"a").agg(min(array($"record.*"))).first() == Row(3, Seq(3, 1))) + + // CreateStruct and CreateArray in project list (unresolved alias) + assert(structDf.select(struct($"record.*")).first() == Row(Row(1, 1))) + assert(structDf.select(array($"record.*")).first().getAs[Seq[Int]](0) === Seq(1, 1)) + + // CreateStruct and CreateArray in project list (alias) + assert(structDf.select(struct($"record.*").as("a")).first() == Row(Row(1, 1))) + assert(structDf.select(array($"record.*").as("a")).first().getAs[Seq[Int]](0) === Seq(1, 1)) + } + + test("Star Expansion - explode should fail with a meaningful message if it takes a star") { + val df = Seq(("1", "1,2"), ("2", "4"), ("3", "7,8,9")).toDF("prefix", "csv") + val e = intercept[AnalysisException] { + df.explode($"*") { case Row(prefix: String, csv: String) => + csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq + }.queryExecution.assertAnalyzed() + } + assert(e.getMessage.contains("Invalid usage of '*' in explode/json_tuple/UDTF")) + + checkAnswer( + df.explode('prefix, 'csv) { case Row(prefix: String, csv: String) => + csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq + }, + Row("1", "1,2", "1:1") :: + Row("1", "1,2", "1:2") :: + Row("2", "4", "2:4") :: + Row("3", "7,8,9", "3:7") :: + Row("3", "7,8,9", "3:8") :: + Row("3", "7,8,9", "3:9") :: Nil) + } + + test("Star Expansion - explode alias and star") { + val df = Seq((Array("a"), 1)).toDF("a", "b") + + checkAnswer( + df.select(explode($"a").as("a"), $"*"), + Row("a", Seq("a"), 1) :: Nil) + } + test("sort after generate with join=true") { val df = Seq((Array("a"), 1)).toDF("a", "b") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 54cd22cc4dce..182f287dd001 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1934,50 +1934,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } - test("Star Expansion - CreateStruct and CreateArray") { - val structDf = testData2.select("a", "b").as("record") - // CreateStruct and CreateArray in aggregateExpressions - assert(structDf.groupBy($"a").agg(min(struct($"record.*"))).first() == Row(3, Row(3, 1))) - assert(structDf.groupBy($"a").agg(min(array($"record.*"))).first() == Row(3, Seq(3, 1))) - - // CreateStruct and CreateArray in project list (unresolved alias) - assert(structDf.select(struct($"record.*")).first() == Row(Row(1, 1))) - assert(structDf.select(array($"record.*")).first().getAs[Seq[Int]](0) === Seq(1, 1)) - - // CreateStruct and CreateArray in project list (alias) - assert(structDf.select(struct($"record.*").as("a")).first() == Row(Row(1, 1))) - assert(structDf.select(array($"record.*").as("a")).first().getAs[Seq[Int]](0) === Seq(1, 1)) - } - - test("Star Expansion - explode should fail with a meaningful message if it takes a star") { - val df = Seq(("1", "1,2"), ("2", "4"), ("3", "7,8,9")).toDF("prefix", "csv") - val e = intercept[AnalysisException] { - df.explode($"*") { case Row(prefix: String, csv: String) => - csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq - }.queryExecution.assertAnalyzed() - } - assert(e.getMessage.contains("Invalid usage of '*' in explode/json_tuple/UDTF")) - - checkAnswer( - df.explode('prefix, 'csv) { case Row(prefix: String, csv: String) => - csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq - }, - Row("1", "1,2", "1:1") :: - Row("1", "1,2", "1:2") :: - Row("2", "4", "2:4") :: - Row("3", "7,8,9", "3:7") :: - Row("3", "7,8,9", "3:8") :: - Row("3", "7,8,9", "3:9") :: Nil) - } - - test("Star Expansion - explode alias and star") { - val df = Seq((Array("a"), 1)).toDF("a", "b") - - checkAnswer( - df.select(explode($"a").as("a"), $"*"), - Row("a", Seq("a"), 1) :: Nil) - } - test("Common subexpression elimination") { // TODO: support subexpression elimination in whole stage codegen withSQLConf("spark.sql.codegen.wholeStage" -> "false") {