Skip to content

Commit 4f1c989

Browse files
authored
Add smoke tests for the training examples (#585)
* Add smoke tests for the training examples * upd * use a dummy dataset * mark as slow * cleanup * Update test cases * naming
1 parent 3fc8ef7 commit 4f1c989

File tree

3 files changed

+221
-1
lines changed

3 files changed

+221
-1
lines changed

.github/workflows/push_tests.yml

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,5 +59,56 @@ jobs:
5959
if: ${{ always() }}
6060
uses: actions/upload-artifact@v2
6161
with:
62-
name: push_torch_test_reports
62+
name: torch_test_reports
63+
path: reports
64+
65+
66+
67+
run_examples_single_gpu:
68+
name: Examples tests
69+
strategy:
70+
fail-fast: false
71+
matrix:
72+
machine_type: [ single-gpu ]
73+
runs-on: [ self-hosted, docker-gpu, '${{ matrix.machine_type }}' ]
74+
container:
75+
image: nvcr.io/nvidia/pytorch:22.07-py3
76+
options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
77+
78+
steps:
79+
- name: Checkout diffusers
80+
uses: actions/checkout@v3
81+
with:
82+
fetch-depth: 2
83+
84+
- name: NVIDIA-SMI
85+
run: |
86+
nvidia-smi
87+
88+
- name: Install dependencies
89+
run: |
90+
python -m pip install --upgrade pip
91+
python -m pip uninstall -y torch torchvision torchtext
92+
python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cu116
93+
python -m pip install -e .[quality,test,training]
94+
95+
- name: Environment
96+
run: |
97+
python utils/print_env.py
98+
99+
- name: Run example tests on GPU
100+
env:
101+
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
102+
run: |
103+
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_gpu examples/
104+
105+
- name: Failure short reports
106+
if: ${{ failure() }}
107+
run: cat reports/examples_torch_gpu_failures_short.txt
108+
109+
- name: Test suite reports artifacts
110+
if: ${{ always() }}
111+
uses: actions/upload-artifact@v2
112+
with:
113+
name: examples_test_reports
63114
path: reports

examples/conftest.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2022 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# tests directory-specific settings - this file is run automatically
16+
# by pytest before any tests are run
17+
18+
import sys
19+
import warnings
20+
from os.path import abspath, dirname, join
21+
22+
23+
# allow having multiple repository checkouts and not needing to remember to rerun
24+
# 'pip install -e .[dev]' when switching between checkouts and running tests.
25+
git_repo_path = abspath(join(dirname(dirname(dirname(__file__))), "src"))
26+
sys.path.insert(1, git_repo_path)
27+
28+
29+
# silence FutureWarning warnings in tests since often we can't act on them until
30+
# they become normal warnings - i.e. the tests still need to test the current functionality
31+
warnings.simplefilter(action="ignore", category=FutureWarning)
32+
33+
34+
def pytest_addoption(parser):
35+
from diffusers.testing_utils import pytest_addoption_shared
36+
37+
pytest_addoption_shared(parser)
38+
39+
40+
def pytest_terminal_summary(terminalreporter):
41+
from diffusers.testing_utils import pytest_terminal_summary_main
42+
43+
make_reports = terminalreporter.config.getoption("--make-reports")
44+
if make_reports:
45+
pytest_terminal_summary_main(terminalreporter, id=make_reports)

examples/test_examples.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# coding=utf-8
2+
# Copyright 2022 HuggingFace Inc..
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
import logging
18+
import os
19+
import shutil
20+
import subprocess
21+
import sys
22+
import tempfile
23+
import unittest
24+
from typing import List
25+
26+
from accelerate.utils import write_basic_config
27+
from diffusers.testing_utils import slow
28+
29+
30+
logging.basicConfig(level=logging.DEBUG)
31+
32+
logger = logging.getLogger()
33+
34+
35+
# These utils relate to ensuring the right error message is received when running scripts
36+
class SubprocessCallException(Exception):
37+
pass
38+
39+
40+
def run_command(command: List[str], return_stdout=False):
41+
"""
42+
Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture
43+
if an error occured while running `command`
44+
"""
45+
try:
46+
output = subprocess.check_output(command, stderr=subprocess.STDOUT)
47+
if return_stdout:
48+
if hasattr(output, "decode"):
49+
output = output.decode("utf-8")
50+
return output
51+
except subprocess.CalledProcessError as e:
52+
raise SubprocessCallException(
53+
f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
54+
) from e
55+
56+
57+
stream_handler = logging.StreamHandler(sys.stdout)
58+
logger.addHandler(stream_handler)
59+
60+
61+
class ExamplesTestsAccelerate(unittest.TestCase):
62+
@classmethod
63+
def setUpClass(cls):
64+
super().setUpClass()
65+
cls._tmpdir = tempfile.mkdtemp()
66+
cls.configPath = os.path.join(cls._tmpdir, "default_config.yml")
67+
68+
write_basic_config(save_location=cls.configPath)
69+
cls._launch_args = ["accelerate", "launch", "--config_file", cls.configPath]
70+
71+
@classmethod
72+
def tearDownClass(cls):
73+
super().tearDownClass()
74+
shutil.rmtree(cls._tmpdir)
75+
76+
@slow
77+
def test_train_unconditional(self):
78+
with tempfile.TemporaryDirectory() as tmpdir:
79+
test_args = f"""
80+
examples/unconditional_image_generation/train_unconditional.py
81+
--dataset_name huggan/few-shot-aurora
82+
--resolution 64
83+
--output_dir {tmpdir}
84+
--train_batch_size 4
85+
--num_epochs 1
86+
--gradient_accumulation_steps 1
87+
--learning_rate 1e-3
88+
--lr_warmup_steps 5
89+
--mixed_precision fp16
90+
""".split()
91+
92+
run_command(self._launch_args + test_args, return_stdout=True)
93+
# save_pretrained smoke test
94+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
95+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
96+
# logging test
97+
self.assertTrue(len(os.listdir(os.path.join(tmpdir, "logs", "train_unconditional"))) > 0)
98+
99+
@slow
100+
def test_textual_inversion(self):
101+
with tempfile.TemporaryDirectory() as tmpdir:
102+
test_args = f"""
103+
examples/textual_inversion/textual_inversion.py
104+
--pretrained_model_name_or_path CompVis/stable-diffusion-v1-4
105+
--use_auth_token
106+
--train_data_dir docs/source/imgs
107+
--learnable_property object
108+
--placeholder_token <cat-toy>
109+
--initializer_token toy
110+
--resolution 64
111+
--train_batch_size 1
112+
--gradient_accumulation_steps 2
113+
--max_train_steps 10
114+
--learning_rate 5.0e-04
115+
--scale_lr
116+
--lr_scheduler constant
117+
--lr_warmup_steps 0
118+
--output_dir {tmpdir}
119+
--mixed_precision fp16
120+
""".split()
121+
122+
run_command(self._launch_args + test_args)
123+
# save_pretrained smoke test
124+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "learned_embeds.bin")))

0 commit comments

Comments
 (0)