Skip to content

Commit bff9c84

Browse files
Support to load Kohya-ss style LoRA file format (without restrictions)
Heavily based on huggingface#3756 by @takuma104 Co-Authored-By: Takuma Mori <[email protected]>
1 parent 3eb498e commit bff9c84

File tree

7 files changed

+221
-45
lines changed

7 files changed

+221
-45
lines changed

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -924,7 +924,7 @@ def load_model_hook(models, input_dir):
924924
else:
925925
raise ValueError(f"unexpected save model: {model.__class__}")
926926

927-
lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir)
927+
lora_state_dict, network_alpha, _ = LoraLoaderMixin.lora_state_dict(input_dir)
928928
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_)
929929
LoraLoaderMixin.load_lora_into_text_encoder(
930930
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -836,7 +836,7 @@ def load_model_hook(models, input_dir):
836836
else:
837837
raise ValueError(f"unexpected save model: {model.__class__}")
838838

839-
lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir)
839+
lora_state_dict, network_alpha, _ = LoraLoaderMixin.lora_state_dict(input_dir)
840840
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_)
841841
LoraLoaderMixin.load_lora_into_text_encoder(
842842
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_

src/diffusers/loaders.py

Lines changed: 101 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@
3232
LoRAAttnAddedKVProcessor,
3333
LoRAAttnProcessor,
3434
LoRAAttnProcessor2_0,
35-
LoRALinearLayer,
3635
LoRAXFormersAttnProcessor,
3736
SlicedAttnAddedKVProcessor,
3837
XFormersAttnProcessor,
3938
)
39+
from .models.lora import Conv2dWithLoRA, LinearWithLoRA, LoRAConv2dLayer, LoRALinearLayer
4040
from .utils import (
4141
DIFFUSERS_CACHE,
4242
HF_HUB_OFFLINE,
@@ -464,6 +464,36 @@ def save_function(weights, filename):
464464
save_function(state_dict, os.path.join(save_directory, weight_name))
465465
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
466466

467+
def _load_lora_aux(self, state_dict, network_alpha=None):
468+
lora_grouped_dict = defaultdict(dict)
469+
for key, value in state_dict.items():
470+
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
471+
lora_grouped_dict[attn_processor_key][sub_key] = value
472+
473+
for key, value_dict in lora_grouped_dict.items():
474+
rank = value_dict["lora.down.weight"].shape[0]
475+
hidden_size = value_dict["lora.up.weight"].shape[0]
476+
target_modules = [module for name, module in self.named_modules() if name == key]
477+
if len(target_modules) == 0:
478+
logger.warning(f"Could not find module {key} in the model. Skipping.")
479+
continue
480+
481+
target_module = target_modules[0]
482+
value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
483+
484+
lora = None
485+
if isinstance(target_module, Conv2dWithLoRA):
486+
lora = LoRAConv2dLayer(hidden_size, hidden_size, rank, network_alpha)
487+
elif isinstance(target_module, LinearWithLoRA):
488+
lora = LoRALinearLayer(target_module.in_features, target_module.out_features, rank, network_alpha)
489+
else:
490+
raise ValueError(f"Module {key} is not a Conv2dWithLoRA or LinearWithLoRA module.")
491+
lora.load_state_dict(value_dict)
492+
lora.to(device=self.device, dtype=self.dtype)
493+
494+
# install lora
495+
target_module.lora_layer = lora
496+
467497

468498
class TextualInversionLoaderMixin:
469499
r"""
@@ -825,10 +855,18 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
825855
kwargs:
826856
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
827857
"""
828-
state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
829-
self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet)
858+
state_dict, network_alpha, (unet_state_dict_aux, te_state_dict_aux) = self.lora_state_dict(
859+
pretrained_model_name_or_path_or_dict, **kwargs
860+
)
861+
self.load_lora_into_unet(
862+
state_dict, network_alpha=network_alpha, unet=self.unet, state_dict_aux=unet_state_dict_aux
863+
)
830864
self.load_lora_into_text_encoder(
831-
state_dict, network_alpha=network_alpha, text_encoder=self.text_encoder, lora_scale=self.lora_scale
865+
state_dict,
866+
network_alpha=network_alpha,
867+
text_encoder=self.text_encoder,
868+
lora_scale=self.lora_scale,
869+
state_dict_aux=te_state_dict_aux,
832870
)
833871

