Skip to content

Commit 15667a0

Browse files
Davies Liurxin
authored andcommitted
[SPARK-9281] [SQL] use decimal or double when parsing SQL
Right now, we use double to parse all the float number in SQL. When it's used in expression together with DecimalType, it will turn the decimal into double as well. Also it will loss some precision when using double. This PR change to parse float number to decimal or double, based on it's using scientific notation or not, see https://msdn.microsoft.com/en-us/library/ms179899.aspx This is a break change, should we doc it somewhere? Author: Davies Liu <[email protected]> Closes apache#7642 from davies/parse_decimal and squashes the following commits: 1f576d9 [Davies Liu] Merge branch 'master' of github.com:apache/spark into parse_decimal 5e142b6 [Davies Liu] fix scala style eca99de [Davies Liu] fix tests 2afe702 [Davies Liu] Merge branch 'master' of github.com:apache/spark into parse_decimal f4a320b [Davies Liu] Update SqlParser.scala 1c48e34 [Davies Liu] use decimal or double when parsing SQL
1 parent 6309b93 commit 15667a0

File tree

6 files changed

+62
-37
lines changed

6 files changed

+62
-37
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
332332
protected lazy val numericLiteral: Parser[Literal] =
333333
( integral ^^ { case i => Literal(toNarrowestIntegerType(i)) }
334334
| sign.? ~ unsignedFloat ^^ {
335-
// TODO(davies): some precisions may loss, we should create decimal literal
336-
case s ~ f => Literal(BigDecimal(s.getOrElse("") + f).doubleValue())
335+
case s ~ f => Literal(toDecimalOrDouble(s.getOrElse("") + f))
337336
}
338337
)
339338

@@ -420,6 +419,17 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
420419
}
421420
}
422421

