Skip to content

Commit 40dc888

Browse files
add first logic for from hub code download
1 parent e8ad2b7 commit 40dc888

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

models/vision/ddpm/modeling_ddpm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020

2121

2222
class DDPM(DiffusionPipeline):
23+
24+
modeling_file = "modeling_ddpm.py"
25+
2326
def __init__(self, unet, noise_scheduler):
2427
super().__init__()
2528
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)

src/diffusers/pipeline_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,11 @@ def register_modules(self, **kwargs):
5353
# retrive class_name
5454
class_name = module.__class__.__name__
5555

56+
register_dict = {name: (library, class_name)}
57+
register_dict["_module"] = self.__module__
58+
5659
# save model index config
57-
self.register(**{name: (library, class_name)})
60+
self.register(**register_dict)
5861

5962
# set models
6063
setattr(self, name, module)
@@ -84,7 +87,10 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike]):
8487
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
8588
# use snapshot download here to get it working from from_pretrained
8689
cached_folder = snapshot_download(pretrained_model_name_or_path)
87-
config_dict, _ = cls.get_config_dict(cached_folder)
90+
config_dict, pipeline_kwargs = cls.get_config_dict(cached_folder)
91+
92+
module = pipeline_kwargs["_module"]
93+
# TODO(Suraj) - make from hub import work
8894

8995
init_kwargs = {}
9096

0 commit comments

Comments
 (0)