@@ -346,7 +346,7 @@ def collect(self):
346346 @ignore_unicode_prefix
347347 @since (2.0 )
348348 def collectAsArrow (self ):
349- """Returns all the records as an Arrow
349+ """Returns all the records as an ArrowRecordBatch
350350 """
351351 with SCCallSiteSync (self ._sc ) as css :
352352 port = self ._jdf .collectAsArrowToPython ()
@@ -1531,8 +1531,54 @@ def toPandas(self, useArrow=True):
15311531 1 5 Bob
15321532 """
15331533 import pandas as pd
1534+
15341535 if useArrow :
1535- return self .collectAsArrow ().to_pandas ()
1536+ import io
1537+ from pyarrow .array import from_pylist
1538+ from pyarrow .table import RecordBatch
1539+ from pyarrow .ipc import ArrowFileReader , ArrowFileWriter
1540+
1541+ names = self .columns # capture for closure
1542+
1543+ # reduce a partition to a serialized ArrowRecordBatch
1544+ def reducePartition (iterator ):
1545+ cols = [[] for _ in range (len (names ))]
1546+ for row in iterator :
1547+ for i in range (len (row )):
1548+ cols [i ].append (row [i ])
1549+
1550+ arrs = list (map (lambda c : from_pylist (c ), cols ))
1551+ batch = RecordBatch .from_arrays (names , arrs )
1552+ sink = io .BytesIO ()
1553+ writer = ArrowFileWriter (sink , batch .schema )
1554+ writer .write_record_batch (batch )
1555+ writer .close ()
1556+ yield sink .getvalue ()
1557+
1558+ # convert partitions to serialized ArrowRecordBatches and collect byte arrays
1559+ batch_bytes = self .rdd .mapPartitions (reducePartition ).collect ()
1560+
1561+ def read_batch (b ):
1562+ reader = ArrowFileReader (bytes (b ))
1563+ return reader .get_record_batch (0 )
1564+
1565+ # deserialize ArrowRecordBatch and create a Pandas DataFrame for each batch
1566+ frames = list (map (lambda b : read_batch (b ).to_pandas (), batch_bytes ))
1567+
1568+ # merge all DataFrames to one
1569+ return pd .concat (frames , ignore_index = True )
1570+
1571+ # ~ alternate to concat ~
1572+ # batch = read_batch(batch_bytes[0])
1573+ # pdf = batch.to_pandas()
1574+ # for i in range(1, len(batch_bytes)):
1575+ # batch = read_batch(batch_bytes[i])
1576+ # pdf = pdf.append(batch.to_pandas(), ignore_index=True)
1577+ #
1578+ # return pdf
1579+
1580+ # TODO - Uses Arrow hybrid (Java -> C++) pipeline
1581+ #return self.collectAsArrow().to_pandas()
15361582 else :
15371583 return pd .DataFrame .from_records (self .collect (), columns = self .columns )
15381584
0 commit comments