Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions docs/sql-ref-ansi-compliance.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,25 @@ The least common type resolution is used to:
- Derive the result type for expressions such as the case expression.
- Derive the element, key, or value types for array and map constructors.
Special rules are applied if the least common type resolves to FLOAT. With float type values, if any of the types is INT, BIGINT, or DECIMAL the least common type is pushed to DOUBLE to avoid potential loss of digits.

Decimal type is a bit more complicated here, as it's not a simple type but has parameters: precision and scale.
A `decimal(precision, scale)` means the value can have at most `precision - scale` digits in the integral part and `scale` digits in the fractional part.
A least common type between decimal types should have enough digits in both integral and fractional parts to represent all values.
More precisely, a least common type between `decimal(p1, s1)` and `decimal(p2, s2)` has the scale of `max(s1, s2)` and precision of `max(s1, s2) + max(p1 - s1, p2 - s2)`.
However, decimal types in Spark have a maximum precision: 38. If the final decimal type need more precision, we must do truncation.
Since the digits in the integral part are more significant, Spark truncates the digits in the fractional part first. For example, `decimal(48, 20)` will be reduced to `decimal(38, 10)`.

Note, arithmetic operations have special rules to calculate the least common type for decimal inputs:

| Operation | Result precision | Result scale |
|------------|------------------------------------------|---------------------|
| e1 + e2 | max(s1, s2) + max(p1 - s1, p2 - s2) + 1 | max(s1, s2) |
| e1 - e2 | max(s1, s2) + max(p1 - s1, p2 - s2) + 1 | max(s1, s2) |
| e1 * e2 | p1 + p2 + 1 | s1 + s2 |
| e1 / e2 | p1 - s1 + s2 + max(6, s1 + p2 + 1) | max(6, s1 + p2 + 1) |
| e1 % e2 | min(p1 - s1, p2 - s2) + max(s1, s2) | max(s1, s2) |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK, the arithmetic operations did not strictly follow this rule.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which one does not follow? The final decimal type can be different as there is one more truncation step.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

* AND /.
For example:

val a = Decimal(100) // p: 10, s: 0
val b= Decimal(-100) // p: 10, s: 0
val c = a * b  // Decimal(-10000) p: 5, s: 0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not the Spark SQL multiple. Please take a look at Multiple#resultDecimalType


The truncation rule is also different for arithmetic operations: they retain at least 6 digits in the fractional part, which means we can only reduce `scale` to 6. Overflow may happen in this case.

