@@ -78,11 +78,10 @@ def test_dummy_all_tpus(self):
7878
7979 assert images .shape == (num_samples , 1 , 64 , 64 , 3 )
8080 if jax .device_count () == 8 :
81- assert np .abs (np .abs (images [0 , 0 , :2 , :2 , - 2 :], dtype = np .float32 ).sum () - 3.1111548 ) < 1e-3
82- assert np .abs (np .abs (images , dtype = np .float32 ).sum () - 199746.95 ) < 5e-1
81+ assert np .abs (np .abs (images [0 , 0 , :2 , :2 , - 2 :], dtype = np .float32 ).sum () - 4.1514745 ) < 1e-3
82+ assert np .abs (np .abs (images , dtype = np .float32 ).sum () - 49947.875 ) < 5e-1
8383
8484 images_pil = pipeline .numpy_to_pil (np .asarray (images .reshape ((num_samples ,) + images .shape [- 3 :])))
85-
8685 assert len (images_pil ) == num_samples
8786
8887 def test_stable_diffusion_v1_4 (self ):
@@ -140,8 +139,8 @@ def test_stable_diffusion_v1_4_bfloat_16(self):
140139
141140 assert images .shape == (num_samples , 1 , 512 , 512 , 3 )
142141 if jax .device_count () == 8 :
143- assert np .abs ((np .abs (images [0 , 0 , :2 , :2 , - 2 :], dtype = np .float32 ).sum () - 0.06652832 )) < 1e-3
144- assert np .abs ((np .abs (images , dtype = np .float32 ).sum () - 2384849.8 )) < 5e-1
142+ assert np .abs ((np .abs (images [0 , 0 , :2 , :2 , - 2 :], dtype = np .float32 ).sum () - 0.04003906 )) < 1e-3
143+ assert np .abs ((np .abs (images , dtype = np .float32 ).sum () - 2373516.75 )) < 5e-1
145144
146145 def test_stable_diffusion_v1_4_bfloat_16_with_safety (self ):
147146 pipeline , params = FlaxStableDiffusionPipeline .from_pretrained (
@@ -169,8 +168,8 @@ def test_stable_diffusion_v1_4_bfloat_16_with_safety(self):
169168
170169 assert images .shape == (num_samples , 1 , 512 , 512 , 3 )
171170 if jax .device_count () == 8 :
172- assert np .abs ((np .abs (images [0 , 0 , :2 , :2 , - 2 :], dtype = np .float32 ).sum () - 0.06652832 )) < 1e-3
173- assert np .abs ((np .abs (images , dtype = np .float32 ).sum () - 2384849.8 )) < 5e-1
171+ assert np .abs ((np .abs (images [0 , 0 , :2 , :2 , - 2 :], dtype = np .float32 ).sum () - 0.04003906 )) < 1e-3
172+ assert np .abs ((np .abs (images , dtype = np .float32 ).sum () - 2373516.75 )) < 5e-1
174173
175174 def test_stable_diffusion_v1_4_bfloat_16_ddim (self ):
176175 scheduler = FlaxDDIMScheduler (
0 commit comments