Skip to content

Commit b25c63c

Browse files
Davies LiuCodingCat
authored andcommitted
[SPARK-9759] [SQL] improve decimal.times() and cast(int, decimalType)
This patch optimize two things: 1. passing MathContext to JavaBigDecimal.multiply/divide/reminder to do right rounding, because java.math.BigDecimal.apply(MathContext) is expensive 2. Cast integer/short/byte to decimal directly (without double) This two optimizations could speed up the end-to-end time of a aggregation (SUM(short * decimal(5, 2)) 75% (from 19s -> 10.8s) Author: Davies Liu <[email protected]> Closes apache#8052 from davies/optimize_decimal and squashes the following commits: 225efad [Davies Liu] improve decimal.times() and cast(int, decimalType)
1 parent d344213 commit b25c63c

File tree

2 files changed

+22
-32
lines changed

2 files changed

+22
-32
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 15 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ case class Cast(child: Expression, dataType: DataType)
155155
case ByteType =>
156156
buildCast[Byte](_, _ != 0)
157157
case DecimalType() =>
158-
buildCast[Decimal](_, _ != Decimal.ZERO)
158+
buildCast[Decimal](_, !_.isZero)
159159
case DoubleType =>
160160
buildCast[Double](_, _ != 0)
161161
case FloatType =>
@@ -315,13 +315,13 @@ case class Cast(child: Expression, dataType: DataType)
315315
case TimestampType =>
316316
// Note that we lose precision here.
317317
buildCast[Long](_, t => changePrecision(Decimal(timestampToDouble(t)), target))
318-
case DecimalType() =>
318+
case dt: DecimalType =>
319319
b => changePrecision(b.asInstanceOf[Decimal].clone(), target)
320-
case LongType =>
321-
b => changePrecision(Decimal(b.asInstanceOf[Long]), target)
322-
case x: NumericType => // All other numeric types can be represented precisely as Doubles
320+
case t: IntegralType =>
321+
b => changePrecision(Decimal(t.integral.asInstanceOf[Integral[Any]].toLong(b)), target)
322+
case x: FractionalType =>
323323
b => try {
324-
changePrecision(Decimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)), target)
324+
changePrecision(Decimal(x.fractional.asInstanceOf[Fractional[Any]].toDouble(b)), target)
325325
} catch {
326326
case _: NumberFormatException => null
327327
}
@@ -534,10 +534,7 @@ case class Cast(child: Expression, dataType: DataType)
534534
(c, evPrim, evNull) =>
535535
s"""
536536
try {
537-
org.apache.spark.sql.types.Decimal tmpDecimal =
538-
new org.apache.spark.sql.types.Decimal().set(
539-
new scala.math.BigDecimal(
540-
new java.math.BigDecimal($c.toString())));
537+
Decimal tmpDecimal = Decimal.apply(new java.math.BigDecimal($c.toString()));
541538
${changePrecision("tmpDecimal", target, evPrim, evNull)}
542539
} catch (java.lang.NumberFormatException e) {
543540
$evNull = true;
@@ -546,12 +543,7 @@ case class Cast(child: Expression, dataType: DataType)
546543
case BooleanType =>
547544
(c, evPrim, evNull) =>
548545
s"""
549-
org.apache.spark.sql.types.Decimal tmpDecimal = null;
550-
if ($c) {
551-
tmpDecimal = new org.apache.spark.sql.types.Decimal().set(1);
552-
} else {
553-
tmpDecimal = new org.apache.spark.sql.types.Decimal().set(0);
554-
}
546+
Decimal tmpDecimal = $c ? Decimal.apply(1) : Decimal.apply(0);
555547
${changePrecision("tmpDecimal", target, evPrim, evNull)}
556548
"""
557549
case DateType =>
@@ -561,32 +553,28 @@ case class Cast(child: Expression, dataType: DataType)
561553
// Note that we lose precision here.
562554
(c, evPrim, evNull) =>
563555
s"""
564-
org.apache.spark.sql.types.Decimal tmpDecimal =
565-
new org.apache.spark.sql.types.Decimal().set(
566-
scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)}));
556+
Decimal tmpDecimal = Decimal.apply(
557+
scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)}));
567558
${changePrecision("tmpDecimal", target, evPrim, evNull)}
568559
"""
569560
case DecimalType() =>
570561
(c, evPrim, evNull) =>
571562
s"""
572-
org.apache.spark.sql.types.Decimal tmpDecimal = $c.clone();
563+
Decimal tmpDecimal = $c.clone();
573564
${changePrecision("tmpDecimal", target, evPrim, evNull)}
574565
"""
575-
case LongType =>
566+
case x: IntegralType =>
576567
(c, evPrim, evNull) =>
577568
s"""
578-
org.apache.spark.sql.types.Decimal tmpDecimal =
579-
new org.apache.spark.sql.types.Decimal().set($c);
569+
Decimal tmpDecimal = Decimal.apply((long) $c);
580570
${changePrecision("tmpDecimal", target, evPrim, evNull)}
581571
"""
582-
case x: NumericType =>
572+
case x: FractionalType =>
583573
// All other numeric types can be represented precisely as Doubles
584574
(c, evPrim, evNull) =>
585575
s"""
586576
try {
587-
org.apache.spark.sql.types.Decimal tmpDecimal =
588-
new org.apache.spark.sql.types.Decimal().set(
589-
scala.math.BigDecimal.valueOf((double) $c));
577+
Decimal tmpDecimal = Decimal.apply(scala.math.BigDecimal.valueOf((double) $c));
590578
${changePrecision("tmpDecimal", target, evPrim, evNull)}
591579
} catch (java.lang.NumberFormatException e) {
592580
$evNull = true;

sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,9 @@ final class Decimal extends Ordered[Decimal] with Serializable {
139139

140140
def toBigDecimal: BigDecimal = {
141141
if (decimalVal.ne(null)) {
142-
decimalVal(MATH_CONTEXT)
142+
decimalVal
143143
} else {
144-
BigDecimal(longVal, _scale)(MATH_CONTEXT)
144+
BigDecimal(longVal, _scale)
145145
}
146146
}
147147

@@ -280,13 +280,15 @@ final class Decimal extends Ordered[Decimal] with Serializable {
280280
}
281281

282282
// HiveTypeCoercion will take care of the precision, scale of result
283-
def * (that: Decimal): Decimal = Decimal(toBigDecimal * that.toBigDecimal)
283+
def * (that: Decimal): Decimal =
284+
Decimal(toJavaBigDecimal.multiply(that.toJavaBigDecimal, MATH_CONTEXT))
284285

285286
def / (that: Decimal): Decimal =
286-
if (that.isZero) null else Decimal(toBigDecimal / that.toBigDecimal)
287+
if (that.isZero) null else Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal, MATH_CONTEXT))
287288

288289
def % (that: Decimal): Decimal =
289-
if (that.isZero) null else Decimal(toBigDecimal % that.toBigDecimal)
290+
if (that.isZero) null
291+
else Decimal(toJavaBigDecimal.remainder(that.toJavaBigDecimal, MATH_CONTEXT))
290292

291293
def remainder(that: Decimal): Decimal = this % that
292294

0 commit comments

Comments
 (0)