Skip to content

Commit 3f855ec

Browse files
committed
changed toPandas to use Arrow with pure Python pipeline since hybrid IPC in Arrow not ready
1 parent 4227ec6 commit 3f855ec

File tree

2 files changed

+50
-30
lines changed

2 files changed

+50
-30
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def collect(self):
346346
@ignore_unicode_prefix
347347
@since(2.0)
348348
def collectAsArrow(self):
349-
"""Returns all the records as an Arrow
349+
"""Returns all the records as an ArrowRecordBatch
350350
"""
351351
with SCCallSiteSync(self._sc) as css:
352352
port = self._jdf.collectAsArrowToPython()
@@ -1531,8 +1531,54 @@ def toPandas(self, useArrow=True):
15311531
1 5 Bob
15321532
"""
15331533
import pandas as pd
1534+
15341535
if useArrow:
1535-
return self.collectAsArrow().to_pandas()
1536+
import io
1537+
from pyarrow.array import from_pylist
1538+
from pyarrow.table import RecordBatch
1539+
from pyarrow.ipc import ArrowFileReader, ArrowFileWriter
1540+
1541+
names = self.columns # capture for closure
1542+
1543+
# reduce a partition to a serialized ArrowRecordBatch
1544+
def reducePartition(iterator):
1545+
cols = [[] for _ in range(len(names))]
1546+
for row in iterator:
1547+
for i in range(len(row)):
1548+
cols[i].append(row[i])
1549+
1550+
arrs = list(map(lambda c: from_pylist(c), cols))
1551+
batch = RecordBatch.from_arrays(names, arrs)
1552+
sink = io.BytesIO()
1553+
writer = ArrowFileWriter(sink, batch.schema)
1554+
writer.write_record_batch(batch)
1555+
writer.close()
1556+
yield sink.getvalue()
1557+
1558+
# convert partitions to serialized ArrowRecordBatches and collect byte arrays
1559+
batch_bytes = self.rdd.mapPartitions(reducePartition).collect()
1560+
1561+
def read_batch(b):
1562+
reader = ArrowFileReader(bytes(b))
1563+
return reader.get_record_batch(0)
1564+
1565+
# deserialize ArrowRecordBatch and create a Pandas DataFrame for each batch
1566+
frames = list(map(lambda b: read_batch(b).to_pandas(), batch_bytes))
1567+
1568+
# merge all DataFrames to one
1569+
return pd.concat(frames, ignore_index=True)
1570+
1571+
# ~ alternate to concat ~
1572+
# batch = read_batch(batch_bytes[0])
1573+
# pdf = batch.to_pandas()
1574+
# for i in range(1, len(batch_bytes)):
1575+
# batch = read_batch(batch_bytes[i])
1576+
# pdf = pdf.append(batch.to_pandas(), ignore_index=True)
1577+
#
1578+
# return pdf
1579+
1580+
# TODO - Uses Arrow hybrid (Java -> C++) pipeline
1581+
#return self.collectAsArrow().to_pandas()
15361582
else:
15371583
return pd.DataFrame.from_records(self.collect(), columns=self.columns)
15381584

python/pyspark/sql/tests.py

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1971,37 +1971,11 @@ def setUpClass(cls):
19711971
ReusedPySparkTestCase.setUpClass()
19721972
cls.spark = SparkSession(cls.sc)
19731973

1974-
'''
1975-
# TODO - remove, just testing pyarrow api
1976-
def test_no_ser(self):
1977-
import io
1978-
import pandas as pd
1979-
from pandas.util.testing import assert_frame_equal
1980-
from pyarrow.ipc import ArrowFileReader, ArrowFileWriter
1981-
pdf = pd.DataFrame({'test': [1.5]})
1982-
batch = pyarrow.RecordBatch.from_pandas(pdf)
1983-
sink = io.BytesIO()
1984-
writer = ArrowFileWriter(sink, batch.schema)
1985-
writer.write_record_batch(batch)
1986-
writer.close()
1987-
data = [[bytearray(sink.getvalue())]]
1988-
schema = StructType([StructField('test', BinaryType())])
1989-
df = self.spark.createDataFrame(data, schema=schema)
1990-
rows = df.collect()
1991-
reader = ArrowFileReader(bytes(rows[0][0]))
1992-
batch_rt = reader.get_record_batch(0)
1993-
pdf_rt = batch_rt.to_pandas()
1994-
assert_frame_equal(pdf, pdf_rt)
1995-
'''
1996-
1997-
def test_arrow_round_trip(self):
1974+
def test_arrow_toPandas(self):
19981975
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
19991976
pdf = df.toPandas(useArrow=False)
20001977
pdf_arrow = df.toPandas(useArrow=True)
2001-
# TODO - compare Pandas DataFrames
2002-
print(pdf)
2003-
print(pdf_arrow)
2004-
self.assertTrue(False)
1978+
self.assertTrue(pdf.equals(pdf_arrow))
20051979

20061980

20071981
if __name__ == "__main__":

0 commit comments

Comments
 (0)