Skip to content

Commit 81125d8

Browse files
Make dynamo wrapped modules work with save_pretrained (#2726)
* Workaround for saving dynamo-wrapped models. * Accept suggestion from code review Co-authored-by: Patrick von Platen <[email protected]> * Apply workaround when overriding pipeline components. * Ensure the correct config.json is saved to disk. Instead of the dynamo class. * Save correct module (not compiled one) * Add test * style * fix docstrings * Go back to using string comparisons. PyTorch CPU does not have _dynamo. * Simple test for save_pretrained of compiled models. * Helper function to test whether module is compiled. --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent d4f846f commit 81125d8

File tree

6 files changed

+99
-6
lines changed

6 files changed

+99
-6
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
get_class_from_dynamic_module,
5151
is_accelerate_available,
5252
is_accelerate_version,
53+
is_compiled_module,
5354
is_safetensors_available,
5455
is_torch_version,
5556
is_transformers_available,
@@ -255,7 +256,14 @@ def maybe_raise_or_warn(
255256
if class_candidate is not None and issubclass(class_obj, class_candidate):
256257
expected_class_obj = class_candidate
257258

258-
if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
259+
# Dynamo wraps the original model in a private class.
260+
# I didn't find a public API to get the original class.
261+
sub_model = passed_class_obj[name]
262+
model_cls = sub_model.__class__
263+
if is_compiled_module(sub_model):
264+
model_cls = sub_model._orig_mod.__class__
265+
266+
if not issubclass(model_cls, expected_class_obj):
259267
raise ValueError(
260268
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
261269
f" {expected_class_obj}"
@@ -419,6 +427,10 @@ def register_modules(self, **kwargs):
419427
if module is None:
420428
register_dict = {name: (None, None)}
421429
else:
430+
# register the original module, not the dynamo compiled one
431+
if is_compiled_module(module):
432+
module = module._orig_mod
433+
422434
library = module.__module__.split(".")[0]
423435

424436
# check if the module is a pipeline module
@@ -484,6 +496,12 @@ def is_saveable_module(name, value):
484496
sub_model = getattr(self, pipeline_component_name)
485497
model_cls = sub_model.__class__
486498

499+
# Dynamo wraps the original model in a private class.
500+
# I didn't find a public API to get the original class.
501+
if is_compiled_module(sub_model):
502+
sub_model = sub_model._orig_mod
503+
model_cls = sub_model.__class__
504+
487505
save_method_name = None
488506
# search for the model's base class in LOADABLE_CLASSES
489507
for library_name, library_classes in LOADABLE_CLASSES.items():

src/diffusers/utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
from .logging import get_logger
7575
from .outputs import BaseOutput
7676
from .pil_utils import PIL_INTERPOLATION
77-
from .torch_utils import randn_tensor
77+
from .torch_utils import is_compiled_module, randn_tensor
7878

7979

8080
if is_torch_available():
@@ -86,6 +86,7 @@
8686
nightly,
8787
parse_flag_from_env,
8888
print_tensor_test,
89+
require_torch_2,
8990
require_torch_gpu,
9091
skip_mps,
9192
slow,

src/diffusers/utils/testing_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
is_onnx_available,
2626
is_opencv_available,
2727
is_torch_available,
28+
is_torch_version,
2829
)
2930
from .logging import get_logger
3031

@@ -165,6 +166,15 @@ def require_torch(test_case):
165166
return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
166167

167168

169+
def require_torch_2(test_case):
170+
"""
171+
Decorator marking a test that requires PyTorch 2. These tests are skipped when it isn't installed.
172+
"""
173+
return unittest.skipUnless(is_torch_available() and is_torch_version(">=", "2.0.0"), "test requires PyTorch 2")(
174+
test_case
175+
)
176+
177+
168178
def require_torch_gpu(test_case):
169179
"""Decorator marking a test that requires CUDA and PyTorch."""
170180
return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")(

src/diffusers/utils/torch_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing import List, Optional, Tuple, Union
1818

1919
from . import logging
20-
from .import_utils import is_torch_available
20+
from .import_utils import is_torch_available, is_torch_version
2121

2222

2323
if is_torch_available():
@@ -68,3 +68,10 @@ def randn_tensor(
6868
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
6969

7070
return latents
71+
72+
73+
def is_compiled_module(module):
74+
"""Check whether the module was compiled with torch.compile()"""
75+
if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"):
76+
return False
77+
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)

tests/test_modeling_common.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from diffusers.models import UNet2DConditionModel
2828
from diffusers.training_utils import EMAModel
2929
from diffusers.utils import torch_device
30+
from diffusers.utils.testing_utils import require_torch_gpu
3031

3132

3233
class ModelUtilsTest(unittest.TestCase):
@@ -167,6 +168,21 @@ def test_from_save_pretrained_variant(self):
167168
max_diff = (image - new_image).abs().sum().item()
168169
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
169170

171+
@require_torch_gpu
172+
def test_from_save_pretrained_dynamo(self):
173+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
174+
175+
model = self.model_class(**init_dict)
176+
model.to(torch_device)
177+
model = torch.compile(model)
178+
179+
with tempfile.TemporaryDirectory() as tmpdirname:
180+
model.save_pretrained(tmpdirname)
181+
new_model = self.model_class.from_pretrained(tmpdirname)
182+
new_model.to(torch_device)
183+
184+
assert new_model.__class__ == self.model_class
185+
170186
def test_from_save_pretrained_dtype(self):
171187
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
172188

tests/test_pipelines.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,16 @@
5454
logging,
5555
)
5656
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
57-
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, is_flax_available, nightly, slow, torch_device
57+
from diffusers.utils import (
58+
CONFIG_NAME,
59+
WEIGHTS_NAME,
60+
floats_tensor,
61+
is_flax_available,
62+
nightly,
63+
require_torch_2,
64+
slow,
65+
torch_device,
66+
)
5867
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, load_numpy, require_compel, require_torch_gpu
5968

6069

@@ -966,9 +975,41 @@ def test_from_save_pretrained(self):
966975
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
967976
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
968977
)
969-
schedular = DDPMScheduler(num_train_timesteps=10)
978+
scheduler = DDPMScheduler(num_train_timesteps=10)
979+
980+
ddpm = DDPMPipeline(model, scheduler)
981+
ddpm.to(torch_device)
982+
ddpm.set_progress_bar_config(disable=None)
983+
984+
with tempfile.TemporaryDirectory() as tmpdirname:
985+
ddpm.save_pretrained(tmpdirname)
986+
new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
987+
new_ddpm.to(torch_device)
988+
989+
generator = torch.Generator(device=torch_device).manual_seed(0)
990+
image = ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
991+
992+
generator = torch.Generator(device=torch_device).manual_seed(0)
993+
new_image = new_ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images
994+
995+
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
996+
997+
@require_torch_2
998+
def test_from_save_pretrained_dynamo(self):
999+
# 1. Load models
1000+
model = UNet2DModel(
1001+
block_out_channels=(32, 64),
1002+
layers_per_block=2,
1003+
sample_size=32,
1004+
in_channels=3,
1005+
out_channels=3,
1006+
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
1007+
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
1008+
)
1009+
model = torch.compile(model)
1010+
scheduler = DDPMScheduler(num_train_timesteps=10)
9701011

971-
ddpm = DDPMPipeline(model, schedular)
1012+
ddpm = DDPMPipeline(model, scheduler)
9721013
ddpm.to(torch_device)
9731014
ddpm.set_progress_bar_config(disable=None)
9741015

0 commit comments

Comments
 (0)