Skip to content

Commit 1c48e34

Browse files
author
Davies Liu
committed
use decimal or double when parsing SQL
1 parent d4d762f commit 1c48e34

File tree

5 files changed

+60
-35
lines changed

5 files changed

+60
-35
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,8 +323,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
323323
protected lazy val numericLiteral: Parser[Literal] =
324324
( integral ^^ { case i => Literal(toNarrowestIntegerType(i)) }
325325
| sign.? ~ unsignedFloat ^^ {
326-
// TODO(davies): some precisions may loss, we should create decimal literal
327-
case s ~ f => Literal(BigDecimal(s.getOrElse("") + f).doubleValue())
326+
case s ~ f => Literal(toDecimalOrDouble(s.getOrElse("") + f))
328327
}
329328
)
330329

@@ -411,6 +410,17 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
411410
}
412411
}
413412

413+
private def toDecimalOrDouble(value: String): Any = {
414+
val decimal = BigDecimal(value)
415+
// follow the behavior in MS SQL Server
416+
// https://msdn.microsoft.com/en-us/library/ms179899.aspx
417+
if (value.contains('E') || value.contains('e)) {
418+
decimal.doubleValue()
419+
} else {
420+
decimal.underlying()
421+
}
422+
}
423+
414424
protected lazy val baseExpression: Parser[Expression] =
415425
( "*" ^^^ UnresolvedStar(None)
416426
| ident <~ "." ~ "*" ^^ { case tableName => UnresolvedStar(Option(tableName)) }

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,35 @@ object HiveTypeCoercion {
111111
* Find the tightest common type of a set of types by continuously applying
112112
* `findTightestCommonTypeOfTwo` on these types.
113113
*/
114-
private def findTightestCommonType(types: Seq[DataType]) = {
114+
private def findTightestCommonType(types: Seq[DataType]): Option[DataType] = {
115115
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
116116
case None => None
117117
case Some(d) => findTightestCommonTypeOfTwo(d, c)
118118
})
119119
}
120120

121+
private def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = (t1, t2) match {
122+
case (t1: DecimalType, t2: DecimalType) =>
123+
Some(DecimalPrecision.widerDecimalType(t1, t2))
124+
case (t: IntegralType, d: DecimalType) =>
125+
Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
126+
case (d: DecimalType, t: IntegralType) =>
127+
Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
128+
case (t: FractionalType, d: DecimalType) =>
129+
Some(DoubleType)
130+
case (d: DecimalType, t: FractionalType) =>
131+
Some(DoubleType)
132+
case _ =>
133+
findTightestCommonTypeToString(t1, t2)
134+
}
135+
136+
private def findWiderCommonType(types: Seq[DataType]) = {
137+
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
138+
case Some(d) => findWiderTypeForTwo(d, c)
139+
case None => None
140+
})
141+
}
142+
121143
/**
122144
* Applies any changes to [[AttributeReference]] data types that are made by other rules to
123145
* instances higher in the query tree.
@@ -184,20 +206,7 @@ object HiveTypeCoercion {
184206

185207
val castedTypes = left.output.zip(right.output).map {
186208
case (lhs, rhs) if lhs.dataType != rhs.dataType =>
187-
(lhs.dataType, rhs.dataType) match {
188-
case (t1: DecimalType, t2: DecimalType) =>
189-
Some(DecimalPrecision.widerDecimalType(t1, t2))
190-
case (t: IntegralType, d: DecimalType) =>
191-
Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
192-
case (d: DecimalType, t: IntegralType) =>
193-
Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
194-
case (t: FractionalType, d: DecimalType) =>
195-
Some(DoubleType)
196-
case (d: DecimalType, t: FractionalType) =>
197-
Some(DoubleType)
198-
case _ =>
199-
findTightestCommonTypeToString(lhs.dataType, rhs.dataType)
200-
}
209+
findWiderTypeForTwo(lhs.dataType, rhs.dataType)
201210
case other => None
202211
}
203212

@@ -238,8 +247,13 @@ object HiveTypeCoercion {
238247
// Skip nodes who's children have not been resolved yet.
239248
case e if !e.childrenResolved => e
240249

241-
case a @ BinaryArithmetic(left @ StringType(), r) =>
242-
a.makeCopy(Array(Cast(left, DoubleType), r))
250+
case a @ BinaryArithmetic(left @ StringType(), right @ DecimalType.Expression(_, _)) =>
251+
a.makeCopy(Array(Cast(left, DecimalType.SYSTEM_DEFAULT), right))
252+
case a @ BinaryArithmetic(left @ DecimalType.Expression(_, _), right @ StringType()) =>
253+
a.makeCopy(Array(left, Cast(right, DecimalType.SYSTEM_DEFAULT)))
254+
255+
case a @ BinaryArithmetic(left @ StringType(), right) =>
256+
a.makeCopy(Array(Cast(left, DoubleType), right))
243257
case a @ BinaryArithmetic(left, right @ StringType()) =>
244258
a.makeCopy(Array(left, Cast(right, DoubleType)))
245259

@@ -557,7 +571,7 @@ object HiveTypeCoercion {
557571
// compatible with every child column.
558572
case c @ Coalesce(es) if es.map(_.dataType).distinct.size > 1 =>
559573
val types = es.map(_.dataType)
560-
findTightestCommonTypeAndPromoteToString(types) match {
574+
findWiderCommonType(types) match {
561575
case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType)))
562576
case None => c
563577
}

sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,8 @@ class MathExpressionsSuite extends QueryTest {
216216
checkAnswer(
217217
ctx.sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " +
218218
s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"),
219-
Seq(Row(0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142))
219+
Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3),
220+
BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142")))
220221
)
221222
}
222223

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
300300
Row(1))
301301
checkAnswer(
302302
sql("SELECT COALESCE(null, 1, 1.5)"),
303-
Row(1.toDouble))
303+
Row(BigDecimal(1)))
304304
checkAnswer(
305305
sql("SELECT COALESCE(null, null, null)"),
306306
Row(null))
@@ -1149,19 +1149,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
11491149

11501150
test("Floating point number format") {
11511151
checkAnswer(
1152-
sql("SELECT 0.3"), Row(0.3)
1152+
sql("SELECT 0.3"), Row(BigDecimal(0.3).underlying())
11531153
)
11541154

11551155
checkAnswer(
1156-
sql("SELECT -0.8"), Row(-0.8)
1156+
sql("SELECT -0.8"), Row(BigDecimal(-0.8).underlying())
11571157
)
11581158

11591159
checkAnswer(
1160-
sql("SELECT .5"), Row(0.5)
1160+
sql("SELECT .5"), Row(BigDecimal(0.5))
11611161
)
11621162

11631163
checkAnswer(
1164-
sql("SELECT -.18"), Row(-0.18)
1164+
sql("SELECT -.18"), Row(BigDecimal(-0.18))
11651165
)
11661166
}
11671167

@@ -1194,11 +1194,11 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
11941194
)
11951195

11961196
checkAnswer(
1197-
sql("SELECT -5.2"), Row(-5.2)
1197+
sql("SELECT -5.2"), Row(BigDecimal(-5.2))
11981198
)
11991199

12001200
checkAnswer(
1201-
sql("SELECT +6.8"), Row(6.8)
1201+
sql("SELECT +6.8"), Row(BigDecimal(6.8))
12021202
)
12031203

12041204
checkAnswer(

sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -422,14 +422,14 @@ class JsonSuite extends QueryTest with TestJsonData {
422422
Row(-89) :: Row(21474836370L) :: Row(21474836470L) :: Nil
423423
)
424424

425-
// Widening to DoubleType
425+
// Widening to DecimalType
426426
checkAnswer(
427427
sql("select num_num_2 + 1.3 from jsonTable where num_num_2 > 1.1"),
428-
Row(21474836472.2) ::
429-
Row(92233720368547758071.3) :: Nil
428+
Row(BigDecimal("21474836472.2")) ::
429+
Row(BigDecimal("92233720368547758071.3")) :: Nil
430430
)
431431

432-
// Widening to DoubleType
432+
// Widening to Double
433433
checkAnswer(
434434
sql("select num_num_3 + 1.2 from jsonTable where num_num_3 > 1.1"),
435435
Row(101.2) :: Row(21474836471.2) :: Nil
@@ -438,13 +438,13 @@ class JsonSuite extends QueryTest with TestJsonData {
438438
// Number and String conflict: resolve the type as number in this query.
439439
checkAnswer(
440440
sql("select num_str + 1.2 from jsonTable where num_str > 14"),
441-
Row(92233720368547758071.2)
441+
Row(BigDecimal("92233720368547758071.2"))
442442
)
443443

444444
// Number and String conflict: resolve the type as number in this query.
445445
checkAnswer(
446446
sql("select num_str + 1.2 from jsonTable where num_str >= 92233720368547758060"),
447-
Row(new java.math.BigDecimal("92233720368547758071.2").doubleValue)
447+
Row(new java.math.BigDecimal("92233720368547758071.2"))
448448
)
449449

450450
// String and Boolean conflict: resolve the type as string.
@@ -503,7 +503,7 @@ class JsonSuite extends QueryTest with TestJsonData {
503503
// Number and String conflict: resolve the type as number in this query.
504504
checkAnswer(
505505
sql("select num_str + 1.2 from jsonTable where num_str > 13"),
506-
Row(14.3) :: Row(92233720368547758071.2) :: Nil
506+
Row(BigDecimal("14.3")) :: Row(BigDecimal("92233720368547758071.2")) :: Nil
507507
)
508508
}
509509

0 commit comments

Comments
 (0)