Skip to content

Commit 419a9ed

Browse files
committed
add session for dataframe
1 parent eefdfd3 commit 419a9ed

File tree

5 files changed

+244
-90
lines changed

5 files changed

+244
-90
lines changed

e6data_python_connector/e6data_grpc.py

Lines changed: 82 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from io import BytesIO
1616
from ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
1717
from typing import overload
18+
import uuid
1819

1920
import grpc
2021
from grpc._channel import _InactiveRpcError
@@ -181,6 +182,9 @@ def __init__(
181182
self.grpc_prepare_timeout = grpc_options.get('grpc_prepare_timeout') or self.grpc_prepare_timeout
182183
self._create_client()
183184

185+
# initialize session for dataframe
186+
self._dataframe_session = DataFrameSession(self)
187+
184188
def _create_client(self):
185189
if self._secure_channel:
186190
self._channel = grpc.secure_channel(
@@ -269,13 +273,15 @@ def __enter__(self):
269273

270274
def __exit__(self, exc_type, exc_val, exc_tb):
271275
"""Call close"""
276+
self._dataframe_session.terminate()
272277
self.close()
273278

274279
def close(self):
275280
if self._channel is not None:
276281
self._channel.close()
277282
self._channel = None
278283
self._session_id = None
284+
self._dataframe_session.terminate()
279285

280286
def check_connection(self):
281287
return self._channel is not None
@@ -362,8 +368,15 @@ def cursor(self, catalog_name=None, db_name=None):
362368
"""Return a new :py:class:`Cursor` object using the connection."""
363369
return Cursor(self, database=db_name, catalog_name=catalog_name)
364370

365-
def load_parquet(self, parquet_path):
366-
return DataFrame(self, file_path=parquet_path)
371+
def load_parquet(self, parquet_path) -> "DataFrame":
372+
dataframe = DataFrame(
373+
self,
374+
file_path=parquet_path,
375+
user_uuid=self._dataframe_session.get_user_uuid,
376+
dataframe_number=self._dataframe_session.get_dataframe_number)
377+
378+
self._dataframe_session.update_dataframe_map(dataframe=dataframe)
379+
return dataframe
367380

368381
def rollback(self):
369382
raise Exception("e6xdb does not support transactions") # pragma: no cover
@@ -700,8 +713,10 @@ def explain_analyse(self):
700713

701714
class DataFrame:
702715

703-
def __init__(self, connection: Connection, file_path):
704-
self.connection = connection
716+
def __init__(self, connection: Connection, file_path, user_uuid, dataframe_number):
717+
self._user_uuid = user_uuid
718+
self._dataframe_number = dataframe_number
719+
self._connection = connection
705720
self._file_path = file_path
706721
self._engine_ip = connection.host
707722
self._sessionId = connection.get_session_id
@@ -711,21 +726,17 @@ def __init__(self, connection: Connection, file_path):
711726
self._batch = None
712727
self._create_dataframe()
713728

714-
def __enter__(self):
715-
pass
716-
717-
def __exit__(self, exc_type, exc_val, exc_tb):
718-
pass
719-
720729
def _create_dataframe(self):
721-
client = self.connection.client
730+
client = self._connection.client
722731

723732
create_dataframe_request = e6x_engine_pb2.CreateDataFrameRequest(
724733
parquetFilePath=self._file_path,
725-
catalog=self.connection.catalog_name,
726-
schema=self.connection.database,
734+
catalog=self._connection.catalog_name,
735+
schema=self._connection.database,
727736
sessionId=self._sessionId,
728-
engineIP=self._engine_ip
737+
engineIP=self._engine_ip,
738+
userUUID=self._user_uuid,
739+
dataframeNumber=self._dataframe_number
729740
)
730741

731742
create_dataframe_response = client.createDataFrame(
@@ -738,9 +749,11 @@ def select(self, *fields) -> "DataFrame":
738749
for field in fields:
739750
projection_fields.append(field)
740751

741-
client = self.connection.client
752+
client = self._connection.client
742753
projection_on_dataframe_request = e6x_engine_pb2.ProjectionOnDataFrameRequest(
754+
userUUID=self._user_uuid,
743755
queryId=self._query_id,
756+
dataframeNumber=self._dataframe_number,
744757
sessionId=self._sessionId,
745758
field=projection_fields
746759
)
@@ -752,9 +765,11 @@ def select(self, *fields) -> "DataFrame":
752765
return self
753766

754767
def where(self, where_clause : str) -> "DataFrame":
755-
client = self.connection.client
768+
client = self._connection.client
756769
filter_on_dataframe_request = e6x_engine_pb2.FilterOnDataFrameRequest(
770+
userUUID=self._user_uuid,
757771
queryId=self._query_id,
772+
dataframeNumber=self._dataframe_number,
758773
sessionId=self._sessionId,
759774
whereClause=where_clause
760775
)
@@ -794,7 +809,9 @@ def order_by(self, field_list : list, sort_direction_list = None, null_direction
794809
client = self.connection.client
795810

796811
orderby_on_dataframe_request = e6x_engine_pb2.OrderByOnDataFrameRequest(
812+
userUUID=self._user_uuid,
797813
queryId=self._query_id,
814+
dataframeNumber=self._dataframe_number,
798815
sessionId=self._sessionId,
799816
field=orderby_fields,
800817
sortDirection=sort_direction_request,
@@ -813,10 +830,12 @@ def order_by(self, *field_list) -> "DataFrame":
813830
for field in field_list:
814831
orderby_fields.append(field)
815832

816-
client = self.connection.client
833+
client = self._connection.client
817834

818835
orderby_on_dataframe_request = e6x_engine_pb2.OrderByOnDataFrameRequest(
836+
userUUID=self._user_uuid,
819837
queryId=self._query_id,
838+
dataframeNumber=self._dataframe_number,
820839
sessionId=self._sessionId,
821840
field=orderby_fields,
822841
sortDirection=sort_direction_request,
@@ -829,9 +848,11 @@ def order_by(self, *field_list) -> "DataFrame":
829848
return self
830849

831850
def limit(self, fetch_limit : int) -> "DataFrame":
832-
client = self.connection.client
851+
client = self._connection.client
833852
limit_on_dataframe_request = e6x_engine_pb2.LimitOnDataFrameRequest(
853+
userUUID=self._user_uuid,
834854
queryId=self._query_id,
855+
dataframeNumber=self._dataframe_number,
835856
sessionId=self._sessionId,
836857
fetchLimit=fetch_limit
837858
)
@@ -847,9 +868,11 @@ def show(self):
847868
return self.fetchall()
848869

849870
def execute(self):
850-
client = self.connection.client
871+
client = self._connection.client
851872
execute_dataframe_request = e6x_engine_pb2.ExecuteDataFrameRequest(
873+
userUUID=self._user_uuid,
852874
queryId=self._query_id,
875+
dataframeNumber=self._dataframe_number,
853876
sessionId=self._sessionId
854877
)
855878
execute_dataframe_response = client.executeDataFrame(
@@ -862,15 +885,15 @@ def _update_meta_data(self):
862885
sessionId=self._sessionId,
863886
queryId=self._query_id
864887
)
865-
get_result_metadata_response = self.connection.client.getResultMetadata(
888+
get_result_metadata_response = self._connection.client.getResultMetadata(
866889
result_meta_data_request,
867890
)
868891
buffer = BytesIO(get_result_metadata_response.resultMetaData)
869892
self._rowcount, self._query_columns_description = get_query_columns_info(buffer)
870893
self._is_metadata_updated = True
871894

872895
def _fetch_batch(self):
873-
client = self.connection.client
896+
client = self._connection.client
874897
get_next_result_batch_request = e6x_engine_pb2.GetNextResultBatchRequest(
875898
engineIP=self._engine_ip,
876899
sessionId=self._sessionId,
@@ -898,6 +921,44 @@ def fetchall(self):
898921
self._data = None
899922
return rows
900923

924+
class DataFrameSession:
925+
def __init__(self, connection: Connection):
926+
self._user_uuid = str(uuid.uuid4())
927+
self._connection = connection
928+
self._dataframe_count = 0
929+
self._dataframe_map = dict()
930+
self._is_terminated = False
931+
932+
def __exit__(self, exc_type, exc_val, exc_tb):
933+
self.terminate()
934+
935+
def update_dataframe_map(self, dataframe : "DataFrame"):
936+
self._dataframe_map.update({self._dataframe_count, dataframe})
937+
self._dataframe_count = self._dataframe_count + 1
938+
939+
@property
940+
def get_user_uuid(self):
941+
return self._user_uuid
942+
943+
@property
944+
def get_dataframe_number(self) -> int:
945+
return self._dataframe_count
946+
947+
@property
948+
def is_terminated(self) -> bool:
949+
return self._is_terminated
950+
951+
def terminate(self):
952+
953+
if not self._is_terminated:
954+
drop_user_context_request = e6x_engine_pb2.DropUserContextRequest(
955+
userUUID=self.get_user_uuid
956+
)
957+
958+
drop_user_context_response = self._connection.client.dropUserContext(drop_user_context_request)
959+
self._is_terminated = True
960+
961+
901962

902963
def poll(self, get_progress_update=True):
903964
"""Poll for and return the raw status data provided by the Hive Thrift REST API.

0 commit comments

Comments
 (0)