1919
2020from .models .cross_attention import LoRACrossAttnProcessor
2121from .models .modeling_utils import _get_model_file
22- from .utils import DIFFUSERS_CACHE , HF_HUB_OFFLINE , is_safetensors_available , logging
22+ from .utils import DIFFUSERS_CACHE , HF_HUB_OFFLINE , deprecate , is_safetensors_available , logging
2323
2424
2525if is_safetensors_available ():
@@ -150,13 +150,14 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
150150
151151 model_file = None
152152 if not isinstance (pretrained_model_name_or_path_or_dict , dict ):
153- if (is_safetensors_available () and weight_name is None ) or weight_name .endswith (".safetensors" ):
154- if weight_name is None :
155- weight_name = LORA_WEIGHT_NAME_SAFE
153+ # Let's first try to load .safetensors weights
154+ if (is_safetensors_available () and weight_name is None ) or (
155+ weight_name is not None and weight_name .endswith (".safetensors" )
156+ ):
156157 try :
157158 model_file = _get_model_file (
158159 pretrained_model_name_or_path_or_dict ,
159- weights_name = weight_name ,
160+ weights_name = weight_name or LORA_WEIGHT_NAME_SAFE ,
160161 cache_dir = cache_dir ,
161162 force_download = force_download ,
162163 resume_download = resume_download ,
@@ -169,14 +170,13 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
169170 )
170171 state_dict = safetensors .torch .load_file (model_file , device = "cpu" )
171172 except EnvironmentError :
172- if weight_name == LORA_WEIGHT_NAME_SAFE :
173- weight_name = None
173+ # try loading non-safetensors weights
174+ pass
175+
174176 if model_file is None :
175- if weight_name is None :
176- weight_name = LORA_WEIGHT_NAME
177177 model_file = _get_model_file (
178178 pretrained_model_name_or_path_or_dict ,
179- weights_name = weight_name ,
179+ weights_name = weight_name or LORA_WEIGHT_NAME ,
180180 cache_dir = cache_dir ,
181181 force_download = force_download ,
182182 resume_download = resume_download ,
@@ -225,9 +225,10 @@ def save_attn_procs(
225225 self ,
226226 save_directory : Union [str , os .PathLike ],
227227 is_main_process : bool = True ,
228- weights_name : str = None ,
228+ weight_name : str = None ,
229229 save_function : Callable = None ,
230230 safe_serialization : bool = False ,
231+ ** kwargs ,
231232 ):
232233 r"""
233234 Save an attention processor to a directory, so that it can be re-loaded using the
@@ -245,6 +246,12 @@ def save_attn_procs(
245246 need to replace `torch.save` by another method. Can be configured with the environment variable
246247 `DIFFUSERS_SAVE_MODE`.
247248 """
249+ weight_name = weight_name or deprecate (
250+ "weights_name" ,
251+ "0.18.0" ,
252+ "`weights_name` is deprecated, please use `weight_name` instead." ,
253+ take_from = kwargs ,
254+ )
248255 if os .path .isfile (save_directory ):
249256 logger .error (f"Provided path ({ save_directory } ) should be a directory, not a file" )
250257 return
@@ -265,22 +272,13 @@ def save_function(weights, filename):
265272 # Save the model
266273 state_dict = model_to_save .state_dict ()
267274
268- # Clean the folder from a previous save
269- for filename in os .listdir (save_directory ):
270- full_filename = os .path .join (save_directory , filename )
271- # If we have a shard file that is not going to be replaced, we delete it, but only from the main process
272- # in distributed settings to avoid race conditions.
273- weights_no_suffix = weights_name .replace (".bin" , "" )
274- if filename .startswith (weights_no_suffix ) and os .path .isfile (full_filename ) and is_main_process :
275- os .remove (full_filename )
276-
277- if weights_name is None :
275+ if weight_name is None :
278276 if safe_serialization :
279- weights_name = LORA_WEIGHT_NAME_SAFE
277+ weight_name = LORA_WEIGHT_NAME_SAFE
280278 else :
281- weights_name = LORA_WEIGHT_NAME
279+ weight_name = LORA_WEIGHT_NAME
282280
283281 # Save the model
284- save_function (state_dict , os .path .join (save_directory , weights_name ))
282+ save_function (state_dict , os .path .join (save_directory , weight_name ))
285283
286- logger .info (f"Model weights saved in { os .path .join (save_directory , weights_name )} " )
284+ logger .info (f"Model weights saved in { os .path .join (save_directory , weight_name )} " )
0 commit comments