From 3b44c5978bd44db986621d3e8511e9165b66926b Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Wed, 20 Apr 2016 11:06:30 -0700 Subject: [PATCH 01/10] adding testcase --- .../org/apache/spark/sql/DataFrameSuite.scala | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index e953a6e8ef0c..009c101e746d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1429,4 +1429,23 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { getMessage() assert(e1.startsWith("Path does not exist")) } + + test("SPARK-12987: drop column ") { + val df = Seq((1, 2)).toDF("a_b", "a.c") + val df1 = df.drop("a_b") + checkAnswer(df1, Row(2)) + assert(df1.schema.map(_.name) === Seq("a.c")) + } + + test("SPARK-14759: drop column ") { + val df1 = sqlContext.createDataFrame(Seq((1, 2), (3, 4))).toDF("any", "hour") + val df2 = sqlContext.createDataFrame(Seq((1, 3))).toDF("any").withColumn("hour", lit(10)) + val j = df1.join(df2, $"df1.hour" === $"df2.hour", "left") + assert(j.schema.map(_.name) === Seq("any","hour","any","hour")) + print("Columns after join:{0}".format(j.columns)) + val jj = j.drop($"df2.hour") + assert(jj.schema.map(_.name) === Seq("any")) + print("Columns after drop 'hour':{0}".format(jj.columns)) + } + } From ae0be70734003cb0281f07a08f6af7c030475360 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Fri, 13 May 2016 10:58:37 -0700 Subject: [PATCH 02/10] fix comments --- .../sql/catalyst/CatalystTypeConverters.scala | 6 ++++- .../sql/catalyst/JavaTypeInference.scala | 1 + .../spark/sql/catalyst/ScalaReflection.scala | 22 ++++++++++++++++++ .../org/apache/spark/sql/types/Decimal.scala | 23 ++++++++++++++++++- .../apache/spark/sql/types/DecimalType.scala | 3 +++ .../apache/spark/sql/types/DecimalSuite.scala | 2 ++ .../apache/spark/sql/JavaDataFrameSuite.java | 11 ++++++++- .../org/apache/spark/sql/DataFrameSuite.scala | 19 --------------- .../sql/ScalaReflectionRelationSuite.scala | 14 ++++++++--- 9 files changed, 76 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 9bfc38163914..82d567e0e783 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -19,11 +19,13 @@ package org.apache.spark.sql.catalyst import java.lang.{Iterable => JavaIterable} import java.math.{BigDecimal => JavaBigDecimal} +import java.math.{BigInteger => JavaBigInteger} import java.sql.{Date, Timestamp} import java.util.{Map => JavaMap} import javax.annotation.Nullable import scala.language.existentials +import scala.math.{BigInt => ScalaBigInt} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ @@ -321,11 +323,13 @@ object CatalystTypeConverters { } private class DecimalConverter(dataType: DecimalType) - extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] { + extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] { override def toCatalystImpl(scalaValue: Any): Decimal = { val decimal = scalaValue match { case d: BigDecimal => Decimal(d) case d: JavaBigDecimal => Decimal(d) + case d: JavaBigInteger => Decimal(d) + case d: ScalaBigInt => Decimal(d) case d: Decimal => d } if (decimal.changePrecision(dataType.precision, dataType.scale)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 6f9fbbbead47..271f4c1ddc2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -83,6 +83,7 @@ object JavaTypeInference { case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true) case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType.SYSTEM_DEFAULT, true) + case c: Class[_] if c == classOf[java.math.BigInteger] => (DecimalType.BIGINT_DEFAULT, true) case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true) case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 79bb7a701baf..9625aea31de4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -258,6 +258,12 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[BigDecimal] => Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal])) + case t if t <:< localTypeOf[java.math.BigInteger] => + Invoke(getPath, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger])) + + case t if t <:< localTypeOf[scala.math.BigInt] => + Invoke(getPath, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt])) + case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t @@ -590,6 +596,18 @@ object ScalaReflection extends ScalaReflection { DecimalType.SYSTEM_DEFAULT, "apply", inputObject :: Nil) + case t if t <:< localTypeOf[java.math.BigInteger] => + StaticInvoke( + Decimal.getClass, + DecimalType.BIGINT_DEFAULT, + "apply", + inputObject :: Nil) + case t if t <:< localTypeOf[scala.math.BigInt] => + StaticInvoke( + Decimal.getClass, + DecimalType.BIGINT_DEFAULT, + "apply", + inputObject :: Nil) case t if t <:< localTypeOf[java.lang.Integer] => Invoke(inputObject, "intValue", IntegerType) @@ -735,6 +753,10 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[BigDecimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) case t if t <:< localTypeOf[java.math.BigDecimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) + case t if t <:< localTypeOf[java.math.BigInteger] => + Schema(DecimalType.BIGINT_DEFAULT, nullable = true) + case t if t <:< localTypeOf[scala.math.BigInt] => + Schema(DecimalType.BIGINT_DEFAULT, nullable = true) case t if t <:< localTypeOf[Decimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) case t if t <:< localTypeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) case t if t <:< localTypeOf[java.lang.Long] => Schema(LongType, nullable = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 6f4ec6b70191..7e684ad50d3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.types -import java.math.{MathContext, RoundingMode} +import java.math.{BigInteger, MathContext, RoundingMode} import org.apache.spark.annotation.DeveloperApi @@ -128,6 +128,23 @@ final class Decimal extends Ordered[Decimal] with Serializable { this } + /** + * Set this Decimal to the given BigInteger value. Will have precision 38 and scale 0. + */ + def set(BigIntVal: BigInteger): Decimal = { + try { + this.decimalVal = null + this.longVal = BigIntVal.longValueExact() + this._precision = DecimalType.MAX_PRECISION + this._scale = 0 + this + } + catch { + case e: ArithmeticException => + throw new IllegalArgumentException(s"BigInteger ${BigIntVal} too large for decimal") + } + } + /** * Set this Decimal to the given Decimal value. */ @@ -371,6 +388,10 @@ object Decimal { def apply(value: java.math.BigDecimal): Decimal = new Decimal().set(value) + def apply(value: java.math.BigInteger): Decimal = new Decimal().set(value) + + def apply(value: scala.math.BigInt): Decimal = new Decimal().set(value.bigInteger) + def apply(value: BigDecimal, precision: Int, scale: Int): Decimal = new Decimal().set(value, precision, scale) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 9c1319c1c5e6..e8c1b00f9301 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.types +import java.math.BigInteger + import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi @@ -109,6 +111,7 @@ object DecimalType extends AbstractDataType { val MAX_SCALE = 38 val SYSTEM_DEFAULT: DecimalType = DecimalType(MAX_PRECISION, 18) val USER_DEFAULT: DecimalType = DecimalType(10, 0) + val BIGINT_DEFAULT: DecimalType = DecimalType(MAX_PRECISION, 0) // The decimal types compatible with other numeric types private[sql] val ByteDecimal = DecimalType(3, 0) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index e1675c95907a..22068b792133 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.types +import java.math.BigInteger + import scala.language.postfixOps import org.scalatest.PrivateMethodTester diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 324ebbae3876..c81785cf05ef 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -21,6 +21,8 @@ import java.net.URISyntaxException; import java.net.URL; import java.util.*; +import java.math.BigInteger; +import java.math.BigDecimal; import scala.collection.JavaConverters; import scala.collection.Seq; @@ -130,6 +132,7 @@ public static class Bean implements Serializable { private Integer[] b = { 0, 1 }; private Map c = ImmutableMap.of("hello", new int[] { 1, 2 }); private List d = Arrays.asList("floppy", "disk"); + private BigInteger e = new BigInteger("1234567"); public double getA() { return a; @@ -146,6 +149,8 @@ public Map getC() { public List getD() { return d; } + + public BigInteger getE() { return e; } } void validateDataFrameWithBeans(Bean bean, Dataset df) { @@ -163,7 +168,9 @@ void validateDataFrameWithBeans(Bean bean, Dataset df) { Assert.assertEquals( new StructField("d", new ArrayType(DataTypes.StringType, true), true, Metadata.empty()), schema.apply("d")); - Row first = df.select("a", "b", "c", "d").first(); + Assert.assertEquals(new StructField("e", DataTypes.createDecimalType(38,0), true, Metadata.empty()), + schema.apply("e")); + Row first = df.select("a", "b", "c", "d","e").first(); Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0); // Now Java lists and maps are converted to Scala Seq's and Map's. Once we get a Seq below, // verify that it has the expected length, and contains expected elements. @@ -182,6 +189,8 @@ void validateDataFrameWithBeans(Bean bean, Dataset df) { for (int i = 0; i < d.length(); i++) { Assert.assertEquals(bean.getD().get(i), d.apply(i)); } + // Java.math.BigInteger is equavient to Spark Decimal(38,0) + Assert.assertEquals(new BigDecimal(bean.getE()), first.getDecimal(4).setScale(0)); } @Test diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index c9f5a0dc33e2..f77403c13e7c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1475,23 +1475,4 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { getMessage() assert(e1.startsWith("Path does not exist")) } - - test("SPARK-12987: drop column ") { - val df = Seq((1, 2)).toDF("a_b", "a.c") - val df1 = df.drop("a_b") - checkAnswer(df1, Row(2)) - assert(df1.schema.map(_.name) === Seq("a.c")) - } - - test("SPARK-14759: drop column ") { - val df1 = sqlContext.createDataFrame(Seq((1, 2), (3, 4))).toDF("any", "hour") - val df2 = sqlContext.createDataFrame(Seq((1, 3))).toDF("any").withColumn("hour", lit(10)) - val j = df1.join(df2, $"df1.hour" === $"df2.hour", "left") - assert(j.schema.map(_.name) === Seq("any","hour","any","hour")) - print("Columns after join:{0}".format(j.columns)) - val jj = j.drop($"df2.hour") - assert(jj.schema.map(_.name) === Seq("any")) - print("Columns after drop 'hour':{0}".format(jj.columns)) - } - } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index 295f02f9a7b5..87671e39443e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -34,7 +34,13 @@ case class ReflectData( decimalField: java.math.BigDecimal, date: Date, timestampField: Timestamp, - seqInt: Seq[Int]) + seqInt: Seq[Int], + javaBigInt: java.math.BigInteger, + scalaBigInt: scala.math.BigInt) + +case class ReflectData3( + scalaBigInt: scala.math.BigInt + ) case class NullReflectData( intField: java.lang.Integer, @@ -77,13 +83,15 @@ class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSQLContext { test("query case class RDD") { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, - new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3)) + new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3), + new java.math.BigInteger("1"), scala.math.BigInt(1)) Seq(data).toDF().registerTempTable("reflectData") assert(sql("SELECT * FROM reflectData").collect().head === Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), - new Timestamp(12345), Seq(1, 2, 3))) + new Timestamp(12345), Seq(1, 2, 3), new java.math.BigDecimal(1), + new java.math.BigDecimal(1))) } test("query case class RDD with nulls") { From 741daffd4db0462792a92e0fb3097a461777f3e2 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Fri, 13 May 2016 11:12:52 -0700 Subject: [PATCH 03/10] fixing style --- .../scala/org/apache/spark/sql/catalyst/ScalaReflection.scala | 2 ++ .../org/apache/spark/sql/ScalaReflectionRelationSuite.scala | 4 ---- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 9625aea31de4..2abb46f7ae48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -596,12 +596,14 @@ object ScalaReflection extends ScalaReflection { DecimalType.SYSTEM_DEFAULT, "apply", inputObject :: Nil) + case t if t <:< localTypeOf[java.math.BigInteger] => StaticInvoke( Decimal.getClass, DecimalType.BIGINT_DEFAULT, "apply", inputObject :: Nil) + case t if t <:< localTypeOf[scala.math.BigInt] => StaticInvoke( Decimal.getClass, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index 87671e39443e..1873d11d54b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -38,10 +38,6 @@ case class ReflectData( javaBigInt: java.math.BigInteger, scalaBigInt: scala.math.BigInt) -case class ReflectData3( - scalaBigInt: scala.math.BigInt - ) - case class NullReflectData( intField: java.lang.Integer, longField: java.lang.Long, From bbed47aabe572d7f91d0d011249df4a9624f0f97 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Tue, 17 May 2016 14:20:20 -0700 Subject: [PATCH 04/10] address comments --- .../spark/sql/catalyst/CatalystTypeConverters.scala | 6 +----- .../scala/org/apache/spark/sql/types/Decimal.scala | 13 ++++++++++--- .../catalyst/encoders/ExpressionEncoderSuite.scala | 4 ++++ 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 82d567e0e783..9bfc38163914 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -19,13 +19,11 @@ package org.apache.spark.sql.catalyst import java.lang.{Iterable => JavaIterable} import java.math.{BigDecimal => JavaBigDecimal} -import java.math.{BigInteger => JavaBigInteger} import java.sql.{Date, Timestamp} import java.util.{Map => JavaMap} import javax.annotation.Nullable import scala.language.existentials -import scala.math.{BigInt => ScalaBigInt} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ @@ -323,13 +321,11 @@ object CatalystTypeConverters { } private class DecimalConverter(dataType: DecimalType) - extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] { + extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] { override def toCatalystImpl(scalaValue: Any): Decimal = { val decimal = scalaValue match { case d: BigDecimal => Decimal(d) case d: JavaBigDecimal => Decimal(d) - case d: JavaBigInteger => Decimal(d) - case d: ScalaBigInt => Decimal(d) case d: Decimal => d } if (decimal.changePrecision(dataType.precision, dataType.scale)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 7e684ad50d3c..c84dfa7cd3de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -131,17 +131,17 @@ final class Decimal extends Ordered[Decimal] with Serializable { /** * Set this Decimal to the given BigInteger value. Will have precision 38 and scale 0. */ - def set(BigIntVal: BigInteger): Decimal = { + def set(bigintval: BigInteger): Decimal = { try { this.decimalVal = null - this.longVal = BigIntVal.longValueExact() + this.longVal = bigintval.longValueExact() this._precision = DecimalType.MAX_PRECISION this._scale = 0 this } catch { case e: ArithmeticException => - throw new IllegalArgumentException(s"BigInteger ${BigIntVal} too large for decimal") + throw new IllegalArgumentException(s"BigInteger ${bigintval} too large for decimal") } } @@ -172,6 +172,11 @@ final class Decimal extends Ordered[Decimal] with Serializable { } } + def toScalaBigInt: BigInt = BigInt(toLong) + + def toJavaBigInteger: java.math.BigInteger = + java.math.BigInteger.valueOf(toLong) + def toUnscaledLong: Long = { if (decimalVal.ne(null)) { decimalVal.underlying().unscaledValue().longValue() @@ -407,6 +412,8 @@ object Decimal { def fromDecimal(value: Any): Decimal = { value match { case j: java.math.BigDecimal => apply(j) + case k: scala.math.BigInt => apply(k) + case l: java.math.BigInteger => apply(l) case d: Decimal => d } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index c3b20e2cc00a..632e0d8c71a7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.encoders +import java.math.BigInteger import java.sql.{Date, Timestamp} import java.util.Arrays @@ -109,6 +110,9 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { encodeDecodeTest(BigDecimal("32131413.211321313"), "scala decimal") // encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal") + encodeDecodeTest(BigInt("23134123123"), "scala biginteger") + encodeDecodeTest(new BigInteger("23134123123"), "java BigInteger") + encodeDecodeTest(Decimal("32131413.211321313"), "catalyst decimal") From 536d20cde27ee2f91d7b83f32a90ba5e30d92ea6 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Tue, 17 May 2016 14:48:28 -0700 Subject: [PATCH 05/10] clean the codes --- .../src/main/scala/org/apache/spark/sql/types/Decimal.scala | 3 +-- .../spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala | 1 - .../test/scala/org/apache/spark/sql/types/DecimalSuite.scala | 2 -- 3 files changed, 1 insertion(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index c84dfa7cd3de..07f46c343b03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -174,8 +174,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { def toScalaBigInt: BigInt = BigInt(toLong) - def toJavaBigInteger: java.math.BigInteger = - java.math.BigInteger.valueOf(toLong) + def toJavaBigInteger: java.math.BigInteger = java.math.BigInteger.valueOf(toLong) def toUnscaledLong: Long = { if (decimalVal.ne(null)) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 632e0d8c71a7..8ae4b6740c53 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -113,7 +113,6 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { encodeDecodeTest(BigInt("23134123123"), "scala biginteger") encodeDecodeTest(new BigInteger("23134123123"), "java BigInteger") - encodeDecodeTest(Decimal("32131413.211321313"), "catalyst decimal") encodeDecodeTest("hello", "string") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index 22068b792133..e1675c95907a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.types -import java.math.BigInteger - import scala.language.postfixOps import org.scalatest.PrivateMethodTester From 54cfc24be949dabae0bc62f3171f485e6d7d6ba4 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Tue, 17 May 2016 14:54:56 -0700 Subject: [PATCH 06/10] remove unused import --- .../src/main/scala/org/apache/spark/sql/types/DecimalType.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index e8c1b00f9301..ba8f780243af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.types -import java.math.BigInteger - import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi From db4bb48fead3b8324e291ad7f5cf776fd88c9def Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Wed, 18 May 2016 00:14:29 -0700 Subject: [PATCH 07/10] adding JavaBigInteger in converters --- .../org/apache/spark/sql/catalyst/CatalystTypeConverters.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 9bfc38163914..9cc7b2ac7920 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst import java.lang.{Iterable => JavaIterable} import java.math.{BigDecimal => JavaBigDecimal} +import java.math.{BigInteger => JavaBigInteger} import java.sql.{Date, Timestamp} import java.util.{Map => JavaMap} import javax.annotation.Nullable @@ -326,6 +327,7 @@ object CatalystTypeConverters { val decimal = scalaValue match { case d: BigDecimal => Decimal(d) case d: JavaBigDecimal => Decimal(d) + case d: JavaBigInteger => Decimal(d) case d: Decimal => d } if (decimal.changePrecision(dataType.precision, dataType.scale)) { From b26412e6cd6fa48840d5f957300aea172f996996 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Thu, 19 May 2016 07:36:14 -0700 Subject: [PATCH 08/10] address comments --- .../org/apache/spark/sql/catalyst/JavaTypeInference.scala | 2 +- .../org/apache/spark/sql/catalyst/ScalaReflection.scala | 8 ++++---- .../scala/org/apache/spark/sql/types/DecimalType.scala | 1 + 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 15b6dc3a579a..1fe143494aba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -89,7 +89,7 @@ object JavaTypeInference { case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true) case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType.SYSTEM_DEFAULT, true) - case c: Class[_] if c == classOf[java.math.BigInteger] => (DecimalType.BIGINT_DEFAULT, true) + case c: Class[_] if c == classOf[java.math.BigInteger] => (DecimalType.BigIntDecimal, true) case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true) case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 5c1b6090520a..58df651da294 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -601,14 +601,14 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.math.BigInteger] => StaticInvoke( Decimal.getClass, - DecimalType.BIGINT_DEFAULT, + DecimalType.BigIntDecimal, "apply", inputObject :: Nil) case t if t <:< localTypeOf[scala.math.BigInt] => StaticInvoke( Decimal.getClass, - DecimalType.BIGINT_DEFAULT, + DecimalType.BigIntDecimal, "apply", inputObject :: Nil) @@ -757,9 +757,9 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.math.BigDecimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) case t if t <:< localTypeOf[java.math.BigInteger] => - Schema(DecimalType.BIGINT_DEFAULT, nullable = true) + Schema(DecimalType.BigIntDecimal, nullable = true) case t if t <:< localTypeOf[scala.math.BigInt] => - Schema(DecimalType.BIGINT_DEFAULT, nullable = true) + Schema(DecimalType.BigIntDecimal, nullable = true) case t if t <:< localTypeOf[Decimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) case t if t <:< localTypeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) case t if t <:< localTypeOf[java.lang.Long] => Schema(LongType, nullable = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index ba8f780243af..cf667977ef0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -118,6 +118,7 @@ object DecimalType extends AbstractDataType { private[sql] val LongDecimal = DecimalType(20, 0) private[sql] val FloatDecimal = DecimalType(14, 7) private[sql] val DoubleDecimal = DecimalType(30, 15) + private[sql] val BigIntDecimal = DecimalType(38, 0) private[sql] def forType(dataType: DataType): DecimalType = dataType match { case ByteType => ByteDecimal From 43faed35bc8ca3c5bcff573509d45bfecc583831 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Thu, 19 May 2016 07:42:24 -0700 Subject: [PATCH 09/10] delete BIGINT_DEFAULT --- .../src/main/scala/org/apache/spark/sql/types/DecimalType.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index cf667977ef0e..6b7e3714e0b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -109,7 +109,6 @@ object DecimalType extends AbstractDataType { val MAX_SCALE = 38 val SYSTEM_DEFAULT: DecimalType = DecimalType(MAX_PRECISION, 18) val USER_DEFAULT: DecimalType = DecimalType(10, 0) - val BIGINT_DEFAULT: DecimalType = DecimalType(MAX_PRECISION, 0) // The decimal types compatible with other numeric types private[sql] val ByteDecimal = DecimalType(3, 0) From 3b4e3608b2c4eb69940ceb1a4949b97814534b51 Mon Sep 17 00:00:00 2001 From: Kevin Yu Date: Thu, 19 May 2016 09:16:13 -0700 Subject: [PATCH 10/10] address comments --- .../java/test/org/apache/spark/sql/JavaDataFrameSuite.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index c81785cf05ef..35a9f44feca6 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -170,7 +170,7 @@ void validateDataFrameWithBeans(Bean bean, Dataset df) { schema.apply("d")); Assert.assertEquals(new StructField("e", DataTypes.createDecimalType(38,0), true, Metadata.empty()), schema.apply("e")); - Row first = df.select("a", "b", "c", "d","e").first(); + Row first = df.select("a", "b", "c", "d", "e").first(); Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0); // Now Java lists and maps are converted to Scala Seq's and Map's. Once we get a Seq below, // verify that it has the expected length, and contains expected elements. @@ -190,7 +190,7 @@ void validateDataFrameWithBeans(Bean bean, Dataset df) { Assert.assertEquals(bean.getD().get(i), d.apply(i)); } // Java.math.BigInteger is equavient to Spark Decimal(38,0) - Assert.assertEquals(new BigDecimal(bean.getE()), first.getDecimal(4).setScale(0)); + Assert.assertEquals(new BigDecimal(bean.getE()), first.getDecimal(4)); } @Test