@@ -110,7 +110,7 @@ class _Process(multiprocessing.Process):
110110 and send finished samples using shared memory.
111111 """
112112
113- def __init__ (self , name , msg_pipe , step_method , shared_point , draws , tune , seed ):
113+ def __init__ (self , name : str , msg_pipe , step_method , shared_point , draws : int , tune : int , seed ):
114114 super ().__init__ (daemon = True , name = name )
115115 self ._msg_pipe = msg_pipe
116116 self ._step_method = step_method
@@ -173,6 +173,10 @@ def _start_loop(self):
173173 raise ValueError ("Unexpected msg " + msg [0 ])
174174
175175 while True :
176+ if draw == self ._tune :
177+ self ._step_method .stop_tuning ()
178+ tuning = False
179+
176180 if draw < self ._draws + self ._tune :
177181 try :
178182 point , stats = self ._compute_point ()
@@ -183,10 +187,6 @@ def _start_loop(self):
183187 else :
184188 return
185189
186- if draw == self ._tune :
187- self ._step_method .stop_tuning ()
188- tuning = False
189-
190190 msg = self ._recv_msg ()
191191 if msg [0 ] == "abort" :
192192 raise KeyboardInterrupt ()
@@ -222,7 +222,7 @@ def _collect_warnings(self):
222222class ProcessAdapter :
223223 """Control a Chain process from the main thread."""
224224
225- def __init__ (self , draws , tune , step_method , chain , seed , start ):
225+ def __init__ (self , draws : int , tune : int , step_method , chain : int , seed , start ):
226226 self .chain = chain
227227 process_name = "worker_chain_%s" % chain
228228 self ._msg_pipe , remote_conn = multiprocessing .Pipe ()
@@ -353,15 +353,15 @@ def terminate_all(processes, patience=2):
353353class ParallelSampler :
354354 def __init__ (
355355 self ,
356- draws ,
357- tune ,
358- chains ,
359- cores ,
360- seeds ,
361- start_points ,
356+ draws : int ,
357+ tune : int ,
358+ chains : int ,
359+ cores : int ,
360+ seeds : list ,
361+ start_points : list ,
362362 step_method ,
363- start_chain_num = 0 ,
364- progressbar = True ,
363+ start_chain_num : int = 0 ,
364+ progressbar : bool = True ,
365365 ):
366366
367367 if any (len (arg ) != chains for arg in [seeds , start_points ]):
0 commit comments