```sql
-- The coalesce function accepts any set of argument types as long as they share a least common type.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,18 @@ object DecimalType extends AbstractDataType {
DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE))
}

private[sql] def boundedPreferIntegralDigits(precision: Int, scale: Int): DecimalType = {
if (precision <= MAX_PRECISION) {
DecimalType(precision, scale)
} else {
// If we have to reduce the precision, we should retain the digits in the integral part first,
// as they are more significant to the value. Here we reduce the scale as well to drop the
// digits in the fractional part.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good.

val diff = precision - MAX_PRECISION
DecimalType(MAX_PRECISION, math.max(0, scale - diff))
}
}

private[sql] def checkNegativeScale(scale: Int): Unit = {
if (scale < 0 && !SqlApiConf.get.allowNegativeScaleOfDecimalEnabled) {
throw DataTypeErrors.negativeScaleNotAllowedError(scale)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal._
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._


Expand Down Expand Up @@ -64,7 +65,11 @@ object DecimalPrecision extends TypeCoercionRule {
def widerDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = {
val scale = max(s1, s2)
val range = max(p1 - s1, p2 - s2)
DecimalType.bounded(range + scale, scale)
if (conf.getConf(SQLConf.LEGACY_RETAIN_FRACTION_DIGITS_FIRST)) {
DecimalType.bounded(range + scale, scale)
Copy link
Member

@gengliangwang gengliangwang Nov 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are many usages of DecimalType.bounded.
Why we only change the behavior here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To limit the scope to type coercion only. Some arithmetic operations also call it to determine the result decimal type and I don't want to change that part.

} else {
DecimalType.boundedPreferIntegralDigits(range + scale, scale)
}
}

override def transform: PartialFunction[Expression, Expression] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4541,6 +4541,16 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val LEGACY_RETAIN_FRACTION_DIGITS_FIRST =
buildConf("spark.sql.legacy.decimal.retainFractionDigitsOnTruncate")
.internal()
.doc("When set to true, we will try to retain the fraction digits first rather than " +
"integral digits as prior Spark 4.0, when getting a least common type between decimal " +
"types, and the result decimal precision exceeds the max precision.")
.version("4.0.0")
.booleanConf
.createWithDefault(false)

/**
* Holds information about keys that have been deprecated.
*
Expand Down Expand Up @@ -5425,7 +5435,7 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
}

def legacyRaiseErrorWithoutErrorClass: Boolean =
getConf(SQLConf.LEGACY_RAISE_ERROR_WITHOUT_ERROR_CLASS)
getConf(SQLConf.LEGACY_RAISE_ERROR_WITHOUT_ERROR_CLASS)
Copy link
Contributor

@ryan-johnson-databricks ryan-johnson-databricks Nov 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noise? (whitespace changes best made in a non-bugfix PR?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since I touched this file, I just fixed the wrong indentation.


/** ********************** SQLConf functionality methods ************ */

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -481,9 +481,9 @@ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase {
:: Literal.create(null, DecimalType(22, 10))
:: Literal.create(null, DecimalType(38, 38))
:: Nil),
CreateArray(Literal.create(null, DecimalType(5, 3)).cast(DecimalType(38, 38))
:: Literal.create(null, DecimalType(22, 10)).cast(DecimalType(38, 38))
:: Literal.create(null, DecimalType(38, 38))
CreateArray(Literal.create(null, DecimalType(5, 3)).cast(DecimalType(38, 26))
:: Literal.create(null, DecimalType(22, 10)).cast(DecimalType(38, 26))
:: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 26))
:: Nil))
}

Expand Down Expand Up @@ -530,9 +530,9 @@ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase {
:: Literal.create(null, DecimalType(38, 38))
:: Nil),
CreateMap(Literal(1)
:: Literal.create(null, DecimalType(38, 0)).cast(DecimalType(38, 38))
:: Literal.create(null, DecimalType(38, 0))
:: Literal(2)
:: Literal.create(null, DecimalType(38, 38))
:: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 0))
:: Nil))
// type coercion for both map keys and values
ruleTest(AnsiTypeCoercion.FunctionArgumentConversion,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ class TypeCoercionSuite extends TypeCoercionSuiteBase {
widenTestWithStringPromotion(
ArrayType(DecimalType(36, 0), containsNull = false),
ArrayType(DecimalType(36, 35), containsNull = false),
Some(ArrayType(DecimalType(38, 35), containsNull = true)))
Some(ArrayType(DecimalType(38, 2), containsNull = false)))

// MapType
widenTestWithStringPromotion(
Expand All @@ -808,15 +808,15 @@ class TypeCoercionSuite extends TypeCoercionSuiteBase {
widenTestWithStringPromotion(
MapType(StringType, DecimalType(36, 0), valueContainsNull = false),
MapType(StringType, DecimalType(36, 35), valueContainsNull = false),
Some(MapType(StringType, DecimalType(38, 35), valueContainsNull = true)))
Some(MapType(StringType, DecimalType(38, 2), valueContainsNull = false)))
widenTestWithStringPromotion(
MapType(IntegerType, StringType, valueContainsNull = false),
MapType(DecimalType.IntDecimal, StringType, valueContainsNull = false),
Some(MapType(DecimalType.IntDecimal, StringType, valueContainsNull = false)))
widenTestWithStringPromotion(
MapType(DecimalType(36, 0), StringType, valueContainsNull = false),
MapType(DecimalType(36, 35), StringType, valueContainsNull = false),
None)
Some(MapType(DecimalType(38, 2), StringType, valueContainsNull = false)))

// StructType
widenTestWithStringPromotion(
Expand Down Expand Up @@ -847,7 +847,7 @@ class TypeCoercionSuite extends TypeCoercionSuiteBase {
widenTestWithStringPromotion(
new StructType().add("num", DecimalType(36, 0), nullable = false),
new StructType().add("num", DecimalType(36, 35), nullable = false),
Some(new StructType().add("num", DecimalType(38, 35), nullable = true)))
Some(new StructType().add("num", DecimalType(38, 2), nullable = false)))

widenTestWithStringPromotion(
new StructType().add("num", IntegerType),
Expand Down Expand Up @@ -1046,9 +1046,9 @@ class TypeCoercionSuite extends TypeCoercionSuiteBase {
:: Literal.create(null, DecimalType(22, 10))
:: Literal.create(null, DecimalType(38, 38))
:: Nil),
CreateArray(Literal.create(null, DecimalType(5, 3)).cast(DecimalType(38, 38))
:: Literal.create(null, DecimalType(22, 10)).cast(DecimalType(38, 38))
:: Literal.create(null, DecimalType(38, 38))
CreateArray(Literal.create(null, DecimalType(5, 3)).cast(DecimalType(38, 26))
:: Literal.create(null, DecimalType(22, 10)).cast(DecimalType(38, 26))
:: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 26))
:: Nil))
}

Expand Down Expand Up @@ -1095,9 +1095,9 @@ class TypeCoercionSuite extends TypeCoercionSuiteBase {
:: Literal.create(null, DecimalType(38, 38))
:: Nil),
CreateMap(Literal(1)
:: Literal.create(null, DecimalType(38, 0)).cast(DecimalType(38, 38))
:: Literal.create(null, DecimalType(38, 0))
:: Literal(2)
:: Literal.create(null, DecimalType(38, 38))
:: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 0))
:: Nil))
// type coercion for both map keys and values
ruleTest(TypeCoercion.FunctionArgumentConversion,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,24 +128,13 @@ Project [map_zip_with(double_map#x, cast(float_map#x as map<double,float>), lamb
SELECT map_zip_with(decimal_map1, decimal_map2, (k, v1, v2) -> struct(k, v1, v2)) m
FROM various_maps
-- !query analysis
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "DATATYPE_MISMATCH.MAP_ZIP_WITH_DIFF_TYPES",
"sqlState" : "42K09",
"messageParameters" : {
"functionName" : "`map_zip_with`",
"leftType" : "\"DECIMAL(36,0)\"",
"rightType" : "\"DECIMAL(36,35)\"",
"sqlExpr" : "\"map_zip_with(decimal_map1, decimal_map2, lambdafunction(struct(k, v1, v2), k, v1, v2))\""
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 8,
"stopIndex" : 81,
"fragment" : "map_zip_with(decimal_map1, decimal_map2, (k, v1, v2) -> struct(k, v1, v2))"
} ]
}
Project [map_zip_with(cast(decimal_map1#x as map<decimal(38,2),decimal(36,0)>), cast(decimal_map2#x as map<decimal(38,2),decimal(36,35)>), lambdafunction(struct(k, lambda k#x, v1, lambda v1#x, v2, lambda v2#x), lambda k#x, lambda v1#x, lambda v2#x, false)) AS m#x]
+- SubqueryAlias various_maps
+- View (`various_maps`, [boolean_map#x,tinyint_map#x,smallint_map#x,int_map#x,bigint_map#x,decimal_map1#x,decimal_map2#x,double_map#x,float_map#x,date_map#x,timestamp_map#x,string_map1#x,string_map2#x,string_map3#x,string_map4#x,array_map1#x,array_map2#x,struct_map1#x,struct_map2#x])
+- Project [cast(boolean_map#x as map<boolean,boolean>) AS boolean_map#x, cast(tinyint_map#x as map<tinyint,tinyint>) AS tinyint_map#x, cast(smallint_map#x as map<smallint,smallint>) AS smallint_map#x, cast(int_map#x as map<int,int>) AS int_map#x, cast(bigint_map#x as map<bigint,bigint>) AS bigint_map#x, cast(decimal_map1#x as map<decimal(36,0),decimal(36,0)>) AS decimal_map1#x, cast(decimal_map2#x as map<decimal(36,35),decimal(36,35)>) AS decimal_map2#x, cast(double_map#x as map<double,double>) AS double_map#x, cast(float_map#x as map<float,float>) AS float_map#x, cast(date_map#x as map<date,date>) AS date_map#x, cast(timestamp_map#x as map<timestamp,timestamp>) AS timestamp_map#x, cast(string_map1#x as map<string,string>) AS string_map1#x, cast(string_map2#x as map<string,string>) AS string_map2#x, cast(string_map3#x as map<string,string>) AS string_map3#x, cast(string_map4#x as map<string,string>) AS string_map4#x, cast(array_map1#x as map<array<bigint>,array<bigint>>) AS array_map1#x, cast(array_map2#x as map<array<int>,array<int>>) AS array_map2#x, cast(struct_map1#x as map<struct<col1:smallint,col2:bigint>,struct<col1:smallint,col2:bigint>>) AS struct_map1#x, cast(struct_map2#x as map<struct<col1:int,col2:int>,struct<col1:int,col2:int>>) AS struct_map2#x]
+- Project [boolean_map#x, tinyint_map#x, smallint_map#x, int_map#x, bigint_map#x, decimal_map1#x, decimal_map2#x, double_map#x, float_map#x, date_map#x, timestamp_map#x, string_map1#x, string_map2#x, string_map3#x, string_map4#x, array_map1#x, array_map2#x, struct_map1#x, struct_map2#x]
+- SubqueryAlias various_maps
+- LocalRelation [boolean_map#x, tinyint_map#x, smallint_map#x, int_map#x, bigint_map#x, decimal_map1#x, decimal_map2#x, double_map#x, float_map#x, date_map#x, timestamp_map#x, string_map1#x, string_map2#x, string_map3#x, string_map4#x, array_map1#x, array_map2#x, struct_map1#x, struct_map2#x]


-- !query
Expand Down Expand Up @@ -178,24 +167,13 @@ Project [map_zip_with(cast(decimal_map1#x as map<double,decimal(36,0)>), double_
SELECT map_zip_with(decimal_map2, int_map, (k, v1, v2) -> struct(k, v1, v2)) m
FROM various_maps
-- !query analysis
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "DATATYPE_MISMATCH.MAP_ZIP_WITH_DIFF_TYPES",
"sqlState" : "42K09",
"messageParameters" : {
"functionName" : "`map_zip_with`",
"leftType" : "\"DECIMAL(36,35)\"",
"rightType" : "\"INT\"",
"sqlExpr" : "\"map_zip_with(decimal_map2, int_map, lambdafunction(struct(k, v1, v2), k, v1, v2))\""
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 8,
"stopIndex" : 76,
"fragment" : "map_zip_with(decimal_map2, int_map, (k, v1, v2) -> struct(k, v1, v2))"
} ]
}
Project [map_zip_with(cast(decimal_map2#x as map<decimal(38,28),decimal(36,35)>), cast(int_map#x as map<decimal(38,28),int>), lambdafunction(struct(k, lambda k#x, v1, lambda v1#x, v2, lambda v2#x), lambda k#x, lambda v1#x, lambda v2#x, false)) AS m#x]
+- SubqueryAlias various_maps
+- View (`various_maps`, [boolean_map#x,tinyint_map#x,smallint_map#x,int_map#x,bigint_map#x,decimal_map1#x,decimal_map2#x,double_map#x,float_map#x,date_map#x,timestamp_map#x,string_map1#x,string_map2#x,string_map3#x,string_map4#x,array_map1#x,array_map2#x,struct_map1#x,struct_map2#x])
+- Project [cast(boolean_map#x as map<boolean,boolean>) AS boolean_map#x, cast(tinyint_map#x as map<tinyint,tinyint>) AS tinyint_map#x, cast(smallint_map#x as map<smallint,smallint>) AS smallint_map#x, cast(int_map#x as map<int,int>) AS int_map#x, cast(bigint_map#x as map<bigint,bigint>) AS bigint_map#x, cast(decimal_map1#x as map<decimal(36,0),decimal(36,0)>) AS decimal_map1#x, cast(decimal_map2#x as map<decimal(36,35),decimal(36,35)>) AS decimal_map2#x, cast(double_map#x as map<double,double>) AS double_map#x, cast(float_map#x as map<float,float>) AS float_map#x, cast(date_map#x as map<date,date>) AS date_map#x, cast(timestamp_map#x as map<timestamp,timestamp>) AS timestamp_map#x, cast(string_map1#x as map<string,string>) AS string_map1#x, cast(string_map2#x as map<string,string>) AS string_map2#x, cast(string_map3#x as map<string,string>) AS string_map3#x, cast(string_map4#x as map<string,string>) AS string_map4#x, cast(array_map1#x as map<array<bigint>,array<bigint>>) AS array_map1#x, cast(array_map2#x as map<array<int>,array<int>>) AS array_map2#x, cast(struct_map1#x as map<struct<col1:smallint,col2:bigint>,struct<col1:smallint,col2:bigint>>) AS struct_map1#x, cast(struct_map2#x as map<struct<col1:int,col2:int>,struct<col1:int,col2:int>>) AS struct_map2#x]
+- Project [boolean_map#x, tinyint_map#x, smallint_map#x, int_map#x, bigint_map#x, decimal_map1#x, decimal_map2#x, double_map#x, float_map#x, date_map#x, timestamp_map#x, string_map1#x, string_map2#x, string_map3#x, string_map4#x, array_map1#x, array_map2#x, struct_map1#x, struct_map2#x]
+- SubqueryAlias various_maps
+- LocalRelation [boolean_map#x, tinyint_map#x, smallint_map#x, int_map#x, bigint_map#x, decimal_map1#x, decimal_map2#x, double_map#x, float_map#x, date_map#x, timestamp_map#x, string_map1#x, string_map2#x, string_map3#x, string_map4#x, array_map1#x, array_map2#x, struct_map1#x, struct_map2#x]


-- !query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,26 +79,9 @@ struct<m:map<double,struct<k:double,v1:double,v2:float>>>
SELECT map_zip_with(decimal_map1, decimal_map2, (k, v1, v2) -> struct(k, v1, v2)) m
FROM various_maps
-- !query schema
struct<>
struct<m:map<decimal(38,2),struct<k:decimal(38,2),v1:decimal(36,0),v2:decimal(36,35)>>>
-- !query output
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "DATATYPE_MISMATCH.MAP_ZIP_WITH_DIFF_TYPES",
"sqlState" : "42K09",
"messageParameters" : {
"functionName" : "`map_zip_with`",
"leftType" : "\"DECIMAL(36,0)\"",
"rightType" : "\"DECIMAL(36,35)\"",
"sqlExpr" : "\"map_zip_with(decimal_map1, decimal_map2, lambdafunction(struct(k, v1, v2), k, v1, v2))\""
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 8,
"stopIndex" : 81,
"fragment" : "map_zip_with(decimal_map1, decimal_map2, (k, v1, v2) -> struct(k, v1, v2))"
} ]
}
{9.22:{"k":9.22,"v1":null,"v2":9.22337203685477897945456575809789456},922337203685477897945456575809789456.00:{"k":922337203685477897945456575809789456.00,"v1":922337203685477897945456575809789456,"v2":null}}


-- !query
Expand All @@ -123,26 +106,9 @@ struct<m:map<double,struct<k:double,v1:decimal(36,0),v2:double>>>
SELECT map_zip_with(decimal_map2, int_map, (k, v1, v2) -> struct(k, v1, v2)) m
FROM various_maps
-- !query schema
struct<>
struct<m:map<decimal(38,28),struct<k:decimal(38,28),v1:decimal(36,35),v2:int>>>
-- !query output
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "DATATYPE_MISMATCH.MAP_ZIP_WITH_DIFF_TYPES",
"sqlState" : "42K09",
"messageParameters" : {
"functionName" : "`map_zip_with`",
"leftType" : "\"DECIMAL(36,35)\"",
"rightType" : "\"INT\"",
"sqlExpr" : "\"map_zip_with(decimal_map2, int_map, lambdafunction(struct(k, v1, v2), k, v1, v2))\""
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 8,
"stopIndex" : 76,
"fragment" : "map_zip_with(decimal_map2, int_map, (k, v1, v2) -> struct(k, v1, v2))"
} ]
}
{2.0000000000000000000000000000:{"k":2.0000000000000000000000000000,"v1":null,"v2":1},9.2233720368547789794545657581:{"k":9.2233720368547789794545657581,"v1":9.22337203685477897945456575809789456,"v2":null}}


-- !query
Expand Down
Loading