Skip to content

Commit cd47e23

Browse files
cloud-fanhvanhovell
authored andcommitted
[SPARK-15814][SQL] Aggregator can return null result
## What changes were proposed in this pull request? It's similar to the bug fixed in #13425, we should consider null object and wrap the `CreateStruct` with `If` to do null check. This PR also improves the test framework to test the objects of `Dataset[T]` directly, instead of calling `toDF` and compare the rows. ## How was this patch tested? new test in `DatasetAggregatorSuite` Author: Wenchen Fan <[email protected]> Closes #13553 from cloud-fan/agg-null.
1 parent d681742 commit cd47e23

File tree

8 files changed

+117
-64
lines changed

8 files changed

+117
-64
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,12 @@ case class TypedAggregateExpression(
127127

128128
dataType match {
129129
case s: StructType =>
130-
ReferenceToExpressions(CreateStruct(outputSerializer), resultObj :: Nil)
130+
val objRef = outputSerializer.head.find(_.isInstanceOf[BoundReference]).get
131+
val struct = If(
132+
IsNull(objRef),
133+
Literal.create(null, dataType),
134+
CreateStruct(outputSerializer))
135+
ReferenceToExpressions(struct, resultObj :: Nil)
131136
case _ =>
132137
assert(outputSerializer.length == 1)
133138
outputSerializer.head transform {

sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,23 @@ object RowAgg extends Aggregator[Row, Int, Int] {
115115
override def outputEncoder: Encoder[Int] = Encoders.scalaInt
116116
}
117117

118+
object NullResultAgg extends Aggregator[AggData, AggData, AggData] {
119+
override def zero: AggData = AggData(0, "")
120+
override def reduce(b: AggData, a: AggData): AggData = AggData(b.a + a.a, b.b + a.b)
121+
override def finish(reduction: AggData): AggData = {
122+
if (reduction.a % 2 == 0) null else reduction
123+
}
124+
override def merge(b1: AggData, b2: AggData): AggData = AggData(b1.a + b2.a, b1.b + b2.b)
125+
override def bufferEncoder: Encoder[AggData] = Encoders.product[AggData]
126+
override def outputEncoder: Encoder[AggData] = Encoders.product[AggData]
127+
}
118128

119-
class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
120129

130+
class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
121131
import testImplicits._
122132

133+
private implicit val ordering = Ordering.by((c: AggData) => c.a -> c.b)
134+
123135
test("typed aggregation: TypedAggregator") {
124136
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
125137

@@ -204,7 +216,7 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
204216
ds.select(expr("avg(a)").as[Double], ComplexBufferAgg.toColumn),
205217
(1.5, 2))
206218

207-
checkDataset(
219+
checkDatasetUnorderly(
208220
ds.groupByKey(_.b).agg(ComplexBufferAgg.toColumn),
209221
("one", 1), ("two", 1))
210222
}
@@ -271,4 +283,11 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
271283
"RowAgg(org.apache.spark.sql.Row)")
272284
assert(df.groupBy($"j").agg(RowAgg.toColumn as "agg1").columns.last == "agg1")
273285
}
286+
287+
test("SPARK-15814 Aggregator can return null result") {
288+
val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS()
289+
checkDatasetUnorderly(
290+
ds.groupByKey(_.a).agg(NullResultAgg.toColumn),
291+
1 -> AggData(1, "one"), 2 -> null)
292+
}
274293
}

sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
8282
test("groupBy function, keys") {
8383
val ds = Seq(1, 2, 3, 4, 5).toDS()
8484
val grouped = ds.groupByKey(_ % 2)
85-
checkDataset(
85+
checkDatasetUnorderly(
8686
grouped.keys,
8787
0, 1)
8888
}
@@ -95,7 +95,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
9595
(name, iter.size)
9696
}
9797

98-
checkDataset(
98+
checkDatasetUnorderly(
9999
agged,
100100
("even", 5), ("odd", 6))
101101
}
@@ -105,7 +105,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
105105
val grouped = ds.groupByKey(_.length)
106106
val agged = grouped.flatMapGroups { case (g, iter) => Iterator(g.toString, iter.mkString) }
107107

