Skip to content

Commit 7545c75

Browse files
Enabling compression
Signed-off-by: Mohit Singla <[email protected]>
1 parent b9645f9 commit 7545c75

File tree

3 files changed

+92
-33
lines changed

3 files changed

+92
-33
lines changed

poetry.lock

Lines changed: 67 additions & 25 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ python = "^3.7.1"
1313
thrift = "^0.13.0"
1414
pandas = "^1.3.0"
1515
pyarrow = "^9.0.0"
16+
lz4 = "^4.0.2"
1617

1718
[tool.poetry.dev-dependencies]
1819
pytest = "^7.1.2"

src/databricks/sql/thrift_backend.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import math
55
import time
66
import threading
7+
import lz4.frame
78
from uuid import uuid4
89
from ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED, create_default_context
910

@@ -435,7 +436,7 @@ def open_session(self, session_configuration, catalog, schema):
435436
initial_namespace = None
436437

437438
open_session_req = ttypes.TOpenSessionReq(
438-
client_protocol_i64=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V5,
439+
client_protocol_i64=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V6,
439440
client_protocol=None,
440441
initialNamespace=initial_namespace,
441442
canUseMultipleCatalogs=True,
@@ -491,7 +492,9 @@ def _poll_for_status(self, op_handle):
491492
)
492493
return self.make_request(self._client.GetOperationStatus, req)
493494

494-
def _create_arrow_table(self, t_row_set, schema_bytes, description):
495+
def _create_arrow_table(
496+
self, t_row_set, are_arrow_results_compressed, schema_bytes, description
497+
):
495498
if t_row_set.columns is not None:
496499
(
497500
arrow_table,
@@ -504,7 +507,7 @@ def _create_arrow_table(self, t_row_set, schema_bytes, description):
504507
arrow_table,
505508
num_rows,
506509
) = ThriftBackend._convert_arrow_based_set_to_arrow_table(
507-
t_row_set.arrowBatches, schema_bytes
510+
t_row_set.arrowBatches, are_arrow_results_compressed, schema_bytes
508511
)
509512
else:
510513
raise OperationalError("Unsupported TRowSet instance {}".format(t_row_set))
@@ -529,13 +532,18 @@ def _convert_decimals_in_arrow_table(table, description):
529532
return table
530533

531534
@staticmethod
532-
def _convert_arrow_based_set_to_arrow_table(arrow_batches, schema_bytes):
535+
def _convert_arrow_based_set_to_arrow_table(
536+
arrow_batches, are_arrow_results_compressed, schema_bytes
537+
):
533538
ba = bytearray()
534539
ba += schema_bytes
535540
n_rows = 0
536541
for arrow_batch in arrow_batches:
537542
n_rows += arrow_batch.rowCount
538-
ba += arrow_batch.batch
543+
if are_arrow_results_compressed:
544+
ba += lz4.frame.decompress(arrow_batch.batch)
545+
else:
546+
ba += arrow_batch.batch
539547
arrow_table = pyarrow.ipc.open_stream(ba).read_all()
540548
return arrow_table, n_rows
541549

@@ -710,11 +718,19 @@ def _results_message_to_execute_response(self, resp, operation_state):
710718
.to_pybytes()
711719
)
712720

721+
are_arrow_results_compressed = (
722+
t_result_set_metadata_resp and t_result_set_metadata_resp.lz4Compressed
723+
)
724+
713725
if direct_results and direct_results.resultSet:
714726
assert direct_results.resultSet.results.startRowOffset == 0
715727
assert direct_results.resultSetMetadata
728+
716729
arrow_results, n_rows = self._create_arrow_table(
717-
direct_results.resultSet.results, schema_bytes, description
730+
direct_results.resultSet.results,
731+
are_arrow_results_compressed,
732+
schema_bytes,
733+
description,
718734
)
719735
arrow_queue_opt = ArrowQueue(arrow_results, n_rows, 0)
720736
else:
@@ -786,7 +802,7 @@ def execute_command(self, operation, session_handle, max_rows, max_bytes, cursor
786802
maxRows=max_rows, maxBytes=max_bytes
787803
),
788804
canReadArrowResult=True,
789-
canDecompressLZ4Result=False,
805+
canDecompressLZ4Result=True,
790806
canDownloadResult=False,
791807
confOverlay={
792808
# We want to receive proper Timestamp arrow types.
@@ -925,7 +941,7 @@ def fetch_results(
925941
)
926942
)
927943
arrow_results, n_rows = self._create_arrow_table(
928-
resp.results, arrow_schema_bytes, description
944+
resp.results, are_arrow_results_compressed, arrow_schema_bytes, description
929945
)
930946
arrow_queue = ArrowQueue(arrow_results, n_rows)
931947

0 commit comments

Comments
 (0)