Skip to content

Commit 6dbb0fa

Browse files
committed
dataframe session changes & DF API changes
1 parent ed8024c commit 6dbb0fa

File tree

6 files changed

+914
-416
lines changed

6 files changed

+914
-416
lines changed

dataframe_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def setUp(self) -> None:
1515
self.e6x_connection = Connection(
1616
host=self._host,
1717
port=9001,
18-
username='[email protected]',
18+
username='[email protected]',
1919
password='Dummy@123',
2020
database=self._database,
2121
catalog=self._catalog

e6data_python_connector/e6data_grpc.py

Lines changed: 75 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from io import BytesIO
1717
from ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
1818
from typing import overload
19-
import uuid
2019

2120
import grpc
2221
from grpc._channel import _InactiveRpcError
@@ -26,6 +25,7 @@
2625
from e6data_python_connector.constants import *
2726
from e6data_python_connector.datainputstream import get_query_columns_info, read_rows_from_chunk
2827
from e6data_python_connector.server import e6x_engine_pb2_grpc, e6x_engine_pb2
28+
from e6data_python_connector.server.e6x_engine_pb2 import AggregateFunction
2929
from e6data_python_connector.typeId import *
3030

3131
apilevel = '2.0'
@@ -371,11 +371,11 @@ def close(self):
371371
372372
This method ensures that the gRPC channel is properly closed and the session ID is reset to None.
373373
"""
374+
self._dataframe_session.terminate()
374375
if self._channel is not None:
375376
self._channel.close()
376377
self._channel = None
377378
self._session_id = None
378-
self._dataframe_session.terminate()
379379

380380
def check_connection(self):
381381
"""
@@ -544,8 +544,22 @@ def load_parquet(self, parquet_path) -> "DataFrame":
544544
dataframe = DataFrame(
545545
self,
546546
file_path=parquet_path,
547-
user_uuid=self._dataframe_session.get_user_uuid,
548-
dataframe_number=self._dataframe_session.get_dataframe_number)
547+
dataframe_number=self._dataframe_session.get_dataframe_number,
548+
table_name=None
549+
)
550+
551+
self._dataframe_session.update_dataframe_map(dataframe=dataframe)
552+
return dataframe
553+
554+
def load_table(self, table_name, database = None, catalog = None) -> "DataFrame":
555+
dataframe = DataFrame(
556+
self,
557+
file_path=None,
558+
dataframe_number=self._dataframe_session.get_dataframe_number,
559+
table_name=table_name,
560+
database=database,
561+
catalog=catalog
562+
)
549563

550564
self._dataframe_session.update_dataframe_map(dataframe=dataframe)
551565
return dataframe
@@ -1055,153 +1069,133 @@ def explain_analyse(self):
10551069

10561070
class DataFrame:
10571071

1058-
def __init__(self, connection: Connection, file_path, user_uuid, dataframe_number):
1059-
self._user_uuid = user_uuid
1072+
def __init__(self, connection: Connection, file_path, dataframe_number, table_name, database = None, catalog = None):
10601073
self._dataframe_number = dataframe_number
10611074
self._connection = connection
1075+
self._catalog = self._connection.catalog_name if catalog is None else catalog
1076+
self._database = self._connection.database if database is None else database
1077+
self._table_name = table_name
10621078
self._file_path = file_path
10631079
self._engine_ip = connection.host
10641080
self._sessionId = connection.get_session_id
10651081
self._is_metadata_updated = False
10661082
self._query_id = None
10671083
self._data = None
10681084
self._batch = None
1069-
self._create_dataframe()
1085+
self._create_dataframe(self._file_path is not None)
10701086

1071-
def _create_dataframe(self):
1087+
def _create_dataframe(self, create_dataframe_from_parquet : bool):
10721088
client = self._connection.client
10731089

10741090
create_dataframe_request = e6x_engine_pb2.CreateDataFrameRequest(
10751091
parquetFilePath=self._file_path,
1076-
catalog=self._connection.catalog_name,
1077-
schema=self._connection.database,
1092+
catalog=self._catalog,
1093+
schema=self._database,
1094+
table=self._table_name,
10781095
sessionId=self._sessionId,
10791096
engineIP=self._engine_ip,
1080-
userUUID=self._user_uuid,
1081-
dataframeNumber=self._dataframe_number
1097+
dataframeNumber=self._dataframe_number,
1098+
createFromParquet=create_dataframe_from_parquet
10821099
)
10831100

10841101
create_dataframe_response = client.createDataFrame(
10851102
create_dataframe_request
10861103
)
10871104
self._query_id = create_dataframe_response.queryId
10881105