108-
checkDataset(
108+
checkDatasetUnorderly(
109109
agged,
110110
"1", "abc", "3", "xyz", "5", "hello")
111111
}

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructT
3232
class DatasetSuite extends QueryTest with SharedSQLContext {
3333
import testImplicits._
3434

35+
private implicit val ordering = Ordering.by((c: ClassData) => c.a -> c.b)
36+
3537
test("toDS") {
3638
val data = Seq(("a", 1), ("b", 2), ("c", 3))
3739
checkDataset(
@@ -95,12 +97,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
9597
}
9698

9799
assert(ds.repartition(10).rdd.partitions.length == 10)
98-
checkDataset(
100+
checkDatasetUnorderly(
99101
ds.repartition(10),
100102
data: _*)
101103

102104
assert(ds.coalesce(1).rdd.partitions.length == 1)
103-
checkDataset(
105+
checkDatasetUnorderly(
104106
ds.coalesce(1),
105107
data: _*)
106108
}
@@ -163,7 +165,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
163165
.map(c => ClassData(c.a, c.b + 1))
164166
.groupByKey(p => p).count()
165167

166-
checkDataset(
168+
checkDatasetUnorderly(
167169
ds,
168170
(ClassData("one", 2), 1L), (ClassData("two", 3), 1L))
169171
}
@@ -204,7 +206,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
204206

205207
test("select 2, primitive and class, fields reordered") {
206208
val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
207-
checkDecoding(
209+
checkDataset(
208210
ds.select(
209211
expr("_1").as[String],
210212
expr("named_struct('b', _2, 'a', _1)").as[ClassData]),
@@ -291,7 +293,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
291293
test("groupBy function, keys") {
292294
val ds = Seq(("a", 1), ("b", 1)).toDS()
293295
val grouped = ds.groupByKey(v => (1, v._2))
294-
checkDataset(
296+
checkDatasetUnorderly(
295297
grouped.keys,
296298
(1, 1))
297299
}
@@ -301,7 +303,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
301303
val grouped = ds.groupByKey(v => (v._1, "word"))
302304
val agged = grouped.mapGroups { case (g, iter) => (g._1, iter.map(_._2).sum) }
303305

304-
checkDataset(
306+
checkDatasetUnorderly(
305307
agged,
306308
("a", 30), ("b", 3), ("c", 1))
307309
}
@@ -313,7 +315,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
313315
Iterator(g._1, iter.map(_._2).sum.toString)
314316
}
315317

316-
checkDataset(
318+
checkDatasetUnorderly(
317319
agged,
318320
"a", "30", "b", "3", "c", "1")
319321
}
@@ -322,7 +324,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
322324
val ds = Seq("abc", "xyz", "hello").toDS()
323325
val agged = ds.groupByKey(_.length).reduceGroups(_ + _)
324326

325-
checkDataset(
327+
checkDatasetUnorderly(
326328
agged,
327329
3 -> "abcxyz", 5 -> "hello")
328330
}
@@ -340,31 +342,31 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
340342
test("typed aggregation: expr") {
341343
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
342344

343-
checkDataset(
345+
checkDatasetUnorderly(
344346
ds.groupByKey(_._1).agg(sum("_2").as[Long]),
345347
("a", 30L), ("b", 3L), ("c", 1L))
346348
}
347349

348350
test("typed aggregation: expr, expr") {
349351
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
350352

351-
checkDataset(
353+
checkDatasetUnorderly(
352354
ds.groupByKey(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long]),
353355
("a", 30L, 32L), ("b", 3L, 5L), ("c", 1L, 2L))
354356
}
355357

356358
test("typed aggregation: expr, expr, expr") {
357359
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
358360

359-
checkDataset(
361+
checkDatasetUnorderly(
360362
ds.groupByKey(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long], count("*")),
361363
("a", 30L, 32L, 2L), ("b", 3L, 5L, 2L), ("c", 1L, 2L, 1L))
362364
}
363365

364366
test("typed aggregation: expr, expr, expr, expr") {
365367
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
366368

367-
checkDataset(
369+
checkDatasetUnorderly(
368370
ds.groupByKey(_._1).agg(
369371
sum("_2").as[Long],
370372
sum($"_2" + 1).as[Long],
@@ -380,7 +382,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
380382
Iterator(key -> (data1.map(_._2).mkString + "#" + data2.map(_._2).mkString))
381383
}
382384

383-
checkDataset(
385+
checkDatasetUnorderly(
384386
cogrouped,
385387
1 -> "a#", 2 -> "#q", 3 -> "abcfoo#w", 5 -> "hello#er")
386388
}
@@ -392,7 +394,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
392394
Iterator(key -> (data1.map(_._2.a).mkString + data2.map(_._2.a).mkString))
393395
}
394396

395-
checkDataset(
397+
checkDatasetUnorderly(
396398
cogrouped,
397399
1 -> "a", 2 -> "bc", 3 -> "d")
398400
}
@@ -482,8 +484,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
482484
checkDataset(
483485
ds1.joinWith(ds2, lit(true)),
484486
((nullInt, "1"), (nullInt, "1")),
485-
((new java.lang.Integer(22), "2"), (nullInt, "1")),
486487
((nullInt, "1"), (new java.lang.Integer(22), "2")),
488+
((new java.lang.Integer(22), "2"), (nullInt, "1")),
487489
((new java.lang.Integer(22), "2"), (new java.lang.Integer(22), "2")))
488490
}
489491

@@ -776,9 +778,9 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
776778
val ds1 = ds.as("d1")
777779
val ds2 = ds.as("d2")
778780

779-
checkDataset(ds1.joinWith(ds2, $"d1.value" === $"d2.value"), (2, 2), (3, 3), (4, 4))
780-
checkDataset(ds1.intersect(ds2), 2, 3, 4)
781-
checkDataset(ds1.except(ds1))
781+
checkDatasetUnorderly(ds1.joinWith(ds2, $"d1.value" === $"d2.value"), (2, 2), (3, 3), (4, 4))
782+
checkDatasetUnorderly(ds1.intersect(ds2), 2, 3, 4)
783+
checkDatasetUnorderly(ds1.except(ds1))
782784
}
783785

784786
test("SPARK-15441: Dataset outer join") {

sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala

Lines changed: 61 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -68,28 +68,62 @@ abstract class QueryTest extends PlanTest {
6868
/**
6969
* Evaluates a dataset to make sure that the result of calling collect matches the given
7070
* expected answer.
71-
* - Special handling is done based on whether the query plan should be expected to return
72-
* the results in sorted order.
73-
* - This function also checks to make sure that the schema for serializing the expected answer
74-
* matches that produced by the dataset (i.e. does manual construction of object match
75-
* the constructed encoder for cases like joins, etc). Note that this means that it will fail
76-
* for cases where reordering is done on fields. For such tests, user `checkDecoding` instead
77-
* which performs a subset of the checks done by this function.
7871
*/
7972
protected def checkDataset[T](
80-
ds: Dataset[T],
73+
ds: => Dataset[T],
8174
expectedAnswer: T*): Unit = {
82-
checkAnswer(
83-
ds.toDF(),
84-
spark.createDataset(expectedAnswer)(ds.exprEnc).toDF().collect().toSeq)
75+
val result = getResult(ds)
8576

86-
checkDecoding(ds, expectedAnswer: _*)
77+
if (!compare(result.toSeq, expectedAnswer)) {
78+
fail(
79+
s"""
80+
|Decoded objects do not match expected objects:
81+
|expected: $expectedAnswer
82+
|actual: ${result.toSeq}
83+
|${ds.exprEnc.deserializer.treeString}
84+
""".stripMargin)
85+
}
8786
}
8887

89-
protected def checkDecoding[T](
88+
/**
89+
* Evaluates a dataset to make sure that the result of calling collect matches the given
90+
* expected answer, after sort.
91+
*/
92+
protected def checkDatasetUnorderly[T : Ordering](
9093
ds: => Dataset[T],
9194
expectedAnswer: T*): Unit = {
92-
val decoded = try ds.collect().toSet catch {
95+
val result = getResult(ds)
96+
97+
if (!compare(result.toSeq.sorted, expectedAnswer.sorted)) {
98+
fail(
99+
s"""
100+
|Decoded objects do not match expected objects:
101+
|expected: $expectedAnswer
102+
|actual: ${result.toSeq}
103+
|${ds.exprEnc.deserializer.treeString}
104+
""".stripMargin)
105+
}
106+
}
107+
108+
private def getResult[T](ds: => Dataset[T]): Array[T] = {
109+
val analyzedDS = try ds catch {
110+
case ae: AnalysisException =>
111+
if (ae.plan.isDefined) {
112+
fail(
113+
s"""
114+
|Failed to analyze query: $ae
115+
|${ae.plan.get}
116+
|
117+
|${stackTraceToString(ae)}
118+
""".stripMargin)
119+
} else {
120+
throw ae
121+
}
122+
}
123+
checkJsonFormat(analyzedDS)
124+
assertEmptyMissingInput(analyzedDS)
125+
126+
try ds.collect() catch {
93127
case e: Exception =>
94128
fail(
95129
s"""
@@ -99,24 +133,17 @@ abstract class QueryTest extends PlanTest {
99133
|${ds.queryExecution}
100134
""".stripMargin, e)
101135
}
136+
}
102137

103-
// Handle the case where the return type is an array
104-
val isArray = decoded.headOption.map(_.getClass.isArray).getOrElse(false)
105-
def normalEquality = decoded == expectedAnswer.toSet
106-
def expectedAsSeq = expectedAnswer.map(_.asInstanceOf[Array[_]].toSeq).toSet
107-
def decodedAsSeq = decoded.map(_.asInstanceOf[Array[_]].toSeq)
108-
109-
if (!((isArray && expectedAsSeq == decodedAsSeq) || normalEquality)) {
110-
val expected = expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted
111-
val actual = decoded.toSet.toSeq.map((a: Any) => a.toString).sorted
112-
113-
val comparison = sideBySide("expected" +: expected, "spark" +: actual).mkString("\n")
114-
fail(
115-
s"""Decoded objects do not match expected objects:
116-
|$comparison
117-
|${ds.exprEnc.deserializer.treeString}
118-
""".stripMargin)
119-
}
138+
private def compare(obj1: Any, obj2: Any): Boolean = (obj1, obj2) match {
139+
case (null, null) => true
140+
case (null, _) => false
141+
case (_, null) => false
142+
case (a: Array[_], b: Array[_]) =>
143+
a.length == b.length && a.zip(b).forall { case (l, r) => compare(l, r)}
144+
case (a: Iterable[_], b: Iterable[_]) =>
145+
a.size == b.size && a.zip(b).forall { case (l, r) => compare(l, r)}
146+
case (a, b) => a == b
120147
}
121148

122149
/**
@@ -143,7 +170,7 @@ abstract class QueryTest extends PlanTest {
143170

144171
checkJsonFormat(analyzedDF)
145172

146-
assertEmptyMissingInput(df)
173+
assertEmptyMissingInput(analyzedDF)
147174

148175
QueryTest.checkAnswer(analyzedDF, expectedAnswer) match {
149176
case Some(errorMessage) => fail(errorMessage)
@@ -201,10 +228,10 @@ abstract class QueryTest extends PlanTest {
201228
planWithCaching)
202229
}
203230

204-
private def checkJsonFormat(df: DataFrame): Unit = {
231+
private def checkJsonFormat(ds: Dataset[_]): Unit = {
205232
// Get the analyzed plan and rewrite the PredicateSubqueries in order to make sure that
206233
// RDD and Data resolution does not break.
207-
val logicalPlan = df.queryExecution.analyzed
234+
val logicalPlan = ds.queryExecution.analyzed
208235

209236
// bypass some cases that we can't handle currently.
210237
logicalPlan.transform {

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,9 @@ class TextSuite extends QueryTest with SharedSQLContext {
132132
ds1.write.text(s"$path/part=a")
133133
ds1.write.text(s"$path/part=b")
134134

135-
checkDataset(
135+
checkAnswer(
136136
spark.read.format("text").load(path).select($"part"),
137-
Row("a"), Row("b"))
137+
Row("a") :: Row("b") :: Nil)
138138
}
139139
}
140140

0 commit comments

Comments
 (0)