Skip to content

Commit d340b18

Browse files
committed
add torch numpy test
1 parent ef663bc commit d340b18

File tree

1 file changed

+28
-3
lines changed

1 file changed

+28
-3
lines changed

tests/test_scheduler.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,7 @@ def test_step_shape(self):
826826

827827
class LMSDiscreteSchedulerTest(SchedulerCommonTest):
828828
scheduler_classes = (LMSDiscreteScheduler,)
829+
num_inference_steps = 10
829830

830831
def get_scheduler_config(self, **kwargs):
831832
config = {
@@ -858,15 +859,39 @@ def test_time_indices(self):
858859
self.check_over_forward(time_step=t)
859860

860861
def test_pytorch_equal_numpy(self):
861-
pass
862+
for scheduler_class in self.scheduler_classes:
863+
sample_pt = self.dummy_sample
864+
residual_pt = 0.1 * sample_pt
865+
dummy_past_residuals_pt = [residual_pt + 0.2, residual_pt + 0.15, residual_pt + 0.1, residual_pt + 0.05]
866+
867+
sample = sample_pt.numpy()
868+
residual = 0.1 * sample
869+
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
870+
871+
scheduler_config = self.get_scheduler_config()
872+
scheduler_config["tensor_format"] = "np"
873+
scheduler = scheduler_class(**scheduler_config)
874+
875+
scheduler_config["tensor_format"] = "pt"
876+
scheduler_pt = scheduler_class(**scheduler_config)
877+
878+
scheduler.set_timesteps(self.num_inference_steps)
879+
scheduler_pt.set_timesteps(self.num_inference_steps)
880+
881+
# copy over dummy past residuals (must be done after set_timesteps)
882+
scheduler.ets = dummy_past_residuals[:]
883+
scheduler_pt.ets = dummy_past_residuals_pt[:]
884+
885+
output = scheduler.step(residual, 1, sample).prev_sample
886+
output_pt = scheduler_pt.step(residual_pt, 1, sample_pt).prev_sample
887+
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
862888

863889
def test_full_loop_no_noise(self):
864890
scheduler_class = self.scheduler_classes[0]
865891
scheduler_config = self.get_scheduler_config()
866892
scheduler = scheduler_class(**scheduler_config)
867893

868-
num_inference_steps = 10
869-
scheduler.set_timesteps(num_inference_steps)
894+
scheduler.set_timesteps(self.num_inference_steps)
870895

871896
model = self.dummy_model()
872897
sample = self.dummy_sample_deter * scheduler.sigmas[0]

0 commit comments

Comments
 (0)