1089-
def select(self, *fields) -> "DataFrame":
1106+
def select(self, *fields : str) -> "DataFrame":
10901107
projection_fields = []
10911108
for field in fields:
10921109
projection_fields.append(field)
10931110

10941111
client = self._connection.client
10951112
projection_on_dataframe_request = e6x_engine_pb2.ProjectionOnDataFrameRequest(
1096-
userUUID=self._user_uuid,
10971113
queryId=self._query_id,
10981114
dataframeNumber=self._dataframe_number,
10991115
sessionId=self._sessionId,
11001116
field=projection_fields
11011117
)
11021118

1103-
projection_on_dataframe_response = client.projectionOnDataFrame(
1104-
projection_on_dataframe_request
1105-
)
1119+
client.projectionOnDataFrame(projection_on_dataframe_request)
11061120

11071121
return self
11081122

1109-
def where(self, where_clause : str) -> "DataFrame":
1123+
def aggregate(self, agg_function : dict[str, str], group_by : list[str] = None) -> "DataFrame":
1124+
def get_agg_enum(function_name : str) -> AggregateFunction | None:
1125+
match function_name.lower():
1126+
case 'sum':
1127+
return e6x_engine_pb2.AggregateFunction.SUM
1128+
case 'count':
1129+
return e6x_engine_pb2.AggregateFunction.COUNT
1130+
case 'count_star':
1131+
return e6x_engine_pb2.AggregateFunction.COUNT_STAR
1132+
case 'count_distinct':
1133+
return e6x_engine_pb2.AggregateFunction.COUNT_DISTINCT
1134+
case _:
1135+
return None
1136+
1137+
agg_function_map = {}
1138+
1139+
for column in agg_function.keys():
1140+
fun = get_agg_enum(agg_function.get(column))
1141+
if fun is not None:
1142+
agg_function_map.update({column : fun})
1143+
11101144
client = self._connection.client
1111-
filter_on_dataframe_request = e6x_engine_pb2.FilterOnDataFrameRequest(
1112-
userUUID=self._user_uuid,
1145+
aggregate_on_dataframe_request = e6x_engine_pb2.AggregateOnDataFrameRequest(
11131146
queryId=self._query_id,
11141147
dataframeNumber=self._dataframe_number,
11151148
sessionId=self._sessionId,
1116-
whereClause=where_clause
1149+
aggregateFunctionMap=agg_function_map,
1150+
groupBy=group_by
11171151
)
11181152

1119-
filter_on_dataframe_response = client.filterOnDataFrame(
1120-
filter_on_dataframe_request
1121-
)
1153+
client.aggregateOnDataFrame(aggregate_on_dataframe_request)
11221154

11231155
return self
11241156

1125-
@overload
1126-
def order_by(self, field_list : list, sort_direction_list = None, null_direction_list = None) -> "DataFrame":
1127-
orderby_fields = []
1128-
sort_direction_request = []
1129-
null_direction_request = []
1130-
for field in field_list:
1131-
orderby_fields.append(field)
1132-
1133-
for direction in sort_direction_list:
1134-
direction = str(direction).upper()
1135-
if direction == 'ASC':
1136-
sort_direction_request.append(e6x_engine_pb2.SortDirection.ASC)
1137-
elif direction == 'DESC':
1138-
sort_direction_request.append(e6x_engine_pb2.SortDirection.DESC)
1139-
else:
1140-
sort_direction_request.append(None)
1141-
1142-
for direction in null_direction_list:
1143-
direction = str(direction).upper()
1144-
if direction == 'NULLS_FIRST':
1145-
null_direction_request.append(e6x_engine_pb2.NullDirection.FIRST)
1146-
elif direction == 'NULLS_LAST':
1147-
null_direction_request.append(e6x_engine_pb2.NullDirection.LAST)
1148-
else:
1149-
null_direction_request.append(None)
1150-
1151-
client = self.connection.client
1152-
1153-
orderby_on_dataframe_request = e6x_engine_pb2.OrderByOnDataFrameRequest(
1154-
userUUID=self._user_uuid,
1157+
def where(self, where_clause : str) -> "DataFrame":
1158+
client = self._connection.client
1159+
filter_on_dataframe_request = e6x_engine_pb2.FilterOnDataFrameRequest(
11551160
queryId=self._query_id,
11561161
dataframeNumber=self._dataframe_number,
11571162
sessionId=self._sessionId,
1158-
field=orderby_fields,
1159-
sortDirection=sort_direction_request,
1160-
nullsDirection=null_direction_request
1163+
whereClause=where_clause
11611164
)
11621165

1163-
orderby_on_dataframe_response = client.orderByOnDataFrame(
1164-
orderby_on_dataframe_request
1165-
)
1166+
client.filterOnDataFrame(filter_on_dataframe_request)
1167+
11661168
return self
11671169

1168-
def order_by(self, *field_list) -> "DataFrame":
1169-
orderby_fields = []
1170-
sort_direction_request = []
1171-
null_direction_request = []
1172-
for field in field_list:
1173-
orderby_fields.append(field)
1170+
def order_by(self, *field_list : str) -> "DataFrame":
1171+
order_by_map = dict()
1172+
1173+
# default sorting in ASCENDING order
1174+
for column in field_list:
1175+
order_by_map.update({ column : e6x_engine_pb2.SortDirection.ASC})
11741176

11751177
client = self._connection.client
11761178

11771179
orderby_on_dataframe_request = e6x_engine_pb2.OrderByOnDataFrameRequest(
1178-
userUUID=self._user_uuid,
11791180
queryId=self._query_id,
11801181
dataframeNumber=self._dataframe_number,
11811182
sessionId=self._sessionId,
1182-
field=orderby_fields,
1183-
sortDirection=sort_direction_request,
1184-
nullsDirection=null_direction_request
1183+
orderByFieldMap=order_by_map
11851184
)
11861185

1187-
orderby_on_dataframe_response = client.orderByOnDataFrame(
1188-
orderby_on_dataframe_request
1189-
)
1186+
client.orderByOnDataFrame(orderby_on_dataframe_request)
11901187
return self
11911188

11921189
def limit(self, fetch_limit : int) -> "DataFrame":
11931190
client = self._connection.client
11941191
limit_on_dataframe_request = e6x_engine_pb2.LimitOnDataFrameRequest(
1195-
userUUID=self._user_uuid,
11961192
queryId=self._query_id,
11971193
dataframeNumber=self._dataframe_number,
11981194
sessionId=self._sessionId,
11991195
fetchLimit=fetch_limit
12001196
)
12011197

1202-
limit_on_dataframe_response = client.limitOnDataFrame(
1203-
limit_on_dataframe_request
1204-
)
1198+
client.limitOnDataFrame(limit_on_dataframe_request)
12051199

12061200
return self
12071201

@@ -1212,14 +1206,11 @@ def show(self):
12121206
def execute(self):
12131207
client = self._connection.client
12141208
execute_dataframe_request = e6x_engine_pb2.ExecuteDataFrameRequest(
1215-
userUUID=self._user_uuid,
12161209
queryId=self._query_id,
12171210
dataframeNumber=self._dataframe_number,
12181211
sessionId=self._sessionId
12191212
)
1220-
execute_dataframe_response = client.executeDataFrame(
1221-
execute_dataframe_request
1222-
)
1213+
client.executeDataFrame(execute_dataframe_request)
12231214

12241215
def _update_meta_data(self):
12251216
result_meta_data_request = e6x_engine_pb2.GetResultMetadataRequest(
@@ -1265,7 +1256,6 @@ def fetchall(self):
12651256

12661257
class DataFrameSession:
12671258
def __init__(self, connection: Connection):
1268-
self._user_uuid = str(uuid.uuid4())
12691259
self._connection = connection
12701260
self._dataframe_count = 0
12711261
self._dataframe_map = dict()
@@ -1278,10 +1268,6 @@ def update_dataframe_map(self, dataframe : "DataFrame"):
12781268
self._dataframe_map.update({self._dataframe_count : dataframe})
12791269
self._dataframe_count = self._dataframe_count + 1
12801270

1281-
@property
1282-
def get_user_uuid(self):
1283-
return self._user_uuid
1284-
12851271
@property
12861272
def get_dataframe_number(self) -> int:
12871273
return self._dataframe_count
@@ -1291,14 +1277,13 @@ def is_terminated(self) -> bool:
12911277
return self._is_terminated
12921278

12931279
def terminate(self):
1294-
12951280
if not self._is_terminated:
12961281
drop_user_context_request = e6x_engine_pb2.DropUserContextRequest(
1297-
userUUID=self.get_user_uuid
1282+
sessionId=self._connection.get_session_id
12981283
)
12991284

1300-
drop_user_context_response = self._connection.client.dropUserContext(drop_user_context_request)
1301-
self._is_terminated = True
1285+
self._connection.client.dropUserContext(drop_user_context_request)
1286+
self._is_terminated = True
13021287

13031288

13041289

0 commit comments

Comments
 (0)