1616from io import BytesIO
1717from ssl import CERT_NONE , CERT_OPTIONAL , CERT_REQUIRED
1818from typing import overload
19- import uuid
2019
2120import grpc
2221from grpc ._channel import _InactiveRpcError
2625from e6data_python_connector .constants import *
2726from e6data_python_connector .datainputstream import get_query_columns_info , read_rows_from_chunk
2827from e6data_python_connector .server import e6x_engine_pb2_grpc , e6x_engine_pb2
28+ from e6data_python_connector .server .e6x_engine_pb2 import AggregateFunction
2929from e6data_python_connector .typeId import *
3030
3131apilevel = '2.0'
@@ -371,11 +371,11 @@ def close(self):
371371
372372 This method ensures that the gRPC channel is properly closed and the session ID is reset to None.
373373 """
374+ self ._dataframe_session .terminate ()
374375 if self ._channel is not None :
375376 self ._channel .close ()
376377 self ._channel = None
377378 self ._session_id = None
378- self ._dataframe_session .terminate ()
379379
380380 def check_connection (self ):
381381 """
@@ -544,8 +544,22 @@ def load_parquet(self, parquet_path) -> "DataFrame":
544544 dataframe = DataFrame (
545545 self ,
546546 file_path = parquet_path ,
547- user_uuid = self ._dataframe_session .get_user_uuid ,
548- dataframe_number = self ._dataframe_session .get_dataframe_number )
547+ dataframe_number = self ._dataframe_session .get_dataframe_number ,
548+ table_name = None
549+ )
550+
551+ self ._dataframe_session .update_dataframe_map (dataframe = dataframe )
552+ return dataframe
553+
554+ def load_table (self , table_name , database = None , catalog = None ) -> "DataFrame" :
555+ dataframe = DataFrame (
556+ self ,
557+ file_path = None ,
558+ dataframe_number = self ._dataframe_session .get_dataframe_number ,
559+ table_name = table_name ,
560+ database = database ,
561+ catalog = catalog
562+ )
549563
550564 self ._dataframe_session .update_dataframe_map (dataframe = dataframe )
551565 return dataframe
@@ -1055,153 +1069,133 @@ def explain_analyse(self):
10551069
10561070class DataFrame :
10571071
1058- def __init__ (self , connection : Connection , file_path , user_uuid , dataframe_number ):
1059- self ._user_uuid = user_uuid
1072+ def __init__ (self , connection : Connection , file_path , dataframe_number , table_name , database = None , catalog = None ):
10601073 self ._dataframe_number = dataframe_number
10611074 self ._connection = connection
1075+ self ._catalog = self ._connection .catalog_name if catalog is None else catalog
1076+ self ._database = self ._connection .database if database is None else database
1077+ self ._table_name = table_name
10621078 self ._file_path = file_path
10631079 self ._engine_ip = connection .host
10641080 self ._sessionId = connection .get_session_id
10651081 self ._is_metadata_updated = False
10661082 self ._query_id = None
10671083 self ._data = None
10681084 self ._batch = None
1069- self ._create_dataframe ()
1085+ self ._create_dataframe (self . _file_path is not None )
10701086
1071- def _create_dataframe (self ):
1087+ def _create_dataframe (self , create_dataframe_from_parquet : bool ):
10721088 client = self ._connection .client
10731089
10741090 create_dataframe_request = e6x_engine_pb2 .CreateDataFrameRequest (
10751091 parquetFilePath = self ._file_path ,
1076- catalog = self ._connection .catalog_name ,
1077- schema = self ._connection .database ,
1092+ catalog = self ._catalog ,
1093+ schema = self ._database ,
1094+ table = self ._table_name ,
10781095 sessionId = self ._sessionId ,
10791096 engineIP = self ._engine_ip ,
1080- userUUID = self ._user_uuid ,
1081- dataframeNumber = self . _dataframe_number
1097+ dataframeNumber = self ._dataframe_number ,
1098+ createFromParquet = create_dataframe_from_parquet
10821099 )
10831100
10841101 create_dataframe_response = client .createDataFrame (
10851102 create_dataframe_request
10861103 )
10871104 self ._query_id = create_dataframe_response .queryId
10881105
1089- def select (self , * fields ) -> "DataFrame" :
1106+ def select (self , * fields : str ) -> "DataFrame" :
10901107 projection_fields = []
10911108 for field in fields :
10921109 projection_fields .append (field )
10931110
10941111 client = self ._connection .client
10951112 projection_on_dataframe_request = e6x_engine_pb2 .ProjectionOnDataFrameRequest (
1096- userUUID = self ._user_uuid ,
10971113 queryId = self ._query_id ,
10981114 dataframeNumber = self ._dataframe_number ,
10991115 sessionId = self ._sessionId ,
11001116 field = projection_fields
11011117 )
11021118
1103- projection_on_dataframe_response = client .projectionOnDataFrame (
1104- projection_on_dataframe_request
1105- )
1119+ client .projectionOnDataFrame (projection_on_dataframe_request )
11061120
11071121 return self
11081122
1109- def where (self , where_clause : str ) -> "DataFrame" :
1123+ def aggregate (self , agg_function : dict [str , str ], group_by : list [str ] = None ) -> "DataFrame" :
1124+ def get_agg_enum (function_name : str ) -> AggregateFunction | None :
1125+ match function_name .lower ():
1126+ case 'sum' :
1127+ return e6x_engine_pb2 .AggregateFunction .SUM
1128+ case 'count' :
1129+ return e6x_engine_pb2 .AggregateFunction .COUNT
1130+ case 'count_star' :
1131+ return e6x_engine_pb2 .AggregateFunction .COUNT_STAR
1132+ case 'count_distinct' :
1133+ return e6x_engine_pb2 .AggregateFunction .COUNT_DISTINCT
1134+ case _:
1135+ return None
1136+
1137+ agg_function_map = {}
1138+
1139+ for column in agg_function .keys ():
1140+ fun = get_agg_enum (agg_function .get (column ))
1141+ if fun is not None :
1142+ agg_function_map .update ({column : fun })
1143+
11101144 client = self ._connection .client
1111- filter_on_dataframe_request = e6x_engine_pb2 .FilterOnDataFrameRequest (
1112- userUUID = self ._user_uuid ,
1145+ aggregate_on_dataframe_request = e6x_engine_pb2 .AggregateOnDataFrameRequest (
11131146 queryId = self ._query_id ,
11141147 dataframeNumber = self ._dataframe_number ,
11151148 sessionId = self ._sessionId ,
1116- whereClause = where_clause
1149+ aggregateFunctionMap = agg_function_map ,
1150+ groupBy = group_by
11171151 )
11181152
1119- filter_on_dataframe_response = client .filterOnDataFrame (
1120- filter_on_dataframe_request
1121- )
1153+ client .aggregateOnDataFrame (aggregate_on_dataframe_request )
11221154
11231155 return self
11241156
1125- @overload
1126- def order_by (self , field_list : list , sort_direction_list = None , null_direction_list = None ) -> "DataFrame" :
1127- orderby_fields = []
1128- sort_direction_request = []
1129- null_direction_request = []
1130- for field in field_list :
1131- orderby_fields .append (field )
1132-
1133- for direction in sort_direction_list :
1134- direction = str (direction ).upper ()
1135- if direction == 'ASC' :
1136- sort_direction_request .append (e6x_engine_pb2 .SortDirection .ASC )
1137- elif direction == 'DESC' :
1138- sort_direction_request .append (e6x_engine_pb2 .SortDirection .DESC )
1139- else :
1140- sort_direction_request .append (None )
1141-
1142- for direction in null_direction_list :
1143- direction = str (direction ).upper ()
1144- if direction == 'NULLS_FIRST' :
1145- null_direction_request .append (e6x_engine_pb2 .NullDirection .FIRST )
1146- elif direction == 'NULLS_LAST' :
1147- null_direction_request .append (e6x_engine_pb2 .NullDirection .LAST )
1148- else :
1149- null_direction_request .append (None )
1150-
1151- client = self .connection .client
1152-
1153- orderby_on_dataframe_request = e6x_engine_pb2 .OrderByOnDataFrameRequest (
1154- userUUID = self ._user_uuid ,
1157+ def where (self , where_clause : str ) -> "DataFrame" :
1158+ client = self ._connection .client
1159+ filter_on_dataframe_request = e6x_engine_pb2 .FilterOnDataFrameRequest (
11551160 queryId = self ._query_id ,
11561161 dataframeNumber = self ._dataframe_number ,
11571162 sessionId = self ._sessionId ,
1158- field = orderby_fields ,
1159- sortDirection = sort_direction_request ,
1160- nullsDirection = null_direction_request
1163+ whereClause = where_clause
11611164 )
11621165
1163- orderby_on_dataframe_response = client .orderByOnDataFrame (
1164- orderby_on_dataframe_request
1165- )
1166+ client .filterOnDataFrame (filter_on_dataframe_request )
1167+
11661168 return self
11671169
1168- def order_by (self , * field_list ) -> "DataFrame" :
1169- orderby_fields = []
1170- sort_direction_request = []
1171- null_direction_request = []
1172- for field in field_list :
1173- orderby_fields . append ( field )
1170+ def order_by (self , * field_list : str ) -> "DataFrame" :
1171+ order_by_map = dict ()
1172+
1173+ # default sorting in ASCENDING order
1174+ for column in field_list :
1175+ order_by_map . update ({ column : e6x_engine_pb2 . SortDirection . ASC } )
11741176
11751177 client = self ._connection .client
11761178
11771179 orderby_on_dataframe_request = e6x_engine_pb2 .OrderByOnDataFrameRequest (
1178- userUUID = self ._user_uuid ,
11791180 queryId = self ._query_id ,
11801181 dataframeNumber = self ._dataframe_number ,
11811182 sessionId = self ._sessionId ,
1182- field = orderby_fields ,
1183- sortDirection = sort_direction_request ,
1184- nullsDirection = null_direction_request
1183+ orderByFieldMap = order_by_map
11851184 )
11861185
1187- orderby_on_dataframe_response = client .orderByOnDataFrame (
1188- orderby_on_dataframe_request
1189- )
1186+ client .orderByOnDataFrame (orderby_on_dataframe_request )
11901187 return self
11911188
11921189 def limit (self , fetch_limit : int ) -> "DataFrame" :
11931190 client = self ._connection .client
11941191 limit_on_dataframe_request = e6x_engine_pb2 .LimitOnDataFrameRequest (
1195- userUUID = self ._user_uuid ,
11961192 queryId = self ._query_id ,
11971193 dataframeNumber = self ._dataframe_number ,
11981194 sessionId = self ._sessionId ,
11991195 fetchLimit = fetch_limit
12001196 )
12011197
1202- limit_on_dataframe_response = client .limitOnDataFrame (
1203- limit_on_dataframe_request
1204- )
1198+ client .limitOnDataFrame (limit_on_dataframe_request )
12051199
12061200 return self
12071201
@@ -1212,14 +1206,11 @@ def show(self):
12121206 def execute (self ):
12131207 client = self ._connection .client
12141208 execute_dataframe_request = e6x_engine_pb2 .ExecuteDataFrameRequest (
1215- userUUID = self ._user_uuid ,
12161209 queryId = self ._query_id ,
12171210 dataframeNumber = self ._dataframe_number ,
12181211 sessionId = self ._sessionId
12191212 )
1220- execute_dataframe_response = client .executeDataFrame (
1221- execute_dataframe_request
1222- )
1213+ client .executeDataFrame (execute_dataframe_request )
12231214
12241215 def _update_meta_data (self ):
12251216 result_meta_data_request = e6x_engine_pb2 .GetResultMetadataRequest (
@@ -1265,7 +1256,6 @@ def fetchall(self):
12651256
12661257class DataFrameSession :
12671258 def __init__ (self , connection : Connection ):
1268- self ._user_uuid = str (uuid .uuid4 ())
12691259 self ._connection = connection
12701260 self ._dataframe_count = 0
12711261 self ._dataframe_map = dict ()
@@ -1278,10 +1268,6 @@ def update_dataframe_map(self, dataframe : "DataFrame"):
12781268 self ._dataframe_map .update ({self ._dataframe_count : dataframe })
12791269 self ._dataframe_count = self ._dataframe_count + 1
12801270
1281- @property
1282- def get_user_uuid (self ):
1283- return self ._user_uuid
1284-
12851271 @property
12861272 def get_dataframe_number (self ) -> int :
12871273 return self ._dataframe_count
@@ -1291,14 +1277,13 @@ def is_terminated(self) -> bool:
12911277 return self ._is_terminated
12921278
12931279 def terminate (self ):
1294-
12951280 if not self ._is_terminated :
12961281 drop_user_context_request = e6x_engine_pb2 .DropUserContextRequest (
1297- userUUID = self .get_user_uuid
1282+ sessionId = self ._connection . get_session_id
12981283 )
12991284
1300- drop_user_context_response = self ._connection .client .dropUserContext (drop_user_context_request )
1301- self ._is_terminated = True
1285+ self ._connection .client .dropUserContext (drop_user_context_request )
1286+ self ._is_terminated = True
13021287
13031288
13041289
0 commit comments