Skip to content

Commit c4ef1ef

Browse files
[Tests] Better prints (#1043)
1 parent 8d6487f commit c4ef1ef

File tree

4 files changed

+37
-21
lines changed

4 files changed

+37
-21
lines changed

src/diffusers/utils/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,15 @@
4040

4141

4242
if is_torch_available():
43-
from .testing_utils import floats_tensor, load_image, parse_flag_from_env, require_torch_gpu, slow, torch_device
43+
from .testing_utils import (
44+
floats_tensor,
45+
load_image,
46+
parse_flag_from_env,
47+
require_torch_gpu,
48+
slow,
49+
torch_all_close,
50+
torch_device,
51+
)
4452

4553

4654
logger = get_logger(__name__)

src/diffusers/utils/testing_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@
3434
torch_device = "mps" if (mps_backend_registered and torch.backends.mps.is_available()) else torch_device
3535

3636

37+
def torch_all_close(a, b, *args, **kwargs):
38+
if not is_torch_available():
39+
raise ValueError("PyTorch needs to be installed to use this function.")
40+
if not torch.allclose(a, b, *args, **kwargs):
41+
assert False, f"Max diff is absolute {(a - b).abs().max()}. Diff tensor is {(a - b).abs()}."
42+
return True
43+
44+
3745
def get_tests_dir(append_path=None):
3846
"""
3947
Args:

tests/models/test_models_unet_2d.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import torch
2222

2323
from diffusers import UNet2DConditionModel, UNet2DModel
24-
from diffusers.utils import floats_tensor, require_torch_gpu, slow, torch_device
24+
from diffusers.utils import floats_tensor, require_torch_gpu, slow, torch_all_close, torch_device
2525
from parameterized import parameterized
2626

2727
from ..test_modeling_common import ModelTesterMixin
@@ -156,7 +156,7 @@ def test_from_pretrained_accelerate_wont_change_results(self):
156156
model_normal_load.eval()
157157
arr_normal_load = model_normal_load(noise, time_step)["sample"]
158158

159-
assert torch.allclose(arr_accelerate, arr_normal_load, rtol=1e-3)
159+
assert torch_all_close(arr_accelerate, arr_normal_load, rtol=1e-3)
160160

161161
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
162162
def test_memory_footprint_gets_reduced(self):
@@ -207,7 +207,7 @@ def test_output_pretrained(self):
207207
expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800])
208208
# fmt: on
209209

210-
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
210+
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3))
211211

212212

213213
class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
@@ -287,7 +287,7 @@ def test_gradient_checkpointing(self):
287287
named_params = dict(model.named_parameters())
288288
named_params_2 = dict(model_2.named_parameters())
289289
for name, param in named_params.items():
290-
self.assertTrue(torch.allclose(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
290+
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
291291

292292

293293
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
@@ -377,7 +377,7 @@ def test_output_pretrained_ve_mid(self):
377377
expected_output_slice = torch.tensor([-4836.2231, -6487.1387, -3816.7969, -7964.9253, -10966.2842, -20043.6016, 8137.0571, 2340.3499, 544.6114])
378378
# fmt: on
379379

380-
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
380+
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
381381

382382
def test_output_pretrained_ve_large(self):
383383
model = UNet2DModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update")
@@ -402,7 +402,7 @@ def test_output_pretrained_ve_large(self):
402402
expected_output_slice = torch.tensor([-0.0325, -0.0900, -0.0869, -0.0332, -0.0725, -0.0270, -0.0101, 0.0227, 0.0256])
403403
# fmt: on
404404

405-
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
405+
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
406406

407407
def test_forward_with_norm_groups(self):
408408
# not required for this model
@@ -464,7 +464,7 @@ def test_compvis_sd_v1_4(self, seed, timestep, expected_slice):
464464
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
465465
expected_output_slice = torch.tensor(expected_slice)
466466

467-
assert torch.allclose(output_slice, expected_output_slice, atol=1e-4)
467+
assert torch_all_close(output_slice, expected_output_slice, atol=1e-4)
468468

469469
@parameterized.expand(
470470
[
@@ -490,7 +490,7 @@ def test_compvis_sd_v1_4_fp16(self, seed, timestep, expected_slice):
490490
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
491491
expected_output_slice = torch.tensor(expected_slice)
492492

493-
assert torch.allclose(output_slice, expected_output_slice, atol=1e-4)
493+
assert torch_all_close(output_slice, expected_output_slice, atol=1e-4)
494494

495495
@parameterized.expand(
496496
[
@@ -515,7 +515,7 @@ def test_compvis_sd_v1_5(self, seed, timestep, expected_slice):
515515
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
516516
expected_output_slice = torch.tensor(expected_slice)
517517

518-
assert torch.allclose(output_slice, expected_output_slice, atol=1e-4)
518+
assert torch_all_close(output_slice, expected_output_slice, atol=1e-4)
519519

520520
@parameterized.expand(
521521
[
@@ -541,7 +541,7 @@ def test_compvis_sd_v1_5_fp16(self, seed, timestep, expected_slice):
541541
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
542542
expected_output_slice = torch.tensor(expected_slice)
543543

544-
assert torch.allclose(output_slice, expected_output_slice, atol=1e-4)
544+
assert torch_all_close(output_slice, expected_output_slice, atol=1e-4)
545545

546546
@parameterized.expand(
547547
[
@@ -566,7 +566,7 @@ def test_compvis_sd_inpaint(self, seed, timestep, expected_slice):
566566
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
567567
expected_output_slice = torch.tensor(expected_slice)
568568

569-
assert torch.allclose(output_slice, expected_output_slice, atol=1e-4)
569+
assert torch_all_close(output_slice, expected_output_slice, atol=1e-4)
570570

571571
@parameterized.expand(
572572
[
@@ -592,4 +592,4 @@ def test_compvis_sd_inpaint_fp16(self, seed, timestep, expected_slice):
592592
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
593593
expected_output_slice = torch.tensor(expected_slice)
594594

595-
assert torch.allclose(output_slice, expected_output_slice, atol=1e-4)
595+
assert torch_all_close(output_slice, expected_output_slice, atol=1e-4)

tests/models/test_models_vae.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from diffusers import AutoencoderKL
2222
from diffusers.modeling_utils import ModelMixin
23-
from diffusers.utils import floats_tensor, require_torch_gpu, slow, torch_device
23+
from diffusers.utils import floats_tensor, require_torch_gpu, slow, torch_all_close, torch_device
2424
from parameterized import parameterized
2525

2626
from ..test_modeling_common import ModelTesterMixin
@@ -131,7 +131,7 @@ def test_output_pretrained(self):
131131
[-0.2421, 0.4642, 0.2507, -0.0438, 0.0682, 0.3160, -0.2018, -0.0727, 0.2485]
132132
)
133133

134-
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
134+
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
135135

136136

137137
@slow
@@ -185,7 +185,7 @@ def test_stable_diffusion(self, seed, expected_slice):
185185
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
186186
expected_output_slice = torch.tensor(expected_slice)
187187

188-
assert torch.allclose(output_slice, expected_output_slice, atol=1e-4)
188+
assert torch_all_close(output_slice, expected_output_slice, atol=1e-4)
189189

190190
@parameterized.expand(
191191
[
@@ -209,7 +209,7 @@ def test_stable_diffusion_fp16(self, seed, expected_slice):
209209
output_slice = sample[-1, -2:, :2, -2:].flatten().float().cpu()
210210
expected_output_slice = torch.tensor(expected_slice)
211211

212-
assert torch.allclose(output_slice, expected_output_slice, atol=1e-4)
212+
assert torch_all_close(output_slice, expected_output_slice, atol=1e-4)
213213

214214
@parameterized.expand(
215215
[
@@ -231,7 +231,7 @@ def test_stable_diffusion_mode(self, seed, expected_slice):
231231
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
232232
expected_output_slice = torch.tensor(expected_slice)
233233

234-
assert torch.allclose(output_slice, expected_output_slice, atol=1e-4)
234+
assert torch_all_close(output_slice, expected_output_slice, atol=1e-4)
235235

236236
@parameterized.expand(
237237
[
@@ -254,7 +254,7 @@ def test_stable_diffusion_decode(self, seed, expected_slice):
254254
output_slice = sample[-1, -2:, :2, -2:].flatten().cpu()
255255
expected_output_slice = torch.tensor(expected_slice)
256256

257-
assert torch.allclose(output_slice, expected_output_slice, atol=1e-4)
257+
assert torch_all_close(output_slice, expected_output_slice, atol=1e-4)
258258

259259
@parameterized.expand(
260260
[
@@ -276,7 +276,7 @@ def test_stable_diffusion_decode_fp16(self, seed, expected_slice):
276276
output_slice = sample[-1, -2:, :2, -2:].flatten().float().cpu()
277277
expected_output_slice = torch.tensor(expected_slice)
278278

279-
assert torch.allclose(output_slice, expected_output_slice, atol=1e-4)
279+
assert torch_all_close(output_slice, expected_output_slice, atol=1e-4)
280280

281281
@parameterized.expand(
282282
[
@@ -300,4 +300,4 @@ def test_stable_diffusion_encode_sample(self, seed, expected_slice):
300300
output_slice = sample[0, -1, -3:, -3:].flatten().cpu()
301301
expected_output_slice = torch.tensor(expected_slice)
302302

303-
assert torch.allclose(output_slice, expected_output_slice, atol=1e-4)
303+
assert torch_all_close(output_slice, expected_output_slice, atol=1e-4)

0 commit comments

Comments
 (0)