diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 53ea3cfef678..9497f852f756 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -76,6 +76,7 @@ class Analyzer( EliminateUnions), Batch("Resolution", fixedPoint, ResolveRelations :: + ResolveStar :: ResolveReferences :: ResolveGroupingAnalytics :: ResolvePivot :: @@ -369,28 +370,83 @@ class Analyzer( } /** - * Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from - * a logical plan node's children. + * Expand [[UnresolvedStar]] or [[ResolvedStar]] to the matching attributes in child's output. */ - object ResolveReferences extends Rule[LogicalPlan] { + object ResolveStar extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p: LogicalPlan if !p.childrenResolved => p + + // If the projection list contains Stars, expand it. + case p: Project if containsStar(p.projectList) => + val expanded = p.projectList.flatMap { + case s: Star => s.expand(p.child, resolver) + case ua @ UnresolvedAlias(_: UnresolvedFunction | _: CreateArray | _: CreateStruct, _) => + UnresolvedAlias(child = expandStarExpression(ua.child, p.child)) :: Nil + case a @ Alias(_: UnresolvedFunction | _: CreateArray | _: CreateStruct, _) => + a.withNewChildren(expandStarExpression(a.child, p.child) :: Nil) + .asInstanceOf[Alias] :: Nil + case o => o :: Nil + } + Project(projectList = expanded, p.child) + // If the aggregate function argument contains Stars, expand it. + case a: Aggregate if containsStar(a.aggregateExpressions) => + val expanded = a.aggregateExpressions.flatMap { + case s: Star => s.expand(a.child, resolver) + case o if containsStar(o :: Nil) => expandStarExpression(o, a.child) :: Nil + case o => o :: Nil + }.map(_.asInstanceOf[NamedExpression]) + a.copy(aggregateExpressions = expanded) + // If the script transformation input contains Stars, expand it. + case t: ScriptTransformation if containsStar(t.input) => + t.copy( + input = t.input.flatMap { + case s: Star => s.expand(t.child, resolver) + case o => o :: Nil + } + ) + case g: Generate if containsStar(g.generator.children) => + failAnalysis("Invalid usage of '*' in explode/json_tuple/UDTF") + } + + /** + * Returns true if `exprs` contains a [[Star]]. + */ + def containsStar(exprs: Seq[Expression]): Boolean = + exprs.exists(_.collect { case _: Star => true }.nonEmpty) + /** - * Foreach expression, expands the matching attribute.*'s in `child`'s input for the subtree - * rooted at each expression. + * Expands the matching attribute.*'s in `child`'s output. */ - def expandStarExpressions(exprs: Seq[Expression], child: LogicalPlan): Seq[Expression] = { - exprs.flatMap { - case s: Star => s.expand(child, resolver) - case e => - e.transformDown { - case f1: UnresolvedFunction if containsStar(f1.children) => - f1.copy(children = f1.children.flatMap { - case s: Star => s.expand(child, resolver) - case o => o :: Nil - }) - } :: Nil + def expandStarExpression(expr: Expression, child: LogicalPlan): Expression = { + expr.transformUp { + case f1: UnresolvedFunction if containsStar(f1.children) => + f1.copy(children = f1.children.flatMap { + case s: Star => s.expand(child, resolver) + case o => o :: Nil + }) + case c: CreateStruct if containsStar(c.children) => + c.copy(children = c.children.flatMap { + case s: Star => s.expand(child, resolver) + case o => o :: Nil + }) + case c: CreateArray if containsStar(c.children) => + c.copy(children = c.children.flatMap { + case s: Star => s.expand(child, resolver) + case o => o :: Nil + }) + // count(*) has been replaced by count(1) + case o if containsStar(o.children) => + failAnalysis(s"Invalid usage of '*' in expression '${o.prettyName}'") } } + } + /** + * Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from + * a logical plan node's children. + */ + object ResolveReferences extends Rule[LogicalPlan] { /** * Generate a new logical plan for the right child with different expression IDs * for all conflicting attributes. @@ -452,48 +508,6 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p: LogicalPlan if !p.childrenResolved => p - // If the projection list contains Stars, expand it. - case p @ Project(projectList, child) if containsStar(projectList) => - Project( - projectList.flatMap { - case s: Star => s.expand(child, resolver) - case UnresolvedAlias(f @ UnresolvedFunction(_, args, _), _) if containsStar(args) => - val newChildren = expandStarExpressions(args, child) - UnresolvedAlias(child = f.copy(children = newChildren)) :: Nil - case a @ Alias(f @ UnresolvedFunction(_, args, _), name) if containsStar(args) => - val newChildren = expandStarExpressions(args, child) - Alias(child = f.copy(children = newChildren), name)( - isGenerated = a.isGenerated) :: Nil - case UnresolvedAlias(c @ CreateArray(args), _) if containsStar(args) => - val expandedArgs = args.flatMap { - case s: Star => s.expand(child, resolver) - case o => o :: Nil - } - UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil - case UnresolvedAlias(c @ CreateStruct(args), _) if containsStar(args) => - val expandedArgs = args.flatMap { - case s: Star => s.expand(child, resolver) - case o => o :: Nil - } - UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil - case o => o :: Nil - }, - child) - - case t: ScriptTransformation if containsStar(t.input) => - t.copy( - input = t.input.flatMap { - case s: Star => s.expand(t.child, resolver) - case o => o :: Nil - } - ) - - // If the aggregate function argument contains Stars, expand it. - case a: Aggregate if containsStar(a.aggregateExpressions) => - val expanded = expandStarExpressions(a.aggregateExpressions, a.child) - .map(_.asInstanceOf[NamedExpression]) - a.copy(aggregateExpressions = expanded) - // To resolve duplicate expression IDs for Join and Intersect case j @ Join(left, right, _, _) if !j.duplicateResolved => j.copy(right = dedupRight(left, right)) @@ -588,12 +602,6 @@ class Analyzer( def findAliases(projectList: Seq[NamedExpression]): AttributeSet = { AttributeSet(projectList.collect { case a: Alias => a.toAttribute }) } - - /** - * Returns true if `exprs` contains a [[Star]]. - */ - def containsStar(exprs: Seq[Expression]): Boolean = - exprs.exists(_.collect { case _: Star => true }.nonEmpty) } private def resolveExpression(expr: Expression, plan: LogicalPlan, throws: Boolean = false) = { @@ -893,8 +901,6 @@ class Analyzer( */ object ResolveGenerate extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case g: Generate if ResolveReferences.containsStar(g.generator.children) => - failAnalysis("Cannot explode *, explode can only be applied on a specific column.") case p: Generate if !p.child.resolved || !p.generator.resolved => p case g: Generate if !g.resolved => g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index de9a56dc9c06..49596f7e5303 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -196,6 +196,11 @@ class AnalysisErrorSuite extends AnalysisTest { .orderBy('havingCondition.asc), "cannot resolve" :: "havingCondition" :: Nil) + errorTest( + "unresolved star expansion in max", + testRelation2.groupBy('a)(sum(UnresolvedStar(None))), + "Invalid usage of '*'" :: "in expression 'sum'" :: Nil) + errorTest( "bad casts", testRelation.select(Literal(1).cast(BinaryType).as('badCast)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 46cd380a797e..477d66f87a64 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -164,22 +164,43 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { ) } - test("SPARK-8930: explode should fail with a meaningful message if it takes a star") { + test("Star Expansion - CreateStruct and CreateArray") { + val structDf = testData2.select("a", "b").as("record") + // CreateStruct and CreateArray in aggregateExpressions + assert(structDf.groupBy($"a").agg(min(struct($"record.*"))).first() == Row(3, Row(3, 1))) + assert(structDf.groupBy($"a").agg(min(array($"record.*"))).first() == Row(3, Seq(3, 1))) + + // CreateStruct and CreateArray in project list (unresolved alias) + assert(structDf.select(struct($"record.*")).first() == Row(Row(1, 1))) + assert(structDf.select(array($"record.*")).first().getAs[Seq[Int]](0) === Seq(1, 1)) + + // CreateStruct and CreateArray in project list (alias) + assert(structDf.select(struct($"record.*").as("a")).first() == Row(Row(1, 1))) + assert(structDf.select(array($"record.*").as("a")).first().getAs[Seq[Int]](0) === Seq(1, 1)) + } + + test("Star Expansion - explode should fail with a meaningful message if it takes a star") { val df = Seq(("1", "1,2"), ("2", "4"), ("3", "7,8,9")).toDF("prefix", "csv") val e = intercept[AnalysisException] { df.explode($"*") { case Row(prefix: String, csv: String) => csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq }.queryExecution.assertAnalyzed() } - assert(e.getMessage.contains( - "Cannot explode *, explode can only be applied on a specific column.")) + assert(e.getMessage.contains("Invalid usage of '*' in explode/json_tuple/UDTF")) - df.explode('prefix, 'csv) { case Row(prefix: String, csv: String) => - csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq - }.queryExecution.assertAnalyzed() + checkAnswer( + df.explode('prefix, 'csv) { case Row(prefix: String, csv: String) => + csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq + }, + Row("1", "1,2", "1:1") :: + Row("1", "1,2", "1:2") :: + Row("2", "4", "2:4") :: + Row("3", "7,8,9", "3:7") :: + Row("3", "7,8,9", "3:8") :: + Row("3", "7,8,9", "3:9") :: Nil) } - test("explode alias and star") { + test("Star Expansion - explode alias and star") { val df = Seq((Array("a"), 1)).toDF("a", "b") checkAnswer( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index b42f00e90f31..ada2d5209621 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -737,20 +737,24 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { .queryExecution.analyzed } + test("Star Expansion - script transform") { + val data = (1 to 100000).map { i => (i, i, i) } + data.toDF("d1", "d2", "d3").registerTempTable("script_trans") + assert(100000 === sql("SELECT TRANSFORM (*) USING 'cat' FROM script_trans").count()) + } + test("test script transform for stdout") { val data = (1 to 100000).map { i => (i, i, i) } data.toDF("d1", "d2", "d3").registerTempTable("script_trans") assert(100000 === - sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat' AS (a,b,c) FROM script_trans") - .queryExecution.toRdd.count()) + sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat' AS (a,b,c) FROM script_trans").count()) } test("test script transform for stderr") { val data = (1 to 100000).map { i => (i, i, i) } data.toDF("d1", "d2", "d3").registerTempTable("script_trans") assert(0 === - sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat 1>&2' AS (a,b,c) FROM script_trans") - .queryExecution.toRdd.count()) + sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat 1>&2' AS (a,b,c) FROM script_trans").count()) } test("test script transform data type") {