2121import torch
2222
2323from diffusers import UNet2DConditionModel , UNet2DModel
24- from diffusers .utils import floats_tensor , require_torch_gpu , slow , torch_device
24+ from diffusers .utils import floats_tensor , require_torch_gpu , slow , torch_all_close , torch_device
2525from parameterized import parameterized
2626
2727from ..test_modeling_common import ModelTesterMixin
@@ -156,7 +156,7 @@ def test_from_pretrained_accelerate_wont_change_results(self):
156156 model_normal_load .eval ()
157157 arr_normal_load = model_normal_load (noise , time_step )["sample" ]
158158
159- assert torch . allclose (arr_accelerate , arr_normal_load , rtol = 1e-3 )
159+ assert torch_all_close (arr_accelerate , arr_normal_load , rtol = 1e-3 )
160160
161161 @unittest .skipIf (torch_device != "cuda" , "This test is supposed to run on GPU" )
162162 def test_memory_footprint_gets_reduced (self ):
@@ -207,7 +207,7 @@ def test_output_pretrained(self):
207207 expected_output_slice = torch .tensor ([- 13.3258 , - 20.1100 , - 15.9873 , - 17.6617 , - 23.0596 , - 17.9419 , - 13.3675 , - 16.1889 , - 12.3800 ])
208208 # fmt: on
209209
210- self .assertTrue (torch . allclose (output_slice , expected_output_slice , rtol = 1e-3 ))
210+ self .assertTrue (torch_all_close (output_slice , expected_output_slice , rtol = 1e-3 ))
211211
212212
213213class UNet2DConditionModelTests (ModelTesterMixin , unittest .TestCase ):
@@ -287,7 +287,7 @@ def test_gradient_checkpointing(self):
287287 named_params = dict (model .named_parameters ())
288288 named_params_2 = dict (model_2 .named_parameters ())
289289 for name , param in named_params .items ():
290- self .assertTrue (torch . allclose (param .grad .data , named_params_2 [name ].grad .data , atol = 5e-5 ))
290+ self .assertTrue (torch_all_close (param .grad .data , named_params_2 [name ].grad .data , atol = 5e-5 ))
291291
292292
293293class NCSNppModelTests (ModelTesterMixin , unittest .TestCase ):
@@ -377,7 +377,7 @@ def test_output_pretrained_ve_mid(self):
377377 expected_output_slice = torch .tensor ([- 4836.2231 , - 6487.1387 , - 3816.7969 , - 7964.9253 , - 10966.2842 , - 20043.6016 , 8137.0571 , 2340.3499 , 544.6114 ])
378378 # fmt: on
379379
380- self .assertTrue (torch . allclose (output_slice , expected_output_slice , rtol = 1e-2 ))
380+ self .assertTrue (torch_all_close (output_slice , expected_output_slice , rtol = 1e-2 ))
381381
382382 def test_output_pretrained_ve_large (self ):
383383 model = UNet2DModel .from_pretrained ("fusing/ncsnpp-ffhq-ve-dummy-update" )
@@ -402,7 +402,7 @@ def test_output_pretrained_ve_large(self):
402402 expected_output_slice = torch .tensor ([- 0.0325 , - 0.0900 , - 0.0869 , - 0.0332 , - 0.0725 , - 0.0270 , - 0.0101 , 0.0227 , 0.0256 ])
403403 # fmt: on
404404
405- self .assertTrue (torch . allclose (output_slice , expected_output_slice , rtol = 1e-2 ))
405+ self .assertTrue (torch_all_close (output_slice , expected_output_slice , rtol = 1e-2 ))
406406
407407 def test_forward_with_norm_groups (self ):
408408 # not required for this model
@@ -464,7 +464,7 @@ def test_compvis_sd_v1_4(self, seed, timestep, expected_slice):
464464 output_slice = sample [- 1 , - 2 :, - 2 :, :2 ].flatten ().float ().cpu ()
465465 expected_output_slice = torch .tensor (expected_slice )
466466
467- assert torch . allclose (output_slice , expected_output_slice , atol = 1e-4 )
467+ assert torch_all_close (output_slice , expected_output_slice , atol = 1e-4 )
468468
469469 @parameterized .expand (
470470 [
@@ -490,7 +490,7 @@ def test_compvis_sd_v1_4_fp16(self, seed, timestep, expected_slice):
490490 output_slice = sample [- 1 , - 2 :, - 2 :, :2 ].flatten ().float ().cpu ()
491491 expected_output_slice = torch .tensor (expected_slice )
492492
493- assert torch . allclose (output_slice , expected_output_slice , atol = 1e-4 )
493+ assert torch_all_close (output_slice , expected_output_slice , atol = 1e-4 )
494494
495495 @parameterized .expand (
496496 [
@@ -515,7 +515,7 @@ def test_compvis_sd_v1_5(self, seed, timestep, expected_slice):
515515 output_slice = sample [- 1 , - 2 :, - 2 :, :2 ].flatten ().float ().cpu ()
516516 expected_output_slice = torch .tensor (expected_slice )
517517
518- assert torch . allclose (output_slice , expected_output_slice , atol = 1e-4 )
518+ assert torch_all_close (output_slice , expected_output_slice , atol = 1e-4 )
519519
520520 @parameterized .expand (
521521 [
@@ -541,7 +541,7 @@ def test_compvis_sd_v1_5_fp16(self, seed, timestep, expected_slice):
541541 output_slice = sample [- 1 , - 2 :, - 2 :, :2 ].flatten ().float ().cpu ()
542542 expected_output_slice = torch .tensor (expected_slice )
543543
544- assert torch . allclose (output_slice , expected_output_slice , atol = 1e-4 )
544+ assert torch_all_close (output_slice , expected_output_slice , atol = 1e-4 )
545545
546546 @parameterized .expand (
547547 [
@@ -566,7 +566,7 @@ def test_compvis_sd_inpaint(self, seed, timestep, expected_slice):
566566 output_slice = sample [- 1 , - 2 :, - 2 :, :2 ].flatten ().float ().cpu ()
567567 expected_output_slice = torch .tensor (expected_slice )
568568
569- assert torch . allclose (output_slice , expected_output_slice , atol = 1e-4 )
569+ assert torch_all_close (output_slice , expected_output_slice , atol = 1e-4 )
570570
571571 @parameterized .expand (
572572 [
@@ -592,4 +592,4 @@ def test_compvis_sd_inpaint_fp16(self, seed, timestep, expected_slice):
592592 output_slice = sample [- 1 , - 2 :, - 2 :, :2 ].flatten ().float ().cpu ()
593593 expected_output_slice = torch .tensor (expected_slice )
594594
595- assert torch . allclose (output_slice , expected_output_slice , atol = 1e-4 )
595+ assert torch_all_close (output_slice , expected_output_slice , atol = 1e-4 )
0 commit comments