From 3037d4aa6afc4d7630d86d29b8dd7d7d724cc990 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sun, 17 Dec 2017 22:45:06 +0100 Subject: [PATCH 01/16] [SPARK-22036][SQL] Decimal multiplication with high precision/scale often returns NULL --- .../catalyst/analysis/DecimalPrecision.scala | 81 +++++++++++++------ .../sql/catalyst/expressions/literals.scala | 2 +- .../apache/spark/sql/types/DecimalType.scala | 47 ++++++++++- .../sql/catalyst/analysis/AnalysisSuite.scala | 4 +- .../analysis/DecimalPrecisionSuite.scala | 20 ++--- .../resources/sql-tests/inputs/decimals.sql | 16 ++++ .../sql-tests/results/decimals.sql.out | 72 +++++++++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 6 +- 8 files changed, 206 insertions(+), 42 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/decimals.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/decimals.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index a8100b9b24aac..d64e89513fd55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -93,41 +93,46 @@ object DecimalPrecision extends TypeCoercionRule { case e: BinaryArithmetic if e.left.isInstanceOf[PromotePrecision] => e case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + val resultScale = max(s1, s2) + val dt = DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, + resultScale) CheckOverflow(Add(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt) case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + val resultScale = max(s1, s2) + val dt = DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, + resultScale) CheckOverflow(Subtract(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt) case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = DecimalType.bounded(p1 + p2 + 1, s1 + s2) + val resultType = DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2) val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow(Multiply(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), resultType) case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2) - var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1)) - val diff = (intDig + decDig) - DecimalType.MAX_SCALE - if (diff > 0) { - decDig -= diff / 2 + 1 - intDig = DecimalType.MAX_SCALE - decDig - } - val resultType = DecimalType.bounded(intDig + decDig, decDig) + // From https://msdn.microsoft.com/en-us/library/ms190476.aspx + // Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1) + // Scale: max(6, s1 + p2 + 1) + val intDig = p1 - s1 + s2 + val scale = max(DecimalType.MINIMUM_ADJUSTED_SCALE, s1 + p2 + 1) + val prec = intDig + scale + val resultType = DecimalType.adjustPrecisionScale(prec, scale) val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow(Divide(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), resultType) case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + val resultType = DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), + max(s1, s2)) // resultType may have lower precision, so we cast them into wider type first. val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow(Remainder(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), resultType) case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + val resultType = DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), + max(s1, s2)) // resultType may have lower precision, so we cast them into wider type first. val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow(Pmod(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), @@ -243,17 +248,43 @@ object DecimalPrecision extends TypeCoercionRule { // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles case b @ BinaryOperator(left, right) if left.dataType != right.dataType => - (left.dataType, right.dataType) match { - case (t: IntegralType, DecimalType.Fixed(p, s)) => - b.makeCopy(Array(Cast(left, DecimalType.forType(t)), right)) - case (DecimalType.Fixed(p, s), t: IntegralType) => - b.makeCopy(Array(left, Cast(right, DecimalType.forType(t)))) - case (t, DecimalType.Fixed(p, s)) if isFloat(t) => - b.makeCopy(Array(left, Cast(right, DoubleType))) - case (DecimalType.Fixed(p, s), t) if isFloat(t) => - b.makeCopy(Array(Cast(left, DoubleType), right)) - case _ => - b - } + nondecimalLiteralAndDecimal(b).lift((left, right)).getOrElse( + nondecimalNonliteralAndDecimal(b).applyOrElse((left.dataType, right.dataType), + (_: (DataType, DataType)) => b)) } + + /** + * Type coercion for BinaryOperator in which one side is a non-decimal literal numeric, and the + * other side is a decimal. + */ + private def nondecimalLiteralAndDecimal( + b: BinaryOperator): PartialFunction[(Expression, Expression), Expression] = { + // Promote literal integers inside a binary expression with fixed-precision decimals to + // decimals. The precision and scale are the ones needed by the integer value. + case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType] + && l.dataType.isInstanceOf[IntegralType] => + b.makeCopy(Array(Cast(l, DecimalType.forLiteral(l)), r)) + case (l, r: Literal) if l.dataType.isInstanceOf[DecimalType] + && r.dataType.isInstanceOf[IntegralType] => + b.makeCopy(Array(l, Cast(r, DecimalType.forLiteral(r)))) + } + + /** + * Type coercion for BinaryOperator in which one side is a non-decimal non-literal numeric, and + * the other side is a decimal. + */ + private def nondecimalNonliteralAndDecimal( + b: BinaryOperator): PartialFunction[(DataType, DataType), Expression] = { + // Promote integers inside a binary expression with fixed-precision decimals to decimals, + // and fixed-precision decimals in an expression with floats / doubles to doubles + case (t: IntegralType, DecimalType.Fixed(p, s)) => + b.makeCopy(Array(Cast(b.left, DecimalType.forType(t)), b.right)) + case (DecimalType.Fixed(_, _), t: IntegralType) => + b.makeCopy(Array(b.left, Cast(b.right, DecimalType.forType(t)))) + case (t, DecimalType.Fixed(_, _)) if isFloat(t) => + b.makeCopy(Array(b.left, Cast(b.right, DoubleType))) + case (DecimalType.Fixed(_, _), t) if isFloat(t) => + b.makeCopy(Array(Cast(b.left, DoubleType), b.right)) + } + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 383203a209833..cd176d941819f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -58,7 +58,7 @@ object Literal { case s: Short => Literal(s, ShortType) case s: String => Literal(UTF8String.fromString(s), StringType) case b: Boolean => Literal(b, BooleanType) - case d: BigDecimal => Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale)) + case d: BigDecimal => Literal(Decimal(d), DecimalType.fromBigDecimal(d)) case d: JavaBigDecimal => Literal(Decimal(d), DecimalType(Math.max(d.precision, d.scale), d.scale())) case d: Decimal => Literal(d, DecimalType(Math.max(d.precision, d.scale), d.scale)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 6e050c18b8acb..1f97480d48d1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -23,7 +23,7 @@ import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} /** @@ -117,6 +117,7 @@ object DecimalType extends AbstractDataType { val MAX_SCALE = 38 val SYSTEM_DEFAULT: DecimalType = DecimalType(MAX_PRECISION, 18) val USER_DEFAULT: DecimalType = DecimalType(10, 0) + val MINIMUM_ADJUSTED_SCALE = 6 // The decimal types compatible with other numeric types private[sql] val ByteDecimal = DecimalType(3, 0) @@ -136,10 +137,54 @@ object DecimalType extends AbstractDataType { case DoubleType => DoubleDecimal } + private[sql] def forLiteral(literal: Literal): DecimalType = literal.value match { + case v: Short => fromBigDecimal(BigDecimal(v)) + case v: Int => fromBigDecimal(BigDecimal(v)) + case v: Long => fromBigDecimal(BigDecimal(v)) + case _ => forType(literal.dataType) + } + + private[sql] def fromBigDecimal(d: BigDecimal): DecimalType = { + DecimalType(Math.max(d.precision, d.scale), d.scale) + } + private[sql] def bounded(precision: Int, scale: Int): DecimalType = { DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE)) } + // scalastyle:off line.size.limit + /** + * Decimal implementation is based on Hive's one, which is itself inspired to SQLServer's one. + * In particular, when a result precision is greater than {@link #MAX_PRECISION}, the + * corresponding scale is reduced to prevent the integral part of a result from being truncated. + * + * For further reference, please see + * https://blogs.msdn.microsoft.com/sqlprogrammability/2006/03/29/multiplication-and-division-with-numerics/. + * + * @param precision + * @param scale + * @return + */ + // scalastyle:on line.size.limit + private[sql] def adjustPrecisionScale(precision: Int, scale: Int): DecimalType = { + // Assumptions: + // precision >= scale + // scale >= 0 + if (precision <= MAX_PRECISION) { + // Adjustment only needed when we exceed max precision + DecimalType(precision, scale) + } else { + // Precision/scale exceed maximum precision. Result must be adjusted to MAX_PRECISION. + val intDigits = precision - scale + // If original scale less than MINIMUM_ADJUSTED_SCALE, use original scale value; otherwise + // preserve at least MINIMUM_ADJUSTED_SCALE fractional digits + val minScaleValue = Math.min(scale, MINIMUM_ADJUSTED_SCALE) + val adjustedScale = Math.max(MAX_PRECISION - intDigits, minScaleValue) + + DecimalType(MAX_PRECISION, adjustedScale) + } + } + override private[sql] def defaultConcreteType: DataType = SYSTEM_DEFAULT override private[sql] def acceptsType(other: DataType): Boolean = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index f4514205d3ae0..cd8579584eada 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -408,8 +408,8 @@ class AnalysisSuite extends AnalysisTest with Matchers { assertExpressionType(sum(Divide(1.0, 2.0)), DoubleType) assertExpressionType(sum(Divide(1, 2.0f)), DoubleType) assertExpressionType(sum(Divide(1.0f, 2)), DoubleType) - assertExpressionType(sum(Divide(1, Decimal(2))), DecimalType(31, 11)) - assertExpressionType(sum(Divide(Decimal(1), 2)), DecimalType(31, 11)) + assertExpressionType(sum(Divide(1, Decimal(2))), DecimalType(22, 11)) + assertExpressionType(sum(Divide(Decimal(1), 2)), DecimalType(26, 6)) assertExpressionType(sum(Divide(Decimal(1), 2.0)), DoubleType) assertExpressionType(sum(Divide(1.0, Decimal(2.0))), DoubleType) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 60e46a9910a8b..c86dc18dfa680 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -136,19 +136,19 @@ class DecimalPrecisionSuite extends AnalysisTest with BeforeAndAfter { test("maximum decimals") { for (expr <- Seq(d1, d2, i, u)) { - checkType(Add(expr, u), DecimalType.SYSTEM_DEFAULT) - checkType(Subtract(expr, u), DecimalType.SYSTEM_DEFAULT) + checkType(Add(expr, u), DecimalType(38, 17)) + checkType(Subtract(expr, u), DecimalType(38, 17)) } - checkType(Multiply(d1, u), DecimalType(38, 19)) - checkType(Multiply(d2, u), DecimalType(38, 20)) - checkType(Multiply(i, u), DecimalType(38, 18)) - checkType(Multiply(u, u), DecimalType(38, 36)) + checkType(Multiply(d1, u), DecimalType(38, 16)) + checkType(Multiply(d2, u), DecimalType(38, 14)) + checkType(Multiply(i, u), DecimalType(38, 7)) + checkType(Multiply(u, u), DecimalType(38, 6)) - checkType(Divide(u, d1), DecimalType(38, 18)) - checkType(Divide(u, d2), DecimalType(38, 19)) - checkType(Divide(u, i), DecimalType(38, 23)) - checkType(Divide(u, u), DecimalType(38, 18)) + checkType(Divide(u, d1), DecimalType(38, 17)) + checkType(Divide(u, d2), DecimalType(38, 16)) + checkType(Divide(u, i), DecimalType(38, 18)) + checkType(Divide(u, u), DecimalType(38, 6)) checkType(Remainder(d1, u), DecimalType(19, 18)) checkType(Remainder(d2, u), DecimalType(21, 18)) diff --git a/sql/core/src/test/resources/sql-tests/inputs/decimals.sql b/sql/core/src/test/resources/sql-tests/inputs/decimals.sql new file mode 100644 index 0000000000000..d7f466c300883 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/decimals.sql @@ -0,0 +1,16 @@ +-- tests for decimals handling in operations +-- Spark draws its inspiration byt Hive implementation +create table decimals_test(id int, a decimal(38,18), b decimal(38,18)) using parquet; + +insert into decimals_test values(1, 100.0, 999.0); +insert into decimals_test values(2, 12345.123, 12345.123); +insert into decimals_test values(3, 0.1234567891011, 1234.1); +insert into decimals_test values(4, 123456789123456789.0, 1.123456789123456789); + +-- test decimal operations +select id, a+b, a-b, a*b, a/b from decimals_test order by id; + +-- test operations between decimals and constants +select id, a*10, b/10 from decimals_test order by id; + +drop table decimals_test; \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/decimals.sql.out b/sql/core/src/test/resources/sql-tests/results/decimals.sql.out new file mode 100644 index 0000000000000..73c4ae8551027 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/decimals.sql.out @@ -0,0 +1,72 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 8 + + +-- !query 0 +create table decimals_test(id int, a decimal(38,18), b decimal(38,18)) using parquet +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +insert into decimals_test values(1, 100.0, 999.0) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +insert into decimals_test values(2, 12345.123, 12345.123) +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +insert into decimals_test values(3, 0.1234567891011, 1234.1) +-- !query 3 schema +struct<> +-- !query 3 output + + + +-- !query 4 +insert into decimals_test values(4, 123456789123456789.0, 1.123456789123456789) +-- !query 4 schema +struct<> +-- !query 4 output + + + +-- !query 5 +select id, a+b, a-b, a*b, a/b from decimals_test order by id +-- !query 5 schema +struct +-- !query 5 output +1 1099 -899 99900 0.1001 +2 24690.246 0 152402061.885129 1 +3 1234.2234567891011 -1233.9765432108989 152.358023 0.0001 +4 123456789123456790.12345678912345679 123456789123456787.87654321087654321 138698367904130467.515623 109890109097814272.043109 + + +-- !query 6 +select id, a*10, b/10 from decimals_test order by id +-- !query 6 schema +struct +-- !query 6 output +1 1000 99.9 +2 123451.23 1234.5123 +3 1.234567891011 123.41 +4 1234567891234567890 0.112345678912345679 + + +-- !query 7 +drop table decimals_test +-- !query 7 schema +struct<> +-- !query 7 output + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5e077285ade55..7dfb3cc8aab4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1526,15 +1526,15 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(sql("select 10.300000000000000000 * 3.000000000000000000"), Row(BigDecimal("30.900000000000000000000000000000000000", new MathContext(38)))) checkAnswer(sql("select 10.300000000000000000 * 3.0000000000000000000"), - Row(null)) + Row(BigDecimal("30.900000000000000000000000000000000000", new MathContext(38)))) checkAnswer(sql("select 10.3 / 3.0"), Row(BigDecimal("3.433333"))) checkAnswer(sql("select 10.3000 / 3.0"), Row(BigDecimal("3.4333333"))) checkAnswer(sql("select 10.30000 / 30.0"), Row(BigDecimal("0.343333333"))) checkAnswer(sql("select 10.300000000000000000 / 3.00000000000000000"), - Row(BigDecimal("3.433333333333333333333333333", new MathContext(38)))) + Row(BigDecimal("3.4333333333333333333", new MathContext(38)))) checkAnswer(sql("select 10.3000000000000000000 / 3.00000000000000000"), - Row(BigDecimal("3.4333333333333333333333333333", new MathContext(38)))) + Row(BigDecimal("3.4333333333333333333", new MathContext(38)))) } test("SPARK-10215 Div of Decimal returns null") { From 6701a54068145994e10b8dd38d9a38a1be1f3674 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 9 Jan 2018 19:04:44 +0100 Subject: [PATCH 02/16] introduce spark.sql.decimalOperations.allowTruncat --- .../catalyst/analysis/DecimalPrecision.scala | 81 +++++++++++++------ .../apache/spark/sql/internal/SQLConf.scala | 12 +++ .../apache/spark/sql/types/DecimalType.scala | 12 ++- .../analysis/DecimalPrecisionSuite.scala | 20 ++--- .../resources/sql-tests/inputs/decimals.sql | 35 ++++++-- .../sql-tests/results/decimals.sql.out | 27 ++++--- 6 files changed, 129 insertions(+), 58 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index d64e89513fd55..63919dfcc7ece 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -42,8 +43,10 @@ import org.apache.spark.sql.types._ * e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1) * e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2) * e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2) - * sum(e1) p1 + 10 s1 - * avg(e1) p1 + 4 s1 + 4 + * + * When `spark.sql.decimalOperations.allowTruncat` is set to true, if the precision / scale needed + * are out of the range of available values, the scale is reduced up to 6, in order to prevent the + * truncation of the integer part of the decimals. * * To implement the rules for fixed-precision types, we introduce casts to turn them to unlimited * precision, do the math on unlimited-precision numbers, then introduce casts back to the @@ -56,6 +59,7 @@ import org.apache.spark.sql.types._ * - INT gets turned into DECIMAL(10, 0) * - LONG gets turned into DECIMAL(20, 0) * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE + * - Literals INT and LONG gets turned into DECIMAL with the precision strictly needed by the value */ // scalastyle:on object DecimalPrecision extends TypeCoercionRule { @@ -93,46 +97,76 @@ object DecimalPrecision extends TypeCoercionRule { case e: BinaryArithmetic if e.left.isInstanceOf[PromotePrecision] => e case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultScale = max(s1, s2) - val dt = DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, - resultScale) - CheckOverflow(Add(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt) + val resultType = if (SQLConf.get.decimalOperationsAllowTruncat) { + val resultScale = max(s1, s2) + DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, + resultScale) + } else { + DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + } + CheckOverflow(Add(promotePrecision(e1, resultType), promotePrecision(e2, resultType)), + resultType) case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultScale = max(s1, s2) - val dt = DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, - resultScale) - CheckOverflow(Subtract(promotePrecision(e1, dt), promotePrecision(e2, dt)), dt) + val resultType = if (SQLConf.get.decimalOperationsAllowTruncat) { + val resultScale = max(s1, s2) + DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, + resultScale) + } else { + DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + } + CheckOverflow(Subtract(promotePrecision(e1, resultType), promotePrecision(e2, resultType)), + resultType) case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2) + val resultType = if (SQLConf.get.decimalOperationsAllowTruncat) { + DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2) + } else { + DecimalType.bounded(p1 + p2 + 1, s1 + s2) + } val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow(Multiply(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), resultType) case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - // From https://msdn.microsoft.com/en-us/library/ms190476.aspx - // Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1) - // Scale: max(6, s1 + p2 + 1) - val intDig = p1 - s1 + s2 - val scale = max(DecimalType.MINIMUM_ADJUSTED_SCALE, s1 + p2 + 1) - val prec = intDig + scale - val resultType = DecimalType.adjustPrecisionScale(prec, scale) + val resultType = if (SQLConf.get.decimalOperationsAllowTruncat) { + // Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1) + // Scale: max(6, s1 + p2 + 1) + val intDig = p1 - s1 + s2 + val scale = max(DecimalType.MINIMUM_ADJUSTED_SCALE, s1 + p2 + 1) + val prec = intDig + scale + DecimalType.adjustPrecisionScale(prec, scale) + } else { + var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2) + var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1)) + val diff = (intDig + decDig) - DecimalType.MAX_SCALE + if (diff > 0) { + decDig -= diff / 2 + 1 + intDig = DecimalType.MAX_SCALE - decDig + } + DecimalType.bounded(intDig + decDig, decDig) + } val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow(Divide(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), resultType) case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), - max(s1, s2)) + val resultType = if (SQLConf.get.decimalOperationsAllowTruncat) { + DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } else { + DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } // resultType may have lower precision, so we cast them into wider type first. val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow(Remainder(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), resultType) case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), - max(s1, s2)) + val resultType = if (SQLConf.get.decimalOperationsAllowTruncat) { + DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } else { + DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + } // resultType may have lower precision, so we cast them into wider type first. val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow(Pmod(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), @@ -142,9 +176,6 @@ object DecimalPrecision extends TypeCoercionRule { e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => val resultType = widerDecimalType(p1, s1, p2, s2) b.makeCopy(Array(Cast(e1, resultType), Cast(e2, resultType))) - - // TODO: MaxOf, MinOf, etc might want other rules - // SUM and AVERAGE are handled by the implementations of those expressions } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index cf7e3ebce7411..00b3e877cfa55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1035,6 +1035,16 @@ object SQLConf { .booleanConf .createWithDefault(true) + val DECIMAL_OPERATIONS_ALLOW_TRUNCAT = + buildConf("spark.sql.decimalOperations.allowTruncat") + .internal() + .doc("When true, establishing the result type of an arithmetic operation happens " + + "according to Hive behavior and SQL ANSI 2011 specification, ie. rounding the decimal " + + "part of the result if an exact representation is not possible. Otherwise, NULL is" + + "returned in those cases, as previously (default).") + .booleanConf + .createWithDefault(false) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1346,6 +1356,8 @@ class SQLConf extends Serializable with Logging { def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) + def decimalOperationsAllowTruncat: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_TRUNCAT) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 1f97480d48d1a..31bef1ab17bcd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -152,20 +152,18 @@ object DecimalType extends AbstractDataType { DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE)) } - // scalastyle:off line.size.limit /** - * Decimal implementation is based on Hive's one, which is itself inspired to SQLServer's one. - * In particular, when a result precision is greater than {@link #MAX_PRECISION}, the - * corresponding scale is reduced to prevent the integral part of a result from being truncated. + * Scale adjustment implementation is based on Hive's one, which is itself inspired to + * SQLServer's one. In particular, when a result precision is greater than + * {@link #MAX_PRECISION}, the corresponding scale is reduced to prevent the integral part of a + * result from being truncated. * - * For further reference, please see - * https://blogs.msdn.microsoft.com/sqlprogrammability/2006/03/29/multiplication-and-division-with-numerics/. + * This method is used only when `spark.sql.decimalOperations.allowTruncat` is set to true. * * @param precision * @param scale * @return */ - // scalastyle:on line.size.limit private[sql] def adjustPrecisionScale(precision: Int, scale: Int): DecimalType = { // Assumptions: // precision >= scale diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index c86dc18dfa680..60e46a9910a8b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -136,19 +136,19 @@ class DecimalPrecisionSuite extends AnalysisTest with BeforeAndAfter { test("maximum decimals") { for (expr <- Seq(d1, d2, i, u)) { - checkType(Add(expr, u), DecimalType(38, 17)) - checkType(Subtract(expr, u), DecimalType(38, 17)) + checkType(Add(expr, u), DecimalType.SYSTEM_DEFAULT) + checkType(Subtract(expr, u), DecimalType.SYSTEM_DEFAULT) } - checkType(Multiply(d1, u), DecimalType(38, 16)) - checkType(Multiply(d2, u), DecimalType(38, 14)) - checkType(Multiply(i, u), DecimalType(38, 7)) - checkType(Multiply(u, u), DecimalType(38, 6)) + checkType(Multiply(d1, u), DecimalType(38, 19)) + checkType(Multiply(d2, u), DecimalType(38, 20)) + checkType(Multiply(i, u), DecimalType(38, 18)) + checkType(Multiply(u, u), DecimalType(38, 36)) - checkType(Divide(u, d1), DecimalType(38, 17)) - checkType(Divide(u, d2), DecimalType(38, 16)) - checkType(Divide(u, i), DecimalType(38, 18)) - checkType(Divide(u, u), DecimalType(38, 6)) + checkType(Divide(u, d1), DecimalType(38, 18)) + checkType(Divide(u, d2), DecimalType(38, 19)) + checkType(Divide(u, i), DecimalType(38, 23)) + checkType(Divide(u, u), DecimalType(38, 18)) checkType(Remainder(d1, u), DecimalType(19, 18)) checkType(Remainder(d2, u), DecimalType(21, 18)) diff --git a/sql/core/src/test/resources/sql-tests/inputs/decimals.sql b/sql/core/src/test/resources/sql-tests/inputs/decimals.sql index d7f466c300883..98ea04adde4ff 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/decimals.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/decimals.sql @@ -1,11 +1,34 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one or more +-- contributor license agreements. See the NOTICE file distributed with +-- this work for additional information regarding copyright ownership. +-- The ASF licenses this file to You under the Apache License, Version 2.0 +-- (the "License"); you may not use this file except in compliance with +-- the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- + -- tests for decimals handling in operations --- Spark draws its inspiration byt Hive implementation create table decimals_test(id int, a decimal(38,18), b decimal(38,18)) using parquet; -insert into decimals_test values(1, 100.0, 999.0); -insert into decimals_test values(2, 12345.123, 12345.123); -insert into decimals_test values(3, 0.1234567891011, 1234.1); -insert into decimals_test values(4, 123456789123456789.0, 1.123456789123456789); +insert into decimals_test values(1, 100.0, 999.0), (2, 12345.123, 12345.123), + (3, 0.1234567891011, 1234.1), (4, 123456789123456789.0, 1.123456789123456789); + +-- test decimal operations +select id, a+b, a-b, a*b, a/b from decimals_test order by id; + +-- test operations between decimals and constants +select id, a*10, b/10 from decimals_test order by id; + +-- use rounding instead of returning NULL, according to new Hive's behavior and SQL standard +set spark.sql.decimalOperations.allowTruncat=true; -- test decimal operations select id, a+b, a-b, a*b, a/b from decimals_test order by id; @@ -13,4 +36,4 @@ select id, a+b, a-b, a*b, a/b from decimals_test order by id; -- test operations between decimals and constants select id, a*10, b/10 from decimals_test order by id; -drop table decimals_test; \ No newline at end of file +drop table decimals_test; diff --git a/sql/core/src/test/resources/sql-tests/results/decimals.sql.out b/sql/core/src/test/resources/sql-tests/results/decimals.sql.out index 73c4ae8551027..784405b1be30d 100644 --- a/sql/core/src/test/resources/sql-tests/results/decimals.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/decimals.sql.out @@ -11,7 +11,8 @@ struct<> -- !query 1 -insert into decimals_test values(1, 100.0, 999.0) +insert into decimals_test values(1, 100.0, 999.0), (2, 12345.123, 12345.123), + (3, 0.1234567891011, 1234.1), (4, 123456789123456789.0, 1.123456789123456789) -- !query 1 schema struct<> -- !query 1 output @@ -19,27 +20,33 @@ struct<> -- !query 2 -insert into decimals_test values(2, 12345.123, 12345.123) +select id, a+b, a-b, a*b, a/b from decimals_test order by id -- !query 2 schema -struct<> +struct -- !query 2 output - +1 1099 -899 NULL 0.1001001001001001 +2 24690.246 0 NULL 1 +3 1234.2234567891011 -1233.9765432108989 NULL 0.000100037913541123 +4 123456789123456790.123456789123456789 123456789123456787.876543210876543211 NULL 109890109097814272.043109406191131436 -- !query 3 -insert into decimals_test values(3, 0.1234567891011, 1234.1) +select id, a*10, b/10 from decimals_test order by id -- !query 3 schema -struct<> +struct -- !query 3 output - +1 1000 99.9 +2 123451.23 1234.5123 +3 1.234567891011 123.41 +4 1234567891234567890 0.1123456789123456789 -- !query 4 -insert into decimals_test values(4, 123456789123456789.0, 1.123456789123456789) +set spark.sql.decimalOperations.allowTruncat=true -- !query 4 schema -struct<> +struct -- !query 4 output - +spark.sql.decimalOperations.allowTruncat true -- !query 5 From ecd6a2a0d01283b0d3a9582c93618ff1c2a772f7 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 10 Jan 2018 11:52:35 +0100 Subject: [PATCH 03/16] fix UT --- .../org/apache/spark/sql/SQLQuerySuite.scala | 49 +++++++++++++------ 1 file changed, 34 insertions(+), 15 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 1f44096f8f132..1c7eb7138ea52 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1518,21 +1518,40 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("decimal precision with multiply/division") { - checkAnswer(sql("select 10.3 * 3.0"), Row(BigDecimal("30.90"))) - checkAnswer(sql("select 10.3000 * 3.0"), Row(BigDecimal("30.90000"))) - checkAnswer(sql("select 10.30000 * 30.0"), Row(BigDecimal("309.000000"))) - checkAnswer(sql("select 10.300000000000000000 * 3.000000000000000000"), - Row(BigDecimal("30.900000000000000000000000000000000000", new MathContext(38)))) - checkAnswer(sql("select 10.300000000000000000 * 3.0000000000000000000"), - Row(BigDecimal("30.900000000000000000000000000000000000", new MathContext(38)))) - - checkAnswer(sql("select 10.3 / 3.0"), Row(BigDecimal("3.433333"))) - checkAnswer(sql("select 10.3000 / 3.0"), Row(BigDecimal("3.4333333"))) - checkAnswer(sql("select 10.30000 / 30.0"), Row(BigDecimal("0.343333333"))) - checkAnswer(sql("select 10.300000000000000000 / 3.00000000000000000"), - Row(BigDecimal("3.4333333333333333333", new MathContext(38)))) - checkAnswer(sql("select 10.3000000000000000000 / 3.00000000000000000"), - Row(BigDecimal("3.4333333333333333333", new MathContext(38)))) + withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_TRUNCAT.key -> "false") { + checkAnswer(sql("select 10.3 * 3.0"), Row(BigDecimal("30.90"))) + checkAnswer(sql("select 10.3000 * 3.0"), Row(BigDecimal("30.90000"))) + checkAnswer(sql("select 10.30000 * 30.0"), Row(BigDecimal("309.000000"))) + checkAnswer(sql("select 10.300000000000000000 * 3.000000000000000000"), + Row(BigDecimal("30.900000000000000000000000000000000000", new MathContext(38)))) + checkAnswer(sql("select 10.300000000000000000 * 3.0000000000000000000"), + Row(null)) + + checkAnswer(sql("select 10.3 / 3.0"), Row(BigDecimal("3.433333"))) + checkAnswer(sql("select 10.3000 / 3.0"), Row(BigDecimal("3.4333333"))) + checkAnswer(sql("select 10.30000 / 30.0"), Row(BigDecimal("0.343333333"))) + checkAnswer(sql("select 10.300000000000000000 / 3.00000000000000000"), + Row(BigDecimal("3.433333333333333333333333333", new MathContext(38)))) + checkAnswer(sql("select 10.3000000000000000000 / 3.00000000000000000"), + Row(BigDecimal("3.4333333333333333333333333333", new MathContext(38)))) + } + withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_TRUNCAT.key -> "true") { + checkAnswer(sql("select 10.3 * 3.0"), Row(BigDecimal("30.90"))) + checkAnswer(sql("select 10.3000 * 3.0"), Row(BigDecimal("30.90000"))) + checkAnswer(sql("select 10.30000 * 30.0"), Row(BigDecimal("309.000000"))) + checkAnswer(sql("select 10.300000000000000000 * 3.000000000000000000"), + Row(BigDecimal("30.900000000000000000000000000000000000", new MathContext(38)))) + checkAnswer(sql("select 10.300000000000000000 * 3.0000000000000000000"), + Row(BigDecimal("30.900000000000000000000000000000000000", new MathContext(38)))) + + checkAnswer(sql("select 10.3 / 3.0"), Row(BigDecimal("3.433333"))) + checkAnswer(sql("select 10.3000 / 3.0"), Row(BigDecimal("3.4333333"))) + checkAnswer(sql("select 10.30000 / 30.0"), Row(BigDecimal("0.343333333"))) + checkAnswer(sql("select 10.300000000000000000 / 3.00000000000000000"), + Row(BigDecimal("3.4333333333333333333", new MathContext(38)))) + checkAnswer(sql("select 10.3000000000000000000 / 3.00000000000000000"), + Row(BigDecimal("3.4333333333333333333", new MathContext(38)))) + } } test("SPARK-10215 Div of Decimal returns null") { From cb62433880b08defb4d14cc90429a4a39ab7e0be Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 15 Jan 2018 12:34:14 +0100 Subject: [PATCH 04/16] rename to allowPrecisionLoss --- .../catalyst/analysis/DecimalPrecision.scala | 18 +++++++++--------- .../apache/spark/sql/internal/SQLConf.scala | 6 +++--- .../apache/spark/sql/types/DecimalType.scala | 2 +- .../resources/sql-tests/inputs/decimals.sql | 2 +- .../sql-tests/results/decimals.sql.out | 4 ++-- 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index 63919dfcc7ece..aad8f4b9d6009 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -44,9 +44,9 @@ import org.apache.spark.sql.types._ * e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2) * e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2) * - * When `spark.sql.decimalOperations.allowTruncat` is set to true, if the precision / scale needed - * are out of the range of available values, the scale is reduced up to 6, in order to prevent the - * truncation of the integer part of the decimals. + * When `spark.sql.decimalOperations.allowPrecisionLoss` is set to true, if the precision / scale + * needed are out of the range of available values, the scale is reduced up to 6, in order to + * prevent the truncation of the integer part of the decimals. * * To implement the rules for fixed-precision types, we introduce casts to turn them to unlimited * precision, do the math on unlimited-precision numbers, then introduce casts back to the @@ -97,7 +97,7 @@ object DecimalPrecision extends TypeCoercionRule { case e: BinaryArithmetic if e.left.isInstanceOf[PromotePrecision] => e case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = if (SQLConf.get.decimalOperationsAllowTruncat) { + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { val resultScale = max(s1, s2) DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) @@ -108,7 +108,7 @@ object DecimalPrecision extends TypeCoercionRule { resultType) case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = if (SQLConf.get.decimalOperationsAllowTruncat) { + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { val resultScale = max(s1, s2) DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) @@ -119,7 +119,7 @@ object DecimalPrecision extends TypeCoercionRule { resultType) case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = if (SQLConf.get.decimalOperationsAllowTruncat) { + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2) } else { DecimalType.bounded(p1 + p2 + 1, s1 + s2) @@ -129,7 +129,7 @@ object DecimalPrecision extends TypeCoercionRule { resultType) case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = if (SQLConf.get.decimalOperationsAllowTruncat) { + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { // Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1) // Scale: max(6, s1 + p2 + 1) val intDig = p1 - s1 + s2 @@ -151,7 +151,7 @@ object DecimalPrecision extends TypeCoercionRule { resultType) case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = if (SQLConf.get.decimalOperationsAllowTruncat) { + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) } else { DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) @@ -162,7 +162,7 @@ object DecimalPrecision extends TypeCoercionRule { resultType) case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = if (SQLConf.get.decimalOperationsAllowTruncat) { + val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) } else { DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index f1bdab3da7e20..22eed6f6300c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1048,8 +1048,8 @@ object SQLConf { .booleanConf .createWithDefault(true) - val DECIMAL_OPERATIONS_ALLOW_TRUNCAT = - buildConf("spark.sql.decimalOperations.allowTruncat") + val DECIMAL_OPERATIONS_ALLOW_PREC_LOSS = + buildConf("spark.sql.decimalOperations.allowPrecisionLoss") .internal() .doc("When true, establishing the result type of an arithmetic operation happens " + "according to Hive behavior and SQL ANSI 2011 specification, ie. rounding the decimal " + @@ -1433,7 +1433,7 @@ class SQLConf extends Serializable with Logging { def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) - def decimalOperationsAllowTruncat: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_TRUNCAT) + def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS) def continuousStreamingExecutorQueueSize: Int = getConf(CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 31bef1ab17bcd..1d5f29c35d33a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -158,7 +158,7 @@ object DecimalType extends AbstractDataType { * {@link #MAX_PRECISION}, the corresponding scale is reduced to prevent the integral part of a * result from being truncated. * - * This method is used only when `spark.sql.decimalOperations.allowTruncat` is set to true. + * This method is used only when `spark.sql.decimalOperations.allowPrecisionLoss` is set to true. * * @param precision * @param scale diff --git a/sql/core/src/test/resources/sql-tests/inputs/decimals.sql b/sql/core/src/test/resources/sql-tests/inputs/decimals.sql index 98ea04adde4ff..d36c57f3c879f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/decimals.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/decimals.sql @@ -28,7 +28,7 @@ select id, a+b, a-b, a*b, a/b from decimals_test order by id; select id, a*10, b/10 from decimals_test order by id; -- use rounding instead of returning NULL, according to new Hive's behavior and SQL standard -set spark.sql.decimalOperations.allowTruncat=true; +set spark.sql.decimalOperations.allowPrecisionLoss=true; -- test decimal operations select id, a+b, a-b, a*b, a/b from decimals_test order by id; diff --git a/sql/core/src/test/resources/sql-tests/results/decimals.sql.out b/sql/core/src/test/resources/sql-tests/results/decimals.sql.out index 784405b1be30d..1ccb870e93c8a 100644 --- a/sql/core/src/test/resources/sql-tests/results/decimals.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/decimals.sql.out @@ -42,11 +42,11 @@ struct -- !query 4 output -spark.sql.decimalOperations.allowTruncat true +spark.sql.decimalOperations.allowPrecisionLoss true -- !query 5 From 77e445fa51ad413cc608f6f4421a186f7fd282ba Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 15 Jan 2018 12:40:10 +0100 Subject: [PATCH 05/16] address comments --- .../spark/sql/catalyst/analysis/DecimalPrecision.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index aad8f4b9d6009..2cedc866c455d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -97,23 +97,23 @@ object DecimalPrecision extends TypeCoercionRule { case e: BinaryArithmetic if e.left.isInstanceOf[PromotePrecision] => e case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + val resultScale = max(s1, s2) val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { - val resultScale = max(s1, s2) DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) } else { - DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) } CheckOverflow(Add(promotePrecision(e1, resultType), promotePrecision(e2, resultType)), resultType) case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + val resultScale = max(s1, s2) val resultType = if (SQLConf.get.decimalOperationsAllowPrecisionLoss) { - val resultScale = max(s1, s2) DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) } else { - DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale) } CheckOverflow(Subtract(promotePrecision(e1, resultType), promotePrecision(e2, resultType)), resultType) From 519571d5df6cecc9095bb2029c370fb52c9f6b16 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 15 Jan 2018 15:45:21 +0100 Subject: [PATCH 06/16] fix typo --- .../apache/spark/sql/catalyst/analysis/DecimalPrecision.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index 2cedc866c455d..07e546aa254af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -59,7 +59,7 @@ import org.apache.spark.sql.types._ * - INT gets turned into DECIMAL(10, 0) * - LONG gets turned into DECIMAL(20, 0) * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE - * - Literals INT and LONG gets turned into DECIMAL with the precision strictly needed by the value + * - Literals INT and LONG get turned into DECIMAL with the precision strictly needed by the value */ // scalastyle:on object DecimalPrecision extends TypeCoercionRule { From 1f36cf6bcf784a04d77b2b5153cd504e6155c875 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 15 Jan 2018 18:25:15 +0100 Subject: [PATCH 07/16] fix build error --- .../src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 1c7eb7138ea52..67602426e750a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1518,7 +1518,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("decimal precision with multiply/division") { - withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_TRUNCAT.key -> "false") { + withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key -> "false") { checkAnswer(sql("select 10.3 * 3.0"), Row(BigDecimal("30.90"))) checkAnswer(sql("select 10.3000 * 3.0"), Row(BigDecimal("30.90000"))) checkAnswer(sql("select 10.30000 * 30.0"), Row(BigDecimal("309.000000"))) @@ -1535,7 +1535,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(sql("select 10.3000000000000000000 / 3.00000000000000000"), Row(BigDecimal("3.4333333333333333333333333333", new MathContext(38)))) } - withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_TRUNCAT.key -> "true") { + withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key -> "true") { checkAnswer(sql("select 10.3 * 3.0"), Row(BigDecimal("30.90"))) checkAnswer(sql("select 10.3000 * 3.0"), Row(BigDecimal("30.90000"))) checkAnswer(sql("select 10.30000 * 30.0"), Row(BigDecimal("309.000000"))) From 1ff819d897a815077543d4e1666830e119b731fa Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 16 Jan 2018 17:42:29 +0100 Subject: [PATCH 08/16] address comments --- .../org/apache/spark/sql/types/DecimalType.scala | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 1d5f29c35d33a..1b94cf8fd714e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -159,24 +159,24 @@ object DecimalType extends AbstractDataType { * result from being truncated. * * This method is used only when `spark.sql.decimalOperations.allowPrecisionLoss` is set to true. - * - * @param precision - * @param scale - * @return */ private[sql] def adjustPrecisionScale(precision: Int, scale: Int): DecimalType = { // Assumptions: - // precision >= scale - // scale >= 0 + assert(precision >= scale) + assert(scale >= 0) + if (precision <= MAX_PRECISION) { // Adjustment only needed when we exceed max precision DecimalType(precision, scale) } else { // Precision/scale exceed maximum precision. Result must be adjusted to MAX_PRECISION. val intDigits = precision - scale - // If original scale less than MINIMUM_ADJUSTED_SCALE, use original scale value; otherwise + // If original scale is less than MINIMUM_ADJUSTED_SCALE, use original scale value; otherwise // preserve at least MINIMUM_ADJUSTED_SCALE fractional digits val minScaleValue = Math.min(scale, MINIMUM_ADJUSTED_SCALE) + // The resulting scale is the maximum between what is available without causing a loss of + // digits for the integer part of the decimal and the minimum guaranteed scale, which is + // computed above val adjustedScale = Math.max(MAX_PRECISION - intDigits, minScaleValue) DecimalType(MAX_PRECISION, adjustedScale) From c1c57814a1177144bc3e35730b74b256a83f8317 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 16 Jan 2018 18:03:22 +0100 Subject: [PATCH 09/16] allowPrecisionLoss true by default --- .../apache/spark/sql/internal/SQLConf.scala | 10 +- .../analysis/DecimalPrecisionSuite.scala | 20 +-- .../resources/sql-tests/inputs/decimals.sql | 18 ++- .../sql-tests/results/decimals.sql.out | 126 ++++++++++++++---- .../decimalArithmeticOperations.sql.out | 6 +- .../native/decimalPrecision.sql.out | 4 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 37 ----- 7 files changed, 139 insertions(+), 82 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 22eed6f6300c0..26139b80b5fa9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1051,12 +1051,12 @@ object SQLConf { val DECIMAL_OPERATIONS_ALLOW_PREC_LOSS = buildConf("spark.sql.decimalOperations.allowPrecisionLoss") .internal() - .doc("When true, establishing the result type of an arithmetic operation happens " + - "according to Hive behavior and SQL ANSI 2011 specification, ie. rounding the decimal " + - "part of the result if an exact representation is not possible. Otherwise, NULL is" + - "returned in those cases, as previously (default).") + .doc("When true (default), establishing the result type of an arithmetic operation " + + "happens according to Hive behavior and SQL ANSI 2011 specification, ie. rounding the " + + "decimal part of the result if an exact representation is not possible. Otherwise, NULL " + + "is returned in those cases, as previously.") .booleanConf - .createWithDefault(false) + .createWithDefault(true) val SQL_STRING_REDACTION_PATTERN = ConfigBuilder("spark.sql.redaction.string.regex") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 60e46a9910a8b..c86dc18dfa680 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -136,19 +136,19 @@ class DecimalPrecisionSuite extends AnalysisTest with BeforeAndAfter { test("maximum decimals") { for (expr <- Seq(d1, d2, i, u)) { - checkType(Add(expr, u), DecimalType.SYSTEM_DEFAULT) - checkType(Subtract(expr, u), DecimalType.SYSTEM_DEFAULT) + checkType(Add(expr, u), DecimalType(38, 17)) + checkType(Subtract(expr, u), DecimalType(38, 17)) } - checkType(Multiply(d1, u), DecimalType(38, 19)) - checkType(Multiply(d2, u), DecimalType(38, 20)) - checkType(Multiply(i, u), DecimalType(38, 18)) - checkType(Multiply(u, u), DecimalType(38, 36)) + checkType(Multiply(d1, u), DecimalType(38, 16)) + checkType(Multiply(d2, u), DecimalType(38, 14)) + checkType(Multiply(i, u), DecimalType(38, 7)) + checkType(Multiply(u, u), DecimalType(38, 6)) - checkType(Divide(u, d1), DecimalType(38, 18)) - checkType(Divide(u, d2), DecimalType(38, 19)) - checkType(Divide(u, i), DecimalType(38, 23)) - checkType(Divide(u, u), DecimalType(38, 18)) + checkType(Divide(u, d1), DecimalType(38, 17)) + checkType(Divide(u, d2), DecimalType(38, 16)) + checkType(Divide(u, i), DecimalType(38, 18)) + checkType(Divide(u, u), DecimalType(38, 6)) checkType(Remainder(d1, u), DecimalType(19, 18)) checkType(Remainder(d2, u), DecimalType(21, 18)) diff --git a/sql/core/src/test/resources/sql-tests/inputs/decimals.sql b/sql/core/src/test/resources/sql-tests/inputs/decimals.sql index d36c57f3c879f..271a2d8adb583 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/decimals.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/decimals.sql @@ -27,8 +27,15 @@ select id, a+b, a-b, a*b, a/b from decimals_test order by id; -- test operations between decimals and constants select id, a*10, b/10 from decimals_test order by id; --- use rounding instead of returning NULL, according to new Hive's behavior and SQL standard -set spark.sql.decimalOperations.allowPrecisionLoss=true; +-- test operations on constants +select 10.3 * 3.0; +select 10.3000 * 3.0; +select 10.30000 * 30.0; +select 10.300000000000000000 * 3.000000000000000000; +select 10.300000000000000000 * 3.0000000000000000000; + +-- return NULL instead of rounding, according to old Spark versions' behavior +set spark.sql.decimalOperations.allowPrecisionLoss=false; -- test decimal operations select id, a+b, a-b, a*b, a/b from decimals_test order by id; @@ -36,4 +43,11 @@ select id, a+b, a-b, a*b, a/b from decimals_test order by id; -- test operations between decimals and constants select id, a*10, b/10 from decimals_test order by id; +-- test operations on constants +select 10.3 * 3.0; +select 10.3000 * 3.0; +select 10.30000 * 30.0; +select 10.300000000000000000 * 3.000000000000000000; +select 10.300000000000000000 * 3.0000000000000000000; + drop table decimals_test; diff --git a/sql/core/src/test/resources/sql-tests/results/decimals.sql.out b/sql/core/src/test/resources/sql-tests/results/decimals.sql.out index 1ccb870e93c8a..c479797916f14 100644 --- a/sql/core/src/test/resources/sql-tests/results/decimals.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/decimals.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 8 +-- Number of queries: 18 -- !query 0 @@ -22,58 +22,138 @@ struct<> -- !query 2 select id, a+b, a-b, a*b, a/b from decimals_test order by id -- !query 2 schema -struct +struct -- !query 2 output -1 1099 -899 NULL 0.1001001001001001 -2 24690.246 0 NULL 1 -3 1234.2234567891011 -1233.9765432108989 NULL 0.000100037913541123 -4 123456789123456790.123456789123456789 123456789123456787.876543210876543211 NULL 109890109097814272.043109406191131436 +1 1099 -899 99900 0.1001 +2 24690.246 0 152402061.885129 1 +3 1234.2234567891011 -1233.9765432108989 152.358023 0.0001 +4 123456789123456790.12345678912345679 123456789123456787.87654321087654321 138698367904130467.515623 109890109097814272.043109 -- !query 3 select id, a*10, b/10 from decimals_test order by id -- !query 3 schema -struct +struct -- !query 3 output 1 1000 99.9 2 123451.23 1234.5123 3 1.234567891011 123.41 -4 1234567891234567890 0.1123456789123456789 +4 1234567891234567890 0.112345678912345679 -- !query 4 -set spark.sql.decimalOperations.allowPrecisionLoss=true +select 10.3 * 3.0 -- !query 4 schema -struct +struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)> -- !query 4 output -spark.sql.decimalOperations.allowPrecisionLoss true +30.9 -- !query 5 -select id, a+b, a-b, a*b, a/b from decimals_test order by id +select 10.3000 * 3.0 -- !query 5 schema -struct +struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)> -- !query 5 output -1 1099 -899 99900 0.1001 -2 24690.246 0 152402061.885129 1 -3 1234.2234567891011 -1233.9765432108989 152.358023 0.0001 -4 123456789123456790.12345678912345679 123456789123456787.87654321087654321 138698367904130467.515623 109890109097814272.043109 +30.9 -- !query 6 -select id, a*10, b/10 from decimals_test order by id +select 10.30000 * 30.0 -- !query 6 schema -struct +struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)> -- !query 6 output +309 + + +-- !query 7 +select 10.300000000000000000 * 3.000000000000000000 +-- !query 7 schema +struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,34)> +-- !query 7 output +30.9 + + +-- !query 8 +select 10.300000000000000000 * 3.0000000000000000000 +-- !query 8 schema +struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.0000000000000000000 AS DECIMAL(21,19))):decimal(38,34)> +-- !query 8 output +30.9 + + +-- !query 9 +set spark.sql.decimalOperations.allowPrecisionLoss=false +-- !query 9 schema +struct +-- !query 9 output +spark.sql.decimalOperations.allowPrecisionLoss false + + +-- !query 10 +select id, a+b, a-b, a*b, a/b from decimals_test order by id +-- !query 10 schema +struct +-- !query 10 output +1 1099 -899 NULL 0.1001001001001001 +2 24690.246 0 NULL 1 +3 1234.2234567891011 -1233.9765432108989 NULL 0.000100037913541123 +4 123456789123456790.123456789123456789 123456789123456787.876543210876543211 NULL 109890109097814272.043109406191131436 + + +-- !query 11 +select id, a*10, b/10 from decimals_test order by id +-- !query 11 schema +struct +-- !query 11 output 1 1000 99.9 2 123451.23 1234.5123 3 1.234567891011 123.41 -4 1234567891234567890 0.112345678912345679 +4 1234567891234567890 0.1123456789123456789 --- !query 7 +-- !query 12 +select 10.3 * 3.0 +-- !query 12 schema +struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)> +-- !query 12 output +30.9 + + +-- !query 13 +select 10.3000 * 3.0 +-- !query 13 schema +struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)> +-- !query 13 output +30.9 + + +-- !query 14 +select 10.30000 * 30.0 +-- !query 14 schema +struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)> +-- !query 14 output +309 + + +-- !query 15 +select 10.300000000000000000 * 3.000000000000000000 +-- !query 15 schema +struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,36)> +-- !query 15 output +30.9 + + +-- !query 16 +select 10.300000000000000000 * 3.0000000000000000000 +-- !query 16 schema +struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.0000000000000000000 AS DECIMAL(21,19))):decimal(38,37)> +-- !query 16 output +NULL + + +-- !query 17 drop table decimals_test --- !query 7 schema +-- !query 17 schema struct<> --- !query 7 output +-- !query 17 output diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out index ce02f6adc456c..dbe5e3e9cf3d6 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out @@ -61,7 +61,7 @@ NULL -- !query 7 select 1e35 / 0.1 -- !query 7 schema -struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,3)> +struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,6)> -- !query 7 output NULL @@ -69,9 +69,9 @@ NULL -- !query 8 select 123456789123456789.1234567890 * 1.123456789123456789 -- !query 8 schema -struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,28)> +struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,18)> -- !query 8 output -NULL +138698367904130467.654320988515622621 -- !query 9 diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalPrecision.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalPrecision.sql.out index ebc8201ed5a1d..6ee7f59d69877 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalPrecision.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalPrecision.sql.out @@ -2329,7 +2329,7 @@ struct<(CAST(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) AS DECIMAL(20,0)) / CAST(C -- !query 280 SELECT cast(1 as bigint) / cast(1 as decimal(20, 0)) FROM t -- !query 280 schema -struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) / CAST(1 AS DECIMAL(20,0))):decimal(38,19)> +struct<(CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0)) / CAST(1 AS DECIMAL(20,0))):decimal(38,18)> -- !query 280 output 1 @@ -2661,7 +2661,7 @@ struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DECIMAL(20,0)) / CAST(CAST(CAST(1 AS BI -- !query 320 SELECT cast(1 as decimal(20, 0)) / cast(1 as bigint) FROM t -- !query 320 schema -struct<(CAST(1 AS DECIMAL(20,0)) / CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0))):decimal(38,19)> +struct<(CAST(1 AS DECIMAL(20,0)) / CAST(CAST(1 AS BIGINT) AS DECIMAL(20,0))):decimal(38,18)> -- !query 320 output 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 67602426e750a..9c3807cc180af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1517,43 +1517,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } - test("decimal precision with multiply/division") { - withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key -> "false") { - checkAnswer(sql("select 10.3 * 3.0"), Row(BigDecimal("30.90"))) - checkAnswer(sql("select 10.3000 * 3.0"), Row(BigDecimal("30.90000"))) - checkAnswer(sql("select 10.30000 * 30.0"), Row(BigDecimal("309.000000"))) - checkAnswer(sql("select 10.300000000000000000 * 3.000000000000000000"), - Row(BigDecimal("30.900000000000000000000000000000000000", new MathContext(38)))) - checkAnswer(sql("select 10.300000000000000000 * 3.0000000000000000000"), - Row(null)) - - checkAnswer(sql("select 10.3 / 3.0"), Row(BigDecimal("3.433333"))) - checkAnswer(sql("select 10.3000 / 3.0"), Row(BigDecimal("3.4333333"))) - checkAnswer(sql("select 10.30000 / 30.0"), Row(BigDecimal("0.343333333"))) - checkAnswer(sql("select 10.300000000000000000 / 3.00000000000000000"), - Row(BigDecimal("3.433333333333333333333333333", new MathContext(38)))) - checkAnswer(sql("select 10.3000000000000000000 / 3.00000000000000000"), - Row(BigDecimal("3.4333333333333333333333333333", new MathContext(38)))) - } - withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key -> "true") { - checkAnswer(sql("select 10.3 * 3.0"), Row(BigDecimal("30.90"))) - checkAnswer(sql("select 10.3000 * 3.0"), Row(BigDecimal("30.90000"))) - checkAnswer(sql("select 10.30000 * 30.0"), Row(BigDecimal("309.000000"))) - checkAnswer(sql("select 10.300000000000000000 * 3.000000000000000000"), - Row(BigDecimal("30.900000000000000000000000000000000000", new MathContext(38)))) - checkAnswer(sql("select 10.300000000000000000 * 3.0000000000000000000"), - Row(BigDecimal("30.900000000000000000000000000000000000", new MathContext(38)))) - - checkAnswer(sql("select 10.3 / 3.0"), Row(BigDecimal("3.433333"))) - checkAnswer(sql("select 10.3000 / 3.0"), Row(BigDecimal("3.4333333"))) - checkAnswer(sql("select 10.30000 / 30.0"), Row(BigDecimal("0.343333333"))) - checkAnswer(sql("select 10.300000000000000000 / 3.00000000000000000"), - Row(BigDecimal("3.4333333333333333333", new MathContext(38)))) - checkAnswer(sql("select 10.3000000000000000000 / 3.00000000000000000"), - Row(BigDecimal("3.4333333333333333333", new MathContext(38)))) - } - } - test("SPARK-10215 Div of Decimal returns null") { val d = Decimal(1.12321).toBigDecimal val df = Seq((d, 1)).toDF("a", "b") From 3e79a078f0cf08729e390d9e277abdca77f387a6 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 16 Jan 2018 18:16:00 +0100 Subject: [PATCH 10/16] readability fix --- .../catalyst/analysis/DecimalPrecision.scala | 60 +++++++------------ 1 file changed, 22 insertions(+), 38 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index 07e546aa254af..2716a75914f4f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -278,44 +278,28 @@ object DecimalPrecision extends TypeCoercionRule { private val nondecimalAndDecimal: PartialFunction[Expression, Expression] = { // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles - case b @ BinaryOperator(left, right) if left.dataType != right.dataType => - nondecimalLiteralAndDecimal(b).lift((left, right)).getOrElse( - nondecimalNonliteralAndDecimal(b).applyOrElse((left.dataType, right.dataType), - (_: (DataType, DataType)) => b)) - } - - /** - * Type coercion for BinaryOperator in which one side is a non-decimal literal numeric, and the - * other side is a decimal. - */ - private def nondecimalLiteralAndDecimal( - b: BinaryOperator): PartialFunction[(Expression, Expression), Expression] = { - // Promote literal integers inside a binary expression with fixed-precision decimals to - // decimals. The precision and scale are the ones needed by the integer value. - case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType] - && l.dataType.isInstanceOf[IntegralType] => - b.makeCopy(Array(Cast(l, DecimalType.forLiteral(l)), r)) - case (l, r: Literal) if l.dataType.isInstanceOf[DecimalType] - && r.dataType.isInstanceOf[IntegralType] => - b.makeCopy(Array(l, Cast(r, DecimalType.forLiteral(r)))) - } - - /** - * Type coercion for BinaryOperator in which one side is a non-decimal non-literal numeric, and - * the other side is a decimal. - */ - private def nondecimalNonliteralAndDecimal( - b: BinaryOperator): PartialFunction[(DataType, DataType), Expression] = { - // Promote integers inside a binary expression with fixed-precision decimals to decimals, - // and fixed-precision decimals in an expression with floats / doubles to doubles - case (t: IntegralType, DecimalType.Fixed(p, s)) => - b.makeCopy(Array(Cast(b.left, DecimalType.forType(t)), b.right)) - case (DecimalType.Fixed(_, _), t: IntegralType) => - b.makeCopy(Array(b.left, Cast(b.right, DecimalType.forType(t)))) - case (t, DecimalType.Fixed(_, _)) if isFloat(t) => - b.makeCopy(Array(b.left, Cast(b.right, DoubleType))) - case (DecimalType.Fixed(_, _), t) if isFloat(t) => - b.makeCopy(Array(Cast(b.left, DoubleType), b.right)) + case b@BinaryOperator(left, right) if left.dataType != right.dataType => + (left, right) match { + // Promote literal integers inside a binary expression with fixed-precision decimals to + // decimals. The precision and scale are the ones needed by the integer value. + case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType] + && l.dataType.isInstanceOf[IntegralType] => + b.makeCopy(Array(Cast(l, DecimalType.forLiteral(l)), r)) + case (l, r: Literal) if l.dataType.isInstanceOf[DecimalType] + && r.dataType.isInstanceOf[IntegralType] => + b.makeCopy(Array(l, Cast(r, DecimalType.forLiteral(r)))) + // Promote integers inside a binary expression with fixed-precision decimals to decimals, + // and fixed-precision decimals in an expression with floats / doubles to doubles + case (l @ IntegralType(), r @ DecimalType.Expression(_, _)) => + b.makeCopy(Array(Cast(l, DecimalType.forType(l.dataType)), r)) + case (l @ DecimalType.Expression(_, _), r @ IntegralType()) => + b.makeCopy(Array(l, Cast(r, DecimalType.forType(r.dataType)))) + case (l, r @ DecimalType.Expression(_, _)) if isFloat(l.dataType) => + b.makeCopy(Array(l, Cast(r, DoubleType))) + case (l @ DecimalType.Expression(_, _), r) if isFloat(r.dataType) => + b.makeCopy(Array(Cast(l, DoubleType), r)) + case _ => b + } } } From 090659fe5f2471462ada0d54c0c855d9fe4aba7e Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 16 Jan 2018 18:25:15 +0100 Subject: [PATCH 11/16] adding example in comments for literals --- .../spark/sql/catalyst/analysis/DecimalPrecision.scala | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index 2716a75914f4f..1b92cc68f5bdc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -281,7 +281,15 @@ object DecimalPrecision extends TypeCoercionRule { case b@BinaryOperator(left, right) if left.dataType != right.dataType => (left, right) match { // Promote literal integers inside a binary expression with fixed-precision decimals to - // decimals. The precision and scale are the ones needed by the integer value. + // decimals. The precision and scale are the ones strictly needed by the integer value. + // Requiring more precision than necessary may lead to a useless loss of precision. + // Consider the following example: multiplying a column which is DECIMAL(38, 18) by 2. + // If we use the default precision and scale for the integer type, 2 is considered a + // DECIMAL(10, 0). According to the rules, the result would be DECIMAL(38 + 10 + 1, 18), + // which is out of range and therefore it will becomes DECIMAL(38, 7), leading to + // potentially loosing 11 digits of the fractional part. Using only the precision needed + // by the Literal, instead, the result would be DECIMAL(38 + 1 + 1, 18), which would + // become DECIMAL(38, 16), safely having a much lower precision loss. case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType] && l.dataType.isInstanceOf[IntegralType] => b.makeCopy(Array(Cast(l, DecimalType.forLiteral(l)), r)) From cf3b372ef4d2d4c798bd448f9cf2338269b6ce43 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 16 Jan 2018 19:01:46 +0100 Subject: [PATCH 12/16] update migration section --- docs/sql-programming-guide.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 72f79d6909ecc..d65aea50ec0ac 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1795,6 +1795,11 @@ options. - Since Spark 2.3, when all inputs are binary, SQL `elt()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.eltOutputAsString` to `true`. + - Since Spark 2.3, by default arithmetic operations return a rounded value if an exact representation is not possible. This is compliant to SQL standards and Hive's behavior introduced in HIVE-15331. This involves the following changes + - The rules to determine the result type of an arithmetic operation have been updated. In particular, if the precision / scale needed are out of the range of available values, the scale is reduced up to 6, in order to prevent the truncation of the integer part of the decimals. + - Literal values used in SQL operations are converted to DECIMAL with the exact precision and scale needed by them. + - The configuration `spark.sql.decimalOperations.allowPrecisionLoss` has been introduced. It defaults to `true`, which means the new behavior described here; if set to `false`, Spark will use the previous rules and behavior. + ## Upgrading From Spark SQL 2.1 to 2.2 - Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access. From 03644fe117f56b745c344beb6ea08af5045c2dbd Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 17 Jan 2018 14:45:05 +0100 Subject: [PATCH 13/16] address comments --- docs/sql-programming-guide.md | 2 +- .../apache/spark/sql/catalyst/analysis/DecimalPrecision.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index d65aea50ec0ac..ddaa231e69148 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1795,7 +1795,7 @@ options. - Since Spark 2.3, when all inputs are binary, SQL `elt()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.eltOutputAsString` to `true`. - - Since Spark 2.3, by default arithmetic operations return a rounded value if an exact representation is not possible. This is compliant to SQL standards and Hive's behavior introduced in HIVE-15331. This involves the following changes + - Since Spark 2.3, by default arithmetic operations between decimals return a rounded value if an exact representation is not possible. This is compliant to SQL standards and Hive's behavior introduced in HIVE-15331. This involves the following changes - The rules to determine the result type of an arithmetic operation have been updated. In particular, if the precision / scale needed are out of the range of available values, the scale is reduced up to 6, in order to prevent the truncation of the integer part of the decimals. - Literal values used in SQL operations are converted to DECIMAL with the exact precision and scale needed by them. - The configuration `spark.sql.decimalOperations.allowPrecisionLoss` has been introduced. It defaults to `true`, which means the new behavior described here; if set to `false`, Spark will use the previous rules and behavior. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index 1b92cc68f5bdc..67c0fc4e369aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -278,7 +278,7 @@ object DecimalPrecision extends TypeCoercionRule { private val nondecimalAndDecimal: PartialFunction[Expression, Expression] = { // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles - case b@BinaryOperator(left, right) if left.dataType != right.dataType => + case b @ BinaryOperator(left, right) if left.dataType != right.dataType => (left, right) match { // Promote literal integers inside a binary expression with fixed-precision decimals to // decimals. The precision and scale are the ones strictly needed by the integer value. From 7653e6d5c427c80f925a5992b01c80a8f2d58969 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 17 Jan 2018 15:33:12 +0100 Subject: [PATCH 14/16] move tests in decimalArithmeticOperations.sql --- .../resources/sql-tests/inputs/decimals.sql | 53 ---- .../native/decimalArithmeticOperations.sql | 47 ++++ .../sql-tests/results/decimals.sql.out | 159 ------------ .../decimalArithmeticOperations.sql.out | 245 ++++++++++++++++-- 4 files changed, 274 insertions(+), 230 deletions(-) delete mode 100644 sql/core/src/test/resources/sql-tests/inputs/decimals.sql delete mode 100644 sql/core/src/test/resources/sql-tests/results/decimals.sql.out diff --git a/sql/core/src/test/resources/sql-tests/inputs/decimals.sql b/sql/core/src/test/resources/sql-tests/inputs/decimals.sql deleted file mode 100644 index 271a2d8adb583..0000000000000 --- a/sql/core/src/test/resources/sql-tests/inputs/decimals.sql +++ /dev/null @@ -1,53 +0,0 @@ --- --- Licensed to the Apache Software Foundation (ASF) under one or more --- contributor license agreements. See the NOTICE file distributed with --- this work for additional information regarding copyright ownership. --- The ASF licenses this file to You under the Apache License, Version 2.0 --- (the "License"); you may not use this file except in compliance with --- the License. You may obtain a copy of the License at --- --- http://www.apache.org/licenses/LICENSE-2.0 --- --- Unless required by applicable law or agreed to in writing, software --- distributed under the License is distributed on an "AS IS" BASIS, --- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --- See the License for the specific language governing permissions and --- limitations under the License. --- - --- tests for decimals handling in operations -create table decimals_test(id int, a decimal(38,18), b decimal(38,18)) using parquet; - -insert into decimals_test values(1, 100.0, 999.0), (2, 12345.123, 12345.123), - (3, 0.1234567891011, 1234.1), (4, 123456789123456789.0, 1.123456789123456789); - --- test decimal operations -select id, a+b, a-b, a*b, a/b from decimals_test order by id; - --- test operations between decimals and constants -select id, a*10, b/10 from decimals_test order by id; - --- test operations on constants -select 10.3 * 3.0; -select 10.3000 * 3.0; -select 10.30000 * 30.0; -select 10.300000000000000000 * 3.000000000000000000; -select 10.300000000000000000 * 3.0000000000000000000; - --- return NULL instead of rounding, according to old Spark versions' behavior -set spark.sql.decimalOperations.allowPrecisionLoss=false; - --- test decimal operations -select id, a+b, a-b, a*b, a/b from decimals_test order by id; - --- test operations between decimals and constants -select id, a*10, b/10 from decimals_test order by id; - --- test operations on constants -select 10.3 * 3.0; -select 10.3000 * 3.0; -select 10.30000 * 30.0; -select 10.300000000000000000 * 3.000000000000000000; -select 10.300000000000000000 * 3.0000000000000000000; - -drop table decimals_test; diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql index c8e108ac2c45e..c6d8a49d4b93a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/decimalArithmeticOperations.sql @@ -22,6 +22,51 @@ select a / b from t; select a % b from t; select pmod(a, b) from t; +-- tests for decimals handling in operations +create table decimals_test(id int, a decimal(38,18), b decimal(38,18)) using parquet; + +insert into decimals_test values(1, 100.0, 999.0), (2, 12345.123, 12345.123), + (3, 0.1234567891011, 1234.1), (4, 123456789123456789.0, 1.123456789123456789); + +-- test decimal operations +select id, a+b, a-b, a*b, a/b from decimals_test order by id; + +-- test operations between decimals and constants +select id, a*10, b/10 from decimals_test order by id; + +-- test operations on constants +select 10.3 * 3.0; +select 10.3000 * 3.0; +select 10.30000 * 30.0; +select 10.300000000000000000 * 3.000000000000000000; +select 10.300000000000000000 * 3.0000000000000000000; + +-- arithmetic operations causing an overflow return NULL +select (5e36 + 0.1) + 5e36; +select (-4e36 - 0.1) - 7e36; +select 12345678901234567890.0 * 12345678901234567890.0; +select 1e35 / 0.1; + +-- arithmetic operations causing a precision loss are truncated +select 123456789123456789.1234567890 * 1.123456789123456789; +select 0.001 / 9876543210987654321098765432109876543.2 + +-- return NULL instead of rounding, according to old Spark versions' behavior +set spark.sql.decimalOperations.allowPrecisionLoss=false; + +-- test decimal operations +select id, a+b, a-b, a*b, a/b from decimals_test order by id; + +-- test operations between decimals and constants +select id, a*10, b/10 from decimals_test order by id; + +-- test operations on constants +select 10.3 * 3.0; +select 10.3000 * 3.0; +select 10.30000 * 30.0; +select 10.300000000000000000 * 3.000000000000000000; +select 10.300000000000000000 * 3.0000000000000000000; + -- arithmetic operations causing an overflow return NULL select (5e36 + 0.1) + 5e36; select (-4e36 - 0.1) - 7e36; @@ -31,3 +76,5 @@ select 1e35 / 0.1; -- arithmetic operations causing a precision loss return NULL select 123456789123456789.1234567890 * 1.123456789123456789; select 0.001 / 9876543210987654321098765432109876543.2 + +drop table decimals_test; diff --git a/sql/core/src/test/resources/sql-tests/results/decimals.sql.out b/sql/core/src/test/resources/sql-tests/results/decimals.sql.out deleted file mode 100644 index c479797916f14..0000000000000 --- a/sql/core/src/test/resources/sql-tests/results/decimals.sql.out +++ /dev/null @@ -1,159 +0,0 @@ --- Automatically generated by SQLQueryTestSuite --- Number of queries: 18 - - --- !query 0 -create table decimals_test(id int, a decimal(38,18), b decimal(38,18)) using parquet --- !query 0 schema -struct<> --- !query 0 output - - - --- !query 1 -insert into decimals_test values(1, 100.0, 999.0), (2, 12345.123, 12345.123), - (3, 0.1234567891011, 1234.1), (4, 123456789123456789.0, 1.123456789123456789) --- !query 1 schema -struct<> --- !query 1 output - - - --- !query 2 -select id, a+b, a-b, a*b, a/b from decimals_test order by id --- !query 2 schema -struct --- !query 2 output -1 1099 -899 99900 0.1001 -2 24690.246 0 152402061.885129 1 -3 1234.2234567891011 -1233.9765432108989 152.358023 0.0001 -4 123456789123456790.12345678912345679 123456789123456787.87654321087654321 138698367904130467.515623 109890109097814272.043109 - - --- !query 3 -select id, a*10, b/10 from decimals_test order by id --- !query 3 schema -struct --- !query 3 output -1 1000 99.9 -2 123451.23 1234.5123 -3 1.234567891011 123.41 -4 1234567891234567890 0.112345678912345679 - - --- !query 4 -select 10.3 * 3.0 --- !query 4 schema -struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)> --- !query 4 output -30.9 - - --- !query 5 -select 10.3000 * 3.0 --- !query 5 schema -struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)> --- !query 5 output -30.9 - - --- !query 6 -select 10.30000 * 30.0 --- !query 6 schema -struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)> --- !query 6 output -309 - - --- !query 7 -select 10.300000000000000000 * 3.000000000000000000 --- !query 7 schema -struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,34)> --- !query 7 output -30.9 - - --- !query 8 -select 10.300000000000000000 * 3.0000000000000000000 --- !query 8 schema -struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.0000000000000000000 AS DECIMAL(21,19))):decimal(38,34)> --- !query 8 output -30.9 - - --- !query 9 -set spark.sql.decimalOperations.allowPrecisionLoss=false --- !query 9 schema -struct --- !query 9 output -spark.sql.decimalOperations.allowPrecisionLoss false - - --- !query 10 -select id, a+b, a-b, a*b, a/b from decimals_test order by id --- !query 10 schema -struct --- !query 10 output -1 1099 -899 NULL 0.1001001001001001 -2 24690.246 0 NULL 1 -3 1234.2234567891011 -1233.9765432108989 NULL 0.000100037913541123 -4 123456789123456790.123456789123456789 123456789123456787.876543210876543211 NULL 109890109097814272.043109406191131436 - - --- !query 11 -select id, a*10, b/10 from decimals_test order by id --- !query 11 schema -struct --- !query 11 output -1 1000 99.9 -2 123451.23 1234.5123 -3 1.234567891011 123.41 -4 1234567891234567890 0.1123456789123456789 - - --- !query 12 -select 10.3 * 3.0 --- !query 12 schema -struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)> --- !query 12 output -30.9 - - --- !query 13 -select 10.3000 * 3.0 --- !query 13 schema -struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)> --- !query 13 output -30.9 - - --- !query 14 -select 10.30000 * 30.0 --- !query 14 schema -struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)> --- !query 14 output -309 - - --- !query 15 -select 10.300000000000000000 * 3.000000000000000000 --- !query 15 schema -struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,36)> --- !query 15 output -30.9 - - --- !query 16 -select 10.300000000000000000 * 3.0000000000000000000 --- !query 16 schema -struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.0000000000000000000 AS DECIMAL(21,19))):decimal(38,37)> --- !query 16 output -NULL - - --- !query 17 -drop table decimals_test --- !query 17 schema -struct<> --- !query 17 output - diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out index dbe5e3e9cf3d6..4d70fe19d539f 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/decimalArithmeticOperations.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 10 +-- Number of queries: 32 -- !query 0 @@ -35,48 +35,257 @@ NULL -- !query 4 -select (5e36 + 0.1) + 5e36 +create table decimals_test(id int, a decimal(38,18), b decimal(38,18)) using parquet -- !query 4 schema -struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> +struct<> -- !query 4 output -NULL + -- !query 5 -select (-4e36 - 0.1) - 7e36 +insert into decimals_test values(1, 100.0, 999.0), (2, 12345.123, 12345.123), + (3, 0.1234567891011, 1234.1), (4, 123456789123456789.0, 1.123456789123456789) -- !query 5 schema -struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> +struct<> -- !query 5 output -NULL + -- !query 6 -select 12345678901234567890.0 * 12345678901234567890.0 +select id, a+b, a-b, a*b, a/b from decimals_test order by id -- !query 6 schema -struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> +struct -- !query 6 output -NULL +1 1099 -899 99900 0.1001 +2 24690.246 0 152402061.885129 1 +3 1234.2234567891011 -1233.9765432108989 152.358023 0.0001 +4 123456789123456790.12345678912345679 123456789123456787.87654321087654321 138698367904130467.515623 109890109097814272.043109 -- !query 7 -select 1e35 / 0.1 +select id, a*10, b/10 from decimals_test order by id -- !query 7 schema -struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,6)> +struct -- !query 7 output -NULL +1 1000 99.9 +2 123451.23 1234.5123 +3 1.234567891011 123.41 +4 1234567891234567890 0.112345678912345679 -- !query 8 -select 123456789123456789.1234567890 * 1.123456789123456789 +select 10.3 * 3.0 -- !query 8 schema -struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,18)> +struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)> -- !query 8 output -138698367904130467.654320988515622621 +30.9 -- !query 9 -select 0.001 / 9876543210987654321098765432109876543.2 +select 10.3000 * 3.0 -- !query 9 schema -struct<(CAST(0.001 AS DECIMAL(38,3)) / CAST(9876543210987654321098765432109876543.2 AS DECIMAL(38,3))):decimal(38,37)> +struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)> -- !query 9 output +30.9 + + +-- !query 10 +select 10.30000 * 30.0 +-- !query 10 schema +struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)> +-- !query 10 output +309 + + +-- !query 11 +select 10.300000000000000000 * 3.000000000000000000 +-- !query 11 schema +struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,34)> +-- !query 11 output +30.9 + + +-- !query 12 +select 10.300000000000000000 * 3.0000000000000000000 +-- !query 12 schema +struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.0000000000000000000 AS DECIMAL(21,19))):decimal(38,34)> +-- !query 12 output +30.9 + + +-- !query 13 +select (5e36 + 0.1) + 5e36 +-- !query 13 schema +struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> +-- !query 13 output +NULL + + +-- !query 14 +select (-4e36 - 0.1) - 7e36 +-- !query 14 schema +struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> +-- !query 14 output NULL + + +-- !query 15 +select 12345678901234567890.0 * 12345678901234567890.0 +-- !query 15 schema +struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> +-- !query 15 output +NULL + + +-- !query 16 +select 1e35 / 0.1 +-- !query 16 schema +struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,6)> +-- !query 16 output +NULL + + +-- !query 17 +select 123456789123456789.1234567890 * 1.123456789123456789 +-- !query 17 schema +struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,18)> +-- !query 17 output +138698367904130467.654320988515622621 + + +-- !query 18 +select 0.001 / 9876543210987654321098765432109876543.2 + +set spark.sql.decimalOperations.allowPrecisionLoss=false +-- !query 18 schema +struct<> +-- !query 18 output +org.apache.spark.sql.catalyst.parser.ParseException + +mismatched input 'spark' expecting (line 3, pos 4) + +== SQL == +select 0.001 / 9876543210987654321098765432109876543.2 + +set spark.sql.decimalOperations.allowPrecisionLoss=false +----^^^ + + +-- !query 19 +select id, a+b, a-b, a*b, a/b from decimals_test order by id +-- !query 19 schema +struct +-- !query 19 output +1 1099 -899 99900 0.1001 +2 24690.246 0 152402061.885129 1 +3 1234.2234567891011 -1233.9765432108989 152.358023 0.0001 +4 123456789123456790.12345678912345679 123456789123456787.87654321087654321 138698367904130467.515623 109890109097814272.043109 + + +-- !query 20 +select id, a*10, b/10 from decimals_test order by id +-- !query 20 schema +struct +-- !query 20 output +1 1000 99.9 +2 123451.23 1234.5123 +3 1.234567891011 123.41 +4 1234567891234567890 0.112345678912345679 + + +-- !query 21 +select 10.3 * 3.0 +-- !query 21 schema +struct<(CAST(10.3 AS DECIMAL(3,1)) * CAST(3.0 AS DECIMAL(3,1))):decimal(6,2)> +-- !query 21 output +30.9 + + +-- !query 22 +select 10.3000 * 3.0 +-- !query 22 schema +struct<(CAST(10.3000 AS DECIMAL(6,4)) * CAST(3.0 AS DECIMAL(6,4))):decimal(9,5)> +-- !query 22 output +30.9 + + +-- !query 23 +select 10.30000 * 30.0 +-- !query 23 schema +struct<(CAST(10.30000 AS DECIMAL(7,5)) * CAST(30.0 AS DECIMAL(7,5))):decimal(11,6)> +-- !query 23 output +309 + + +-- !query 24 +select 10.300000000000000000 * 3.000000000000000000 +-- !query 24 schema +struct<(CAST(10.300000000000000000 AS DECIMAL(20,18)) * CAST(3.000000000000000000 AS DECIMAL(20,18))):decimal(38,34)> +-- !query 24 output +30.9 + + +-- !query 25 +select 10.300000000000000000 * 3.0000000000000000000 +-- !query 25 schema +struct<(CAST(10.300000000000000000 AS DECIMAL(21,19)) * CAST(3.0000000000000000000 AS DECIMAL(21,19))):decimal(38,34)> +-- !query 25 output +30.9 + + +-- !query 26 +select (5e36 + 0.1) + 5e36 +-- !query 26 schema +struct<(CAST((CAST(5E+36 AS DECIMAL(38,1)) + CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) + CAST(5E+36 AS DECIMAL(38,1))):decimal(38,1)> +-- !query 26 output +NULL + + +-- !query 27 +select (-4e36 - 0.1) - 7e36 +-- !query 27 schema +struct<(CAST((CAST(-4E+36 AS DECIMAL(38,1)) - CAST(0.1 AS DECIMAL(38,1))) AS DECIMAL(38,1)) - CAST(7E+36 AS DECIMAL(38,1))):decimal(38,1)> +-- !query 27 output +NULL + + +-- !query 28 +select 12345678901234567890.0 * 12345678901234567890.0 +-- !query 28 schema +struct<(12345678901234567890.0 * 12345678901234567890.0):decimal(38,2)> +-- !query 28 output +NULL + + +-- !query 29 +select 1e35 / 0.1 +-- !query 29 schema +struct<(CAST(1E+35 AS DECIMAL(37,1)) / CAST(0.1 AS DECIMAL(37,1))):decimal(38,6)> +-- !query 29 output +NULL + + +-- !query 30 +select 123456789123456789.1234567890 * 1.123456789123456789 +-- !query 30 schema +struct<(CAST(123456789123456789.1234567890 AS DECIMAL(36,18)) * CAST(1.123456789123456789 AS DECIMAL(36,18))):decimal(38,18)> +-- !query 30 output +138698367904130467.654320988515622621 + + +-- !query 31 +select 0.001 / 9876543210987654321098765432109876543.2 + +drop table decimals_test +-- !query 31 schema +struct<> +-- !query 31 output +org.apache.spark.sql.catalyst.parser.ParseException + +mismatched input 'table' expecting (line 3, pos 5) + +== SQL == +select 0.001 / 9876543210987654321098765432109876543.2 + +drop table decimals_test +-----^^^ From 2b6609876aa76da5a9169b11142782613ea39ab7 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 17 Jan 2018 15:37:46 +0100 Subject: [PATCH 15/16] rename to fromLiteral --- .../apache/spark/sql/catalyst/analysis/DecimalPrecision.scala | 4 ++-- .../main/scala/org/apache/spark/sql/types/DecimalType.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index 67c0fc4e369aa..ab63131b07573 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -292,10 +292,10 @@ object DecimalPrecision extends TypeCoercionRule { // become DECIMAL(38, 16), safely having a much lower precision loss. case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType] && l.dataType.isInstanceOf[IntegralType] => - b.makeCopy(Array(Cast(l, DecimalType.forLiteral(l)), r)) + b.makeCopy(Array(Cast(l, DecimalType.fromLiteral(l)), r)) case (l, r: Literal) if l.dataType.isInstanceOf[DecimalType] && r.dataType.isInstanceOf[IntegralType] => - b.makeCopy(Array(l, Cast(r, DecimalType.forLiteral(r)))) + b.makeCopy(Array(l, Cast(r, DecimalType.fromLiteral(r)))) // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles case (l @ IntegralType(), r @ DecimalType.Expression(_, _)) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 1b94cf8fd714e..ef3b67c0d48d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -137,7 +137,7 @@ object DecimalType extends AbstractDataType { case DoubleType => DoubleDecimal } - private[sql] def forLiteral(literal: Literal): DecimalType = literal.value match { + private[sql] def fromLiteral(literal: Literal): DecimalType = literal.value match { case v: Short => fromBigDecimal(BigDecimal(v)) case v: Int => fromBigDecimal(BigDecimal(v)) case v: Long => fromBigDecimal(BigDecimal(v)) From b4b0350dea09db897b70485ef1fad41a742eae30 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 18 Jan 2018 09:58:02 +0100 Subject: [PATCH 16/16] docs fix --- docs/sql-programming-guide.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index ddaa231e69148..3e6779f87a29b 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1795,10 +1795,10 @@ options. - Since Spark 2.3, when all inputs are binary, SQL `elt()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.eltOutputAsString` to `true`. - - Since Spark 2.3, by default arithmetic operations between decimals return a rounded value if an exact representation is not possible. This is compliant to SQL standards and Hive's behavior introduced in HIVE-15331. This involves the following changes - - The rules to determine the result type of an arithmetic operation have been updated. In particular, if the precision / scale needed are out of the range of available values, the scale is reduced up to 6, in order to prevent the truncation of the integer part of the decimals. + - Since Spark 2.3, by default arithmetic operations between decimals return a rounded value if an exact representation is not possible (instead of returning NULL). This is compliant to SQL ANSI 2011 specification and Hive's new behavior introduced in Hive 2.2 (HIVE-15331). This involves the following changes + - The rules to determine the result type of an arithmetic operation have been updated. In particular, if the precision / scale needed are out of the range of available values, the scale is reduced up to 6, in order to prevent the truncation of the integer part of the decimals. All the arithmetic operations are affected by the change, ie. addition (`+`), subtraction (`-`), multiplication (`*`), division (`/`), remainder (`%`) and positive module (`pmod`). - Literal values used in SQL operations are converted to DECIMAL with the exact precision and scale needed by them. - - The configuration `spark.sql.decimalOperations.allowPrecisionLoss` has been introduced. It defaults to `true`, which means the new behavior described here; if set to `false`, Spark will use the previous rules and behavior. + - The configuration `spark.sql.decimalOperations.allowPrecisionLoss` has been introduced. It defaults to `true`, which means the new behavior described here; if set to `false`, Spark uses previous rules, ie. it doesn't adjust the needed scale to represent the values and it returns NULL if an exact representation of the value is not possible. ## Upgrading From Spark SQL 2.1 to 2.2