Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ object DecimalPrecision extends Rule[LogicalPlan] {
}

/** Decimal precision promotion for +, -, *, /, %, pmod, and binary comparison. */
private val decimalAndDecimal: PartialFunction[Expression, Expression] = {
private[catalyst] val decimalAndDecimal: PartialFunction[Expression, Expression] = {
// Skip nodes whose children have not been resolved yet
case e if !e.childrenResolved => e

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, TypeCheckResult}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.TypeUtils
Expand Down Expand Up @@ -77,10 +77,9 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit

// If all input are nulls, count will be 0 and we will get null after the division.
override lazy val evaluateExpression = child.dataType match {
case DecimalType.Fixed(p, s) =>
// increase the precision and scale to prevent precision loss
val dt = DecimalType.bounded(p + 14, s + 4)
Cast(Cast(sum, dt) / Cast(count, DecimalType.bounded(DecimalType.MAX_PRECISION, 0)),
case _: DecimalType =>
Cast(
DecimalPrecision.decimalAndDecimal(sum / Cast(count, DecimalType.LongDecimal)),
resultType)
case _ =>
Cast(sum, resultType) / Cast(count, resultType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1002,6 +1002,19 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
)
)
}

test("SPARK-24957: average with decimal followed by aggregation returning wrong result") {
val df = Seq(("a", BigDecimal("12.0")),
("a", BigDecimal("12.0")),
("a", BigDecimal("11.9999999988")),
("a", BigDecimal("12.0")),
("a", BigDecimal("12.0")),
("a", BigDecimal("11.9999999988")),
("a", BigDecimal("11.9999999988"))).toDF("text", "number")
val agg1 = df.groupBy($"text").agg(avg($"number").as("avg_res"))
val agg2 = agg1.groupBy($"text").agg(sum($"avg_res"))
checkAnswer(agg2, Row("a", BigDecimal("11.9999999994857142857143")))
}
}


Expand Down