@@ -188,6 +188,8 @@ def __init__(
188188 self .cluster_uuid = cluster_uuid
189189 self ._session_id = None
190190 self ._host = host
191+ # engine ip for stickiness
192+ self ._engine_ip = None
191193 self ._port = port
192194
193195 self ._secure_channel = secure
@@ -329,6 +331,7 @@ def get_session_id(self):
329331 metadata = _get_grpc_header (cluster = self .cluster_uuid )
330332 )
331333 self ._session_id = authenticate_response .sessionId
334+ self ._engine_ip = authenticate_response .engineIP
332335 else :
333336 raise e
334337 else :
@@ -541,25 +544,23 @@ def cursor(self, catalog_name=None, db_name=None):
541544 return Cursor (self , database = db_name , catalog_name = catalog_name )
542545
543546 def load_parquet (self , parquet_path ) -> "DataFrame" :
544- dataframe = DataFrame (
545- self ,
546- file_path = parquet_path ,
547- dataframe_number = self ._dataframe_session .get_dataframe_number ,
548- table_name = None
549- )
547+ dataframe = DataFrame (self ,
548+ file_path = parquet_path ,
549+ dataframe_number = self ._dataframe_session .get_dataframe_number ,
550+ table_name = None ,
551+ engine_ip = self ._engine_ip )
550552
551553 self ._dataframe_session .update_dataframe_map (dataframe = dataframe )
552554 return dataframe
553555
554556 def load_table (self , table_name , database = None , catalog = None ) -> "DataFrame" :
555- dataframe = DataFrame (
556- self ,
557- file_path = None ,
558- dataframe_number = self ._dataframe_session .get_dataframe_number ,
559- table_name = table_name ,
560- database = database ,
561- catalog = catalog
562- )
557+ dataframe = DataFrame (self ,
558+ file_path = None ,
559+ dataframe_number = self ._dataframe_session .get_dataframe_number ,
560+ table_name = table_name ,
561+ engine_ip = self ._engine_ip ,
562+ database = database ,
563+ catalog = catalog )
563564
564565 self ._dataframe_session .update_dataframe_map (dataframe = dataframe )
565566 return dataframe
@@ -1069,15 +1070,15 @@ def explain_analyse(self):
10691070
10701071class DataFrame :
10711072
1072- def __init__ (self , connection : Connection , file_path , dataframe_number , table_name , database = None , catalog = None ):
1073+ def __init__ (self , connection : Connection , file_path , dataframe_number , table_name , engine_ip , database = None , catalog = None ):
10731074 self ._dataframe_number = dataframe_number
10741075 self ._connection = connection
10751076 self ._catalog = self ._connection .catalog_name if catalog is None else catalog
10761077 self ._database = self ._connection .database if database is None else database
10771078 self ._table_name = table_name
10781079 self ._file_path = file_path
1079- self ._engine_ip = connection .host
10801080 self ._sessionId = connection .get_session_id
1081+ self ._engine_ip = engine_ip
10811082 self ._is_metadata_updated = False
10821083 self ._query_id = None
10831084 self ._data = None
@@ -1113,6 +1114,7 @@ def select(self, *fields : str) -> "DataFrame":
11131114 queryId = self ._query_id ,
11141115 dataframeNumber = self ._dataframe_number ,
11151116 sessionId = self ._sessionId ,
1117+ engineIP = self ._engine_ip ,
11161118 field = projection_fields
11171119 )
11181120
@@ -1146,6 +1148,7 @@ def get_agg_enum(function_name : str) -> AggregateFunction | None:
11461148 queryId = self ._query_id ,
11471149 dataframeNumber = self ._dataframe_number ,
11481150 sessionId = self ._sessionId ,
1151+ engineIP = self ._engine_ip ,
11491152 aggregateFunctionMap = agg_function_map ,
11501153 groupBy = group_by
11511154 )
@@ -1160,6 +1163,7 @@ def where(self, where_clause : str) -> "DataFrame":
11601163 queryId = self ._query_id ,
11611164 dataframeNumber = self ._dataframe_number ,
11621165 sessionId = self ._sessionId ,
1166+ engineIP = self ._engine_ip ,
11631167 whereClause = where_clause
11641168 )
11651169
@@ -1180,6 +1184,7 @@ def order_by(self, *field_list : str) -> "DataFrame":
11801184 queryId = self ._query_id ,
11811185 dataframeNumber = self ._dataframe_number ,
11821186 sessionId = self ._sessionId ,
1187+ engineIP = self ._engine_ip ,
11831188 orderByFieldMap = order_by_map
11841189 )
11851190
@@ -1192,6 +1197,7 @@ def limit(self, fetch_limit : int) -> "DataFrame":
11921197 queryId = self ._query_id ,
11931198 dataframeNumber = self ._dataframe_number ,
11941199 sessionId = self ._sessionId ,
1200+ engineIP = self ._engine_ip ,
11951201 fetchLimit = fetch_limit
11961202 )
11971203
@@ -1208,7 +1214,8 @@ def execute(self):
12081214 execute_dataframe_request = e6x_engine_pb2 .ExecuteDataFrameRequest (
12091215 queryId = self ._query_id ,
12101216 dataframeNumber = self ._dataframe_number ,
1211- sessionId = self ._sessionId
1217+ sessionId = self ._sessionId ,
1218+ engineIP = self ._engine_ip ,
12121219 )
12131220 client .executeDataFrame (execute_dataframe_request )
12141221
@@ -1255,11 +1262,13 @@ def fetchall(self):
12551262 return rows
12561263
12571264class DataFrameSession :
1258- def __init__ (self , connection : Connection ):
1265+ def __init__ (self , connection : Connection , planner_ip ):
12591266 self ._connection = connection
12601267 self ._dataframe_count = 0
12611268 self ._dataframe_map = dict ()
12621269 self ._is_terminated = False
1270+ self ._session_id = connection .get_session_id
1271+ self ._planner_ip = planner_ip
12631272
12641273 def __exit__ (self , exc_type , exc_val , exc_tb ):
12651274 self .terminate ()
@@ -1276,10 +1285,15 @@ def get_dataframe_number(self) -> int:
12761285 def is_terminated (self ) -> bool :
12771286 return self ._is_terminated
12781287
1288+ @property
1289+ def planner_ip (self ):
1290+ return self ._planner_ip
1291+
12791292 def terminate (self ):
12801293 if not self ._is_terminated :
12811294 drop_user_context_request = e6x_engine_pb2 .DropUserContextRequest (
1282- sessionId = self ._connection .get_session_id
1295+ sessionId = self ._session_id ,
1296+ engineIP = self ._planner_ip
12831297 )
12841298
12851299 self ._connection .client .dropUserContext (drop_user_context_request )
0 commit comments