@@ -89,6 +89,7 @@ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool
8989
9090 self .to_json_file (output_config_file )
9191 logger .info (f"ConfigMixinuration saved in { output_config_file } " )
92+
9293
9394 @classmethod
9495 def get_config_dict (
@@ -182,35 +183,42 @@ def get_config_dict(
182183 logger .info (f"loading configuration file { config_file } " )
183184 else :
184185 logger .info (f"loading configuration file { config_file } from cache at { resolved_config_file } " )
186+
187+ return config_dict
185188
189+ @classmethod
190+ def extract_init_dict (cls , config_dict , ** kwargs ):
186191 expected_keys = set (dict (inspect .signature (cls .__init__ ).parameters ).keys ())
187192 expected_keys .remove ("self" )
188-
193+ init_dict = {}
189194 for key in expected_keys :
190195 if key in kwargs :
191196 # overwrite key
192- config_dict [key ] = kwargs .pop (key )
197+ init_dict [key ] = kwargs .pop (key )
198+ elif key in config_dict :
199+ # use value from config dict
200+ init_dict [key ] = config_dict .pop (key )
193201
194- passed_keys = set (config_dict .keys ())
195-
196- unused_kwargs = kwargs
197- for key in passed_keys - expected_keys :
198- unused_kwargs [key ] = config_dict .pop (key )
199202
203+ unused_kwargs = config_dict .update (kwargs )
204+
205+ passed_keys = set (init_dict .keys ())
200206 if len (expected_keys - passed_keys ) > 0 :
201207 logger .warn (
202208 f"{ expected_keys - passed_keys } was not found in config. Values will be initialized to default values."
203209 )
204210
205- return config_dict , unused_kwargs
211+ return init_dict , unused_kwargs
206212
207213 @classmethod
208214 def from_config (cls , pretrained_model_name_or_path : Union [str , os .PathLike ], return_unused_kwargs = False , ** kwargs ):
209- config_dict , unused_kwargs = cls .get_config_dict (
215+ config_dict = cls .get_config_dict (
210216 pretrained_model_name_or_path = pretrained_model_name_or_path , ** kwargs
211217 )
212218
213- model = cls (** config_dict )
219+ init_dict , unused_kwargs = cls .extract_init_dict (config_dict , ** kwargs )
220+
221+ model = cls (** init_dict )
214222
215223 if return_unused_kwargs :
216224 return model , unused_kwargs
0 commit comments