From 1e98c494e0c414ca218b029bfc1a9d9faf3c2960 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 21 Jun 2017 23:44:13 +0800 Subject: [PATCH 1/5] DataFrame.toPandas should respect the data type --- python/pyspark/sql/dataframe.py | 27 ++++++++++++++++++++++++++- python/pyspark/sql/tests.py | 24 ++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 8541403dfe2f..6f6a3006dfc0 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1721,7 +1721,14 @@ def toPandas(self): 1 5 Bob """ import pandas as pd - return pd.DataFrame.from_records(self.collect(), columns=self.columns) + + dtype = {} + for field in self.schema: + pandas_type = _to_corrected_pandas_type(field.dataType) + if (pandas_type): + dtype[field.name] = pandas_type + + return pd.DataFrame.from_records(self.collect(), columns=self.columns).astype(dtype) ########################################################################################## # Pandas compatibility @@ -1750,6 +1757,24 @@ def _to_scala_map(sc, jm): return sc._jvm.PythonUtils.toScalaMap(jm) +def _to_corrected_pandas_type(dt): + """ + When converting Spark SQL records to Pandas DataFrame, the inferred data type may be wrong. + This method gets the correted data type for Pandas if that type may be inferred uncorrectly. + """ + import numpy as np + if type(dt) == ByteType: + return np.int8 + elif type(dt) == ShortType: + return np.int16 + elif type(dt) == IntegerType: + return np.int32 + elif type(dt) == FloatType: + return np.float32 + else: + return None + + class DataFrameNaFunctions(object): """Functionality for working with missing data in :class:`DataFrame`. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 31f932a36322..daee059a9277 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -46,6 +46,14 @@ else: import unittest +_have_pandas = False +try: + import pandas + _have_pandas = True +except: + # No Pandas, but that's okay, we'll skip those tests + pass + from pyspark import SparkContext from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row from pyspark.sql.types import * @@ -2274,6 +2282,22 @@ def count_bucketed_cols(names, table="pyspark_bucket"): .mode("overwrite").saveAsTable("pyspark_bucket")) self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + @unittest.skipIf(not _have_pandas, "Pandas not installed") + def test_to_pandas(self): + import numpy as np + schema = StructType().add("a", IntegerType()).add("b", StringType())\ + .add("c", BooleanType()).add("d", FloatType()) + data = [ + (1, "foo", True, 3.0), (2, "foo", True, 5.0), + (3, "bar", False, -1.0), (4, "bar", False, 6.0), + ] + df = self.spark.createDataFrame(data, schema) + types = df.toPandas().dtypes + self.assertEquals(types[0], np.int32) + self.assertEquals(types[1], np.object) + self.assertEquals(types[2], np.bool) + self.assertEquals(types[3], np.float32) + class HiveSparkSubmitTests(SparkSubmitTests): From dfaa392c6d64a6e906c8d383b56fca9bb5c40327 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 22 Jun 2017 13:58:09 +0800 Subject: [PATCH 2/5] do not copy --- python/pyspark/sql/dataframe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 6f6a3006dfc0..0c0ead9c5df7 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1728,7 +1728,8 @@ def toPandas(self): if (pandas_type): dtype[field.name] = pandas_type - return pd.DataFrame.from_records(self.collect(), columns=self.columns).astype(dtype) + df = pd.DataFrame.from_records(self.collect(), columns=self.columns) + return df.astype(dtype, copy=False) ########################################################################################## # Pandas compatibility From e903cd2361b43e35d596e8619ee7844bb5cb33bf Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 22 Jun 2017 15:51:41 +0900 Subject: [PATCH 3/5] Work around astype with columns in Pandas < 0.19.0 --- python/pyspark/sql/dataframe.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 0c0ead9c5df7..2daf94d9f33e 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1725,11 +1725,14 @@ def toPandas(self): dtype = {} for field in self.schema: pandas_type = _to_corrected_pandas_type(field.dataType) - if (pandas_type): + if pandas_type is not None: dtype[field.name] = pandas_type - df = pd.DataFrame.from_records(self.collect(), columns=self.columns) - return df.astype(dtype, copy=False) + pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) + + for f, t in dtype.items(): + pdf[f] = pdf[f].astype(t) + return pdf ########################################################################################## # Pandas compatibility From 6702ad131fe0c982b38ae5a0d55e38a9bd604353 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 22 Jun 2017 15:58:40 +0900 Subject: [PATCH 4/5] No copy --- python/pyspark/sql/dataframe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 2daf94d9f33e..e03a4c86a2ab 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1731,7 +1731,7 @@ def toPandas(self): pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) for f, t in dtype.items(): - pdf[f] = pdf[f].astype(t) + pdf[f] = pdf[f].astype(t, copy=False) return pdf ########################################################################################## From d8ba5452539c5fd5b650b7f5e51e467aabc33739 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 22 Jun 2017 16:19:45 +0800 Subject: [PATCH 5/5] fix typo --- python/pyspark/sql/dataframe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index e03a4c86a2ab..0649271ed224 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1764,7 +1764,7 @@ def _to_scala_map(sc, jm): def _to_corrected_pandas_type(dt): """ When converting Spark SQL records to Pandas DataFrame, the inferred data type may be wrong. - This method gets the correted data type for Pandas if that type may be inferred uncorrectly. + This method gets the corrected data type for Pandas if that type may be inferred uncorrectly. """ import numpy as np if type(dt) == ByteType: