Skip to content

Commit 936cd08

Browse files
improve loading a bit
1 parent 3a32b8c commit 936cd08

File tree

4 files changed

+22
-0
lines changed

4 files changed

+22
-0
lines changed

src/diffusers/configuration_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ def get_config_dict(
208208
def extract_init_dict(cls, config_dict, **kwargs):
209209
expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
210210
expected_keys.remove("self")
211+
expected_keys.remove("kwargs")
211212
init_dict = {}
212213
for key in expected_keys:
213214
if key in kwargs:

src/diffusers/modeling_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ class ModelMixin(torch.nn.Module):
147147
models, `pixel_values` for vision models and `input_values` for speech models).
148148
"""
149149
config_name = CONFIG_NAME
150+
_automatically_saved_args = ["_diffusers_version", "_class_name", "name_or_path"]
150151

151152
def __init__(self):
152153
super().__init__()

src/diffusers/models/unet_conditional.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,18 @@ def __init__(
6363
mid_block_scale_factor=1,
6464
center_input_sample=False,
6565
resnet_num_groups=30,
66+
**kwargs,
6667
):
6768
super().__init__()
69+
# remove automatically added kwargs
70+
for arg in self._automatically_saved_args:
71+
kwargs.pop(arg, None)
72+
73+
if len(kwargs) > 0:
74+
raise ValueError(
75+
f"The following keyword arguments do not exist for {self.__class__}: {','.join(kwargs.keys())}"
76+
)
77+
6878
# register all __init__ params to be accessible via `self.config.<...>`
6979
# should probably be automated down the road as this is pure boiler plate code
7080
self.register_to_config(

src/diffusers/models/unet_unconditional.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,18 @@ def __init__(
5959
mid_block_scale_factor=1,
6060
center_input_sample=False,
6161
resnet_num_groups=32,
62+
**kwargs,
6263
):
6364
super().__init__()
65+
# remove automatically added kwargs
66+
for arg in self._automatically_saved_args:
67+
kwargs.pop(arg, None)
68+
69+
if len(kwargs) > 0:
70+
raise ValueError(
71+
f"The following keyword arguments do not exist for {self.__class__}: {','.join(kwargs.keys())}"
72+
)
73+
6474
# register all __init__ params to be accessible via `self.config.<...>`
6575
# should probably be automated down the road as this is pure boiler plate code
6676
self.register_to_config(

0 commit comments

Comments
 (0)