4
4
import math
5
5
import time
6
6
import threading
7
+ import lz4 .frame
7
8
from uuid import uuid4
8
9
from ssl import CERT_NONE , CERT_OPTIONAL , CERT_REQUIRED , create_default_context
9
10
@@ -435,7 +436,7 @@ def open_session(self, session_configuration, catalog, schema):
435
436
initial_namespace = None
436
437
437
438
open_session_req = ttypes .TOpenSessionReq (
438
- client_protocol_i64 = ttypes .TProtocolVersion .SPARK_CLI_SERVICE_PROTOCOL_V5 ,
439
+ client_protocol_i64 = ttypes .TProtocolVersion .SPARK_CLI_SERVICE_PROTOCOL_V6 ,
439
440
client_protocol = None ,
440
441
initialNamespace = initial_namespace ,
441
442
canUseMultipleCatalogs = True ,
@@ -491,7 +492,9 @@ def _poll_for_status(self, op_handle):
491
492
)
492
493
return self .make_request (self ._client .GetOperationStatus , req )
493
494
494
- def _create_arrow_table (self , t_row_set , schema_bytes , description ):
495
+ def _create_arrow_table (
496
+ self , t_row_set , are_arrow_results_compressed , schema_bytes , description
497
+ ):
495
498
if t_row_set .columns is not None :
496
499
(
497
500
arrow_table ,
@@ -504,7 +507,7 @@ def _create_arrow_table(self, t_row_set, schema_bytes, description):
504
507
arrow_table ,
505
508
num_rows ,
506
509
) = ThriftBackend ._convert_arrow_based_set_to_arrow_table (
507
- t_row_set .arrowBatches , schema_bytes
510
+ t_row_set .arrowBatches , are_arrow_results_compressed , schema_bytes
508
511
)
509
512
else :
510
513
raise OperationalError ("Unsupported TRowSet instance {}" .format (t_row_set ))
@@ -529,13 +532,18 @@ def _convert_decimals_in_arrow_table(table, description):
529
532
return table
530
533
531
534
@staticmethod
532
- def _convert_arrow_based_set_to_arrow_table (arrow_batches , schema_bytes ):
535
+ def _convert_arrow_based_set_to_arrow_table (
536
+ arrow_batches , are_arrow_results_compressed , schema_bytes
537
+ ):
533
538
ba = bytearray ()
534
539
ba += schema_bytes
535
540
n_rows = 0
536
541
for arrow_batch in arrow_batches :
537
542
n_rows += arrow_batch .rowCount
538
- ba += arrow_batch .batch
543
+ if are_arrow_results_compressed :
544
+ ba += lz4 .frame .decompress (arrow_batch .batch )
545
+ else :
546
+ ba += arrow_batch .batch
539
547
arrow_table = pyarrow .ipc .open_stream (ba ).read_all ()
540
548
return arrow_table , n_rows
541
549
@@ -710,11 +718,19 @@ def _results_message_to_execute_response(self, resp, operation_state):
710
718
.to_pybytes ()
711
719
)
712
720
721
+ are_arrow_results_compressed = (
722
+ t_result_set_metadata_resp and t_result_set_metadata_resp .lz4Compressed
723
+ )
724
+
713
725
if direct_results and direct_results .resultSet :
714
726
assert direct_results .resultSet .results .startRowOffset == 0
715
727
assert direct_results .resultSetMetadata
728
+
716
729
arrow_results , n_rows = self ._create_arrow_table (
717
- direct_results .resultSet .results , schema_bytes , description
730
+ direct_results .resultSet .results ,
731
+ are_arrow_results_compressed ,
732
+ schema_bytes ,
733
+ description ,
718
734
)
719
735
arrow_queue_opt = ArrowQueue (arrow_results , n_rows , 0 )
720
736
else :
@@ -786,7 +802,7 @@ def execute_command(self, operation, session_handle, max_rows, max_bytes, cursor
786
802
maxRows = max_rows , maxBytes = max_bytes
787
803
),
788
804
canReadArrowResult = True ,
789
- canDecompressLZ4Result = False ,
805
+ canDecompressLZ4Result = True ,
790
806
canDownloadResult = False ,
791
807
confOverlay = {
792
808
# We want to receive proper Timestamp arrow types.
@@ -925,7 +941,7 @@ def fetch_results(
925
941
)
926
942
)
927
943
arrow_results , n_rows = self ._create_arrow_table (
928
- resp .results , arrow_schema_bytes , description
944
+ resp .results , are_arrow_results_compressed , arrow_schema_bytes , description
929
945
)
930
946
arrow_queue = ArrowQueue (arrow_results , n_rows )
931
947
0 commit comments