Skip to content

Commit d185c0d

Browse files
[Lora] correct lora saving & loading (#2655)
* [Lora] correct lora saving & loading * fix final * Apply suggestions from code review
1 parent 7c1b347 commit d185c0d

File tree

1 file changed

+23
-25
lines changed

1 file changed

+23
-25
lines changed

src/diffusers/loaders.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from .models.cross_attention import LoRACrossAttnProcessor
2121
from .models.modeling_utils import _get_model_file
22-
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, is_safetensors_available, logging
22+
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, deprecate, is_safetensors_available, logging
2323

2424

2525
if is_safetensors_available():
@@ -150,13 +150,14 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
150150

151151
model_file = None
152152
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
153-
if (is_safetensors_available() and weight_name is None) or weight_name.endswith(".safetensors"):
154-
if weight_name is None:
155-
weight_name = LORA_WEIGHT_NAME_SAFE
153+
# Let's first try to load .safetensors weights
154+
if (is_safetensors_available() and weight_name is None) or (
155+
weight_name is not None and weight_name.endswith(".safetensors")
156+
):
156157
try:
157158
model_file = _get_model_file(
158159
pretrained_model_name_or_path_or_dict,
159-
weights_name=weight_name,
160+
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
160161
cache_dir=cache_dir,
161162
force_download=force_download,
162163
resume_download=resume_download,
@@ -169,14 +170,13 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
169170
)
170171
state_dict = safetensors.torch.load_file(model_file, device="cpu")
171172
except EnvironmentError:
172-
if weight_name == LORA_WEIGHT_NAME_SAFE:
173-
weight_name = None
173+
# try loading non-safetensors weights
174+
pass
175+
174176
if model_file is None:
175-
if weight_name is None:
176-
weight_name = LORA_WEIGHT_NAME
177177
model_file = _get_model_file(
178178
pretrained_model_name_or_path_or_dict,
179-
weights_name=weight_name,
179+
weights_name=weight_name or LORA_WEIGHT_NAME,
180180
cache_dir=cache_dir,
181181
force_download=force_download,
182182
resume_download=resume_download,
@@ -225,9 +225,10 @@ def save_attn_procs(
225225
self,
226226
save_directory: Union[str, os.PathLike],
227227
is_main_process: bool = True,
228-
weights_name: str = None,
228+
weight_name: str = None,
229229
save_function: Callable = None,
230230
safe_serialization: bool = False,
231+
**kwargs,
231232
):
232233
r"""
233234
Save an attention processor to a directory, so that it can be re-loaded using the
@@ -245,6 +246,12 @@ def save_attn_procs(
245246
need to replace `torch.save` by another method. Can be configured with the environment variable
246247
`DIFFUSERS_SAVE_MODE`.
247248
"""
249+
weight_name = weight_name or deprecate(
250+
"weights_name",
251+
"0.18.0",
252+
"`weights_name` is deprecated, please use `weight_name` instead.",
253+
take_from=kwargs,
254+
)
248255
if os.path.isfile(save_directory):
249256
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
250257
return
@@ -265,22 +272,13 @@ def save_function(weights, filename):
265272
# Save the model
266273
state_dict = model_to_save.state_dict()
267274

268-
# Clean the folder from a previous save
269-
for filename in os.listdir(save_directory):
270-
full_filename = os.path.join(save_directory, filename)
271-
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process
272-
# in distributed settings to avoid race conditions.
273-
weights_no_suffix = weights_name.replace(".bin", "")
274-
if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and is_main_process:
275-
os.remove(full_filename)
276-
277-
if weights_name is None:
275+
if weight_name is None:
278276
if safe_serialization:
279-
weights_name = LORA_WEIGHT_NAME_SAFE
277+
weight_name = LORA_WEIGHT_NAME_SAFE
280278
else:
281-
weights_name = LORA_WEIGHT_NAME
279+
weight_name = LORA_WEIGHT_NAME
282280

283281
# Save the model
284-
save_function(state_dict, os.path.join(save_directory, weights_name))
282+
save_function(state_dict, os.path.join(save_directory, weight_name))
285283

286-
logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
284+
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")

0 commit comments

Comments
 (0)