Skip to content

Commit aa8f10e

Browse files
Davies Liumarmbrus
authored andcommitted
[SPARK-5722] [SQL] [PySpark] infer int as LongType
The `int` is 64-bit on 64-bit machine (very common now), we should infer it as LongType for it in Spark SQL. Also, LongType in SQL will come back as `int`. Author: Davies Liu <[email protected]> Closes #4666 from davies/long and squashes the following commits: 6bc6cc4 [Davies Liu] infer int as LongType
1 parent f0e3b71 commit aa8f10e

File tree

5 files changed

+35
-11
lines changed

5 files changed

+35
-11
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -803,7 +803,7 @@ def mean(self, *cols):
803803
>>> df.groupBy().mean('age').collect()
804804
[Row(AVG(age#0)=3.5)]
805805
>>> df3.groupBy().mean('age', 'height').collect()
806-
[Row(AVG(age#4)=3.5, AVG(height#5)=82.5)]
806+
[Row(AVG(age#4L)=3.5, AVG(height#5L)=82.5)]
807807
"""
808808

809809
@df_varargs_api
@@ -814,7 +814,7 @@ def avg(self, *cols):
814814
>>> df.groupBy().avg('age').collect()
815815
[Row(AVG(age#0)=3.5)]
816816
>>> df3.groupBy().avg('age', 'height').collect()
817-
[Row(AVG(age#4)=3.5, AVG(height#5)=82.5)]
817+
[Row(AVG(age#4L)=3.5, AVG(height#5L)=82.5)]
818818
"""
819819

820820
@df_varargs_api
@@ -825,7 +825,7 @@ def max(self, *cols):
825825
>>> df.groupBy().max('age').collect()
826826
[Row(MAX(age#0)=5)]
827827
>>> df3.groupBy().max('age', 'height').collect()
828-
[Row(MAX(age#4)=5, MAX(height#5)=85)]
828+
[Row(MAX(age#4L)=5, MAX(height#5L)=85)]
829829
"""
830830

831831
@df_varargs_api
@@ -836,7 +836,7 @@ def min(self, *cols):
836836
>>> df.groupBy().min('age').collect()
837837
[Row(MIN(age#0)=2)]
838838
>>> df3.groupBy().min('age', 'height').collect()
839-
[Row(MIN(age#4)=2, MIN(height#5)=80)]
839+
[Row(MIN(age#4L)=2, MIN(height#5L)=80)]
840840
"""
841841

842842
@df_varargs_api
@@ -847,7 +847,7 @@ def sum(self, *cols):
847847
>>> df.groupBy().sum('age').collect()
848848
[Row(SUM(age#0)=7)]
849849
>>> df3.groupBy().sum('age', 'height').collect()
850-
[Row(SUM(age#4)=7, SUM(height#5)=165)]
850+
[Row(SUM(age#4L)=7, SUM(height#5L)=165)]
851851
"""
852852

853853

@@ -1051,7 +1051,9 @@ def _test():
10511051
sc = SparkContext('local[4]', 'PythonTest')
10521052
globs['sc'] = sc
10531053
globs['sqlCtx'] = SQLContext(sc)
1054-
globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF()
1054+
globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')])\
1055+
.toDF(StructType([StructField('age', IntegerType()),
1056+
StructField('name', StringType())]))
10551057
globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF()
10561058
globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80),
10571059
Row(name='Bob', age=5, height=85)]).toDF()

python/pyspark/sql/tests.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838

3939
from pyspark.sql import SQLContext, HiveContext, Column
4040
from pyspark.sql.types import IntegerType, Row, ArrayType, StructType, StructField, \
41-
UserDefinedType, DoubleType, LongType, StringType
41+
UserDefinedType, DoubleType, LongType, StringType, _infer_type
4242
from pyspark.tests import ReusedPySparkTestCase
4343

4444

@@ -324,6 +324,26 @@ def test_help_command(self):
324324
pydoc.render_doc(df.foo)
325325
pydoc.render_doc(df.take(1))
326326

327+
def test_infer_long_type(self):
328+
longrow = [Row(f1='a', f2=100000000000000)]
329+
df = self.sc.parallelize(longrow).toDF()
330+
self.assertEqual(df.schema.fields[1].dataType, LongType())
331+
332+
# this saving as Parquet caused issues as well.
333+
output_dir = os.path.join(self.tempdir.name, "infer_long_type")
334+
df.saveAsParquetFile(output_dir)
335+
df1 = self.sqlCtx.parquetFile(output_dir)
336+
self.assertEquals('a', df1.first().f1)
337+
self.assertEquals(100000000000000, df1.first().f2)
338+
339+
self.assertEqual(_infer_type(1), LongType())
340+
self.assertEqual(_infer_type(2**10), LongType())
341+
self.assertEqual(_infer_type(2**20), LongType())
342+
self.assertEqual(_infer_type(2**31 - 1), LongType())
343+
self.assertEqual(_infer_type(2**31), LongType())
344+
self.assertEqual(_infer_type(2**61), LongType())
345+
self.assertEqual(_infer_type(2**71), LongType())
346+
327347

328348
class HiveContextSQLTests(ReusedPySparkTestCase):
329349

python/pyspark/sql/types.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,7 @@ def _parse_datatype_json_value(json_value):
583583
_type_mappings = {
584584
type(None): NullType,
585585
bool: BooleanType,
586-
int: IntegerType,
586+
int: LongType,
587587
long: LongType,
588588
float: DoubleType,
589589
str: StringType,
@@ -933,11 +933,11 @@ def _infer_schema_type(obj, dataType):
933933
>>> schema = _parse_schema_abstract("a b c d")
934934
>>> row = (1, 1.0, "str", datetime.date(2014, 10, 10))
935935
>>> _infer_schema_type(row, schema)
936-
StructType...IntegerType...DoubleType...StringType...DateType...
936+
StructType...LongType...DoubleType...StringType...DateType...
937937
>>> row = [[1], {"key": (1, 2.0)}]
938938
>>> schema = _parse_schema_abstract("a[] b{c d}")
939939
>>> _infer_schema_type(row, schema)
940-
StructType...a,ArrayType...b,MapType(StringType,...c,IntegerType...
940+
StructType...a,ArrayType...b,MapType(StringType,...c,LongType...
941941
"""
942942
if dataType is None:
943943
return _infer_type(obj)
@@ -992,7 +992,7 @@ def _verify_type(obj, dataType):
992992
993993
>>> _verify_type(None, StructType([]))
994994
>>> _verify_type("", StringType())
995-
>>> _verify_type(0, IntegerType())
995+
>>> _verify_type(0, LongType())
996996
>>> _verify_type(range(3), ArrayType(ShortType()))
997997
>>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL
998998
Traceback (most recent call last):

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,6 +1130,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
11301130
def needsConversion(dataType: DataType): Boolean = dataType match {
11311131
case ByteType => true
11321132
case ShortType => true
1133+
case LongType => true
11331134
case FloatType => true
11341135
case DateType => true
11351136
case TimestampType => true

sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ object EvaluatePython {
186186
case (c: Int, ShortType) => c.toShort
187187
case (c: Long, ShortType) => c.toShort
188188
case (c: Long, IntegerType) => c.toInt
189+
case (c: Int, LongType) => c.toLong
189190
case (c: Double, FloatType) => c.toFloat
190191
case (c, StringType) if !c.isInstanceOf[String] => c.toString
191192

0 commit comments

Comments
 (0)