@@ -188,6 +188,8 @@ def __init__(
188188 self .cluster_uuid = cluster_uuid
189189 self ._session_id = None
190190 self ._host = host
191+ # engine ip for stickiness
192+ self ._engine_ip = None
191193 self ._port = port
192194
193195 self ._secure_channel = secure
@@ -322,6 +324,7 @@ def get_session_id(self):
322324 metadata = _get_grpc_header (cluster = self .cluster_uuid )
323325 )
324326 self ._session_id = authenticate_response .sessionId
327+ self ._engine_ip = authenticate_response .engineIP
325328 else :
326329 raise e
327330 else :
@@ -534,25 +537,23 @@ def cursor(self, catalog_name=None, db_name=None):
534537 return Cursor (self , database = db_name , catalog_name = catalog_name )
535538
536539 def load_parquet (self , parquet_path ) -> "DataFrame" :
537- dataframe = DataFrame (
538- self ,
539- file_path = parquet_path ,
540- dataframe_number = self ._dataframe_session .get_dataframe_number ,
541- table_name = None
542- )
540+ dataframe = DataFrame (self ,
541+ file_path = parquet_path ,
542+ dataframe_number = self ._dataframe_session .get_dataframe_number ,
543+ table_name = None ,
544+ engine_ip = self ._engine_ip )
543545
544546 self ._dataframe_session .update_dataframe_map (dataframe = dataframe )
545547 return dataframe
546548
547549 def load_table (self , table_name , database = None , catalog = None ) -> "DataFrame" :
548- dataframe = DataFrame (
549- self ,
550- file_path = None ,
551- dataframe_number = self ._dataframe_session .get_dataframe_number ,
552- table_name = table_name ,
553- database = database ,
554- catalog = catalog
555- )
550+ dataframe = DataFrame (self ,
551+ file_path = None ,
552+ dataframe_number = self ._dataframe_session .get_dataframe_number ,
553+ table_name = table_name ,
554+ engine_ip = self ._engine_ip ,
555+ database = database ,
556+ catalog = catalog )
556557
557558 self ._dataframe_session .update_dataframe_map (dataframe = dataframe )
558559 return dataframe
@@ -1062,15 +1063,15 @@ def explain_analyse(self):
10621063
10631064class DataFrame :
10641065
1065- def __init__ (self , connection : Connection , file_path , dataframe_number , table_name , database = None , catalog = None ):
1066+ def __init__ (self , connection : Connection , file_path , dataframe_number , table_name , engine_ip , database = None , catalog = None ):
10661067 self ._dataframe_number = dataframe_number
10671068 self ._connection = connection
10681069 self ._catalog = self ._connection .catalog_name if catalog is None else catalog
10691070 self ._database = self ._connection .database if database is None else database
10701071 self ._table_name = table_name
10711072 self ._file_path = file_path
1072- self ._engine_ip = connection .host
10731073 self ._sessionId = connection .get_session_id
1074+ self ._engine_ip = engine_ip
10741075 self ._is_metadata_updated = False
10751076 self ._query_id = None
10761077 self ._data = None
@@ -1106,6 +1107,7 @@ def select(self, *fields : str) -> "DataFrame":
11061107 queryId = self ._query_id ,
11071108 dataframeNumber = self ._dataframe_number ,
11081109 sessionId = self ._sessionId ,
1110+ engineIP = self ._engine_ip ,
11091111 field = projection_fields
11101112 )
11111113
@@ -1139,6 +1141,7 @@ def get_agg_enum(function_name : str) -> AggregateFunction | None:
11391141 queryId = self ._query_id ,
11401142 dataframeNumber = self ._dataframe_number ,
11411143 sessionId = self ._sessionId ,
1144+ engineIP = self ._engine_ip ,
11421145 aggregateFunctionMap = agg_function_map ,
11431146 groupBy = group_by
11441147 )
@@ -1153,6 +1156,7 @@ def where(self, where_clause : str) -> "DataFrame":
11531156 queryId = self ._query_id ,
11541157 dataframeNumber = self ._dataframe_number ,
11551158 sessionId = self ._sessionId ,
1159+ engineIP = self ._engine_ip ,
11561160 whereClause = where_clause
11571161 )
11581162
@@ -1173,6 +1177,7 @@ def order_by(self, *field_list : str) -> "DataFrame":
11731177 queryId = self ._query_id ,
11741178 dataframeNumber = self ._dataframe_number ,
11751179 sessionId = self ._sessionId ,
1180+ engineIP = self ._engine_ip ,
11761181 orderByFieldMap = order_by_map
11771182 )
11781183
@@ -1185,6 +1190,7 @@ def limit(self, fetch_limit : int) -> "DataFrame":
11851190 queryId = self ._query_id ,
11861191 dataframeNumber = self ._dataframe_number ,
11871192 sessionId = self ._sessionId ,
1193+ engineIP = self ._engine_ip ,
11881194 fetchLimit = fetch_limit
11891195 )
11901196
@@ -1201,7 +1207,8 @@ def execute(self):
12011207 execute_dataframe_request = e6x_engine_pb2 .ExecuteDataFrameRequest (
12021208 queryId = self ._query_id ,
12031209 dataframeNumber = self ._dataframe_number ,
1204- sessionId = self ._sessionId
1210+ sessionId = self ._sessionId ,
1211+ engineIP = self ._engine_ip ,
12051212 )
12061213 client .executeDataFrame (execute_dataframe_request )
12071214
@@ -1248,11 +1255,13 @@ def fetchall(self):
12481255 return rows
12491256
12501257class DataFrameSession :
1251- def __init__ (self , connection : Connection ):
1258+ def __init__ (self , connection : Connection , planner_ip ):
12521259 self ._connection = connection
12531260 self ._dataframe_count = 0
12541261 self ._dataframe_map = dict ()
12551262 self ._is_terminated = False
1263+ self ._session_id = connection .get_session_id
1264+ self ._planner_ip = planner_ip
12561265
12571266 def __exit__ (self , exc_type , exc_val , exc_tb ):
12581267 self .terminate ()
@@ -1269,10 +1278,15 @@ def get_dataframe_number(self) -> int:
12691278 def is_terminated (self ) -> bool :
12701279 return self ._is_terminated
12711280
1281+ @property
1282+ def planner_ip (self ):
1283+ return self ._planner_ip
1284+
12721285 def terminate (self ):
12731286 if not self ._is_terminated :
12741287 drop_user_context_request = e6x_engine_pb2 .DropUserContextRequest (
1275- sessionId = self ._connection .get_session_id
1288+ sessionId = self ._session_id ,
1289+ engineIP = self ._planner_ip
12761290 )
12771291
12781292 self ._connection .client .dropUserContext (drop_user_context_request )
0 commit comments