1616from io import BytesIO
1717from ssl import CERT_NONE , CERT_OPTIONAL , CERT_REQUIRED
1818from typing import overload
19+ import uuid
1920
2021import grpc
2122from grpc ._channel import _InactiveRpcError
@@ -201,6 +202,9 @@ def __init__(
201202 self .grpc_prepare_timeout = self ._grpc_options .get ('grpc_prepare_timeout' ) or 10 * 60 # 10 minutes
202203 self ._create_client ()
203204
205+ # initialize session for dataframe
206+ self ._dataframe_session = DataFrameSession (self )
207+
204208 @property
205209 def _get_grpc_options (self ):
206210 """
@@ -351,6 +355,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
351355 exc_val (BaseException): The exception instance raised (if any).
352356 exc_tb (Traceback): The traceback object of the exception (if any).
353357 """
358+ self ._dataframe_session .terminate ()
354359 self .close ()
355360
356361 def close (self ):
@@ -363,6 +368,7 @@ def close(self):
363368 self ._channel .close ()
364369 self ._channel = None
365370 self ._session_id = None
371+ self ._dataframe_session .terminate ()
366372
367373 def check_connection (self ):
368374 """
@@ -527,8 +533,15 @@ def cursor(self, catalog_name=None, db_name=None):
527533 """
528534 return Cursor (self , database = db_name , catalog_name = catalog_name )
529535
530- def load_parquet (self , parquet_path ):
531- return DataFrame (self , file_path = parquet_path )
536+ def load_parquet (self , parquet_path ) -> "DataFrame" :
537+ dataframe = DataFrame (
538+ self ,
539+ file_path = parquet_path ,
540+ user_uuid = self ._dataframe_session .get_user_uuid ,
541+ dataframe_number = self ._dataframe_session .get_dataframe_number )
542+
543+ self ._dataframe_session .update_dataframe_map (dataframe = dataframe )
544+ return dataframe
532545
533546 def rollback (self ):
534547 """
@@ -1035,8 +1048,10 @@ def explain_analyse(self):
10351048
10361049class DataFrame :
10371050
1038- def __init__ (self , connection : Connection , file_path ):
1039- self .connection = connection
1051+ def __init__ (self , connection : Connection , file_path , user_uuid , dataframe_number ):
1052+ self ._user_uuid = user_uuid
1053+ self ._dataframe_number = dataframe_number
1054+ self ._connection = connection
10401055 self ._file_path = file_path
10411056 self ._engine_ip = connection .host
10421057 self ._sessionId = connection .get_session_id
@@ -1046,21 +1061,17 @@ def __init__(self, connection: Connection, file_path):
10461061 self ._batch = None
10471062 self ._create_dataframe ()
10481063
1049- def __enter__ (self ):
1050- pass
1051-
1052- def __exit__ (self , exc_type , exc_val , exc_tb ):
1053- pass
1054-
10551064 def _create_dataframe (self ):
1056- client = self .connection .client
1065+ client = self ._connection .client
10571066
10581067 create_dataframe_request = e6x_engine_pb2 .CreateDataFrameRequest (
10591068 parquetFilePath = self ._file_path ,
1060- catalog = self .connection .catalog_name ,
1061- schema = self .connection .database ,
1069+ catalog = self ._connection .catalog_name ,
1070+ schema = self ._connection .database ,
10621071 sessionId = self ._sessionId ,
1063- engineIP = self ._engine_ip
1072+ engineIP = self ._engine_ip ,
1073+ userUUID = self ._user_uuid ,
1074+ dataframeNumber = self ._dataframe_number
10641075 )
10651076
10661077 create_dataframe_response = client .createDataFrame (
@@ -1073,9 +1084,11 @@ def select(self, *fields) -> "DataFrame":
10731084 for field in fields :
10741085 projection_fields .append (field )
10751086
1076- client = self .connection .client
1087+ client = self ._connection .client
10771088 projection_on_dataframe_request = e6x_engine_pb2 .ProjectionOnDataFrameRequest (
1089+ userUUID = self ._user_uuid ,
10781090 queryId = self ._query_id ,
1091+ dataframeNumber = self ._dataframe_number ,
10791092 sessionId = self ._sessionId ,
10801093 field = projection_fields
10811094 )
@@ -1087,9 +1100,11 @@ def select(self, *fields) -> "DataFrame":
10871100 return self
10881101
10891102 def where (self , where_clause : str ) -> "DataFrame" :
1090- client = self .connection .client
1103+ client = self ._connection .client
10911104 filter_on_dataframe_request = e6x_engine_pb2 .FilterOnDataFrameRequest (
1105+ userUUID = self ._user_uuid ,
10921106 queryId = self ._query_id ,
1107+ dataframeNumber = self ._dataframe_number ,
10931108 sessionId = self ._sessionId ,
10941109 whereClause = where_clause
10951110 )
@@ -1129,7 +1144,9 @@ def order_by(self, field_list : list, sort_direction_list = None, null_direction
11291144 client = self .connection .client
11301145
11311146 orderby_on_dataframe_request = e6x_engine_pb2 .OrderByOnDataFrameRequest (
1147+ userUUID = self ._user_uuid ,
11321148 queryId = self ._query_id ,
1149+ dataframeNumber = self ._dataframe_number ,
11331150 sessionId = self ._sessionId ,
11341151 field = orderby_fields ,
11351152 sortDirection = sort_direction_request ,
@@ -1148,10 +1165,12 @@ def order_by(self, *field_list) -> "DataFrame":
11481165 for field in field_list :
11491166 orderby_fields .append (field )
11501167
1151- client = self .connection .client
1168+ client = self ._connection .client
11521169
11531170 orderby_on_dataframe_request = e6x_engine_pb2 .OrderByOnDataFrameRequest (
1171+ userUUID = self ._user_uuid ,
11541172 queryId = self ._query_id ,
1173+ dataframeNumber = self ._dataframe_number ,
11551174 sessionId = self ._sessionId ,
11561175 field = orderby_fields ,
11571176 sortDirection = sort_direction_request ,
@@ -1164,9 +1183,11 @@ def order_by(self, *field_list) -> "DataFrame":
11641183 return self
11651184
11661185 def limit (self , fetch_limit : int ) -> "DataFrame" :
1167- client = self .connection .client
1186+ client = self ._connection .client
11681187 limit_on_dataframe_request = e6x_engine_pb2 .LimitOnDataFrameRequest (
1188+ userUUID = self ._user_uuid ,
11691189 queryId = self ._query_id ,
1190+ dataframeNumber = self ._dataframe_number ,
11701191 sessionId = self ._sessionId ,
11711192 fetchLimit = fetch_limit
11721193 )
@@ -1182,9 +1203,11 @@ def show(self):
11821203 return self .fetchall ()
11831204
11841205 def execute (self ):
1185- client = self .connection .client
1206+ client = self ._connection .client
11861207 execute_dataframe_request = e6x_engine_pb2 .ExecuteDataFrameRequest (
1208+ userUUID = self ._user_uuid ,
11871209 queryId = self ._query_id ,
1210+ dataframeNumber = self ._dataframe_number ,
11881211 sessionId = self ._sessionId
11891212 )
11901213 execute_dataframe_response = client .executeDataFrame (
@@ -1197,15 +1220,15 @@ def _update_meta_data(self):
11971220 sessionId = self ._sessionId ,
11981221 queryId = self ._query_id
11991222 )
1200- get_result_metadata_response = self .connection .client .getResultMetadata (
1223+ get_result_metadata_response = self ._connection .client .getResultMetadata (
12011224 result_meta_data_request ,
12021225 )
12031226 buffer = BytesIO (get_result_metadata_response .resultMetaData )
12041227 self ._rowcount , self ._query_columns_description = get_query_columns_info (buffer )
12051228 self ._is_metadata_updated = True
12061229
12071230 def _fetch_batch (self ):
1208- client = self .connection .client
1231+ client = self ._connection .client
12091232 get_next_result_batch_request = e6x_engine_pb2 .GetNextResultBatchRequest (
12101233 engineIP = self ._engine_ip ,
12111234 sessionId = self ._sessionId ,
@@ -1233,6 +1256,44 @@ def fetchall(self):
12331256 self ._data = None
12341257 return rows
12351258
1259+ class DataFrameSession :
1260+ def __init__ (self , connection : Connection ):
1261+ self ._user_uuid = str (uuid .uuid4 ())
1262+ self ._connection = connection
1263+ self ._dataframe_count = 0
1264+ self ._dataframe_map = dict ()
1265+ self ._is_terminated = False
1266+
1267+ def __exit__ (self , exc_type , exc_val , exc_tb ):
1268+ self .terminate ()
1269+
1270+ def update_dataframe_map (self , dataframe : "DataFrame" ):
1271+ self ._dataframe_map .update ({self ._dataframe_count , dataframe })
1272+ self ._dataframe_count = self ._dataframe_count + 1
1273+
1274+ @property
1275+ def get_user_uuid (self ):
1276+ return self ._user_uuid
1277+
1278+ @property
1279+ def get_dataframe_number (self ) -> int :
1280+ return self ._dataframe_count
1281+
1282+ @property
1283+ def is_terminated (self ) -> bool :
1284+ return self ._is_terminated
1285+
1286+ def terminate (self ):
1287+
1288+ if not self ._is_terminated :
1289+ drop_user_context_request = e6x_engine_pb2 .DropUserContextRequest (
1290+ userUUID = self .get_user_uuid
1291+ )
1292+
1293+ drop_user_context_response = self ._connection .client .dropUserContext (drop_user_context_request )
1294+ self ._is_terminated = True
1295+
1296+
12361297
12371298def poll (self , get_progress_update = True ):
12381299 """Poll for and return the raw status data provided by the Hive Thrift REST API.
0 commit comments