@@ -65,13 +65,13 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
6565 @register_to_config
6666 def __init__ (
6767 self ,
68- num_train_timesteps = 2000 ,
69- snr = 0.15 ,
70- sigma_min = 0.01 ,
71- sigma_max = 1348 ,
72- sampling_eps = 1e-5 ,
73- correct_steps = 1 ,
74- tensor_format = "pt" ,
68+ num_train_timesteps : int = 2000 ,
69+ snr : float = 0.15 ,
70+ sigma_min : float = 0.01 ,
71+ sigma_max : float = 1348.0 ,
72+ sampling_eps : float = 1e-5 ,
73+ correct_steps : int = 1 ,
74+ tensor_format : str = "pt" ,
7575 ):
7676 # setable values
7777 self .timesteps = None
@@ -81,7 +81,7 @@ def __init__(
8181 self .tensor_format = tensor_format
8282 self .set_format (tensor_format = tensor_format )
8383
84- def set_timesteps (self , num_inference_steps , sampling_eps = None ):
84+ def set_timesteps (self , num_inference_steps : int , sampling_eps : float = None ):
8585 """
8686 Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
8787
@@ -100,7 +100,9 @@ def set_timesteps(self, num_inference_steps, sampling_eps=None):
100100 else :
101101 raise ValueError (f"`self.tensor_format`: { self .tensor_format } is not valid." )
102102
103- def set_sigmas (self , num_inference_steps , sigma_min = None , sigma_max = None , sampling_eps = None ):
103+ def set_sigmas (
104+ self , num_inference_steps : int , sigma_min : float = None , sigma_max : float = None , sampling_eps : float = None
105+ ):
104106 """
105107 Sets the noise scales used for the diffusion chain. Supporting function to be run before inference.
106108
0 commit comments