diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 8541403dfe2f..0649271ed224 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1721,7 +1721,18 @@ def toPandas(self): 1 5 Bob """ import pandas as pd - return pd.DataFrame.from_records(self.collect(), columns=self.columns) + + dtype = {} + for field in self.schema: + pandas_type = _to_corrected_pandas_type(field.dataType) + if pandas_type is not None: + dtype[field.name] = pandas_type + + pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) + + for f, t in dtype.items(): + pdf[f] = pdf[f].astype(t, copy=False) + return pdf ########################################################################################## # Pandas compatibility @@ -1750,6 +1761,24 @@ def _to_scala_map(sc, jm): return sc._jvm.PythonUtils.toScalaMap(jm) +def _to_corrected_pandas_type(dt): + """ + When converting Spark SQL records to Pandas DataFrame, the inferred data type may be wrong. + This method gets the corrected data type for Pandas if that type may be inferred uncorrectly. + """ + import numpy as np + if type(dt) == ByteType: + return np.int8 + elif type(dt) == ShortType: + return np.int16 + elif type(dt) == IntegerType: + return np.int32 + elif type(dt) == FloatType: + return np.float32 + else: + return None + + class DataFrameNaFunctions(object): """Functionality for working with missing data in :class:`DataFrame`. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 31f932a36322..daee059a9277 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -46,6 +46,14 @@ else: import unittest +_have_pandas = False +try: + import pandas + _have_pandas = True +except: + # No Pandas, but that's okay, we'll skip those tests + pass + from pyspark import SparkContext from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row from pyspark.sql.types import * @@ -2274,6 +2282,22 @@ def count_bucketed_cols(names, table="pyspark_bucket"): .mode("overwrite").saveAsTable("pyspark_bucket")) self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + @unittest.skipIf(not _have_pandas, "Pandas not installed") + def test_to_pandas(self): + import numpy as np + schema = StructType().add("a", IntegerType()).add("b", StringType())\ + .add("c", BooleanType()).add("d", FloatType()) + data = [ + (1, "foo", True, 3.0), (2, "foo", True, 5.0), + (3, "bar", False, -1.0), (4, "bar", False, 6.0), + ] + df = self.spark.createDataFrame(data, schema) + types = df.toPandas().dtypes + self.assertEquals(types[0], np.int32) + self.assertEquals(types[1], np.object) + self.assertEquals(types[2], np.bool) + self.assertEquals(types[3], np.float32) + class HiveSparkSubmitTests(SparkSubmitTests):