Skip to content

Commit 32cfcb7

Browse files
committed
improve nullability inference for Aggregator
1 parent 00ad4f0 commit 32cfcb7

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,11 @@ object TypedAggregateExpression {
3838
bufferSerializer.map(_.toAttribute))
3939

4040
val outputEncoder = encoderFor[OUT]
41-
val outputType = if (outputEncoder.flat) {
42-
outputEncoder.schema.head.dataType
41+
val (outputType, outputNullable) = if (outputEncoder.flat) {
42+
val sf = outputEncoder.schema.head
43+
(sf.dataType, sf.nullable)
4344
} else {
44-
outputEncoder.schema
45+
(outputEncoder.schema, true)
4546
}
4647

4748
new TypedAggregateExpression(
@@ -51,7 +52,8 @@ object TypedAggregateExpression {
5152
bufferDeserializer,
5253
outputEncoder.serializer,
5354
outputEncoder.deserializer.dataType,
54-
outputType)
55+
outputType,
56+
outputNullable)
5557
}
5658
}
5759

@@ -65,9 +67,8 @@ case class TypedAggregateExpression(
6567
bufferDeserializer: Expression,
6668
outputSerializer: Seq[Expression],
6769
outputExternalType: DataType,
68-
dataType: DataType) extends DeclarativeAggregate with NonSQLExpression {
69-
70-
override def nullable: Boolean = true
70+
dataType: DataType,
71+
nullable: Boolean) extends DeclarativeAggregate with NonSQLExpression {
7172

7273
override def deterministic: Boolean = true
7374

sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,4 +271,13 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
271271
"RowAgg(org.apache.spark.sql.Row)")
272272
assert(df.groupBy($"j").agg(RowAgg.toColumn as "agg1").columns.last == "agg1")
273273
}
274+
275+
test("spark-15204 improve nullability inference for Aggregator") {
276+
val ds1 = Seq(1, 3, 2, 5).toDS()
277+
assert(ds1.select(typed.sum((i: Int) => i)).schema.head.nullable === false)
278+
val ds2 = Seq(AggData(1, "a"), AggData(2, "a")).toDS()
279+
assert(ds2.groupByKey(_.b).agg(SeqAgg.toColumn).schema(1).nullable === true)
280+
val ds3 = sql("SELECT 'Some String' AS b, 1279869254 AS a").as[AggData]
281+
assert(ds3.groupByKey(_.a).agg(NameAgg.toColumn).schema(1).nullable === true)
282+
}
274283
}

0 commit comments

Comments
 (0)