1616""" ConfigMixin base class and utilities."""
1717import dataclasses
1818import functools
19+ import importlib
1920import inspect
2021import json
2122import os
@@ -48,9 +49,13 @@ class ConfigMixin:
4849 [`~ConfigMixin.save_config`] (should be overridden by parent class).
4950 - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
5051 overridden by parent class).
52+ - **_compatible_classes** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
53+ `from_config` can be used from a class different than the one used to save the config (should be overridden
54+ by parent class).
5155 """
5256 config_name = None
5357 ignore_for_config = []
58+ _compatible_classes = []
5459
5560 def register_to_config (self , ** kwargs ):
5661 if self .config_name is None :
@@ -280,9 +285,14 @@ def get_config_dict(
280285
281286 return config_dict
282287
288+ @staticmethod
289+ def _get_init_keys (cls ):
290+ return set (dict (inspect .signature (cls .__init__ ).parameters ).keys ())
291+
283292 @classmethod
284293 def extract_init_dict (cls , config_dict , ** kwargs ):
285- expected_keys = set (dict (inspect .signature (cls .__init__ ).parameters ).keys ())
294+ # 1. Retrieve expected config attributes from __init__ signature
295+ expected_keys = cls ._get_init_keys (cls )
286296 expected_keys .remove ("self" )
287297 # remove general kwargs if present in dict
288298 if "kwargs" in expected_keys :
@@ -292,9 +302,36 @@ def extract_init_dict(cls, config_dict, **kwargs):
292302 for arg in cls ._flax_internal_args :
293303 expected_keys .remove (arg )
294304
305+ # 2. Remove attributes that cannot be expected from expected config attributes
295306 # remove keys to be ignored
296307 if len (cls .ignore_for_config ) > 0 :
297308 expected_keys = expected_keys - set (cls .ignore_for_config )
309+
310+ # load diffusers library to import compatible and original scheduler
311+ diffusers_library = importlib .import_module (__name__ .split ("." )[0 ])
312+
313+ # remove attributes from compatible classes that orig cannot expect
314+ compatible_classes = [getattr (diffusers_library , c , None ) for c in cls ._compatible_classes ]
315+ # filter out None potentially undefined dummy classes
316+ compatible_classes = [c for c in compatible_classes if c is not None ]
317+ expected_keys_comp_cls = set ()
318+ for c in compatible_classes :
319+ expected_keys_c = cls ._get_init_keys (c )
320+ expected_keys_comp_cls = expected_keys_comp_cls .union (expected_keys_c )
321+ expected_keys_comp_cls = expected_keys_comp_cls - cls ._get_init_keys (cls )
322+ config_dict = {k : v for k , v in config_dict .items () if k not in expected_keys_comp_cls }
323+
324+ # remove attributes from orig class that cannot be expected
325+ orig_cls_name = config_dict .pop ("_class_name" , cls .__name__ )
326+ if orig_cls_name != cls .__name__ :
327+ orig_cls = getattr (diffusers_library , orig_cls_name )
328+ unexpected_keys_from_orig = cls ._get_init_keys (orig_cls ) - expected_keys
329+ config_dict = {k : v for k , v in config_dict .items () if k not in unexpected_keys_from_orig }
330+
331+ # remove private attributes
332+ config_dict = {k : v for k , v in config_dict .items () if not k .startswith ("_" )}
333+
334+ # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
298335 init_dict = {}
299336 for key in expected_keys :
300337 if key in kwargs :
@@ -304,23 +341,24 @@ def extract_init_dict(cls, config_dict, **kwargs):
304341 # use value from config dict
305342 init_dict [key ] = config_dict .pop (key )
306343
307- config_dict = {k : v for k , v in config_dict .items () if not k .startswith ("_" )}
308-
344+ # 4. Give nice warning if unexpected values have been passed
309345 if len (config_dict ) > 0 :
310346 logger .warning (
311347 f"The config attributes { config_dict } were passed to { cls .__name__ } , "
312348 "but are not expected and will be ignored. Please verify your "
313349 f"{ cls .config_name } configuration file."
314350 )
315351
316- unused_kwargs = {** config_dict , ** kwargs }
317-
352+ # 5. Give nice info if config attributes are initiliazed to default because they have not been passed
318353 passed_keys = set (init_dict .keys ())
319354 if len (expected_keys - passed_keys ) > 0 :
320355 logger .info (
321356 f"{ expected_keys - passed_keys } was not found in config. Values will be initialized to default values."
322357 )
323358
359+ # 6. Define unused keyword arguments
360+ unused_kwargs = {** config_dict , ** kwargs }
361+
324362 return init_dict , unused_kwargs
325363
326364 @classmethod
0 commit comments