Skip to content

Commit 17591d9

Browse files
kevinyu98cloud-fan
authored andcommitted
[SPARK-11827][SQL] Adding java.math.BigInteger support in Java type inference for POJOs and Java collections
Hello : Can you help check this PR? I am adding support for the java.math.BigInteger for java bean code path. I saw internally spark is converting the BigInteger to BigDecimal in ColumnType.scala and CatalystRowConverter.scala. I use the similar way and convert the BigInteger to the BigDecimal. . Author: Kevin Yu <[email protected]> Closes #10125 from kevinyu98/working_on_spark-11827.
1 parent d5c47f8 commit 17591d9

File tree

8 files changed

+76
-6
lines changed

8 files changed

+76
-6
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst
1919

2020
import java.lang.{Iterable => JavaIterable}
2121
import java.math.{BigDecimal => JavaBigDecimal}
22+
import java.math.{BigInteger => JavaBigInteger}
2223
import java.sql.{Date, Timestamp}
2324
import java.util.{Map => JavaMap}
2425
import javax.annotation.Nullable
@@ -326,6 +327,7 @@ object CatalystTypeConverters {
326327
val decimal = scalaValue match {
327328
case d: BigDecimal => Decimal(d)
328329
case d: JavaBigDecimal => Decimal(d)
330+
case d: JavaBigInteger => Decimal(d)
329331
case d: Decimal => d
330332
}
331333
if (decimal.changePrecision(dataType.precision, dataType.scale)) {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ object JavaTypeInference {
8989
case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)
9090

9191
case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType.SYSTEM_DEFAULT, true)
92+
case c: Class[_] if c == classOf[java.math.BigInteger] => (DecimalType.BigIntDecimal, true)
9293
case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
9394
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
9495

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,12 @@ object ScalaReflection extends ScalaReflection {
259259
case t if t <:< localTypeOf[BigDecimal] =>
260260
Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal]))
261261

262+
case t if t <:< localTypeOf[java.math.BigInteger] =>
263+
Invoke(getPath, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]))
264+
265+
case t if t <:< localTypeOf[scala.math.BigInt] =>
266+
Invoke(getPath, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]))
267+
262268
case t if t <:< localTypeOf[Array[_]] =>
263269
val TypeRef(_, _, Seq(elementType)) = t
264270

@@ -592,6 +598,20 @@ object ScalaReflection extends ScalaReflection {
592598
"apply",
593599
inputObject :: Nil)
594600