834872
@classmethod
@@ -962,13 +1000,14 @@ def lora_state_dict(
9621000

9631001
# Convert kohya-ss Style LoRA attn procs to diffusers attn procs
9641002
network_alpha = None
1003+
auxilary_states = ({}, {})
9651004
if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()):
966-
state_dict, network_alpha = cls._convert_kohya_lora_to_diffusers(state_dict)
1005+
state_dict, network_alpha, auxilary_states = cls._convert_kohya_lora_to_diffusers(state_dict)
9671006

968-
return state_dict, network_alpha
1007+
return state_dict, network_alpha, auxilary_states
9691008

9701009
@classmethod
971-
def load_lora_into_unet(cls, state_dict, network_alpha, unet):
1010+
def load_lora_into_unet(cls, state_dict, network_alpha, unet, aux_state_dict=None):
9721011
"""
9731012
This will load the LoRA layers specified in `state_dict` into `unet`
9741013
@@ -981,6 +1020,8 @@ def load_lora_into_unet(cls, state_dict, network_alpha, unet):
9811020
See `LoRALinearLayer` for more details.
9821021
unet (`UNet2DConditionModel`):
9831022
The UNet model to load the LoRA layers into.
1023+
aux_state_dict (`dict`, *optional*):
1024+
A dictionary containing the auxilary state (additional lora state) dict for the unet.
9841025
"""
9851026

9861027
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
@@ -1005,8 +1046,11 @@ def load_lora_into_unet(cls, state_dict, network_alpha, unet):
10051046
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
10061047
warnings.warn(warn_message)
10071048

1049+
if aux_state_dict:
1050+
unet._load_lora_aux(aux_state_dict, network_alpha=network_alpha)
1051+
10081052
@classmethod
1009-
def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lora_scale=1.0):
1053+
def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lora_scale=1.0, state_dict_aux=None):
10101054
"""
10111055
This will load the LoRA layers specified in `state_dict` into `text_encoder`
10121056
@@ -1021,6 +1065,8 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lo
10211065
lora_scale (`float`):
10221066
How much to scale the output of the lora linear layer before it is added with the output of the regular
10231067
lora layer.
1068+
state_dict_aux (`dict`, *optional*):
1069+
A dictionary containing the auxilary state dict (additional lora state) for the text encoder.
10241070
"""
10251071

10261072
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
@@ -1078,6 +1124,8 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lo
10781124
].shape[1]
10791125

10801126
cls._modify_text_encoder(text_encoder, lora_scale, network_alpha, rank=rank)
1127+
if state_dict_aux:
1128+
cls._load_lora_aux_for_text_encoder(text_encoder, state_dict_aux, network_alpha=network_alpha)
10811129

