Skip to content

Commit 5e3f8ff

Browse files
Fix some audio tests (#3841)
* Fix some audio tests * make style * fix * make style
1 parent 5df2acf commit 5e3f8ff

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

tests/pipelines/audioldm/test_audioldm.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
PNDMScheduler,
3737
UNet2DConditionModel,
3838
)
39-
from diffusers.utils import slow, torch_device
39+
from diffusers.utils import is_xformers_available, slow, torch_device
4040
from diffusers.utils.testing_utils import enable_full_determinism
4141

4242
from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
@@ -361,9 +361,15 @@ def test_attention_slicing_forward_pass(self):
361361
def test_inference_batch_single_identical(self):
362362
self._test_inference_batch_single_identical(test_mean_pixel_difference=False)
363363

364+
@unittest.skipIf(
365+
torch_device != "cuda" or not is_xformers_available(),
366+
reason="XFormers attention is only available with CUDA and `xformers` installed",
367+
)
368+
def test_xformers_attention_forwardGenerator_pass(self):
369+
self._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False)
370+
364371

365372
@slow
366-
# @require_torch_gpu
367373
class AudioLDMPipelineSlowTests(unittest.TestCase):
368374
def tearDown(self):
369375
super().tearDown()

tests/pipelines/test_pipelines_common.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,9 @@ def test_cpu_offload_forward_pass(self, expected_max_diff=1e-4):
640640
def test_xformers_attention_forwardGenerator_pass(self):
641641
self._test_xformers_attention_forwardGenerator_pass()
642642

643-
def _test_xformers_attention_forwardGenerator_pass(self, test_max_difference=True, expected_max_diff=1e-4):
643+
def _test_xformers_attention_forwardGenerator_pass(
644+
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-4
645+
):
644646
if not self.test_xformers_attention:
645647
return
646648

@@ -660,7 +662,8 @@ def _test_xformers_attention_forwardGenerator_pass(self, test_max_difference=Tru
660662
max_diff = np.abs(output_with_offload - output_without_offload).max()
661663
self.assertLess(max_diff, expected_max_diff, "XFormers attention should not affect the inference results")
662664

663-
assert_mean_pixel_difference(output_with_offload[0], output_without_offload[0])
665+
if test_mean_pixel_difference:
666+
assert_mean_pixel_difference(output_with_offload[0], output_without_offload[0])
664667

665668
def test_progress_bar(self):
666669
components = self.get_dummy_components()

0 commit comments

Comments
 (0)