1414from decimal import Decimal
1515from io import BytesIO
1616from ssl import CERT_NONE , CERT_OPTIONAL , CERT_REQUIRED
17+ from typing import overload
1718
1819import grpc
1920from grpc ._channel import _InactiveRpcError
@@ -710,6 +711,12 @@ def __init__(self, connection: Connection, file_path):
710711 self ._batch = None
711712 self ._create_dataframe ()
712713
714+ def __enter__ (self ):
715+ pass
716+
717+ def __exit__ (self , exc_type , exc_val , exc_tb ):
718+ pass
719+
713720 def _create_dataframe (self ):
714721 client = self .connection .client
715722
@@ -726,7 +733,7 @@ def _create_dataframe(self):
726733 )
727734 self ._query_id = create_dataframe_response .queryId
728735
729- def select (self , * fields ):
736+ def select (self , * fields ) -> "DataFrame" :
730737 projection_fields = []
731738 for field in fields :
732739 projection_fields .append (field )
@@ -744,7 +751,7 @@ def select(self, *fields):
744751
745752 return self
746753
747- def where (self , where_clause : str ):
754+ def where (self , where_clause : str ) -> "DataFrame" :
748755 client = self .connection .client
749756 filter_on_dataframe_request = e6x_engine_pb2 .FilterOnDataFrameRequest (
750757 queryId = self ._query_id ,
@@ -756,11 +763,14 @@ def where(self, where_clause : str):
756763 filter_on_dataframe_request
757764 )
758765
759- def order_by (self , fields_list : list , sort_direction_list = None , null_direction_list = None ):
766+ return self
767+
768+ @overload
769+ def order_by (self , field_list : list , sort_direction_list = None , null_direction_list = None ) -> "DataFrame" :
760770 orderby_fields = []
761771 sort_direction_request = []
762772 null_direction_request = []
763- for field in fields_list :
773+ for field in field_list :
764774 orderby_fields .append (field )
765775
766776 for direction in sort_direction_list :
@@ -796,7 +806,29 @@ def order_by(self, fields_list : list, sort_direction_list = None, null_directio
796806 )
797807 return self
798808
799- def limit (self , fetch_limit : int ):
809+ def order_by (self , * field_list ) -> "DataFrame" :
810+ orderby_fields = []
811+ sort_direction_request = []
812+ null_direction_request = []
813+ for field in field_list :
814+ orderby_fields .append (field )
815+
816+ client = self .connection .client
817+
818+ orderby_on_dataframe_request = e6x_engine_pb2 .OrderByOnDataFrameRequest (
819+ queryId = self ._query_id ,
820+ sessionId = self ._sessionId ,
821+ field = orderby_fields ,
822+ sortDirection = sort_direction_request ,
823+ nullsDirection = null_direction_request
824+ )
825+
826+ orderby_on_dataframe_response = client .orderByOnDataFrame (
827+ orderby_on_dataframe_request
828+ )
829+ return self
830+
831+ def limit (self , fetch_limit : int ) -> "DataFrame" :
800832 client = self .connection .client
801833 limit_on_dataframe_request = e6x_engine_pb2 .LimitOnDataFrameRequest (
802834 queryId = self ._query_id ,
@@ -808,6 +840,8 @@ def limit(self, fetch_limit : int):
808840 limit_on_dataframe_request
809841 )
810842
843+ return self
844+
811845 def show (self ):
812846 self .execute ()
813847 return self .fetchall ()
0 commit comments