10821130
# set correct dtype & device
10831131
text_encoder_lora_state_dict = {
@@ -1109,6 +1157,37 @@ def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
11091157
attn_module.v_proj = attn_module.v_proj.regular_linear_layer
11101158
attn_module.out_proj = attn_module.out_proj.regular_linear_layer
11111159

1160+
@classmethod
1161+
def _load_lora_aux_for_text_encoder(cls, text_encoder, state_dict, network_alpha=None):
1162+
lora_grouped_dict = defaultdict(dict)
1163+
for key, value in state_dict.items():
1164+
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
1165+
lora_grouped_dict[attn_processor_key][sub_key] = value
1166+
1167+
for key, value_dict in lora_grouped_dict.items():
1168+
rank = value_dict["lora.down.weight"].shape[0]
1169+
target_modules = [module for name, module in text_encoder.named_modules() if name == key]
1170+
if len(target_modules) == 0:
1171+
logger.warning(f"Could not find module {key} in the model. Skipping.")
1172+
continue
1173+
1174+
target_module = target_modules[0]
1175+
value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
1176+
lora_layer = LoRALinearLayer(target_module.in_features, target_module.out_features, rank, network_alpha)
1177+
lora_layer.load_state_dict(value_dict)
1178+
lora_layer.to(device=text_encoder.device, dtype=text_encoder.dtype)
1179+
1180+
old_forward = target_module.forward
1181+
1182+
def make_new_forward(old_forward, lora_layer):
1183+
def new_forward(x):
1184+
return old_forward(x) + lora_layer(x)
1185+
1186+
return new_forward
1187+
1188+
# Monkey-patch.
1189+
target_module.forward = make_new_forward(old_forward, lora_layer)
1190+
11121191
@classmethod
11131192
def _modify_text_encoder(cls, text_encoder, lora_scale=1, network_alpha=None, rank=4, dtype=None):
11141193
r"""
@@ -1225,6 +1304,8 @@ def save_function(weights, filename):
12251304
def _convert_kohya_lora_to_diffusers(cls, state_dict):
12261305
unet_state_dict = {}
12271306
te_state_dict = {}
1307+
unet_state_dict_aux = {}
1308+
te_state_dict_aux = {}
12281309
network_alpha = None
12291310

12301311
for key, value in state_dict.items():
@@ -1249,12 +1330,20 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
12491330
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
12501331
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
12511332
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
1333+
diffusers_name = diffusers_name.replace("proj.in", "proj_in")
1334+
diffusers_name = diffusers_name.replace("proj.out", "proj_out")
12521335
if "transformer_blocks" in diffusers_name:
12531336
if "attn1" in diffusers_name or "attn2" in diffusers_name:
12541337
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
12551338
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
12561339
unet_state_dict[diffusers_name] = value
12571340
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
1341+
elif "ff" in diffusers_name:
1342+
unet_state_dict_aux[diffusers_name] = value
1343+
unet_state_dict_aux[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
1344+
elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
1345+
unet_state_dict_aux[diffusers_name] = value
1346+
unet_state_dict_aux[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
12581347
elif lora_name.startswith("lora_te_"):
12591348
diffusers_name = key.replace("lora_te_", "").replace("_", ".")
12601349
diffusers_name = diffusers_name.replace("text.model", "text_model")
@@ -1266,11 +1355,14 @@ def _convert_kohya_lora_to_diffusers(cls, state_dict):
12661355
if "self_attn" in diffusers_name:
12671356
te_state_dict[diffusers_name] = value
12681357
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
1358+
elif "mlp" in diffusers_name:
1359+
te_state_dict_aux[diffusers_name] = value
1360+
te_state_dict_aux[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
12691361

12701362
unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()}
12711363
te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()}
12721364
new_state_dict = {**unet_state_dict, **te_state_dict}
1273-
return new_state_dict, network_alpha
1365+
return new_state_dict, network_alpha, (unet_state_dict_aux, te_state_dict_aux)
12741366

12751367
def unload_lora_weights(self):
12761368
"""

src/diffusers/models/attention.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .activations import get_activation
2222
from .attention_processor import Attention
2323
from .embeddings import CombinedTimestepLabelEmbeddings
24+
from .lora import LinearWithLoRA
2425

2526

2627
@maybe_allow_in_graph
@@ -245,7 +246,7 @@ def __init__(
245246
# project dropout
246247
self.net.append(nn.Dropout(dropout))
247248
# project out
248-
self.net.append(nn.Linear(inner_dim, dim_out))
249+
self.net.append(LinearWithLoRA(inner_dim, dim_out))
249250
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
250251
if final_dropout:
251252
self.net.append(nn.Dropout(dropout))
@@ -289,7 +290,7 @@ class GEGLU(nn.Module):
289290

290291
def __init__(self, dim_in: int, dim_out: int):
291292
super().__init__()
292-
self.proj = nn.Linear(dim_in, dim_out * 2)
293+
self.proj = LinearWithLoRA(dim_in, dim_out * 2)
293294

294295
def gelu(self, gate):
295296
if gate.device.type != "mps":

src/diffusers/models/attention_processor.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from ..utils import deprecate, logging, maybe_allow_in_graph
2121
from ..utils.import_utils import is_xformers_available
22+
from .lora import LoRALinearLayer
2223

2324

2425
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -505,36 +506,6 @@ def __call__(
505506
return hidden_states
506507

507508

508-
class LoRALinearLayer(nn.Module):
509-
def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
510-
super().__init__()
511-
512-
if rank > min(in_features, out_features):
513-
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
514-
515-
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
516-
self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
517-
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
518-
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
519-
self.network_alpha = network_alpha
520-
self.rank = rank
521-
522-
nn.init.normal_(self.down.weight, std=1 / rank)
523-
nn.init.zeros_(self.up.weight)
524-
525-
def forward(self, hidden_states):
526-
orig_dtype = hidden_states.dtype
527-
dtype = self.down.weight.dtype
528-
529-
down_hidden_states = self.down(hidden_states.to(dtype))
530-
up_hidden_states = self.up(down_hidden_states)
531-
532-
if self.network_alpha is not None:
533-
up_hidden_states *= self.network_alpha / self.rank
534-
535-
return up_hidden_states.to(orig_dtype)
536-
537-
538509
class LoRAAttnProcessor(nn.Module):
539510
r"""
540511
Processor for implementing the LoRA attention mechanism.

0 commit comments

Comments
 (0)