@@ -17,15 +17,15 @@ class BaseProfiler(ABC):
1717 """
1818
1919 @abstractmethod
20- def start (self , action_name ) :
20+ def start (self , action_name : str ) -> None :
2121 """Defines how to start recording an action."""
2222
2323 @abstractmethod
24- def stop (self , action_name ) :
24+ def stop (self , action_name : str ) -> None :
2525 """Defines how to record the duration once an action is complete."""
2626
2727 @contextmanager
28- def profile (self , action_name ) :
28+ def profile (self , action_name : str ) -> None :
2929 """
3030 Yields a context manager to encapsulate the scope of a profiled action.
3131
@@ -43,7 +43,7 @@ def profile(self, action_name):
4343 finally :
4444 self .stop (action_name )
4545
46- def profile_iterable (self , iterable , action_name ) :
46+ def profile_iterable (self , iterable , action_name : str ) -> None :
4747 iterator = iter (iterable )
4848 while True :
4949 try :
@@ -55,7 +55,7 @@ def profile_iterable(self, iterable, action_name):
5555 self .stop (action_name )
5656 break
5757
58- def describe (self ):
58+ def describe (self ) -> None :
5959 """Logs a profile report after the conclusion of the training run."""
6060 pass
6161
@@ -69,10 +69,10 @@ class PassThroughProfiler(BaseProfiler):
6969 def __init__ (self ):
7070 pass
7171
72- def start (self , action_name ) :
72+ def start (self , action_name : str ) -> None :
7373 pass
7474
75- def stop (self , action_name ) :
75+ def stop (self , action_name : str ) -> None :
7676 pass
7777
7878
@@ -86,14 +86,14 @@ def __init__(self):
8686 self .current_actions = {}
8787 self .recorded_durations = defaultdict (list )
8888
89- def start (self , action_name ) :
89+ def start (self , action_name : str ) -> None :
9090 if action_name in self .current_actions :
9191 raise ValueError (
9292 f"Attempted to start { action_name } which has already started."
9393 )
9494 self .current_actions [action_name ] = time .monotonic ()
9595
96- def stop (self , action_name ) :
96+ def stop (self , action_name : str ) -> None :
9797 end_time = time .monotonic ()
9898 if action_name not in self .current_actions :
9999 raise ValueError (
@@ -103,7 +103,7 @@ def stop(self, action_name):
103103 duration = end_time - start_time
104104 self .recorded_durations [action_name ].append (duration )
105105
106- def describe (self ):
106+ def describe (self ) -> None :
107107 output_string = "\n \n Profiler Report\n "
108108
109109 def log_row (action , mean , total ):
@@ -126,32 +126,33 @@ class AdvancedProfiler(BaseProfiler):
126126 verbose and you should only use this if you want very detailed reports.
127127 """
128128
129- def __init__ (self , output_filename = None , line_count_restriction = 1.0 ):
129+ def __init__ (self , output_filename : str = None , line_count_restriction : float = 1.0 ):
130130 """
131- :param output_filename (str): optionally save profile results to file instead of printing
132- to std out when training is finished.
133- :param line_count_restriction (int|float): this can be used to limit the number of functions
134- reported for each action. either an integer (to select a count of lines),
135- or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines)
131+ Args:
132+ output_filename: optionally save profile results to file instead of printing
133+ to std out when training is finished.
134+ line_count_restriction: this can be used to limit the number of functions
135+ reported for each action. either an integer (to select a count of lines),
136+ or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines)
136137 """
137138 self .profiled_actions = {}
138139 self .output_filename = output_filename
139140 self .line_count_restriction = line_count_restriction
140141
141- def start (self , action_name ) :
142+ def start (self , action_name : str ) -> None :
142143 if action_name not in self .profiled_actions :
143144 self .profiled_actions [action_name ] = cProfile .Profile ()
144145 self .profiled_actions [action_name ].enable ()
145146
146- def stop (self , action_name ) :
147+ def stop (self , action_name : str ) -> None :
147148 pr = self .profiled_actions .get (action_name )
148149 if pr is None :
149150 raise ValueError ( # pragma: no-cover
150151 f"Attempting to stop recording an action ({ action_name } ) which was never started."
151152 )
152153 pr .disable ()
153154
154- def describe (self ):
155+ def describe (self ) -> None :
155156 self .recorded_stats = {}
156157 for action_name , pr in self .profiled_actions .items ():
157158 s = io .StringIO ()
0 commit comments