Skip to content

Commit 67c7502

Browse files
HyukjinKwoncloud-fan
authored andcommitted
[SPARK-21163][SQL] DataFrame.toPandas should respect the data type
## What changes were proposed in this pull request? Currently we convert a spark DataFrame to Pandas Dataframe by `pd.DataFrame.from_records`. It infers the data type from the data and doesn't respect the spark DataFrame Schema. This PR fixes it. ## How was this patch tested? a new regression test Author: hyukjinkwon <[email protected]> Author: Wenchen Fan <[email protected]> Author: Wenchen Fan <[email protected]> Closes #18378 from cloud-fan/to_pandas.
1 parent d66b143 commit 67c7502

File tree

2 files changed

+54
-1
lines changed

2 files changed

+54
-1
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1721,7 +1721,18 @@ def toPandas(self):
17211721
1 5 Bob
17221722
"""
17231723
import pandas as pd
1724-
return pd.DataFrame.from_records(self.collect(), columns=self.columns)
1724+
1725+
dtype = {}
1726+
for field in self.schema:
1727+
pandas_type = _to_corrected_pandas_type(field.dataType)
1728+
if pandas_type is not None:
1729+
dtype[field.name] = pandas_type
1730+
1731+
pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
1732+
1733+
for f, t in dtype.items():
1734+
pdf[f] = pdf[f].astype(t, copy=False)
1735+
return pdf
17251736

17261737
##########################################################################################
17271738
# Pandas compatibility
@@ -1750,6 +1761,24 @@ def _to_scala_map(sc, jm):
17501761
return sc._jvm.PythonUtils.toScalaMap(jm)
17511762

17521763

1764+
def _to_corrected_pandas_type(dt):
1765+
"""
1766+
When converting Spark SQL records to Pandas DataFrame, the inferred data type may be wrong.
1767+
This method gets the corrected data type for Pandas if that type may be inferred uncorrectly.
1768+
"""
1769+
import numpy as np
1770+
if type(dt) == ByteType:
1771+
return np.int8
1772+
elif type(dt) == ShortType:
1773+
return np.int16
1774+
elif type(dt) == IntegerType:
1775+
return np.int32
1776+
elif type(dt) == FloatType:
1777+
return np.float32
1778+
else:
1779+
return None
1780+
1781+
17531782
class DataFrameNaFunctions(object):
17541783
"""Functionality for working with missing data in :class:`DataFrame`.
17551784

python/pyspark/sql/tests.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,14 @@
4646
else:
4747
import unittest
4848

49+
_have_pandas = False
50+
try:
51+
import pandas
52+
_have_pandas = True
53+
except:
54+
# No Pandas, but that's okay, we'll skip those tests
55+
pass
56+
4957
from pyspark import SparkContext
5058
from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row
5159
from pyspark.sql.types import *
@@ -2290,6 +2298,22 @@ def count_bucketed_cols(names, table="pyspark_bucket"):
22902298
.mode("overwrite").saveAsTable("pyspark_bucket"))
22912299
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))
22922300

2301+
@unittest.skipIf(not _have_pandas, "Pandas not installed")
2302+
def test_to_pandas(self):
2303+
import numpy as np
2304+
schema = StructType().add("a", IntegerType()).add("b", StringType())\
2305+
.add("c", BooleanType()).add("d", FloatType())
2306+
data = [
2307+
(1, "foo", True, 3.0), (2, "foo", True, 5.0),
2308+
(3, "bar", False, -1.0), (4, "bar", False, 6.0),
2309+
]
2310+
df = self.spark.createDataFrame(data, schema)
2311+
types = df.toPandas().dtypes
2312+
self.assertEquals(types[0], np.int32)
2313+
self.assertEquals(types[1], np.object)
2314+
self.assertEquals(types[2], np.bool)
2315+
self.assertEquals(types[3], np.float32)
2316+
22932317

22942318
class HiveSparkSubmitTests(SparkSubmitTests):
22952319

0 commit comments

Comments
 (0)