Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class Analyzer(
EliminateUnions),
Batch("Resolution", fixedPoint,
ResolveRelations ::
ResolveStar ::
ResolveReferences ::
ResolveGroupingAnalytics ::
ResolvePivot ::
Expand Down Expand Up @@ -369,28 +370,83 @@ class Analyzer(
}

/**
* Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from
* a logical plan node's children.
* Expand [[UnresolvedStar]] or [[ResolvedStar]] to the matching attributes in child's output.
*/
object ResolveReferences extends Rule[LogicalPlan] {
object ResolveStar extends Rule[LogicalPlan] {

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p: LogicalPlan if !p.childrenResolved => p

// If the projection list contains Stars, expand it.
case p: Project if containsStar(p.projectList) =>
val expanded = p.projectList.flatMap {
case s: Star => s.expand(p.child, resolver)
case ua @ UnresolvedAlias(_: UnresolvedFunction | _: CreateArray | _: CreateStruct, _) =>
UnresolvedAlias(child = expandStarExpression(ua.child, p.child)) :: Nil
case a @ Alias(_: UnresolvedFunction | _: CreateArray | _: CreateStruct, _) =>
a.withNewChildren(expandStarExpression(a.child, p.child) :: Nil)
.asInstanceOf[Alias] :: Nil
case o => o :: Nil
}
Project(projectList = expanded, p.child)
// If the aggregate function argument contains Stars, expand it.
case a: Aggregate if containsStar(a.aggregateExpressions) =>
val expanded = a.aggregateExpressions.flatMap {
case s: Star => s.expand(a.child, resolver)
case o if containsStar(o :: Nil) => expandStarExpression(o, a.child) :: Nil
case o => o :: Nil
}.map(_.asInstanceOf[NamedExpression])
a.copy(aggregateExpressions = expanded)
// If the script transformation input contains Stars, expand it.
case t: ScriptTransformation if containsStar(t.input) =>
t.copy(
input = t.input.flatMap {
case s: Star => s.expand(t.child, resolver)
case o => o :: Nil
}
)
case g: Generate if containsStar(g.generator.children) =>
failAnalysis("Invalid usage of '*' in explode/json_tuple/UDTF")
}

/**
* Returns true if `exprs` contains a [[Star]].
*/
def containsStar(exprs: Seq[Expression]): Boolean =
exprs.exists(_.collect { case _: Star => true }.nonEmpty)

/**
* Foreach expression, expands the matching attribute.*'s in `child`'s input for the subtree
* rooted at each expression.
* Expands the matching attribute.*'s in `child`'s output.
*/
def expandStarExpressions(exprs: Seq[Expression], child: LogicalPlan): Seq[Expression] = {
exprs.flatMap {
case s: Star => s.expand(child, resolver)
case e =>
e.transformDown {
case f1: UnresolvedFunction if containsStar(f1.children) =>
f1.copy(children = f1.children.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
})
} :: Nil
def expandStarExpression(expr: Expression, child: LogicalPlan): Expression = {
expr.transformUp {
case f1: UnresolvedFunction if containsStar(f1.children) =>
f1.copy(children = f1.children.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
})
case c: CreateStruct if containsStar(c.children) =>
c.copy(children = c.children.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
})
case c: CreateArray if containsStar(c.children) =>
c.copy(children = c.children.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
})
// count(*) has been replaced by count(1)
case o if containsStar(o.children) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can have a method:

private def mayContainsStar(expr: Expression): Boolean = expr.isInstnaceOf[UnresolvedFunction] || expr.isInstnaceOf[CreateStruct]...

then we can simplify this to:

expr.transformUp {
  case e if mayContainsStar(e) =>
    e.copy(children = ...)
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a great idea! : )

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried it, but copy is unable to use here. When the type is Expression (abstract type), we are unable to use the copy function to change the children. In addition, withNewChildren requires the same number of children. Do you have any idea how to fix it? Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh i see, I don't have a better idea, let's just keep it this way.

failAnalysis(s"Invalid usage of '*' in expression '${o.prettyName}'")
}
}
}

/**
* Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from
* a logical plan node's children.
*/
object ResolveReferences extends Rule[LogicalPlan] {
/**
* Generate a new logical plan for the right child with different expression IDs
* for all conflicting attributes.
Expand Down Expand Up @@ -452,48 +508,6 @@ class Analyzer(
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p: LogicalPlan if !p.childrenResolved => p

// If the projection list contains Stars, expand it.
case p @ Project(projectList, child) if containsStar(projectList) =>
Project(
projectList.flatMap {
case s: Star => s.expand(child, resolver)
case UnresolvedAlias(f @ UnresolvedFunction(_, args, _), _) if containsStar(args) =>
val newChildren = expandStarExpressions(args, child)
UnresolvedAlias(child = f.copy(children = newChildren)) :: Nil
case a @ Alias(f @ UnresolvedFunction(_, args, _), name) if containsStar(args) =>
val newChildren = expandStarExpressions(args, child)
Alias(child = f.copy(children = newChildren), name)(
isGenerated = a.isGenerated) :: Nil
case UnresolvedAlias(c @ CreateArray(args), _) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
}
UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil
case UnresolvedAlias(c @ CreateStruct(args), _) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child, resolver)
case o => o :: Nil
}
UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil
case o => o :: Nil
},
child)

case t: ScriptTransformation if containsStar(t.input) =>
t.copy(
input = t.input.flatMap {
case s: Star => s.expand(t.child, resolver)
case o => o :: Nil
}
)

// If the aggregate function argument contains Stars, expand it.
case a: Aggregate if containsStar(a.aggregateExpressions) =>
val expanded = expandStarExpressions(a.aggregateExpressions, a.child)
.map(_.asInstanceOf[NamedExpression])
a.copy(aggregateExpressions = expanded)

// To resolve duplicate expression IDs for Join and Intersect
case j @ Join(left, right, _, _) if !j.duplicateResolved =>
j.copy(right = dedupRight(left, right))
Expand Down Expand Up @@ -588,12 +602,6 @@ class Analyzer(
def findAliases(projectList: Seq[NamedExpression]): AttributeSet = {
AttributeSet(projectList.collect { case a: Alias => a.toAttribute })
}

/**
* Returns true if `exprs` contains a [[Star]].
*/
def containsStar(exprs: Seq[Expression]): Boolean =
exprs.exists(_.collect { case _: Star => true }.nonEmpty)
}