601+
case t if t <:< localTypeOf[java.math.BigInteger] =>
602+
StaticInvoke(
603+
Decimal.getClass,
604+
DecimalType.BigIntDecimal,
605+
"apply",
606+
inputObject :: Nil)
607+
608+
case t if t <:< localTypeOf[scala.math.BigInt] =>
609+
StaticInvoke(
610+
Decimal.getClass,
611+
DecimalType.BigIntDecimal,
612+
"apply",
613+
inputObject :: Nil)
614+
595615
case t if t <:< localTypeOf[java.lang.Integer] =>
596616
Invoke(inputObject, "intValue", IntegerType)
597617
case t if t <:< localTypeOf[java.lang.Long] =>
@@ -736,6 +756,10 @@ object ScalaReflection extends ScalaReflection {
736756
case t if t <:< localTypeOf[BigDecimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
737757
case t if t <:< localTypeOf[java.math.BigDecimal] =>
738758
Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
759+
case t if t <:< localTypeOf[java.math.BigInteger] =>
760+
Schema(DecimalType.BigIntDecimal, nullable = true)
761+
case t if t <:< localTypeOf[scala.math.BigInt] =>
762+
Schema(DecimalType.BigIntDecimal, nullable = true)
739763
case t if t <:< localTypeOf[Decimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
740764
case t if t <:< localTypeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
741765
case t if t <:< localTypeOf[java.lang.Long] => Schema(LongType, nullable = true)

sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.types
1919

20-
import java.math.{MathContext, RoundingMode}
20+
import java.math.{BigInteger, MathContext, RoundingMode}
2121

2222
import org.apache.spark.annotation.DeveloperApi
2323

@@ -128,6 +128,23 @@ final class Decimal extends Ordered[Decimal] with Serializable {
128128
this
129129
}
130130

131+
/**
132+
* Set this Decimal to the given BigInteger value. Will have precision 38 and scale 0.
133+
*/
134+
def set(bigintval: BigInteger): Decimal = {
135+
try {
136+
this.decimalVal = null
137+
this.longVal = bigintval.longValueExact()
138+
this._precision = DecimalType.MAX_PRECISION
139+
this._scale = 0
140+
this
141+
}
142+
catch {
143+
case e: ArithmeticException =>
144+
throw new IllegalArgumentException(s"BigInteger ${bigintval} too large for decimal")
145+
}
146+
}
147+
131148
/**
132149
* Set this Decimal to the given Decimal value.
133150
*/
@@ -155,6 +172,10 @@ final class Decimal extends Ordered[Decimal] with Serializable {
155172
}
156173
}
157174

175+
def toScalaBigInt: BigInt = BigInt(toLong)
176+
177+
def toJavaBigInteger: java.math.BigInteger = java.math.BigInteger.valueOf(toLong)
178+
158179
def toUnscaledLong: Long = {
159180
if (decimalVal.ne(null)) {
160181
decimalVal.underlying().unscaledValue().longValue()
@@ -371,6 +392,10 @@ object Decimal {
371392

372393
def apply(value: java.math.BigDecimal): Decimal = new Decimal().set(value)
373394

395+
def apply(value: java.math.BigInteger): Decimal = new Decimal().set(value)
396+
397+
def apply(value: scala.math.BigInt): Decimal = new Decimal().set(value.bigInteger)
398+
374399
def apply(value: BigDecimal, precision: Int, scale: Int): Decimal =
375400
new Decimal().set(value, precision, scale)
376401

@@ -387,6 +412,8 @@ object Decimal {
387412
value match {
388413
case j: java.math.BigDecimal => apply(j)
389414
case d: BigDecimal => apply(d)
415+
case k: scala.math.BigInt => apply(k)
416+
case l: java.math.BigInteger => apply(l)
390417
case d: Decimal => d
391418
}
392419
}

sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ object DecimalType extends AbstractDataType {
117117
private[sql] val LongDecimal = DecimalType(20, 0)
118118
private[sql] val FloatDecimal = DecimalType(14, 7)
119119
private[sql] val DoubleDecimal = DecimalType(30, 15)
120+
private[sql] val BigIntDecimal = DecimalType(38, 0)
120121

121122
private[sql] def forType(dataType: DataType): DecimalType = dataType match {
122123
case ByteType => ByteDecimal

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.encoders
1919

20+
import java.math.BigInteger
2021
import java.sql.{Date, Timestamp}
2122
import java.util.Arrays
2223

@@ -109,7 +110,8 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
109110

110111
encodeDecodeTest(BigDecimal("32131413.211321313"), "scala decimal")
111112
encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal")
112-
113+
encodeDecodeTest(BigInt("23134123123"), "scala biginteger")
114+
encodeDecodeTest(new BigInteger("23134123123"), "java BigInteger")
113115
encodeDecodeTest(Decimal("32131413.211321313"), "catalyst decimal")
114116

115117
encodeDecodeTest("hello", "string")

sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import java.net.URISyntaxException;
2222
import java.net.URL;
2323
import java.util.*;
24+
import java.math.BigInteger;
25+
import java.math.BigDecimal;
2426

2527
import scala.collection.JavaConverters;
2628
import scala.collection.Seq;
@@ -130,6 +132,7 @@ public static class Bean implements Serializable {
130132
private Integer[] b = { 0, 1 };
131133
private Map<String, int[]> c = ImmutableMap.of("hello", new int[] { 1, 2 });
132134
private List<String> d = Arrays.asList("floppy", "disk");
135+
private BigInteger e = new BigInteger("1234567");
133136

134137
public double getA() {
135138
return a;
@@ -146,6 +149,8 @@ public Map<String, int[]> getC() {
146149
public List<String> getD() {
147150
return d;
148151
}
152+
153+
public BigInteger getE() { return e; }
149154
}
150155

151156
void validateDataFrameWithBeans(Bean bean, Dataset<Row> df) {
@@ -163,7 +168,9 @@ void validateDataFrameWithBeans(Bean bean, Dataset<Row> df) {
163168
Assert.assertEquals(
164169
new StructField("d", new ArrayType(DataTypes.StringType, true), true, Metadata.empty()),
165170
schema.apply("d"));
166-
Row first = df.select("a", "b", "c", "d").first();
171+
Assert.assertEquals(new StructField("e", DataTypes.createDecimalType(38,0), true, Metadata.empty()),
172+
schema.apply("e"));
173+
Row first = df.select("a", "b", "c", "d", "e").first();
167174
Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0);
168175
// Now Java lists and maps are converted to Scala Seq's and Map's. Once we get a Seq below,
169176
// verify that it has the expected length, and contains expected elements.
@@ -182,6 +189,8 @@ void validateDataFrameWithBeans(Bean bean, Dataset<Row> df) {
182189
for (int i = 0; i < d.length(); i++) {
183190
Assert.assertEquals(bean.getD().get(i), d.apply(i));
184191
}
192+
// Java.math.BigInteger is equavient to Spark Decimal(38,0)
193+
Assert.assertEquals(new BigDecimal(bean.getE()), first.getDecimal(4));
185194
}
186195

187196
@Test

sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@ case class ReflectData(
3434
decimalField: java.math.BigDecimal,
3535
date: Date,
3636
timestampField: Timestamp,
37-
seqInt: Seq[Int])
37+
seqInt: Seq[Int],
38+
javaBigInt: java.math.BigInteger,
39+
scalaBigInt: scala.math.BigInt)
3840

3941
case class NullReflectData(
4042
intField: java.lang.Integer,
@@ -77,13 +79,15 @@ class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSQLContext {
7779

7880
test("query case class RDD") {
7981
val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
80-
new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3))
82+
new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3),
83+
new java.math.BigInteger("1"), scala.math.BigInt(1))
8184
Seq(data).toDF().createOrReplaceTempView("reflectData")
8285

8386
assert(sql("SELECT * FROM reflectData").collect().head ===
8487
Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
8588
new java.math.BigDecimal(1), Date.valueOf("1970-01-01"),
86-
new Timestamp(12345), Seq(1, 2, 3)))
89+
new Timestamp(12345), Seq(1, 2, 3), new java.math.BigDecimal(1),
90+
new java.math.BigDecimal(1)))
8791
}
8892

8993
test("query case class RDD with nulls") {

0 commit comments

Comments
 (0)