Skip to content

Commit 21d570d

Browse files
anton-lPrathik Rao
authored andcommitted
Introduce the copy mechanism (huggingface#924)
* Introduce the copy mechanism * init tests * fix dummy tests * with * update copies tests
1 parent cfdea72 commit 21d570d

File tree

8 files changed

+284
-268
lines changed

8 files changed

+284
-268
lines changed

.github/workflows/pr_quality.yml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,20 @@ jobs:
3131
isort --check-only examples tests src utils scripts
3232
flake8 examples tests src utils scripts
3333
doc-builder style src/diffusers docs/source --max_len 119 --check_only --path_to_docs docs/source
34+
35+
check_repository_consistency:
36+
runs-on: ubuntu-latest
37+
steps:
38+
- uses: actions/checkout@v3
39+
- name: Set up Python
40+
uses: actions/setup-python@v4
41+
with:
42+
python-version: "3.7"
43+
- name: Install dependencies
44+
run: |
45+
python -m pip install --upgrade pip
46+
pip install .[quality]
47+
- name: Check quality
48+
run: |
49+
python utils/check_copies.py
50+
python utils/check_dummies.py

Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency
6767
# Make marked copies of snippets of codes conform to the original
6868

6969
fix-copies:
70+
python utils/check_copies.py --fix_and_overwrite
7071
python utils/check_dummies.py --fix_and_overwrite
7172

7273
# Run tests for the library

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929

3030
@dataclass
31+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
3132
class DDIMSchedulerOutput(BaseOutput):
3233
"""
3334
Output class for the scheduler's step function output.

src/diffusers/schedulers/scheduling_lms_discrete.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727

2828
@dataclass
29+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->LMSDiscrete
2930
class LMSDiscreteSchedulerOutput(BaseOutput):
3031
"""
3132
Output class for the scheduler's step function output.
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright 2022 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import re
17+
import shutil
18+
import sys
19+
import tempfile
20+
import unittest
21+
22+
import black
23+
24+
25+
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
26+
sys.path.append(os.path.join(git_repo_path, "utils"))
27+
28+
import check_copies # noqa: E402
29+
30+
31+
# This is the reference code that will be used in the tests.
32+
# If DDPMSchedulerOutput is changed in scheduling_ddpm.py, this code needs to be manually updated.
33+
REFERENCE_CODE = """ \"""
34+
Output class for the scheduler's step function output.
35+
36+
Args:
37+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
38+
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
39+
denoising loop.
40+
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
41+
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
42+
`pred_original_sample` can be used to preview progress or for guidance.
43+
\"""
44+
45+
prev_sample: torch.FloatTensor
46+
pred_original_sample: Optional[torch.FloatTensor] = None
47+
"""
48+
49+
50+
class CopyCheckTester(unittest.TestCase):
51+
def setUp(self):
52+
self.diffusers_dir = tempfile.mkdtemp()
53+
os.makedirs(os.path.join(self.diffusers_dir, "schedulers/"))
54+
check_copies.DIFFUSERS_PATH = self.diffusers_dir
55+
shutil.copy(
56+
os.path.join(git_repo_path, "src/diffusers/schedulers/scheduling_ddpm.py"),
57+
os.path.join(self.diffusers_dir, "schedulers/scheduling_ddpm.py"),
58+
)
59+
60+
def tearDown(self):
61+
check_copies.DIFFUSERS_PATH = "src/diffusers"
62+
shutil.rmtree(self.diffusers_dir)
63+
64+
def check_copy_consistency(self, comment, class_name, class_code, overwrite_result=None):
65+
code = comment + f"\nclass {class_name}(nn.Module):\n" + class_code
66+
if overwrite_result is not None:
67+
expected = comment + f"\nclass {class_name}(nn.Module):\n" + overwrite_result
68+
mode = black.Mode(target_versions={black.TargetVersion.PY35}, line_length=119)
69+
code = black.format_str(code, mode=mode)
70+
fname = os.path.join(self.diffusers_dir, "new_code.py")
71+
with open(fname, "w", newline="\n") as f:
72+
f.write(code)
73+
if overwrite_result is None:
74+
self.assertTrue(len(check_copies.is_copy_consistent(fname)) == 0)
75+
else:
76+
check_copies.is_copy_consistent(f.name, overwrite=True)
77+
with open(fname, "r") as f:
78+
self.assertTrue(f.read(), expected)
79+
80+
def test_find_code_in_diffusers(self):
81+
code = check_copies.find_code_in_diffusers("schedulers.scheduling_ddpm.DDPMSchedulerOutput")
82+
self.assertEqual(code, REFERENCE_CODE)
83+
84+
def test_is_copy_consistent(self):
85+
# Base copy consistency
86+
self.check_copy_consistency(
87+
"# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput",
88+
"DDPMSchedulerOutput",
89+
REFERENCE_CODE + "\n",
90+
)
91+
92+
# With no empty line at the end
93+
self.check_copy_consistency(
94+
"# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput",
95+
"DDPMSchedulerOutput",
96+
REFERENCE_CODE,
97+
)
98+
99+
# Copy consistency with rename
100+
self.check_copy_consistency(
101+
"# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->Test",
102+
"TestSchedulerOutput",
103+
re.sub("DDPM", "Test", REFERENCE_CODE),
104+
)
105+
106+
# Copy consistency with a really long name
107+
long_class_name = "TestClassWithAReallyLongNameBecauseSomePeopleLikeThatForSomeReason"
108+
self.check_copy_consistency(
109+
f"# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->{long_class_name}",
110+
f"{long_class_name}SchedulerOutput",
111+
re.sub("Bert", long_class_name, REFERENCE_CODE),
112+
)
113+
114+
# Copy consistency with overwrite
115+
self.check_copy_consistency(
116+
"# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->Test",
117+
"TestSchedulerOutput",
118+
REFERENCE_CODE,
119+
overwrite_result=re.sub("DDPM", "Test", REFERENCE_CODE),
120+
)
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright 2022 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import sys
17+
import unittest
18+
19+
20+
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
21+
sys.path.append(os.path.join(git_repo_path, "utils"))
22+
23+
import check_dummies
24+
from check_dummies import create_dummy_files, create_dummy_object, find_backend, read_init # noqa: E402
25+
26+
27+
# Align TRANSFORMERS_PATH in check_dummies with the current path
28+
check_dummies.PATH_TO_DIFFUSERS = os.path.join(git_repo_path, "src", "diffusers")
29+
30+
31+
class CheckDummiesTester(unittest.TestCase):
32+
def test_find_backend(self):
33+
simple_backend = find_backend(" if not is_torch_available():")
34+
self.assertEqual(simple_backend, "torch")
35+
36+
# backend_with_underscore = find_backend(" if not is_tensorflow_text_available():")
37+
# self.assertEqual(backend_with_underscore, "tensorflow_text")
38+
39+
double_backend = find_backend(" if not (is_torch_available() and is_transformers_available()):")
40+
self.assertEqual(double_backend, "torch_and_transformers")
41+
42+
# double_backend_with_underscore = find_backend(
43+
# " if not (is_sentencepiece_available() and is_tensorflow_text_available()):"
44+
# )
45+
# self.assertEqual(double_backend_with_underscore, "sentencepiece_and_tensorflow_text")
46+
47+
triple_backend = find_backend(
48+
" if not (is_torch_available() and is_transformers_available() and is_onnx_available()):"
49+
)
50+
self.assertEqual(triple_backend, "torch_and_transformers_and_onnx")
51+
52+
def test_read_init(self):
53+
objects = read_init()
54+
# We don't assert on the exact list of keys to allow for smooth grow of backend-specific objects
55+
self.assertIn("torch", objects)
56+
self.assertIn("torch_and_transformers", objects)
57+
self.assertIn("flax_and_transformers", objects)
58+
self.assertIn("torch_and_transformers_and_onnx", objects)
59+
60+
# Likewise, we can't assert on the exact content of a key
61+
self.assertIn("UNet2DModel", objects["torch"])
62+
self.assertIn("FlaxUNet2DConditionModel", objects["flax"])
63+
self.assertIn("StableDiffusionPipeline", objects["torch_and_transformers"])
64+
self.assertIn("FlaxStableDiffusionPipeline", objects["flax_and_transformers"])
65+
self.assertIn("LMSDiscreteScheduler", objects["torch_and_scipy"])
66+
self.assertIn("OnnxStableDiffusionPipeline", objects["torch_and_transformers_and_onnx"])
67+
68+
def test_create_dummy_object(self):
69+
dummy_constant = create_dummy_object("CONSTANT", "'torch'")
70+
self.assertEqual(dummy_constant, "\nCONSTANT = None\n")
71+
72+
dummy_function = create_dummy_object("function", "'torch'")
73+
self.assertEqual(
74+
dummy_function, "\ndef function(*args, **kwargs):\n requires_backends(function, 'torch')\n"
75+
)
76+
77+
expected_dummy_class = """
78+
class FakeClass(metaclass=DummyObject):
79+
_backends = 'torch'
80+
81+
def __init__(self, *args, **kwargs):
82+
requires_backends(self, 'torch')
83+
84+
@classmethod
85+
def from_config(cls, *args, **kwargs):
86+
requires_backends(cls, 'torch')
87+
88+
@classmethod
89+
def from_pretrained(cls, *args, **kwargs):
90+
requires_backends(cls, 'torch')
91+
"""
92+
dummy_class = create_dummy_object("FakeClass", "'torch'")
93+
self.assertEqual(dummy_class, expected_dummy_class)
94+
95+
def test_create_dummy_files(self):
96+
expected_dummy_pytorch_file = """# This file is autogenerated by the command `make fix-copies`, do not edit.
97+
# flake8: noqa
98+
99+
from ..utils import DummyObject, requires_backends
100+
101+
102+
CONSTANT = None
103+
104+
105+
def function(*args, **kwargs):
106+
requires_backends(function, ["torch"])
107+
108+
109+
class FakeClass(metaclass=DummyObject):
110+
_backends = ["torch"]
111+
112+
def __init__(self, *args, **kwargs):
113+
requires_backends(self, ["torch"])
114+
115+
@classmethod
116+
def from_config(cls, *args, **kwargs):
117+
requires_backends(cls, ["torch"])
118+
119+
@classmethod
120+
def from_pretrained(cls, *args, **kwargs):
121+
requires_backends(cls, ["torch"])
122+
"""
123+
dummy_files = create_dummy_files({"torch": ["CONSTANT", "function", "FakeClass"]})
124+
self.assertEqual(dummy_files["torch"], expected_dummy_pytorch_file)

0 commit comments

Comments
 (0)