Skip to content

Commit f021c57

Browse files
committed
add session for dataframe
1 parent f102265 commit f021c57

File tree

5 files changed

+244
-90
lines changed

5 files changed

+244
-90
lines changed

e6data_python_connector/e6data_grpc.py

Lines changed: 82 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from io import BytesIO
1717
from ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
1818
from typing import overload
19+
import uuid
1920

2021
import grpc
2122
from grpc._channel import _InactiveRpcError
@@ -201,6 +202,9 @@ def __init__(
201202
self.grpc_prepare_timeout = self._grpc_options.get('grpc_prepare_timeout') or 10 * 60 # 10 minutes
202203
self._create_client()
203204

205+
# initialize session for dataframe
206+
self._dataframe_session = DataFrameSession(self)
207+
204208
@property
205209
def _get_grpc_options(self):
206210
"""
@@ -351,6 +355,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
351355
exc_val (BaseException): The exception instance raised (if any).
352356
exc_tb (Traceback): The traceback object of the exception (if any).
353357
"""
358+
self._dataframe_session.terminate()
354359
self.close()
355360

356361
def close(self):
@@ -363,6 +368,7 @@ def close(self):
363368
self._channel.close()
364369
self._channel = None
365370
self._session_id = None
371+
self._dataframe_session.terminate()
366372

367373
def check_connection(self):
368374
"""
@@ -527,8 +533,15 @@ def cursor(self, catalog_name=None, db_name=None):
527533
"""
528534
return Cursor(self, database=db_name, catalog_name=catalog_name)
529535

530-
def load_parquet(self, parquet_path):
531-
return DataFrame(self, file_path=parquet_path)
536+
def load_parquet(self, parquet_path) -> "DataFrame":
537+
dataframe = DataFrame(
538+
self,
539+
file_path=parquet_path,
540+
user_uuid=self._dataframe_session.get_user_uuid,
541+
dataframe_number=self._dataframe_session.get_dataframe_number)
542+
543+
self._dataframe_session.update_dataframe_map(dataframe=dataframe)
544+
return dataframe
532545

533546
def rollback(self):
534547
"""
@@ -1035,8 +1048,10 @@ def explain_analyse(self):
10351048

10361049
class DataFrame:
10371050

1038-
def __init__(self, connection: Connection, file_path):
1039-
self.connection = connection
1051+
def __init__(self, connection: Connection, file_path, user_uuid, dataframe_number):
1052+
self._user_uuid = user_uuid
1053+
self._dataframe_number = dataframe_number
1054+
self._connection = connection
10401055
self._file_path = file_path
10411056
self._engine_ip = connection.host
10421057
self._sessionId = connection.get_session_id
@@ -1046,21 +1061,17 @@ def __init__(self, connection: Connection, file_path):
10461061
self._batch = None
10471062
self._create_dataframe()
10481063

1049-
def __enter__(self):
1050-
pass
1051-
1052-
def __exit__(self, exc_type, exc_val, exc_tb):
1053-
pass
1054-
10551064
def _create_dataframe(self):
1056-
client = self.connection.client
1065+
client = self._connection.client
10571066

10581067
create_dataframe_request = e6x_engine_pb2.CreateDataFrameRequest(
10591068
parquetFilePath=self._file_path,
1060-
catalog=self.connection.catalog_name,
1061-
schema=self.connection.database,
1069+
catalog=self._connection.catalog_name,
1070+
schema=self._connection.database,
10621071
sessionId=self._sessionId,
1063-
engineIP=self._engine_ip
1072+
engineIP=self._engine_ip,
1073+
userUUID=self._user_uuid,
1074+
dataframeNumber=self._dataframe_number
10641075
)
10651076

10661077
create_dataframe_response = client.createDataFrame(
@@ -1073,9 +1084,11 @@ def select(self, *fields) -> "DataFrame":
10731084
for field in fields:
10741085
projection_fields.append(field)
10751086

1076-
client = self.connection.client
1087+
client = self._connection.client
10771088
projection_on_dataframe_request = e6x_engine_pb2.ProjectionOnDataFrameRequest(
1089+
userUUID=self._user_uuid,
10781090
queryId=self._query_id,
1091+
dataframeNumber=self._dataframe_number,
10791092
sessionId=self._sessionId,
10801093
field=projection_fields
10811094
)
@@ -1087,9 +1100,11 @@ def select(self, *fields) -> "DataFrame":
10871100
return self
10881101

10891102
def where(self, where_clause : str) -> "DataFrame":
1090-
client = self.connection.client
1103+
client = self._connection.client
10911104
filter_on_dataframe_request = e6x_engine_pb2.FilterOnDataFrameRequest(
1105+
userUUID=self._user_uuid,
10921106
queryId=self._query_id,
1107+
dataframeNumber=self._dataframe_number,
10931108
sessionId=self._sessionId,
10941109
whereClause=where_clause
10951110
)
@@ -1129,7 +1144,9 @@ def order_by(self, field_list : list, sort_direction_list = None, null_direction
11291144
client = self.connection.client
11301145

11311146
orderby_on_dataframe_request = e6x_engine_pb2.OrderByOnDataFrameRequest(
1147+
userUUID=self._user_uuid,
11321148
queryId=self._query_id,
1149+
dataframeNumber=self._dataframe_number,
11331150
sessionId=self._sessionId,
11341151
field=orderby_fields,
11351152
sortDirection=sort_direction_request,
@@ -1148,10 +1165,12 @@ def order_by(self, *field_list) -> "DataFrame":
11481165
for field in field_list:
11491166
orderby_fields.append(field)
11501167

1151-
client = self.connection.client
1168+
client = self._connection.client
11521169

11531170
orderby_on_dataframe_request = e6x_engine_pb2.OrderByOnDataFrameRequest(
1171+
userUUID=self._user_uuid,
11541172
queryId=self._query_id,
1173+
dataframeNumber=self._dataframe_number,
11551174
sessionId=self._sessionId,
11561175
field=orderby_fields,
11571176
sortDirection=sort_direction_request,
@@ -1164,9 +1183,11 @@ def order_by(self, *field_list) -> "DataFrame":
11641183
return self
11651184

11661185
def limit(self, fetch_limit : int) -> "DataFrame":
1167-
client = self.connection.client
1186+
client = self._connection.client
11681187
limit_on_dataframe_request = e6x_engine_pb2.LimitOnDataFrameRequest(
1188+
userUUID=self._user_uuid,
11691189
queryId=self._query_id,
1190+
dataframeNumber=self._dataframe_number,
11701191
sessionId=self._sessionId,
11711192
fetchLimit=fetch_limit
11721193
)
@@ -1182,9 +1203,11 @@ def show(self):
11821203
return self.fetchall()
11831204

11841205
def execute(self):
1185-
client = self.connection.client
1206+
client = self._connection.client
11861207
execute_dataframe_request = e6x_engine_pb2.ExecuteDataFrameRequest(
1208+
userUUID=self._user_uuid,
11871209
queryId=self._query_id,
1210+
dataframeNumber=self._dataframe_number,
11881211
sessionId=self._sessionId
11891212
)
11901213
execute_dataframe_response = client.executeDataFrame(
@@ -1197,15 +1220,15 @@ def _update_meta_data(self):
11971220
sessionId=self._sessionId,
11981221
queryId=self._query_id
11991222
)
1200-
get_result_metadata_response = self.connection.client.getResultMetadata(
1223+
get_result_metadata_response = self._connection.client.getResultMetadata(
12011224
result_meta_data_request,
12021225
)
12031226
buffer = BytesIO(get_result_metadata_response.resultMetaData)
12041227
self._rowcount, self._query_columns_description = get_query_columns_info(buffer)
12051228
self._is_metadata_updated = True
12061229

12071230
def _fetch_batch(self):
1208-
client = self.connection.client
1231+
client = self._connection.client
12091232
get_next_result_batch_request = e6x_engine_pb2.GetNextResultBatchRequest(
12101233
engineIP=self._engine_ip,
12111234
sessionId=self._sessionId,
@@ -1233,6 +1256,44 @@ def fetchall(self):
12331256
self._data = None
12341257
return rows
12351258

1259+
class DataFrameSession:
1260+
def __init__(self, connection: Connection):
1261+
self._user_uuid = str(uuid.uuid4())
1262+
self._connection = connection
1263+
self._dataframe_count = 0
1264+
self._dataframe_map = dict()
1265+
self._is_terminated = False
1266+
1267+
def __exit__(self, exc_type, exc_val, exc_tb):
1268+
self.terminate()
1269+
1270+
def update_dataframe_map(self, dataframe : "DataFrame"):
1271+
self._dataframe_map.update({self._dataframe_count, dataframe})
1272+
self._dataframe_count = self._dataframe_count + 1
1273+
1274+
@property
1275+
def get_user_uuid(self):
1276+
return self._user_uuid
1277+
1278+
@property
1279+
def get_dataframe_number(self) -> int:
1280+
return self._dataframe_count
1281+
1282+
@property
1283+
def is_terminated(self) -> bool:
1284+
return self._is_terminated
1285+
1286+
def terminate(self):
1287+
1288+
if not self._is_terminated:
1289+
drop_user_context_request = e6x_engine_pb2.DropUserContextRequest(
1290+
userUUID=self.get_user_uuid
1291+
)
1292+
1293+
drop_user_context_response = self._connection.client.dropUserContext(drop_user_context_request)
1294+
self._is_terminated = True
1295+
1296+
12361297

12371298
def poll(self, get_progress_update=True):
12381299
"""Poll for and return the raw status data provided by the Hive Thrift REST API.

0 commit comments

Comments
 (0)