Skip to content

Commit ef663bc

Browse files
committed
fixes for comments
1 parent 9144f5f commit ef663bc

File tree

1 file changed

+12
-20
lines changed

1 file changed

+12
-20
lines changed

tests/test_scheduler.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -850,7 +850,7 @@ def test_betas(self):
850850
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
851851

852852
def test_schedules(self):
853-
for schedule in ["linear"]:
853+
for schedule in ["linear", "scaled_linear"]:
854854
self.check_over_configs(beta_schedule=schedule)
855855

856856
def test_time_indices(self):
@@ -865,30 +865,22 @@ def test_full_loop_no_noise(self):
865865
scheduler_config = self.get_scheduler_config()
866866
scheduler = scheduler_class(**scheduler_config)
867867

868-
num_trained_timesteps = len(scheduler)
868+
num_inference_steps = 10
869+
scheduler.set_timesteps(num_inference_steps)
869870

870871
model = self.dummy_model()
871-
sample = self.dummy_sample_deter
872+
sample = self.dummy_sample_deter * scheduler.sigmas[0]
872873

873-
for t in reversed(range(num_trained_timesteps - 1)):
874-
# 1. predict noise residual
875-
residual = model(sample, t)
876-
# print("residual: ")
877-
# print(residual)
874+
for i, t in enumerate(scheduler.timesteps):
875+
sample = sample / ((scheduler.sigmas[i] ** 2 + 1) ** 0.5)
878876

879-
# 2. predict previous mean of sample x_t-1
880-
pred_prev_sample = scheduler.step(residual, t, sample).prev_sample
877+
model_output = model(sample, t)
878+
879+
output = scheduler.step(model_output, i, sample)
880+
sample = output.prev_sample
881881

882-
# if t > 0:
883-
# noise = self.dummy_sample_deter
884-
# variance = scheduler.get_variance(t) ** (0.5) * noise
885-
#
886-
# sample = pred_prev_sample + variance
887-
sample = pred_prev_sample
888-
print("Result sample: ")
889-
print(sample)
890882
result_sum = torch.sum(torch.abs(sample))
891883
result_mean = torch.mean(torch.abs(sample))
892884

893-
assert abs(result_sum.item() - 259.0883) < 1e-2
894-
assert abs(result_mean.item() - 0.3374) < 1e-3
885+
assert abs(result_sum.item() - 1006.388) < 1e-2
886+
assert abs(result_mean.item() - 1.31) < 1e-3

0 commit comments

Comments
 (0)