Skip to content

Commit 1a6196e

Browse files
add more logic for dynamic loading
1 parent 40dc888 commit 1a6196e

File tree

3 files changed

+13
-2
lines changed

3 files changed

+13
-2
lines changed

models/vision/ddpm/modeling_ddpm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class DDPM(DiffusionPipeline):
2323

2424
modeling_file = "modeling_ddpm.py"
2525

26-
def __init__(self, unet, noise_scheduler):
26+
def __init__(self, unet, noise_scheduler, vqvae):
2727
super().__init__()
2828
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
2929

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#!/usr/bin/env python3

src/diffusers/pipeline_utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike]):
7171
for name, (library_name, class_name) in self._dict_to_save.items():
7272
importable_classes = LOADABLE_CLASSES[library_name]
7373

74+
# TODO: Suraj
75+
if library_name == self.__module__:
76+
library_name = self
77+
7478
library = importlib.import_module(library_name)
7579
class_obj = getattr(library, class_name)
7680
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
@@ -91,12 +95,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
9195

9296
module = pipeline_kwargs["_module"]
9397
# TODO(Suraj) - make from hub import work
98+
# Make `ddpm = DiffusionPipeline.from_pretrained("fusing/ddpm-lsun-bedroom-pipe")` work
99+
# Add Sylvains code from transformers
94100

95101
init_kwargs = {}
96102

97103
for name, (library_name, class_name) in config_dict.items():
98104
importable_classes = LOADABLE_CLASSES[library_name]
99105

106+
if library_name == module:
107+
# TODO(Suraj)
108+
pass
109+
100110
library = importlib.import_module(library_name)
101111
class_obj = getattr(library, class_name)
102112
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
@@ -110,7 +120,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
110120

111121
loaded_sub_model = load_method(os.path.join(cached_folder, name))
112122

113-
init_kwargs[name] = loaded_sub_model
123+
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
114124

115125
model = cls(**init_kwargs)
116126
return model

0 commit comments

Comments
 (0)