422+
private def toDecimalOrDouble(value: String): Any = {
423+
val decimal = BigDecimal(value)
424+
// follow the behavior in MS SQL Server
425+
// https://msdn.microsoft.com/en-us/library/ms179899.aspx
426+
if (value.contains('E') || value.contains('e')) {
427+
decimal.doubleValue()
428+
} else {
429+
decimal.underlying()
430+
}
431+
}
432+
423433
protected lazy val baseExpression: Parser[Expression] =
424434
( "*" ^^^ UnresolvedStar(None)
425435
| ident <~ "." ~ "*" ^^ { case tableName => UnresolvedStar(Option(tableName)) }

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,35 @@ object HiveTypeCoercion {
109109
* Find the tightest common type of a set of types by continuously applying
110110
* `findTightestCommonTypeOfTwo` on these types.
111111
*/
112-
private def findTightestCommonType(types: Seq[DataType]) = {
112+
private def findTightestCommonType(types: Seq[DataType]): Option[DataType] = {
113113
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
114114
case None => None
115115
case Some(d) => findTightestCommonTypeOfTwo(d, c)
116116
})
117117
}
118118

119+
private def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = (t1, t2) match {
120+
case (t1: DecimalType, t2: DecimalType) =>
121+
Some(DecimalPrecision.widerDecimalType(t1, t2))
122+
case (t: IntegralType, d: DecimalType) =>
123+
Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
124+
case (d: DecimalType, t: IntegralType) =>
125+
Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
126+
case (t: FractionalType, d: DecimalType) =>
127+
Some(DoubleType)
128+
case (d: DecimalType, t: FractionalType) =>
129+
Some(DoubleType)
130+
case _ =>
131+
findTightestCommonTypeToString(t1, t2)
132+
}
133+
134+
private def findWiderCommonType(types: Seq[DataType]) = {
135+
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
136+
case Some(d) => findWiderTypeForTwo(d, c)
137+
case None => None
138+
})
139+
}
140+
119141
/**
120142
* Applies any changes to [[AttributeReference]] data types that are made by other rules to
121143
* instances higher in the query tree.
@@ -182,20 +204,7 @@ object HiveTypeCoercion {
182204

183205
val castedTypes = left.output.zip(right.output).map {
184206
case (lhs, rhs) if lhs.dataType != rhs.dataType =>
185-
(lhs.dataType, rhs.dataType) match {
186-
case (t1: DecimalType, t2: DecimalType) =>
187-
Some(DecimalPrecision.widerDecimalType(t1, t2))
188-
case (t: IntegralType, d: DecimalType) =>
189-
Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
190-
case (d: DecimalType, t: IntegralType) =>
191-
Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
192-
case (t: FractionalType, d: DecimalType) =>
193-
Some(DoubleType)
194-
case (d: DecimalType, t: FractionalType) =>
195-
Some(DoubleType)
196-
case _ =>
197-
findTightestCommonTypeToString(lhs.dataType, rhs.dataType)
198-
}
207+
findWiderTypeForTwo(lhs.dataType, rhs.dataType)
199208
case other => None
200209
}
201210

@@ -236,8 +245,13 @@ object HiveTypeCoercion {
236245
// Skip nodes who's children have not been resolved yet.
237246
case e if !e.childrenResolved => e
238247

239-
case a @ BinaryArithmetic(left @ StringType(), r) =>
240-
a.makeCopy(Array(Cast(left, DoubleType), r))
248+
case a @ BinaryArithmetic(left @ StringType(), right @ DecimalType.Expression(_, _)) =>
249+
a.makeCopy(Array(Cast(left, DecimalType.SYSTEM_DEFAULT), right))
250+
case a @ BinaryArithmetic(left @ DecimalType.Expression(_, _), right @ StringType()) =>
251+
a.makeCopy(Array(left, Cast(right, DecimalType.SYSTEM_DEFAULT)))
252+
253+
case a @ BinaryArithmetic(left @ StringType(), right) =>
254+
a.makeCopy(Array(Cast(left, DoubleType), right))
241255
case a @ BinaryArithmetic(left, right @ StringType()) =>
242256
a.makeCopy(Array(left, Cast(right, DoubleType)))
243257

@@ -543,7 +557,7 @@ object HiveTypeCoercion {
543557
// compatible with every child column.
544558
case c @ Coalesce(es) if es.map(_.dataType).distinct.size > 1 =>
545559
val types = es.map(_.dataType)
546-
findTightestCommonTypeAndPromoteToString(types) match {
560+
findWiderCommonType(types) match {
547561
case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType)))
548562
case None => c
549563
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,11 @@ class AnalysisSuite extends AnalysisTest {
145145
'e / 'e as 'div5))
146146
val pl = plan.asInstanceOf[Project].projectList
147147

148-
// StringType will be promoted into Double
149148
assert(pl(0).dataType == DoubleType)
150149
assert(pl(1).dataType == DoubleType)
151150
assert(pl(2).dataType == DoubleType)
152-
assert(pl(3).dataType == DoubleType)
151+
// StringType will be promoted into Decimal(38, 18)
152+
assert(pl(3).dataType == DecimalType(38, 29))
153153
assert(pl(4).dataType == DoubleType)
154154
}
155155

sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,8 @@ class MathExpressionsSuite extends QueryTest {
216216
checkAnswer(
217217
ctx.sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " +
218218
s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"),
219-
Seq(Row(0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142))
219+
Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3),
220+
BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142")))
220221
)
221222
}
222223

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
368368
Row(1))
369369
checkAnswer(
370370
sql("SELECT COALESCE(null, 1, 1.5)"),
371-
Row(1.toDouble))
371+
Row(BigDecimal(1)))
372372
checkAnswer(
373373
sql("SELECT COALESCE(null, null, null)"),
374374
Row(null))
@@ -1234,19 +1234,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
12341234

12351235
test("Floating point number format") {
12361236
checkAnswer(
1237-
sql("SELECT 0.3"), Row(0.3)
1237+
sql("SELECT 0.3"), Row(BigDecimal(0.3).underlying())
12381238
)
12391239

12401240
checkAnswer(
1241-
sql("SELECT -0.8"), Row(-0.8)
1241+
sql("SELECT -0.8"), Row(BigDecimal(-0.8).underlying())
12421242
)
12431243

12441244
checkAnswer(
1245-
sql("SELECT .5"), Row(0.5)
1245+
sql("SELECT .5"), Row(BigDecimal(0.5))
12461246
)
12471247

12481248
checkAnswer(
1249-
sql("SELECT -.18"), Row(-0.18)
1249+
sql("SELECT -.18"), Row(BigDecimal(-0.18))
12501250
)
12511251
}
12521252

@@ -1279,11 +1279,11 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
12791279
)
12801280

12811281
checkAnswer(
1282-
sql("SELECT -5.2"), Row(-5.2)
1282+
sql("SELECT -5.2"), Row(BigDecimal(-5.2))
12831283
)
12841284

12851285
checkAnswer(
1286-
sql("SELECT +6.8"), Row(6.8)
1286+
sql("SELECT +6.8"), Row(BigDecimal(6.8))
12871287
)
12881288

12891289
checkAnswer(

sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -422,14 +422,14 @@ class JsonSuite extends QueryTest with TestJsonData {
422422
Row(-89) :: Row(21474836370L) :: Row(21474836470L) :: Nil
423423
)
424424

425-
// Widening to DoubleType
425+
// Widening to DecimalType
426426
checkAnswer(
427427
sql("select num_num_2 + 1.3 from jsonTable where num_num_2 > 1.1"),
428-
Row(21474836472.2) ::
429-
Row(92233720368547758071.3) :: Nil
428+
Row(BigDecimal("21474836472.2")) ::
429+
Row(BigDecimal("92233720368547758071.3")) :: Nil
430430
)
431431

432-
// Widening to DoubleType
432+
// Widening to Double
433433
checkAnswer(
434434
sql("select num_num_3 + 1.2 from jsonTable where num_num_3 > 1.1"),
435435
Row(101.2) :: Row(21474836471.2) :: Nil
@@ -438,13 +438,13 @@ class JsonSuite extends QueryTest with TestJsonData {
438438
// Number and String conflict: resolve the type as number in this query.
439439
checkAnswer(
440440
sql("select num_str + 1.2 from jsonTable where num_str > 14"),
441-
Row(92233720368547758071.2)
441+
Row(BigDecimal("92233720368547758071.2"))
442442
)
443443

444444
// Number and String conflict: resolve the type as number in this query.
445445
checkAnswer(
446446
sql("select num_str + 1.2 from jsonTable where num_str >= 92233720368547758060"),
447-
Row(new java.math.BigDecimal("92233720368547758071.2").doubleValue)
447+
Row(new java.math.BigDecimal("92233720368547758071.2"))
448448
)
449449

450450
// String and Boolean conflict: resolve the type as string.
@@ -503,7 +503,7 @@ class JsonSuite extends QueryTest with TestJsonData {
503503
// Number and String conflict: resolve the type as number in this query.
504504
checkAnswer(
505505
sql("select num_str + 1.2 from jsonTable where num_str > 13"),
506-
Row(14.3) :: Row(92233720368547758071.2) :: Nil
506+
Row(BigDecimal("14.3")) :: Row(BigDecimal("92233720368547758071.2")) :: Nil
507507
)
508508
}
509509

0 commit comments

Comments
 (0)