Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,19 @@ object NullResultAgg extends Aggregator[AggData, AggData, AggData] {
override def outputEncoder: Encoder[AggData] = Encoders.product[AggData]
}

case class ComplexAggData(d1: AggData, d2: AggData)

object VeryComplexResultAgg extends Aggregator[Row, String, ComplexAggData] {
override def zero: String = ""
override def reduce(buffer: String, input: Row): String = buffer + input.getString(1)
override def merge(b1: String, b2: String): String = b1 + b2
override def finish(reduction: String): ComplexAggData = {
ComplexAggData(AggData(reduction.length, reduction), AggData(reduction.length, reduction))
}
override def bufferEncoder: Encoder[String] = Encoders.STRING
override def outputEncoder: Encoder[ComplexAggData] = Encoders.product[ComplexAggData]
}


class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
import testImplicits._
Expand Down Expand Up @@ -312,4 +325,12 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
val ds3 = sql("SELECT 'Some String' AS b, 1279869254 AS a").as[AggData]
assert(ds3.select(NameAgg.toColumn).schema.head.nullable === true)
}

test("SPARK-18147: very complex aggregator result type") {
val df = Seq(1 -> "a", 2 -> "b", 2 -> "c").toDF("i", "j")

checkAnswer(
df.groupBy($"i").agg(VeryComplexResultAgg.toColumn),
Row(1, Row(Row(1, "a"), Row(1, "a"))) :: Row(2, Row(Row(2, "bc"), Row(2, "bc"))) :: Nil)
}
}