From 05ef1c866b721b16b3402896988b441da4e08966 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 18 Feb 2015 15:06:08 -0800 Subject: [PATCH] infer LongType for int in Python --- python/pyspark/sql.py | 8 +++---- python/pyspark/tests.py | 23 ++++++++++++++++++- .../org/apache/spark/sql/SQLContext.scala | 1 + .../spark/sql/execution/pythonUdfs.scala | 1 + 4 files changed, 28 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index ae288471b0e51..aa5af1bd40497 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -577,7 +577,7 @@ def _parse_datatype_json_value(json_value): _type_mappings = { type(None): NullType, bool: BooleanType, - int: IntegerType, + int: LongType, long: LongType, float: DoubleType, str: StringType, @@ -926,11 +926,11 @@ def _infer_schema_type(obj, dataType): >>> schema = _parse_schema_abstract("a b c d") >>> row = (1, 1.0, "str", datetime.date(2014, 10, 10)) >>> _infer_schema_type(row, schema) - StructType...IntegerType...DoubleType...StringType...DateType... + StructType...LongType...DoubleType...StringType...DateType... >>> row = [[1], {"key": (1, 2.0)}] >>> schema = _parse_schema_abstract("a[] b{c d}") >>> _infer_schema_type(row, schema) - StructType...a,ArrayType...b,MapType(StringType,...c,IntegerType... + StructType...a,ArrayType...b,MapType(StringType,...c,LongType... """ if dataType is None: return _infer_type(obj) @@ -985,7 +985,7 @@ def _verify_type(obj, dataType): >>> _verify_type(None, StructType([])) >>> _verify_type("", StringType()) - >>> _verify_type(0, IntegerType()) + >>> _verify_type(0, LongType()) >>> _verify_type(range(3), ArrayType(ShortType())) >>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 1349384d0fadc..1fc690a649d0c 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -51,7 +51,7 @@ CloudPickleSerializer, CompressedSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ - UserDefinedType, DoubleType + UserDefinedType, DoubleType, LongType, _infer_type from pyspark import shuffle _have_scipy = False @@ -985,6 +985,27 @@ def test_parquet_with_udt(self): point = srdd1.first().point self.assertEquals(point, ExamplePoint(1.0, 2.0)) + def test_infer_long_type(self): + longrow = [Row(f1='a', f2=100000000000000)] + rdd = self.sc.parallelize(longrow) + srdd = self.sqlCtx.inferSchema(rdd) + self.assertEqual(srdd.schema().fields[1].dataType, LongType()) + + # this saving as Parquet caused issues as well. + output_dir = os.path.join(self.tempdir.name, "infer_long_type") + srdd.saveAsParquetFile(output_dir) + df1 = self.sqlCtx.parquetFile(output_dir) + self.assertEquals('a', df1.first().f1) + self.assertEquals(100000000000000, df1.first().f2) + + self.assertEqual(_infer_type(1), LongType()) + self.assertEqual(_infer_type(2**10), LongType()) + self.assertEqual(_infer_type(2**20), LongType()) + self.assertEqual(_infer_type(2**31 - 1), LongType()) + self.assertEqual(_infer_type(2**31), LongType()) + self.assertEqual(_infer_type(2**61), LongType()) + self.assertEqual(_infer_type(2**71), LongType()) + class InputFormatTests(ReusedPySparkTestCase): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 832d5b9938489..6d5d84560b3aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -479,6 +479,7 @@ class SQLContext(@transient val sparkContext: SparkContext) case ByteType => true case ShortType => true case FloatType => true + case LongType => true case DateType => true case TimestampType => true case ArrayType(_, _) => true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 2b4a88d5e864e..fe02302b428b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -187,6 +187,7 @@ object EvaluatePython { case (c: Int, ShortType) => c.toShort case (c: Long, ShortType) => c.toShort case (c: Long, IntegerType) => c.toInt + case (c: Int, LongType) => c.toLong case (c: Double, FloatType) => c.toFloat case (c, StringType) if !c.isInstanceOf[String] => c.toString