Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,12 @@ case class TypedAggregateExpression(

dataType match {
case s: StructType =>
ReferenceToExpressions(CreateStruct(outputSerializer), resultObj :: Nil)
val objRef = outputSerializer.head.find(_.isInstanceOf[BoundReference]).get
val struct = If(
IsNull(objRef),
Literal.create(null, dataType),
CreateStruct(outputSerializer))
ReferenceToExpressions(struct, resultObj :: Nil)
case _ =>
assert(outputSerializer.length == 1)
outputSerializer.head transform {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,23 @@ object RowAgg extends Aggregator[Row, Int, Int] {
override def outputEncoder: Encoder[Int] = Encoders.scalaInt
}

object NullResultAgg extends Aggregator[AggData, AggData, AggData] {
override def zero: AggData = AggData(0, "")
override def reduce(b: AggData, a: AggData): AggData = AggData(b.a + a.a, b.b + a.b)
override def finish(reduction: AggData): AggData = {
if (reduction.a % 2 == 0) null else reduction
}
override def merge(b1: AggData, b2: AggData): AggData = AggData(b1.a + b2.a, b1.b + b2.b)
override def bufferEncoder: Encoder[AggData] = Encoders.product[AggData]
override def outputEncoder: Encoder[AggData] = Encoders.product[AggData]
}

class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {

class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
import testImplicits._

private implicit val ordering = Ordering.by((c: AggData) => c.a -> c.b)

test("typed aggregation: TypedAggregator") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()

Expand Down Expand Up @@ -204,7 +216,7 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
ds.select(expr("avg(a)").as[Double], ComplexBufferAgg.toColumn),
(1.5, 2))

checkDataset(
checkDatasetUnorderly(
ds.groupByKey(_.b).agg(ComplexBufferAgg.toColumn),
("one", 1), ("two", 1))
}
Expand Down Expand Up @@ -271,4 +283,11 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
"RowAgg(org.apache.spark.sql.Row)")
assert(df.groupBy($"j").agg(RowAgg.toColumn as "agg1").columns.last == "agg1")
}

test("SPARK-15814 Aggregator can return null result") {
val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS()
checkDatasetUnorderly(
ds.groupByKey(_.a).agg(NullResultAgg.toColumn),
1 -> AggData(1, "one"), 2 -> null)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
test("groupBy function, keys") {
val ds = Seq(1, 2, 3, 4, 5).toDS()
val grouped = ds.groupByKey(_ % 2)
checkDataset(
checkDatasetUnorderly(
grouped.keys,
0, 1)
}
Expand All @@ -95,7 +95,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
(name, iter.size)
}

checkDataset(
checkDatasetUnorderly(
agged,
("even", 5), ("odd", 6))
}
Expand All @@ -105,7 +105,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
val grouped = ds.groupByKey(_.length)
val agged = grouped.flatMapGroups { case (g, iter) => Iterator(g.toString, iter.mkString) }

checkDataset(
checkDatasetUnorderly(
agged,
"1", "abc", "3", "xyz", "5", "hello")
}
Expand Down
38 changes: 20 additions & 18 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructT
class DatasetSuite extends QueryTest with SharedSQLContext {
import testImplicits._

private implicit val ordering = Ordering.by((c: ClassData) => c.a -> c.b)

test("toDS") {
val data = Seq(("a", 1), ("b", 2), ("c", 3))
checkDataset(
Expand Down Expand Up @@ -95,12 +97,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
}

assert(ds.repartition(10).rdd.partitions.length == 10)
checkDataset(
checkDatasetUnorderly(
ds.repartition(10),
data: _*)

assert(ds.coalesce(1).rdd.partitions.length == 1)
checkDataset(
checkDatasetUnorderly(
ds.coalesce(1),
data: _*)
}
Expand Down Expand Up @@ -163,7 +165,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
.map(c => ClassData(c.a, c.b + 1))
.groupByKey(p => p).count()

checkDataset(
checkDatasetUnorderly(
ds,
(ClassData("one", 2), 1L), (ClassData("two", 3), 1L))
}
Expand Down Expand Up @@ -204,7 +206,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {

test("select 2, primitive and class, fields reordered") {
val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
checkDecoding(
checkDataset(
ds.select(
expr("_1").as[String],
expr("named_struct('b', _2, 'a', _1)").as[ClassData]),
Expand Down Expand Up @@ -291,7 +293,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("groupBy function, keys") {
val ds = Seq(("a", 1), ("b", 1)).toDS()
val grouped = ds.groupByKey(v => (1, v._2))
checkDataset(
checkDatasetUnorderly(
grouped.keys,
(1, 1))
}
Expand All @@ -301,7 +303,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val grouped = ds.groupByKey(v => (v._1, "word"))
val agged = grouped.mapGroups { case (g, iter) => (g._1, iter.map(_._2).sum) }

checkDataset(
checkDatasetUnorderly(
agged,
("a", 30), ("b", 3), ("c", 1))
}
Expand All @@ -313,7 +315,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
Iterator(g._1, iter.map(_._2).sum.toString)
}

checkDataset(
checkDatasetUnorderly(
agged,
"a", "30", "b", "3", "c", "1")
}
Expand All @@ -322,7 +324,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val ds = Seq("abc", "xyz", "hello").toDS()
val agged = ds.groupByKey(_.length).reduceGroups(_ + _)

checkDataset(
checkDatasetUnorderly(
agged,
3 -> "abcxyz", 5 -> "hello")
}
Expand All @@ -340,31 +342,31 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("typed aggregation: expr") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()

checkDataset(
checkDatasetUnorderly(
ds.groupByKey(_._1).agg(sum("_2").as[Long]),
("a", 30L), ("b", 3L), ("c", 1L))
}

test("typed aggregation: expr, expr") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()

checkDataset(
checkDatasetUnorderly(
ds.groupByKey(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long]),
("a", 30L, 32L), ("b", 3L, 5L), ("c", 1L, 2L))
}

test("typed aggregation: expr, expr, expr") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()

checkDataset(
checkDatasetUnorderly(
ds.groupByKey(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long], count("*")),
("a", 30L, 32L, 2L), ("b", 3L, 5L, 2L), ("c", 1L, 2L, 1L))
}

test("typed aggregation: expr, expr, expr, expr") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()

checkDataset(
checkDatasetUnorderly(
ds.groupByKey(_._1).agg(
sum("_2").as[Long],
sum($"_2" + 1).as[Long],
Expand All @@ -380,7 +382,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
Iterator(key -> (data1.map(_._2).mkString + "#" + data2.map(_._2).mkString))
}

checkDataset(
checkDatasetUnorderly(
cogrouped,
1 -> "a#", 2 -> "#q", 3 -> "abcfoo#w", 5 -> "hello#er")
}
Expand All @@ -392,7 +394,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
Iterator(key -> (data1.map(_._2.a).mkString + data2.map(_._2.a).mkString))
}

checkDataset(
checkDatasetUnorderly(
cogrouped,
1 -> "a", 2 -> "bc", 3 -> "d")
}
Expand Down Expand Up @@ -482,8 +484,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
checkDataset(
ds1.joinWith(ds2, lit(true)),
((nullInt, "1"), (nullInt, "1")),
((new java.lang.Integer(22), "2"), (nullInt, "1")),
((nullInt, "1"), (new java.lang.Integer(22), "2")),
((new java.lang.Integer(22), "2"), (nullInt, "1")),
((new java.lang.Integer(22), "2"), (new java.lang.Integer(22), "2")))
}

Expand Down Expand Up @@ -776,9 +778,9 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val ds1 = ds.as("d1")
val ds2 = ds.as("d2")

checkDataset(ds1.joinWith(ds2, $"d1.value" === $"d2.value"), (2, 2), (3, 3), (4, 4))
checkDataset(ds1.intersect(ds2), 2, 3, 4)
checkDataset(ds1.except(ds1))
checkDatasetUnorderly(ds1.joinWith(ds2, $"d1.value" === $"d2.value"), (2, 2), (3, 3), (4, 4))
checkDatasetUnorderly(ds1.intersect(ds2), 2, 3, 4)
checkDatasetUnorderly(ds1.except(ds1))
}

test("SPARK-15441: Dataset outer join") {
Expand Down
95 changes: 61 additions & 34 deletions sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,28 +68,62 @@ abstract class QueryTest extends PlanTest {
/**
* Evaluates a dataset to make sure that the result of calling collect matches the given
* expected answer.
* - Special handling is done based on whether the query plan should be expected to return
* the results in sorted order.
* - This function also checks to make sure that the schema for serializing the expected answer
* matches that produced by the dataset (i.e. does manual construction of object match
* the constructed encoder for cases like joins, etc). Note that this means that it will fail
* for cases where reordering is done on fields. For such tests, user `checkDecoding` instead
* which performs a subset of the checks done by this function.
*/
protected def checkDataset[T](
ds: Dataset[T],
ds: => Dataset[T],
expectedAnswer: T*): Unit = {
checkAnswer(
ds.toDF(),
spark.createDataset(expectedAnswer)(ds.exprEnc).toDF().collect().toSeq)
val result = getResult(ds)

checkDecoding(ds, expectedAnswer: _*)
if (!compare(result.toSeq, expectedAnswer)) {
fail(
s"""
|Decoded objects do not match expected objects:
|expected: $expectedAnswer
|actual: ${result.toSeq}
|${ds.exprEnc.deserializer.treeString}
""".stripMargin)
}
}

protected def checkDecoding[T](
/**
* Evaluates a dataset to make sure that the result of calling collect matches the given
* expected answer, after sort.
*/
protected def checkDatasetUnorderly[T : Ordering](
ds: => Dataset[T],
expectedAnswer: T*): Unit = {
val decoded = try ds.collect().toSet catch {
val result = getResult(ds)

if (!compare(result.toSeq.sorted, expectedAnswer.sorted)) {
fail(
s"""
|Decoded objects do not match expected objects:
|expected: $expectedAnswer
|actual: ${result.toSeq}
|${ds.exprEnc.deserializer.treeString}
""".stripMargin)
}
}

private def getResult[T](ds: => Dataset[T]): Array[T] = {
val analyzedDS = try ds catch {
case ae: AnalysisException =>
if (ae.plan.isDefined) {
fail(
s"""
|Failed to analyze query: $ae
|${ae.plan.get}
|
|${stackTraceToString(ae)}
""".stripMargin)
} else {
throw ae
}
}
checkJsonFormat(analyzedDS)
assertEmptyMissingInput(analyzedDS)

try ds.collect() catch {
case e: Exception =>
fail(
s"""
Expand All @@ -99,24 +133,17 @@ abstract class QueryTest extends PlanTest {
|${ds.queryExecution}
""".stripMargin, e)
}
}

// Handle the case where the return type is an array
val isArray = decoded.headOption.map(_.getClass.isArray).getOrElse(false)
def normalEquality = decoded == expectedAnswer.toSet
def expectedAsSeq = expectedAnswer.map(_.asInstanceOf[Array[_]].toSeq).toSet
def decodedAsSeq = decoded.map(_.asInstanceOf[Array[_]].toSeq)

if (!((isArray && expectedAsSeq == decodedAsSeq) || normalEquality)) {
val expected = expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted
val actual = decoded.toSet.toSeq.map((a: Any) => a.toString).sorted

val comparison = sideBySide("expected" +: expected, "spark" +: actual).mkString("\n")
fail(
s"""Decoded objects do not match expected objects:
|$comparison
|${ds.exprEnc.deserializer.treeString}
""".stripMargin)
}
private def compare(obj1: Any, obj2: Any): Boolean = (obj1, obj2) match {
case (null, null) => true
case (null, _) => false
case (_, null) => false
case (a: Array[_], b: Array[_]) =>
a.length == b.length && a.zip(b).forall { case (l, r) => compare(l, r)}
case (a: Iterable[_], b: Iterable[_]) =>
a.size == b.size && a.zip(b).forall { case (l, r) => compare(l, r)}
case (a, b) => a == b
}

/**
Expand All @@ -143,7 +170,7 @@ abstract class QueryTest extends PlanTest {

checkJsonFormat(analyzedDF)

assertEmptyMissingInput(df)
assertEmptyMissingInput(analyzedDF)

QueryTest.checkAnswer(analyzedDF, expectedAnswer) match {
case Some(errorMessage) => fail(errorMessage)
Expand Down Expand Up @@ -201,10 +228,10 @@ abstract class QueryTest extends PlanTest {
planWithCaching)
}

private def checkJsonFormat(df: DataFrame): Unit = {
private def checkJsonFormat(ds: Dataset[_]): Unit = {
// Get the analyzed plan and rewrite the PredicateSubqueries in order to make sure that
// RDD and Data resolution does not break.
val logicalPlan = df.queryExecution.analyzed
val logicalPlan = ds.queryExecution.analyzed

// bypass some cases that we can't handle currently.
logicalPlan.transform {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ class TextSuite extends QueryTest with SharedSQLContext {
ds1.write.text(s"$path/part=a")
ds1.write.text(s"$path/part=b")

checkDataset(
checkAnswer(
spark.read.format("text").load(path).select($"part"),
Row("a"), Row("b"))
Row("a") :: Row("b") :: Nil)
}
}

Expand Down
Loading