From 79f136d5ab347bf2c9afe8d0c5c29bcdc214e634 Mon Sep 17 00:00:00 2001 From: Don Drake Date: Mon, 16 Feb 2015 22:11:31 -0600 Subject: [PATCH] SPARK-5722 fixes for inferring LongType --- python/pyspark/sql/dataframe.py | 20 ++++++++--------- python/pyspark/sql/tests.py | 22 +++++++++++++++++++ python/pyspark/sql/types.py | 10 ++++++--- .../spark/sql/execution/pythonUdfs.scala | 2 ++ 4 files changed, 41 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 3eef0cc376a2d..b1c83b132d600 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -221,7 +221,7 @@ def schema(self): a L{StructType}). >>> df.schema() - StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true))) + StructType(List(StructField(age,LongType,true),StructField(name,StringType,true))) """ return _parse_datatype_json_string(self._jdf.schema().json()) @@ -230,7 +230,7 @@ def printSchema(self): >>> df.printSchema() root - |-- age: integer (nullable = true) + |-- age: long (nullable = true) |-- name: string (nullable = true) """ @@ -380,7 +380,7 @@ def dtypes(self): """Return all column names and their data types as a list. >>> df.dtypes - [('age', 'integer'), ('name', 'string')] + [('age', 'long'), ('name', 'string')] """ return [(str(f.name), f.dataType.jsonValue()) for f in self.schema().fields] @@ -551,11 +551,11 @@ def groupBy(self, *cols): for all the available aggregate functions. >>> df.groupBy().avg().collect() - [Row(AVG(age#0)=3.5)] + [Row(AVG(age#0L)=3.5)] >>> df.groupBy('name').agg({'age': 'mean'}).collect() - [Row(name=u'Bob', AVG(age#0)=5.0), Row(name=u'Alice', AVG(age#0)=2.0)] + [Row(name=u'Bob', AVG(age#0L)=5.0), Row(name=u'Alice', AVG(age#0L)=2.0)] >>> df.groupBy(df.name).avg().collect() - [Row(name=u'Bob', AVG(age#0)=5.0), Row(name=u'Alice', AVG(age#0)=2.0)] + [Row(name=u'Bob', AVG(age#0L)=5.0), Row(name=u'Alice', AVG(age#0L)=2.0)] """ jcols = ListConverter().convert([_to_java_column(c) for c in cols], self._sc._gateway._gateway_client) @@ -567,10 +567,10 @@ def agg(self, *exprs): (shorthand for df.groupBy.agg()). >>> df.agg({"age": "max"}).collect() - [Row(MAX(age#0)=5)] + [Row(MAX(age#0L)=5)] >>> from pyspark.sql import Dsl >>> df.agg(Dsl.min(df.age)).collect() - [Row(MIN(age#0)=2)] + [Row(MIN(age#0L)=2)] """ return self.groupBy().agg(*exprs) @@ -659,10 +659,10 @@ def agg(self, *exprs): >>> gdf = df.groupBy(df.name) >>> gdf.agg({"age": "max"}).collect() - [Row(name=u'Bob', MAX(age#0)=5), Row(name=u'Alice', MAX(age#0)=2)] + [Row(name=u'Bob', MAX(age#0L)=5), Row(name=u'Alice', MAX(age#0L)=2)] >>> from pyspark.sql import Dsl >>> gdf.agg(Dsl.min(df.age)).collect() - [Row(MIN(age#0)=5), Row(MIN(age#0)=2)] + [Row(MIN(age#0L)=5), Row(MIN(age#0L)=2)] """ assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 5e41e36897b5d..0c399ff7d3302 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -210,6 +210,28 @@ def test_struct_in_map(self): self.assertEqual(1, k.i) self.assertEqual("", v.s) + # SPARK-5722 + def test_infer_long_type(self): + longrow = [Row(f1='a', f2=100000000000000)] + lrdd = self.sc.parallelize(longrow) + slrdd = self.sqlCtx.inferSchema(lrdd) + self.assertEqual(slrdd.schema().fields[1].dataType, LongType()) + + # this saving as Parquet caused issues as well. + output_dir = os.path.join(self.tempdir.name, "infer_long_type") + slrdd.saveAsParquetFile(output_dir) + df1 = self.sqlCtx.parquetFile(output_dir) + self.assertEquals('a', df1.first().f1) + self.assertEquals(100000000000000, df1.first().f2) + + self.assertEqual(_infer_type(1), LongType()) + self.assertEqual(_infer_type(2**10), LongType()) + self.assertEqual(_infer_type(2**20), LongType()) + self.assertEqual(_infer_type(2**31 - 1), LongType()) + self.assertEqual(_infer_type(2**31), LongType()) + self.assertEqual(_infer_type(2**61), LongType()) + self.assertEqual(_infer_type(2**71), LongType()) + def test_convert_row_to_dict(self): row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) self.assertEqual(1, row.asDict()['l'][0].a) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 41afefe48ee5e..dec89162f9d3e 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -551,7 +551,7 @@ def _parse_datatype_json_value(json_value): _type_mappings = { type(None): NullType, bool: BooleanType, - int: IntegerType, + int: LongType, long: LongType, float: DoubleType, str: StringType, @@ -655,6 +655,8 @@ def _need_python_to_sql_conversion(dataType): _need_python_to_sql_conversion(dataType.valueType) elif isinstance(dataType, UserDefinedType): return True + elif isinstance(dataType, LongType): + return True else: return False @@ -708,6 +710,8 @@ def converter(obj): return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()]) elif isinstance(dataType, UserDefinedType): return lambda obj: dataType.serialize(obj) + elif isinstance(dataType, LongType): + return lambda x: long(x) else: raise ValueError("Unexpected type %r" % dataType) @@ -901,11 +905,11 @@ def _infer_schema_type(obj, dataType): >>> schema = _parse_schema_abstract("a b c d") >>> row = (1, 1.0, "str", datetime.date(2014, 10, 10)) >>> _infer_schema_type(row, schema) - StructType...IntegerType...DoubleType...StringType...DateType... + StructType...LongType...DoubleType...StringType...DateType... >>> row = [[1], {"key": (1, 2.0)}] >>> schema = _parse_schema_abstract("a[] b{c d}") >>> _infer_schema_type(row, schema) - StructType...a,ArrayType...b,MapType(StringType,...c,IntegerType... + StructType...a,ArrayType...b,MapType(StringType,...c,LongType... """ if dataType is None: return _infer_type(obj) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 3a2f8d75dac5e..44441e6720527 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -137,6 +137,8 @@ object EvaluatePython { case (date: Int, DateType) => DateUtils.toJavaDate(date) + case (_, LongType) => obj.asInstanceOf[Long] + // Pyrolite can handle Timestamp and Decimal case (other, _) => other }