Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 16 additions & 17 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ def schema(self):
"""Returns the schema of this DataFrame (represented by
a L{StructType}).

>>> df.schema
StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true)))
>>> df.schema()
StructType(List(StructField(age,LongType,true),StructField(name,StringType,true)))
"""
if self._schema is None:
self._schema = _parse_datatype_json_string(self._jdf.schema().json())
Expand All @@ -232,7 +232,7 @@ def printSchema(self):

>>> df.printSchema()
root
|-- age: integer (nullable = true)
|-- age: long (nullable = true)
|-- name: string (nullable = true)
<BLANKLINE>
"""
Expand Down Expand Up @@ -397,7 +397,7 @@ def dtypes(self):
"""Return all column names and their data types as a list.

>>> df.dtypes
[('age', 'int'), ('name', 'string')]
[('age', 'long'), ('name', 'string')]
"""
return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields]

Expand Down Expand Up @@ -568,11 +568,11 @@ def groupBy(self, *cols):
for all the available aggregate functions.

>>> df.groupBy().avg().collect()
[Row(AVG(age#0)=3.5)]
[Row(AVG(age#0L)=3.5)]
>>> df.groupBy('name').agg({'age': 'mean'}).collect()
[Row(name=u'Bob', AVG(age#0)=5.0), Row(name=u'Alice', AVG(age#0)=2.0)]
[Row(name=u'Bob', AVG(age#0L)=5.0), Row(name=u'Alice', AVG(age#0L)=2.0)]
>>> df.groupBy(df.name).avg().collect()
[Row(name=u'Bob', AVG(age#0)=5.0), Row(name=u'Alice', AVG(age#0)=2.0)]
[Row(name=u'Bob', AVG(age#0L)=5.0), Row(name=u'Alice', AVG(age#0L)=2.0)]
"""
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
self._sc._gateway._gateway_client)
Expand All @@ -584,10 +584,10 @@ def agg(self, *exprs):
(shorthand for df.groupBy.agg()).

>>> df.agg({"age": "max"}).collect()
[Row(MAX(age#0)=5)]
>>> from pyspark.sql import functions as F
>>> df.agg(F.min(df.age)).collect()
[Row(MIN(age#0)=2)]
[Row(MAX(age#0L)=5)]
>>> from pyspark.sql import Dsl
>>> df.agg(Dsl.min(df.age)).collect()
[Row(MIN(age#0L)=2)]
"""
return self.groupBy().agg(*exprs)

Expand Down Expand Up @@ -698,12 +698,11 @@ def agg(self, *exprs):
name to aggregate methods.

>>> gdf = df.groupBy(df.name)
>>> gdf.agg({"*": "count"}).collect()
[Row(name=u'Bob', COUNT(1)=1), Row(name=u'Alice', COUNT(1)=1)]

>>> from pyspark.sql import functions as F
>>> gdf.agg(F.min(df.age)).collect()
[Row(MIN(age#0)=5), Row(MIN(age#0)=2)]
>>> gdf.agg({"age": "max"}).collect()
[Row(name=u'Bob', MAX(age#0L)=5), Row(name=u'Alice', MAX(age#0L)=2)]
>>> from pyspark.sql import Dsl
>>> gdf.agg(Dsl.min(df.age)).collect()
[Row(MIN(age#0L)=5), Row(MIN(age#0L)=2)]
"""
assert exprs, "exprs should not be empty"
if len(exprs) == 1 and isinstance(exprs[0], dict):
Expand Down
22 changes: 22 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,28 @@ def test_struct_in_map(self):
self.assertEqual(1, k.i)
self.assertEqual("", v.s)

# SPARK-5722
def test_infer_long_type(self):
longrow = [Row(f1='a', f2=100000000000000)]
lrdd = self.sc.parallelize(longrow)
slrdd = self.sqlCtx.inferSchema(lrdd)
self.assertEqual(slrdd.schema().fields[1].dataType, LongType())

# this saving as Parquet caused issues as well.
output_dir = os.path.join(self.tempdir.name, "infer_long_type")
slrdd.saveAsParquetFile(output_dir)
df1 = self.sqlCtx.parquetFile(output_dir)
self.assertEquals('a', df1.first().f1)
self.assertEquals(100000000000000, df1.first().f2)

self.assertEqual(_infer_type(1), LongType())
self.assertEqual(_infer_type(2**10), LongType())
self.assertEqual(_infer_type(2**20), LongType())
self.assertEqual(_infer_type(2**31 - 1), LongType())
self.assertEqual(_infer_type(2**31), LongType())
self.assertEqual(_infer_type(2**61), LongType())
self.assertEqual(_infer_type(2**71), LongType())

def test_convert_row_to_dict(self):
row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
self.assertEqual(1, row.asDict()['l'][0].a)
Expand Down
10 changes: 7 additions & 3 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def _parse_datatype_json_value(json_value):
_type_mappings = {
type(None): NullType,
bool: BooleanType,
int: IntegerType,
int: LongType,
long: LongType,
float: DoubleType,
str: StringType,
Expand Down Expand Up @@ -687,6 +687,8 @@ def _need_python_to_sql_conversion(dataType):
_need_python_to_sql_conversion(dataType.valueType)
elif isinstance(dataType, UserDefinedType):
return True
elif isinstance(dataType, LongType):
return True
else:
return False

Expand Down Expand Up @@ -740,6 +742,8 @@ def converter(obj):
return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()])
elif isinstance(dataType, UserDefinedType):
return lambda obj: dataType.serialize(obj)
elif isinstance(dataType, LongType):
return lambda x: long(x)
else:
raise ValueError("Unexpected type %r" % dataType)

Expand Down Expand Up @@ -933,11 +937,11 @@ def _infer_schema_type(obj, dataType):
>>> schema = _parse_schema_abstract("a b c d")
>>> row = (1, 1.0, "str", datetime.date(2014, 10, 10))
>>> _infer_schema_type(row, schema)
StructType...IntegerType...DoubleType...StringType...DateType...
StructType...LongType...DoubleType...StringType...DateType...
>>> row = [[1], {"key": (1, 2.0)}]
>>> schema = _parse_schema_abstract("a[] b{c d}")
>>> _infer_schema_type(row, schema)
StructType...a,ArrayType...b,MapType(StringType,...c,IntegerType...
StructType...a,ArrayType...b,MapType(StringType,...c,LongType...
"""
if dataType is None:
return _infer_type(obj)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ object EvaluatePython {

case (date: Int, DateType) => DateUtils.toJavaDate(date)

case (_, LongType) => obj.asInstanceOf[Long]

// Pyrolite can handle Timestamp and Decimal
case (other, _) => other
}
Expand Down