Skip to content

Commit e78ee2c

Browse files
committed
[SPARK-48016][SQL] Fix a bug in try_divide function when with decimals
Currently, the following query will throw DIVIDE_BY_ZERO error instead of returning null ``` SELECT try_divide(1, decimal(0)); ``` This is caused by the rule `DecimalPrecision`: ``` case b BinaryOperator(left, right) if left.dataType != right.dataType => (left, right) match { ... case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType] && l.dataType.isInstanceOf[IntegralType] && literalPickMinimumPrecision => b.makeCopy(Array(Cast(l, DataTypeUtils.fromLiteral(l)), r)) ``` The result of the above makeCopy will contain `ANSI` as the `evalMode`, instead of `TRY`. This PR is to fix this bug by replacing the makeCopy method calls with withNewChildren Bug fix in try_* functions. Yes, it fixes a long-standing bug in the try_divide function. New UT No Closes apache#46286 from gengliangwang/avoidMakeCopy. Authored-by: Gengliang Wang <[email protected]> Signed-off-by: Gengliang Wang <[email protected]> (cherry picked from commit 3fbcb26) Signed-off-by: Gengliang Wang <[email protected]>
1 parent 616c216 commit e78ee2c

File tree

8 files changed

+261
-13
lines changed

8 files changed

+261
-13
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ object DecimalPrecision extends TypeCoercionRule {
8383
val resultType = widerDecimalType(p1, s1, p2, s2)
8484
val newE1 = if (e1.dataType == resultType) e1 else Cast(e1, resultType)
8585
val newE2 = if (e2.dataType == resultType) e2 else Cast(e2, resultType)
86-
b.makeCopy(Array(newE1, newE2))
86+
b.withNewChildren(Seq(newE1, newE2))
8787
}
8888

