Skip to content

Commit 32c5ea0

Browse files
DN6sayakpaul
andcommitted
Add decorator for compile tests (#8703)
* update * update --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 3ecdab0 commit 32c5ea0

File tree

9 files changed

+27
-23
lines changed

9 files changed

+27
-23
lines changed

.github/workflows/push_tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ jobs:
330330
- name: Run example tests on GPU
331331
env:
332332
HF_TOKEN: ${{ secrets.HF_TOKEN }}
333+
RUN_COMPILE: yes
333334
run: |
334335
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
335336
- name: Failure short reports

src/diffusers/utils/testing_utils.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def parse_flag_from_env(key, default=False):
187187

188188
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
189189
_run_nightly_tests = parse_flag_from_env("RUN_NIGHTLY", default=False)
190+
_run_compile_tests = parse_flag_from_env("RUN_COMPILE", default=False)
190191

191192

192193
def floats_tensor(shape, scale=1.0, rng=None, name=None):
@@ -225,6 +226,16 @@ def nightly(test_case):
225226
return unittest.skipUnless(_run_nightly_tests, "test is nightly")(test_case)
226227

227228

229+
def is_torch_compile(test_case):
230+
"""
231+
Decorator marking a test that runs compile tests in the diffusers CI.
232+
233+
Compile tests are skipped by default. Set the RUN_COMPILE environment variable to a truthy value to run them.
234+
235+
"""
236+
return unittest.skipUnless(_run_compile_tests, "test is torch compile")(test_case)
237+
238+
228239
def require_torch(test_case):
229240
"""
230241
Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed.
@@ -390,14 +401,6 @@ def get_python_version():
390401
return major, minor
391402

392403

393-
def require_python39_or_higher(test_case):
394-
def python39_available():
395-
major, minor = get_python_version()
396-
return major == 3 and minor >= 9
397-
398-
return unittest.skipUnless(python39_available(), "test requires Python 3.9 or higher")(test_case)
399-
400-
401404
def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -> np.ndarray:
402405
if isinstance(arry, str):
403406
if local_path is not None:

tests/models/test_modeling_common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from diffusers.utils.testing_utils import (
4444
CaptureLogger,
4545
get_python_version,
46-
require_python39_or_higher,
46+
is_torch_compile,
4747
require_torch_2,
4848
require_torch_accelerator_with_training,
4949
require_torch_gpu,
@@ -512,7 +512,7 @@ def test_from_save_pretrained_variant(self, expected_max_diff=5e-5):
512512
max_diff = (image - new_image).abs().max().item()
513513
self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes")
514514

515-
@require_python39_or_higher
515+
@is_torch_compile
516516
@require_torch_2
517517
@unittest.skipIf(
518518
get_python_version == (3, 12),

tests/pipelines/controlnet/test_controlnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@
3636
from diffusers.utils.testing_utils import (
3737
enable_full_determinism,
3838
get_python_version,
39+
is_torch_compile,
3940
load_image,
4041
load_numpy,
41-
require_python39_or_higher,
4242
require_torch_2,
4343
require_torch_gpu,
4444
run_test_in_subprocess,
@@ -1022,7 +1022,7 @@ def test_canny_guess_mode_euler(self):
10221022
expected_slice = np.array([0.1655, 0.1721, 0.1623, 0.1685, 0.1711, 0.1646, 0.1651, 0.1631, 0.1494])
10231023
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
10241024

1025-
@require_python39_or_higher
1025+
@is_torch_compile
10261026
@require_torch_2
10271027
@unittest.skipIf(
10281028
get_python_version == (3, 12),

tests/pipelines/controlnet_xs/test_controlnetxs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@
3535
from diffusers.utils.import_utils import is_xformers_available
3636
from diffusers.utils.testing_utils import (
3737
enable_full_determinism,
38+
is_torch_compile,
3839
load_image,
3940
load_numpy,
40-
require_python39_or_higher,
4141
require_torch_2,
4242
require_torch_gpu,
4343
run_test_in_subprocess,
@@ -392,7 +392,7 @@ def test_depth(self):
392392
expected_image = np.array([0.4844, 0.4937, 0.4956, 0.4663, 0.5039, 0.5044, 0.4565, 0.4883, 0.4941])
393393
assert np.allclose(original_image, expected_image, atol=1e-04)
394394

395-
@require_python39_or_higher
395+
@is_torch_compile
396396
@require_torch_2
397397
def test_stable_diffusion_compile(self):
398398
run_test_in_subprocess(test_case=self, target_func=_test_stable_diffusion_compile, inputs=None)

tests/pipelines/stable_diffusion/test_stable_diffusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,12 @@
4545
from diffusers.utils.testing_utils import (
4646
CaptureLogger,
4747
enable_full_determinism,
48+
is_torch_compile,
4849
load_image,
4950
load_numpy,
5051
nightly,
5152
numpy_cosine_similarity_distance,
5253
require_accelerate_version_greater,
53-
require_python39_or_higher,
5454
require_torch_2,
5555
require_torch_gpu,
5656
require_torch_multi_gpu,
@@ -1282,7 +1282,7 @@ def test_stable_diffusion_textual_inversion_with_sequential_cpu_offload(self):
12821282
max_diff = np.abs(expected_image - image).max()
12831283
assert max_diff < 8e-1
12841284

1285-
@require_python39_or_higher
1285+
@is_torch_compile
12861286
@require_torch_2
12871287
def test_stable_diffusion_compile(self):
12881288
seed = 0

tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@
3737
from diffusers.utils.testing_utils import (
3838
enable_full_determinism,
3939
floats_tensor,
40+
is_torch_compile,
4041
load_image,
4142
load_numpy,
4243
nightly,
43-
require_python39_or_higher,
4444
require_torch_2,
4545
require_torch_gpu,
4646
run_test_in_subprocess,
@@ -643,7 +643,7 @@ def test_img2img_safety_checker_works(self):
643643
assert out.nsfw_content_detected[0], f"Safety checker should work for prompt: {inputs['prompt']}"
644644
assert np.abs(out.images[0]).sum() < 1e-5 # should be all zeros
645645

646-
@require_python39_or_higher
646+
@is_torch_compile
647647
@require_torch_2
648648
def test_img2img_compile(self):
649649
seed = 0

tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@
3939
from diffusers.utils.testing_utils import (
4040
enable_full_determinism,
4141
floats_tensor,
42+
is_torch_compile,
4243
load_image,
4344
load_numpy,
4445
nightly,
45-
require_python39_or_higher,
4646
require_torch_2,
4747
require_torch_gpu,
4848
run_test_in_subprocess,
@@ -715,7 +715,7 @@ def test_stable_diffusion_inpaint_with_sequential_cpu_offloading(self):
715715
# make sure that less than 2.2 GB is allocated
716716
assert mem_bytes < 2.2 * 10**9
717717

718-
@require_python39_or_higher
718+
@is_torch_compile
719719
@require_torch_2
720720
def test_inpaint_compile(self):
721721
seed = 0
@@ -920,7 +920,7 @@ def test_stable_diffusion_inpaint_with_sequential_cpu_offloading(self):
920920
# make sure that less than 2.45 GB is allocated
921921
assert mem_bytes < 2.45 * 10**9
922922

923-
@require_python39_or_higher
923+
@is_torch_compile
924924
@require_torch_2
925925
def test_inpaint_compile(self):
926926
pass

tests/pipelines/test_pipelines.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,12 @@
6969
floats_tensor,
7070
get_python_version,
7171
get_tests_dir,
72+
is_torch_compile,
7273
load_numpy,
7374
nightly,
7475
require_compel,
7576
require_flax,
7677
require_onnxruntime,
77-
require_python39_or_higher,
7878
require_torch_2,
7979
require_torch_gpu,
8080
run_test_in_subprocess,
@@ -1761,7 +1761,7 @@ def test_from_save_pretrained(self):
17611761

17621762
assert np.abs(image - new_image).max() < 1e-5, "Models don't give the same forward pass"
17631763

1764-
@require_python39_or_higher
1764+
@is_torch_compile
17651765
@require_torch_2
17661766
@unittest.skipIf(
17671767
get_python_version == (3, 12),

0 commit comments

Comments
 (0)