@@ -850,7 +850,7 @@ def test_betas(self):
850850 self .check_over_configs (beta_start = beta_start , beta_end = beta_end )
851851
852852 def test_schedules (self ):
853- for schedule in ["linear" ]:
853+ for schedule in ["linear" , "scaled_linear" ]:
854854 self .check_over_configs (beta_schedule = schedule )
855855
856856 def test_time_indices (self ):
@@ -865,30 +865,22 @@ def test_full_loop_no_noise(self):
865865 scheduler_config = self .get_scheduler_config ()
866866 scheduler = scheduler_class (** scheduler_config )
867867
868- num_trained_timesteps = len (scheduler )
868+ num_inference_steps = 10
869+ scheduler .set_timesteps (num_inference_steps )
869870
870871 model = self .dummy_model ()
871- sample = self .dummy_sample_deter
872+ sample = self .dummy_sample_deter * scheduler . sigmas [ 0 ]
872873
873- for t in reversed (range (num_trained_timesteps - 1 )):
874- # 1. predict noise residual
875- residual = model (sample , t )
876- # print("residual: ")
877- # print(residual)
874+ for i , t in enumerate (scheduler .timesteps ):
875+ sample = sample / ((scheduler .sigmas [i ] ** 2 + 1 ) ** 0.5 )
878876
879- # 2. predict previous mean of sample x_t-1
880- pred_prev_sample = scheduler .step (residual , t , sample ).prev_sample
877+ model_output = model (sample , t )
878+
879+ output = scheduler .step (model_output , i , sample )
880+ sample = output .prev_sample
881881
882- # if t > 0:
883- # noise = self.dummy_sample_deter
884- # variance = scheduler.get_variance(t) ** (0.5) * noise
885- #
886- # sample = pred_prev_sample + variance
887- sample = pred_prev_sample
888- print ("Result sample: " )
889- print (sample )
890882 result_sum = torch .sum (torch .abs (sample ))
891883 result_mean = torch .mean (torch .abs (sample ))
892884
893- assert abs (result_sum .item () - 259.0883 ) < 1e-2
894- assert abs (result_mean .item () - 0.3374 ) < 1e-3
885+ assert abs (result_sum .item () - 1006.388 ) < 1e-2
886+ assert abs (result_mean .item () - 1.31 ) < 1e-3
0 commit comments