private def resolveExpression(expr: Expression, plan: LogicalPlan, throws: Boolean = false) = {
Expand Down Expand Up @@ -893,8 +901,6 @@ class Analyzer(
*/
object ResolveGenerate extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case g: Generate if ResolveReferences.containsStar(g.generator.children) =>
failAnalysis("Cannot explode *, explode can only be applied on a specific column.")
case p: Generate if !p.child.resolved || !p.generator.resolved => p
case g: Generate if !g.resolved =>
g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,11 @@ class AnalysisErrorSuite extends AnalysisTest {
.orderBy('havingCondition.asc),
"cannot resolve" :: "havingCondition" :: Nil)

errorTest(
"unresolved star expansion in max",
testRelation2.groupBy('a)(sum(UnresolvedStar(None))),
"Invalid usage of '*'" :: "in expression 'sum'" :: Nil)

errorTest(
"bad casts",
testRelation.select(Literal(1).cast(BinaryType).as('badCast)),
Expand Down
35 changes: 28 additions & 7 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -164,22 +164,43 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
)
}

test("SPARK-8930: explode should fail with a meaningful message if it takes a star") {
test("Star Expansion - CreateStruct and CreateArray") {
val structDf = testData2.select("a", "b").as("record")
// CreateStruct and CreateArray in aggregateExpressions
assert(structDf.groupBy($"a").agg(min(struct($"record.*"))).first() == Row(3, Row(3, 1)))
assert(structDf.groupBy($"a").agg(min(array($"record.*"))).first() == Row(3, Seq(3, 1)))

// CreateStruct and CreateArray in project list (unresolved alias)
assert(structDf.select(struct($"record.*")).first() == Row(Row(1, 1)))
assert(structDf.select(array($"record.*")).first().getAs[Seq[Int]](0) === Seq(1, 1))

// CreateStruct and CreateArray in project list (alias)
assert(structDf.select(struct($"record.*").as("a")).first() == Row(Row(1, 1)))
assert(structDf.select(array($"record.*").as("a")).first().getAs[Seq[Int]](0) === Seq(1, 1))
}

test("Star Expansion - explode should fail with a meaningful message if it takes a star") {
val df = Seq(("1", "1,2"), ("2", "4"), ("3", "7,8,9")).toDF("prefix", "csv")
val e = intercept[AnalysisException] {
df.explode($"*") { case Row(prefix: String, csv: String) =>
csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq
}.queryExecution.assertAnalyzed()
}
assert(e.getMessage.contains(
"Cannot explode *, explode can only be applied on a specific column."))
assert(e.getMessage.contains("Invalid usage of '*' in explode/json_tuple/UDTF"))

df.explode('prefix, 'csv) { case Row(prefix: String, csv: String) =>
csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq
}.queryExecution.assertAnalyzed()
checkAnswer(
df.explode('prefix, 'csv) { case Row(prefix: String, csv: String) =>
csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq
},
Row("1", "1,2", "1:1") ::
Row("1", "1,2", "1:2") ::
Row("2", "4", "2:4") ::
Row("3", "7,8,9", "3:7") ::
Row("3", "7,8,9", "3:8") ::
Row("3", "7,8,9", "3:9") :: Nil)
}

test("explode alias and star") {
test("Star Expansion - explode alias and star") {
val df = Seq((Array("a"), 1)).toDF("a", "b")

checkAnswer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -737,20 +737,24 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
.queryExecution.analyzed
}

test("Star Expansion - script transform") {
val data = (1 to 100000).map { i => (i, i, i) }
data.toDF("d1", "d2", "d3").registerTempTable("script_trans")
assert(100000 === sql("SELECT TRANSFORM (*) USING 'cat' FROM script_trans").count())
}

test("test script transform for stdout") {
val data = (1 to 100000).map { i => (i, i, i) }
data.toDF("d1", "d2", "d3").registerTempTable("script_trans")
assert(100000 ===
sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat' AS (a,b,c) FROM script_trans")
.queryExecution.toRdd.count())
sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat' AS (a,b,c) FROM script_trans").count())
}

test("test script transform for stderr") {
val data = (1 to 100000).map { i => (i, i, i) }
data.toDF("d1", "d2", "d3").registerTempTable("script_trans")
assert(0 ===
sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat 1>&2' AS (a,b,c) FROM script_trans")
.queryExecution.toRdd.count())
sql("SELECT TRANSFORM (d1, d2, d3) USING 'cat 1>&2' AS (a,b,c) FROM script_trans").count())
}

test("test script transform data type") {
Expand Down