Skip to content

Commit 3b2b448

Browse files
committed
structStarExpansion
1 parent 374c4b2 commit 3b2b448

File tree

3 files changed

+25
-1
lines changed

3 files changed

+25
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,12 +362,26 @@ class Analyzer(
362362
exprs.flatMap {
363363
case s: Star => s.expand(child, resolver)
364364
case e =>
365-
e.transformDown {
365+
e.transformUp {
366+
// ResolveFunctions can handle the case when the number of variables is not valid
366367
case f1: UnresolvedFunction if containsStar(f1.children) =>
367368
f1.copy(children = f1.children.flatMap {
368369
case s: Star => s.expand(child, resolver)
369370
case o => o :: Nil
370371
})
372+
case c: CreateStruct if containsStar(c.children) =>
373+
c.copy(children = c.children.flatMap {
374+
case s: Star => s.expand(child, resolver)
375+
case o => o :: Nil
376+
})
377+
case c: CreateStructUnsafe if containsStar(c.children) =>
378+
c.copy(children = c.children.flatMap {
379+
case s: Star => s.expand(child, resolver)
380+
case o => o :: Nil
381+
})
382+
// count(*) has been replaced by count(1)
383+
case f2: ExpectsInputTypes if containsStar(f2.children) =>
384+
failAnalysis(s"Invalid usage of '*' in function '${f2.prettyName}'")
371385
} :: Nil
372386
}
373387
}

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,14 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
528528
assert(e.getMessage.contains("cannot resolve 'c' given input columns: [a, b]"), e.getMessage)
529529
}
530530

531+
test("verify star in functions fail with a good error") {
532+
val ds = Seq(("a", 1, "c"), ("b", 2, "d")).map(a => (a._1, a._3))
533+
val e = intercept[AnalysisException] {
534+
ds.toDF().groupBy($"_1").agg(sum($"*") as "sumOccurances")
535+
}
536+
assert(e.getMessage.contains("Invalid usage of '*' in function 'sum'"), e.getMessage)
537+
}
538+
531539
test("runtime nullability check") {
532540
val schema = StructType(Seq(
533541
StructField("f", StructType(Seq(

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1834,6 +1834,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
18341834
""".stripMargin).select($"r.*"),
18351835
Row(3, 2) :: Nil)
18361836

1837+
assert(structDf.groupBy($"a").agg(min(struct($"record.*"))).first() == Row(3, Row(3, 1)))
1838+
18371839
// With GROUP BY
18381840
checkAnswer(sql(
18391841
"""

0 commit comments

Comments
 (0)