1515from io import BytesIO
1616from ssl import CERT_NONE , CERT_OPTIONAL , CERT_REQUIRED
1717from typing import overload
18+ import uuid
1819
1920import grpc
2021from grpc ._channel import _InactiveRpcError
@@ -181,6 +182,9 @@ def __init__(
181182 self .grpc_prepare_timeout = grpc_options .get ('grpc_prepare_timeout' ) or self .grpc_prepare_timeout
182183 self ._create_client ()
183184
185+ # initialize session for dataframe
186+ self ._dataframe_session = DataFrameSession (self )
187+
184188 def _create_client (self ):
185189 if self ._secure_channel :
186190 self ._channel = grpc .secure_channel (
@@ -269,13 +273,15 @@ def __enter__(self):
269273
270274 def __exit__ (self , exc_type , exc_val , exc_tb ):
271275 """Call close"""
276+ self ._dataframe_session .terminate ()
272277 self .close ()
273278
274279 def close (self ):
275280 if self ._channel is not None :
276281 self ._channel .close ()
277282 self ._channel = None
278283 self ._session_id = None
284+ self ._dataframe_session .terminate ()
279285
280286 def check_connection (self ):
281287 return self ._channel is not None
@@ -362,8 +368,15 @@ def cursor(self, catalog_name=None, db_name=None):
362368 """Return a new :py:class:`Cursor` object using the connection."""
363369 return Cursor (self , database = db_name , catalog_name = catalog_name )
364370
365- def load_parquet (self , parquet_path ):
366- return DataFrame (self , file_path = parquet_path )
371+ def load_parquet (self , parquet_path ) -> "DataFrame" :
372+ dataframe = DataFrame (
373+ self ,
374+ file_path = parquet_path ,
375+ user_uuid = self ._dataframe_session .get_user_uuid ,
376+ dataframe_number = self ._dataframe_session .get_dataframe_number )
377+
378+ self ._dataframe_session .update_dataframe_map (dataframe = dataframe )
379+ return dataframe
367380
368381 def rollback (self ):
369382 raise Exception ("e6xdb does not support transactions" ) # pragma: no cover
@@ -700,8 +713,10 @@ def explain_analyse(self):
700713
701714class DataFrame :
702715
703- def __init__ (self , connection : Connection , file_path ):
704- self .connection = connection
716+ def __init__ (self , connection : Connection , file_path , user_uuid , dataframe_number ):
717+ self ._user_uuid = user_uuid
718+ self ._dataframe_number = dataframe_number
719+ self ._connection = connection
705720 self ._file_path = file_path
706721 self ._engine_ip = connection .host
707722 self ._sessionId = connection .get_session_id
@@ -711,21 +726,17 @@ def __init__(self, connection: Connection, file_path):
711726 self ._batch = None
712727 self ._create_dataframe ()
713728
714- def __enter__ (self ):
715- pass
716-
717- def __exit__ (self , exc_type , exc_val , exc_tb ):
718- pass
719-
720729 def _create_dataframe (self ):
721- client = self .connection .client
730+ client = self ._connection .client
722731
723732 create_dataframe_request = e6x_engine_pb2 .CreateDataFrameRequest (
724733 parquetFilePath = self ._file_path ,
725- catalog = self .connection .catalog_name ,
726- schema = self .connection .database ,
734+ catalog = self ._connection .catalog_name ,
735+ schema = self ._connection .database ,
727736 sessionId = self ._sessionId ,
728- engineIP = self ._engine_ip
737+ engineIP = self ._engine_ip ,
738+ userUUID = self ._user_uuid ,
739+ dataframeNumber = self ._dataframe_number
729740 )
730741
731742 create_dataframe_response = client .createDataFrame (
@@ -738,9 +749,11 @@ def select(self, *fields) -> "DataFrame":
738749 for field in fields :
739750 projection_fields .append (field )
740751
741- client = self .connection .client
752+ client = self ._connection .client
742753 projection_on_dataframe_request = e6x_engine_pb2 .ProjectionOnDataFrameRequest (
754+ userUUID = self ._user_uuid ,
743755 queryId = self ._query_id ,
756+ dataframeNumber = self ._dataframe_number ,
744757 sessionId = self ._sessionId ,
745758 field = projection_fields
746759 )
@@ -752,9 +765,11 @@ def select(self, *fields) -> "DataFrame":
752765 return self
753766
754767 def where (self , where_clause : str ) -> "DataFrame" :
755- client = self .connection .client
768+ client = self ._connection .client
756769 filter_on_dataframe_request = e6x_engine_pb2 .FilterOnDataFrameRequest (
770+ userUUID = self ._user_uuid ,
757771 queryId = self ._query_id ,
772+ dataframeNumber = self ._dataframe_number ,
758773 sessionId = self ._sessionId ,
759774 whereClause = where_clause
760775 )
@@ -794,7 +809,9 @@ def order_by(self, field_list : list, sort_direction_list = None, null_direction
794809 client = self .connection .client
795810
796811 orderby_on_dataframe_request = e6x_engine_pb2 .OrderByOnDataFrameRequest (
812+ userUUID = self ._user_uuid ,
797813 queryId = self ._query_id ,
814+ dataframeNumber = self ._dataframe_number ,
798815 sessionId = self ._sessionId ,
799816 field = orderby_fields ,
800817 sortDirection = sort_direction_request ,
@@ -813,10 +830,12 @@ def order_by(self, *field_list) -> "DataFrame":
813830 for field in field_list :
814831 orderby_fields .append (field )
815832
816- client = self .connection .client
833+ client = self ._connection .client
817834
818835 orderby_on_dataframe_request = e6x_engine_pb2 .OrderByOnDataFrameRequest (
836+ userUUID = self ._user_uuid ,
819837 queryId = self ._query_id ,
838+ dataframeNumber = self ._dataframe_number ,
820839 sessionId = self ._sessionId ,
821840 field = orderby_fields ,
822841 sortDirection = sort_direction_request ,
@@ -829,9 +848,11 @@ def order_by(self, *field_list) -> "DataFrame":
829848 return self
830849
831850 def limit (self , fetch_limit : int ) -> "DataFrame" :
832- client = self .connection .client
851+ client = self ._connection .client
833852 limit_on_dataframe_request = e6x_engine_pb2 .LimitOnDataFrameRequest (
853+ userUUID = self ._user_uuid ,
834854 queryId = self ._query_id ,
855+ dataframeNumber = self ._dataframe_number ,
835856 sessionId = self ._sessionId ,
836857 fetchLimit = fetch_limit
837858 )
@@ -847,9 +868,11 @@ def show(self):
847868 return self .fetchall ()
848869
849870 def execute (self ):
850- client = self .connection .client
871+ client = self ._connection .client
851872 execute_dataframe_request = e6x_engine_pb2 .ExecuteDataFrameRequest (
873+ userUUID = self ._user_uuid ,
852874 queryId = self ._query_id ,
875+ dataframeNumber = self ._dataframe_number ,
853876 sessionId = self ._sessionId
854877 )
855878 execute_dataframe_response = client .executeDataFrame (
@@ -862,15 +885,15 @@ def _update_meta_data(self):
862885 sessionId = self ._sessionId ,
863886 queryId = self ._query_id
864887 )
865- get_result_metadata_response = self .connection .client .getResultMetadata (
888+ get_result_metadata_response = self ._connection .client .getResultMetadata (
866889 result_meta_data_request ,
867890 )
868891 buffer = BytesIO (get_result_metadata_response .resultMetaData )
869892 self ._rowcount , self ._query_columns_description = get_query_columns_info (buffer )
870893 self ._is_metadata_updated = True
871894
872895 def _fetch_batch (self ):
873- client = self .connection .client
896+ client = self ._connection .client
874897 get_next_result_batch_request = e6x_engine_pb2 .GetNextResultBatchRequest (
875898 engineIP = self ._engine_ip ,
876899 sessionId = self ._sessionId ,
@@ -898,6 +921,44 @@ def fetchall(self):
898921 self ._data = None
899922 return rows
900923
924+ class DataFrameSession :
925+ def __init__ (self , connection : Connection ):
926+ self ._user_uuid = str (uuid .uuid4 ())
927+ self ._connection = connection
928+ self ._dataframe_count = 0
929+ self ._dataframe_map = dict ()
930+ self ._is_terminated = False
931+
932+ def __exit__ (self , exc_type , exc_val , exc_tb ):
933+ self .terminate ()
934+
935+ def update_dataframe_map (self , dataframe : "DataFrame" ):
936+ self ._dataframe_map .update ({self ._dataframe_count , dataframe })
937+ self ._dataframe_count = self ._dataframe_count + 1
938+
939+ @property
940+ def get_user_uuid (self ):
941+ return self ._user_uuid
942+
943+ @property
944+ def get_dataframe_number (self ) -> int :
945+ return self ._dataframe_count
946+
947+ @property
948+ def is_terminated (self ) -> bool :
949+ return self ._is_terminated
950+
951+ def terminate (self ):
952+
953+ if not self ._is_terminated :
954+ drop_user_context_request = e6x_engine_pb2 .DropUserContextRequest (
955+ userUUID = self .get_user_uuid
956+ )
957+
958+ drop_user_context_response = self ._connection .client .dropUserContext (drop_user_context_request )
959+ self ._is_terminated = True
960+
961+
901962
902963def poll (self , get_progress_update = True ):
903964 """Poll for and return the raw status data provided by the Hive Thrift REST API.
0 commit comments