@@ -826,6 +826,7 @@ def test_step_shape(self):
826826
827827class LMSDiscreteSchedulerTest (SchedulerCommonTest ):
828828 scheduler_classes = (LMSDiscreteScheduler ,)
829+ num_inference_steps = 10
829830
830831 def get_scheduler_config (self , ** kwargs ):
831832 config = {
@@ -858,15 +859,39 @@ def test_time_indices(self):
858859 self .check_over_forward (time_step = t )
859860
860861 def test_pytorch_equal_numpy (self ):
861- pass
862+ for scheduler_class in self .scheduler_classes :
863+ sample_pt = self .dummy_sample
864+ residual_pt = 0.1 * sample_pt
865+ dummy_past_residuals_pt = [residual_pt + 0.2 , residual_pt + 0.15 , residual_pt + 0.1 , residual_pt + 0.05 ]
866+
867+ sample = sample_pt .numpy ()
868+ residual = 0.1 * sample
869+ dummy_past_residuals = [residual + 0.2 , residual + 0.15 , residual + 0.1 , residual + 0.05 ]
870+
871+ scheduler_config = self .get_scheduler_config ()
872+ scheduler_config ["tensor_format" ] = "np"
873+ scheduler = scheduler_class (** scheduler_config )
874+
875+ scheduler_config ["tensor_format" ] = "pt"
876+ scheduler_pt = scheduler_class (** scheduler_config )
877+
878+ scheduler .set_timesteps (self .num_inference_steps )
879+ scheduler_pt .set_timesteps (self .num_inference_steps )
880+
881+ # copy over dummy past residuals (must be done after set_timesteps)
882+ scheduler .ets = dummy_past_residuals [:]
883+ scheduler_pt .ets = dummy_past_residuals_pt [:]
884+
885+ output = scheduler .step (residual , 1 , sample ).prev_sample
886+ output_pt = scheduler_pt .step (residual_pt , 1 , sample_pt ).prev_sample
887+ assert np .sum (np .abs (output - output_pt .numpy ())) < 1e-4 , "Scheduler outputs are not identical"
862888
863889 def test_full_loop_no_noise (self ):
864890 scheduler_class = self .scheduler_classes [0 ]
865891 scheduler_config = self .get_scheduler_config ()
866892 scheduler = scheduler_class (** scheduler_config )
867893
868- num_inference_steps = 10
869- scheduler .set_timesteps (num_inference_steps )
894+ scheduler .set_timesteps (self .num_inference_steps )
870895
871896 model = self .dummy_model ()
872897 sample = self .dummy_sample_deter * scheduler .sigmas [0 ]
0 commit comments