8989
/**
@@ -202,21 +202,21 @@ object DecimalPrecision extends TypeCoercionRule {
202202
case (l: Literal, r) if r.dataType.isInstanceOf[DecimalType] &&
203203
l.dataType.isInstanceOf[IntegralType] &&
204204
literalPickMinimumPrecision =>
205-
b.makeCopy(Array(Cast(l, DataTypeUtils.fromLiteral(l)), r))
205+
b.withNewChildren(Seq(Cast(l, DataTypeUtils.fromLiteral(l)), r))
206206
case (l, r: Literal) if l.dataType.isInstanceOf[DecimalType] &&
207207
r.dataType.isInstanceOf[IntegralType] &&
208208
literalPickMinimumPrecision =>
209-
b.makeCopy(Array(l, Cast(r, DataTypeUtils.fromLiteral(r))))
209+
b.withNewChildren(Seq(l, Cast(r, DataTypeUtils.fromLiteral(r))))
210210
// Promote integers inside a binary expression with fixed-precision decimals to decimals,
211211
// and fixed-precision decimals in an expression with floats / doubles to doubles
212212
case (l @ IntegralTypeExpression(), r @ DecimalExpression(_, _)) =>
213-
b.makeCopy(Array(Cast(l, DecimalType.forType(l.dataType)), r))
213+
b.withNewChildren(Seq(Cast(l, DecimalType.forType(l.dataType)), r))
214214
case (l @ DecimalExpression(_, _), r @ IntegralTypeExpression()) =>
215-
b.makeCopy(Array(l, Cast(r, DecimalType.forType(r.dataType))))
215+
b.withNewChildren(Seq(l, Cast(r, DecimalType.forType(r.dataType))))
216216
case (l, r @ DecimalExpression(_, _)) if isFloat(l.dataType) =>
217-
b.makeCopy(Array(l, Cast(r, DoubleType)))
217+
b.withNewChildren(Seq(l, Cast(r, DoubleType)))
218218
case (l @ DecimalExpression(_, _), r) if isFloat(r.dataType) =>
219-
b.makeCopy(Array(Cast(l, DoubleType), r))
219+
b.withNewChildren(Seq(Cast(l, DoubleType), r))
220220
case _ => b
221221
}
222222
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,22 +1102,22 @@ object TypeCoercion extends TypeCoercionBase {
11021102

11031103
case a @ BinaryArithmetic(left @ StringTypeExpression(), right)
11041104
if right.dataType != CalendarIntervalType =>
1105-
a.makeCopy(Array(Cast(left, DoubleType), right))
1105+
a.withNewChildren(Seq(Cast(left, DoubleType), right))
11061106
case a @ BinaryArithmetic(left, right @ StringTypeExpression())
11071107
if left.dataType != CalendarIntervalType =>
1108-
a.makeCopy(Array(left, Cast(right, DoubleType)))
1108+
a.withNewChildren(Seq(left, Cast(right, DoubleType)))
11091109

11101110
// For equality between string and timestamp we cast the string to a timestamp
11111111
// so that things like rounding of subsecond precision does not affect the comparison.
11121112
case p @ Equality(left @ StringTypeExpression(), right @ TimestampTypeExpression()) =>
1113-
p.makeCopy(Array(Cast(left, TimestampType), right))
1113+
p.withNewChildren(Seq(Cast(left, TimestampType), right))
11141114
case p @ Equality(left @ TimestampTypeExpression(), right @ StringTypeExpression()) =>
1115-
p.makeCopy(Array(left, Cast(right, TimestampType)))
1115+
p.withNewChildren(Seq(left, Cast(right, TimestampType)))
11161116

11171117
case p @ BinaryComparison(left, right)
11181118
if findCommonTypeForBinaryComparison(left.dataType, right.dataType, conf).isDefined =>
11191119
val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType, conf).get
1120-
p.makeCopy(Array(castExpr(left, commonType), castExpr(right, commonType)))
1120+
p.withNewChildren(Seq(castExpr(left, commonType), castExpr(right, commonType)))
11211121
}
11221122
}
11231123

sql/core/src/test/resources/log4j2.properties

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ logger.parquet_recordwriter.name = org.apache.parquet.hadoop.InternalParquetReco
5050
logger.parquet_recordwriter.additivity = false
5151
logger.parquet_recordwriter.level = off
5252

53-
logger.parquet_outputcommitter.name = org.apache.parquet.hadoop.ParquetOutputCommitter
53+
logger.parquet_outputcommitter.name = org.sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scalaapache.parquet.hadoop.ParquetOutputCommitter
5454
logger.parquet_outputcommitter.additivity = false
5555
logger.parquet_outputcommitter.level = off
5656

sql/core/src/test/resources/sql-tests/analyzer-results/ansi/try_arithmetic.sql.out

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,20 @@ Project [try_add(2147483647, 1) AS try_add(2147483647, 1)#x]
1313
+- OneRowRelation
1414

1515

16+
-- !query
17+
SELECT try_add(2147483647, decimal(1))
18+
-- !query analysis
19+
Project [try_add(2147483647, cast(1 as decimal(10,0))) AS try_add(2147483647, 1)#x]
20+
+- OneRowRelation
21+
22+
23+
-- !query
24+
SELECT try_add(2147483647, "1")
25+
-- !query analysis
26+
Project [try_add(2147483647, 1) AS try_add(2147483647, 1)#xL]
27+
+- OneRowRelation
28+
29+
1630
-- !query
1731
SELECT try_add(-2147483648, -1)
1832
-- !query analysis
@@ -211,6 +225,20 @@ Project [try_divide(1, (1.0 / 0.0)) AS try_divide(1, (1.0 / 0.0))#x]
211225
+- OneRowRelation
212226

213227

228+
-- !query
229+
SELECT try_divide(1, decimal(0))
230+
-- !query analysis
231+
Project [try_divide(1, cast(0 as decimal(10,0))) AS try_divide(1, 0)#x]
232+
+- OneRowRelation
233+
234+
235+
-- !query
236+
SELECT try_divide(1, "0")
237+
-- !query analysis
238+
Project [try_divide(1, 0) AS try_divide(1, 0)#x]
239+
+- OneRowRelation
240+
241+
214242
-- !query
215243
SELECT try_divide(interval 2 year, 2)
216244
-- !query analysis
@@ -267,6 +295,20 @@ Project [try_subtract(2147483647, -1) AS try_subtract(2147483647, -1)#x]
267295
+- OneRowRelation
268296

269297

298+
-- !query
299+
SELECT try_subtract(2147483647, decimal(-1))
300+
-- !query analysis
301+
Project [try_subtract(2147483647, cast(-1 as decimal(10,0))) AS try_subtract(2147483647, -1)#x]
302+
+- OneRowRelation
303+
304+
305+
-- !query
306+
SELECT try_subtract(2147483647, "-1")
307+
-- !query analysis
308+
Project [try_subtract(2147483647, -1) AS try_subtract(2147483647, -1)#xL]
309+
+- OneRowRelation
310+
311+
270312
-- !query
271313
SELECT try_subtract(-2147483648, 1)
272314
-- !query analysis
@@ -351,6 +393,20 @@ Project [try_multiply(2147483647, -2) AS try_multiply(2147483647, -2)#x]
351393
+- OneRowRelation
352394

353395

396+
-- !query
397+
SELECT try_multiply(2147483647, decimal(-2))
398+
-- !query analysis
399+
Project [try_multiply(2147483647, cast(-2 as decimal(10,0))) AS try_multiply(2147483647, -2)#x]
400+
+- OneRowRelation
401+
402+
403+
-- !query
404+
SELECT try_multiply(2147483647, "-2")
405+
-- !query analysis
406+
Project [try_multiply(2147483647, -2) AS try_multiply(2147483647, -2)#xL]
407+
+- OneRowRelation
408+
409+
354410
-- !query
355411
SELECT try_multiply(-2147483648, 2)
356412
-- !query analysis

sql/core/src/test/resources/sql-tests/analyzer-results/try_arithmetic.sql.out

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,20 @@ Project [try_add(2147483647, 1) AS try_add(2147483647, 1)#x]
1313
+- OneRowRelation
1414

1515

16+
-- !query
17+
SELECT try_add(2147483647, decimal(1))
18+
-- !query analysis
19+
Project [try_add(2147483647, cast(1 as decimal(10,0))) AS try_add(2147483647, 1)#x]
20+
+- OneRowRelation
21+
22+
23+
-- !query
24+
SELECT try_add(2147483647, "1")
25+
-- !query analysis
26+
Project [try_add(2147483647, 1) AS try_add(2147483647, 1)#x]
27+
+- OneRowRelation
28+
29+
1630
-- !query
1731
SELECT try_add(-2147483648, -1)
1832
-- !query analysis
@@ -211,6 +225,20 @@ Project [try_divide(1, (1.0 / 0.0)) AS try_divide(1, (1.0 / 0.0))#x]
211225
+- OneRowRelation
212226

213227

228+
-- !query
229+
SELECT try_divide(1, decimal(0))
230+
-- !query analysis
231+
Project [try_divide(1, cast(0 as decimal(10,0))) AS try_divide(1, 0)#x]
232+
+- OneRowRelation
233+
234+
235+
-- !query
236+
SELECT try_divide(1, "0")
237+
-- !query analysis
238+
Project [try_divide(1, 0) AS try_divide(1, 0)#x]
239+
+- OneRowRelation
240+
241+
214242
-- !query
215243
SELECT try_divide(interval 2 year, 2)
216244
-- !query analysis
@@ -267,6 +295,20 @@ Project [try_subtract(2147483647, -1) AS try_subtract(2147483647, -1)#x]
267295
+- OneRowRelation
268296

269297

298+
-- !query
299+
SELECT try_subtract(2147483647, decimal(-1))
300+
-- !query analysis
301+
Project [try_subtract(2147483647, cast(-1 as decimal(10,0))) AS try_subtract(2147483647, -1)#x]
302+
+- OneRowRelation
303+
304+
305+
-- !query
306+
SELECT try_subtract(2147483647, "-1")
307+
-- !query analysis
308+
Project [try_subtract(2147483647, -1) AS try_subtract(2147483647, -1)#x]
309+
+- OneRowRelation
310+
311+
270312
-- !query
271313
SELECT try_subtract(-2147483648, 1)
272314
-- !query analysis
@@ -351,6 +393,20 @@ Project [try_multiply(2147483647, -2) AS try_multiply(2147483647, -2)#x]
351393
+- OneRowRelation
352394

353395

396+
-- !query
397+
SELECT try_multiply(2147483647, decimal(-2))
398+
-- !query analysis
399+
Project [try_multiply(2147483647, cast(-2 as decimal(10,0))) AS try_multiply(2147483647, -2)#x]
400+
+- OneRowRelation
401+
402+
403+
-- !query
404+
SELECT try_multiply(2147483647, "-2")
405+
-- !query analysis
406+
Project [try_multiply(2147483647, -2) AS try_multiply(2147483647, -2)#x]
407+
+- OneRowRelation
408+
409+
354410
-- !query
355411
SELECT try_multiply(-2147483648, 2)
356412
-- !query analysis

sql/core/src/test/resources/sql-tests/inputs/try_arithmetic.sql

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
-- Numeric + Numeric
22
SELECT try_add(1, 1);
33
SELECT try_add(2147483647, 1);
4+
SELECT try_add(2147483647, decimal(1));
5+
SELECT try_add(2147483647, "1");
46
SELECT try_add(-2147483648, -1);
57
SELECT try_add(9223372036854775807L, 1);
68
SELECT try_add(-9223372036854775808L, -1);
@@ -38,6 +40,8 @@ SELECT try_divide(0, 0);
3840
SELECT try_divide(1, (2147483647 + 1));
3941
SELECT try_divide(1L, (9223372036854775807L + 1L));
4042
SELECT try_divide(1, 1.0 / 0.0);
43+
SELECT try_divide(1, decimal(0));
44+
SELECT try_divide(1, "0");
4145

4246
-- Interval / Numeric
4347
SELECT try_divide(interval 2 year, 2);
@@ -50,6 +54,8 @@ SELECT try_divide(interval 106751991 day, 0.5);
5054
-- Numeric - Numeric
5155
SELECT try_subtract(1, 1);
5256
SELECT try_subtract(2147483647, -1);
57+
SELECT try_subtract(2147483647, decimal(-1));
58+
SELECT try_subtract(2147483647, "-1");
5359
SELECT try_subtract(-2147483648, 1);
5460
SELECT try_subtract(9223372036854775807L, -1);
5561
SELECT try_subtract(-9223372036854775808L, 1);
@@ -66,6 +72,8 @@ SELECT try_subtract(interval 106751991 day, interval -3 day);
6672
-- Numeric * Numeric
6773
SELECT try_multiply(2, 3);
6874
SELECT try_multiply(2147483647, -2);
75+
SELECT try_multiply(2147483647, decimal(-2));
76+
SELECT try_multiply(2147483647, "-2");
6977
SELECT try_multiply(-2147483648, 2);
7078
SELECT try_multiply(9223372036854775807L, 2);
7179
SELECT try_multiply(-9223372036854775808L, -2);

sql/core/src/test/resources/sql-tests/results/ansi/try_arithmetic.sql.out

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,22 @@ struct<try_add(2147483647, 1):int>
1515
NULL
1616

1717

18+
-- !query
19+
SELECT try_add(2147483647, decimal(1))
20+
-- !query schema
21+
struct<try_add(2147483647, 1):decimal(11,0)>
22+
-- !query output
23+
2147483648
24+
25+
26+
-- !query
27+
SELECT try_add(2147483647, "1")
28+
-- !query schema
29+
struct<try_add(2147483647, 1):bigint>
30+
-- !query output
31+
2147483648
32+
33+
1834
-- !query
1935
SELECT try_add(-2147483648, -1)
2036
-- !query schema
@@ -341,6 +357,22 @@ org.apache.spark.SparkArithmeticException
341357
}
342358

343359

360+
-- !query
361+
SELECT try_divide(1, decimal(0))
362+
-- !query schema
363+
struct<try_divide(1, 0):decimal(12,11)>
364+
-- !query output
365+
NULL
366+
367+
368+
-- !query
369+
SELECT try_divide(1, "0")
370+
-- !query schema
371+
struct<try_divide(1, 0):double>
372+
-- !query output
373+
NULL
374+
375+
344376
-- !query
345377
SELECT try_divide(interval 2 year, 2)
346378
-- !query schema
@@ -405,6 +437,22 @@ struct<try_subtract(2147483647, -1):int>
405437
NULL
406438

407439

440+
-- !query
441+
SELECT try_subtract(2147483647, decimal(-1))
442+
-- !query schema
443+
struct<try_subtract(2147483647, -1):decimal(11,0)>
444+
-- !query output
445+
2147483648
446+
447+
448+
-- !query
449+
SELECT try_subtract(2147483647, "-1")
450+
-- !query schema
451+
struct<try_subtract(2147483647, -1):bigint>
452+
-- !query output
453+
2147483648
454+
455+
408456
-- !query
409457
SELECT try_subtract(-2147483648, 1)
410458
-- !query schema
@@ -547,6 +595,22 @@ struct<try_multiply(2147483647, -2):int>
547595
NULL
548596

549597

598+
-- !query
599+
SELECT try_multiply(2147483647, decimal(-2))
600+
-- !query schema
601+
struct<try_multiply(2147483647, -2):decimal(21,0)>
602+
-- !query output
603+
-4294967294
604+
605+
606+
-- !query
607+
SELECT try_multiply(2147483647, "-2")
608+
-- !query schema
609+
struct<try_multiply(2147483647, -2):bigint>
610+
-- !query output
611+
-4294967294
612+
613+
550614
-- !query
551615
SELECT try_multiply(-2147483648, 2)
552616
-- !query schema

0 commit comments

Comments
 (0)