@@ -71,6 +71,10 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike]):
7171 for name , (library_name , class_name ) in self ._dict_to_save .items ():
7272 importable_classes = LOADABLE_CLASSES [library_name ]
7373
74+ # TODO: Suraj
75+ if library_name == self .__module__ :
76+ library_name = self
77+
7478 library = importlib .import_module (library_name )
7579 class_obj = getattr (library , class_name )
7680 class_candidates = {c : getattr (library , c ) for c in importable_classes .keys ()}
@@ -91,12 +95,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
9195
9296 module = pipeline_kwargs ["_module" ]
9397 # TODO(Suraj) - make from hub import work
98+ # Make `ddpm = DiffusionPipeline.from_pretrained("fusing/ddpm-lsun-bedroom-pipe")` work
99+ # Add Sylvains code from transformers
94100
95101 init_kwargs = {}
96102
97103 for name , (library_name , class_name ) in config_dict .items ():
98104 importable_classes = LOADABLE_CLASSES [library_name ]
99105
106+ if library_name == module :
107+ # TODO(Suraj)
108+ pass
109+
100110 library = importlib .import_module (library_name )
101111 class_obj = getattr (library , class_name )
102112 class_candidates = {c : getattr (library , c ) for c in importable_classes .keys ()}
@@ -110,7 +120,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
110120
111121 loaded_sub_model = load_method (os .path .join (cached_folder , name ))
112122
113- init_kwargs [name ] = loaded_sub_model
123+ init_kwargs [name ] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
114124
115125 model = cls (** init_kwargs )
116126 return model
0 commit comments