Skip to content

Commit d8287fc

Browse files
committed
fix issues with loading, add test for pipeline
1 parent fe99460 commit d8287fc

File tree

4 files changed

+59
-7
lines changed

4 files changed

+59
-7
lines changed

src/diffusers/configuration_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,6 @@ def get_config_dict(
190190
def extract_init_dict(cls, config_dict, **kwargs):
191191
expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
192192
expected_keys.remove("self")
193-
import ipdb; ipdb.set_trace()
194193
init_dict = {}
195194
for key in expected_keys:
196195
if key in kwargs:

src/diffusers/pipeline_utils.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,19 +56,23 @@ def register_modules(self, **kwargs):
5656
class_name = module.__class__.__name__
5757

5858
register_dict = {name: (library, class_name)}
59-
register_dict["_module"] = self.__module__
59+
6060

6161
# save model index config
6262
self.register(**register_dict)
6363

6464
# set models
6565
setattr(self, name, module)
66+
67+
register_dict = {"_module" : self.__module__.split(".")[-1] + ".py"}
68+
self.register(**register_dict)
6669

6770
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
6871
self.save_config(save_directory)
6972

7073
model_index_dict = self._dict_to_save
7174
model_index_dict.pop("_class_name")
75+
model_index_dict.pop("_module")
7276

7377
for name, (library_name, class_name) in self._dict_to_save.items():
7478
importable_classes = LOADABLE_CLASSES[library_name]
@@ -98,12 +102,17 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
98102
cached_folder = pretrained_model_name_or_path
99103

100104
config_dict = cls.get_config_dict(cached_folder)
105+
101106
module = config_dict["_module"]
102107
class_name_ = config_dict["_class_name"]
103-
class_obj = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)
108+
109+
if class_name_ == cls.__name__:
110+
pipeline_class = cls
111+
else:
112+
pipeline_class = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)
113+
104114

105-
init_dict, unused = class_obj.extract_init_dict(config_dict, **kwargs)
106-
import ipdb; ipdb.set_trace()
115+
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
107116

108117
init_kwargs = {}
109118

@@ -132,6 +141,5 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
132141

133142
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
134143

135-
136-
model = class_obj(**init_kwargs)
144+
model = pipeline_class(**init_kwargs)
137145
return model

tests/__init__.py

Whitespace-only changes.

tests/test_modeling_utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import torch
2323

2424
from diffusers import GaussianDDPMScheduler, UNetModel
25+
from diffusers.pipeline_utils import DiffusionPipeline
26+
from models.vision.ddpm.modeling_ddpm import DDPM
2527

2628

2729
global_rng = random.Random()
@@ -199,3 +201,46 @@ def test_sample_fast(self):
199201
assert image.shape == (1, 3, 256, 256)
200202
image_slice = image[0, -1, -3:, -3:].cpu()
201203
assert (image_slice - torch.tensor([[0.1746, 0.5125, -0.7920], [-0.5734, -0.2910, -0.1984], [0.4090, -0.7740, -0.3941]])).abs().sum() < 1e-3
204+
205+
206+
class PipelineTesterMixin(unittest.TestCase):
207+
def test_from_pretrained_save_pretrained(self):
208+
# 1. Load models
209+
model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32)
210+
schedular = GaussianDDPMScheduler(timesteps=10)
211+
212+
ddpm = DDPM(model, schedular)
213+
214+
with tempfile.TemporaryDirectory() as tmpdirname:
215+
ddpm.save_pretrained(tmpdirname)
216+
new_ddpm = DDPM.from_pretrained(tmpdirname)
217+
218+
generator = torch.Generator()
219+
generator = generator.manual_seed(669472945848556)
220+
221+
image = ddpm(generator)
222+
generator = generator.manual_seed(669472945848556)
223+
new_image = new_ddpm(generator)
224+
225+
assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"
226+
227+
228+
@slow
229+
def test_from_pretrained_hub(self):
230+
model_path = "fusing/ddpm-cifar10"
231+
232+
ddpm = DDPM.from_pretrained(model_path)
233+
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path)
234+
235+
ddpm.noise_scheduler.num_timesteps = 10
236+
ddpm_from_hub.noise_scheduler.num_timesteps = 10
237+
238+
239+
generator = torch.Generator(device=torch_device)
240+
generator = generator.manual_seed(669472945848556)
241+
242+
image = ddpm(generator)
243+
generator = generator.manual_seed(669472945848556)
244+
new_image = ddpm_from_hub(generator)
245+
246+
assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"

0 commit comments

Comments
 (0)