Skip to content

Commit 1e43851

Browse files
viiryamarmbrus
authored andcommitted
[SPARK-6899][SQL] Fix type mismatch when using codegen with Average on DecimalType
JIRA https://issues.apache.org/jira/browse/SPARK-6899 Author: Liang-Chi Hsieh <[email protected]> Closes apache#5517 from viirya/fix_codegen_average and squashes the following commits: 8ae5f65 [Liang-Chi Hsieh] Add the case of DecimalType.Unlimited to Average.
1 parent d966086 commit 1e43851

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
326326

327327
override def asPartial: SplitEvaluation = {
328328
child.dataType match {
329-
case DecimalType.Fixed(_, _) =>
329+
case DecimalType.Fixed(_, _) | DecimalType.Unlimited =>
330330
// Turn the child to unlimited decimals for calculation, before going back to fixed
331331
val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")()
332332
val partialCount = Alias(Count(child), "PartialCount")()

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,4 +537,13 @@ class DataFrameSuite extends QueryTest {
537537
val df = TestSQLContext.createDataFrame(rowRDD, schema)
538538
df.rdd.collect()
539539
}
540+
541+
test("SPARK-6899") {
542+
val originalValue = TestSQLContext.conf.codegenEnabled
543+
TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, "true")
544+
checkAnswer(
545+
decimalData.agg(avg('a)),
546+
Row(new java.math.BigDecimal(2.0)))
547+
TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)
548+
}
540549
}

0 commit comments

Comments
 (0)