Skip to content

Commit c81a88b

Browse files
[Core] LoRA improvements pt. 3 (#4842)
* throw warning when more than one lora is attempted to be fused. * introduce support of lora scale during fusion. * change test name * changes * change to _lora_scale * lora_scale to call whenever applicable. * debugging * lora_scale additional. * cross_attention_kwargs * lora_scale -> scale. * lora_scale fix * lora_scale in patched projection. * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * styling. * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * remove unneeded prints. * remove unneeded prints. * assign cross_attention_kwargs. * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * clean up. * refactor scale retrieval logic a bit. * fix nonetypw * fix: tests * add more tests * more fixes. * figure out a way to pass lora_scale. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * unify the retrieval logic of lora_scale. * move adjust_lora_scale_text_encoder to lora.py. * introduce dynamic adjustment lora scale support to sd * fix up copies * Empty-Commit * add: test to check fusion equivalence on different scales. * handle lora fusion warning. * make lora smaller * make lora smaller * make lora smaller --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 2c1677e commit c81a88b

File tree

45 files changed

+588
-172
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+588
-172
lines changed

src/diffusers/loaders.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
9595

9696
return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)
9797

98-
def _fuse_lora(self):
98+
def _fuse_lora(self, lora_scale=1.0):
9999
if self.lora_linear_layer is None:
100100
return
101101

@@ -108,7 +108,7 @@ def _fuse_lora(self):
108108
if self.lora_linear_layer.network_alpha is not None:
109109
w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank
110110

111-
fused_weight = w_orig + torch.bmm(w_up[None, :], w_down[None, :])[0]
111+
fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
112112
self.regular_linear_layer.weight.data = fused_weight.to(device=device, dtype=dtype)
113113

114114
# we can drop the lora layer now
@@ -117,6 +117,7 @@ def _fuse_lora(self):
117117
# offload the up and down matrices to CPU to not blow the memory
118118
self.w_up = w_up.cpu()
119119
self.w_down = w_down.cpu()
120+
self.lora_scale = lora_scale
120121

121122
def _unfuse_lora(self):
122123
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
@@ -128,16 +129,19 @@ def _unfuse_lora(self):
128129
w_up = self.w_up.to(device=device).float()
129130
w_down = self.w_down.to(device).float()
130131

131-
unfused_weight = fused_weight.float() - torch.bmm(w_up[None, :], w_down[None, :])[0]
132+
unfused_weight = fused_weight.float() - (self.lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
132133
self.regular_linear_layer.weight.data = unfused_weight.to(device=device, dtype=dtype)
133134

134135
self.w_up = None
135136
self.w_down = None
136137

137138
def forward(self, input):
139+
# print(f"{self.__class__.__name__} has a lora_scale of {self.lora_scale}")
140+
if self.lora_scale is None:
141+
self.lora_scale = 1.0
138142
if self.lora_linear_layer is None:
139143
return self.regular_linear_layer(input)
140-
return self.regular_linear_layer(input) + self.lora_scale * self.lora_linear_layer(input)
144+
return self.regular_linear_layer(input) + (self.lora_scale * self.lora_linear_layer(input))
141145

142146

143147
def text_encoder_attn_modules(text_encoder):
@@ -576,12 +580,13 @@ def save_function(weights, filename):
576580
save_function(state_dict, os.path.join(save_directory, weight_name))
577581
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
578582

579-
def fuse_lora(self):
583+
def fuse_lora(self, lora_scale=1.0):
584+
self.lora_scale = lora_scale
580585
self.apply(self._fuse_lora_apply)
581586

582587
def _fuse_lora_apply(self, module):
583588
if hasattr(module, "_fuse_lora"):
584-
module._fuse_lora()
589+
module._fuse_lora(self.lora_scale)
585590

586591
def unfuse_lora(self):
587592
self.apply(self._unfuse_lora_apply)
@@ -924,6 +929,7 @@ class LoraLoaderMixin:
924929
"""
925930
text_encoder_name = TEXT_ENCODER_NAME
926931
unet_name = UNET_NAME
932+
num_fused_loras = 0
927933

928934
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
929935
"""
@@ -1807,7 +1813,7 @@ def unload_lora_weights(self):
18071813
# Safe to call the following regardless of LoRA.
18081814
self._remove_text_encoder_monkey_patch()
18091815

1810-
def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True):
1816+
def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora_scale: float = 1.0):
18111817
r"""
18121818
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
18131819
@@ -1822,22 +1828,31 @@ def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True):
18221828
fuse_text_encoder (`bool`, defaults to `True`):
18231829
Whether to fuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
18241830
LoRA parameters then it won't have any effect.
1831+
lora_scale (`float`, defaults to 1.0):
1832+
Controls how much to influence the outputs with the LoRA parameters.
18251833
"""
1834+
if fuse_unet or fuse_text_encoder:
1835+
self.num_fused_loras += 1
1836+
if self.num_fused_loras > 1:
1837+
logger.warn(
1838+
"The current API is supported for operating with a single LoRA file. You are trying to load and fuse more than one LoRA which is not well-supported.",
1839+
)
1840+
18261841
if fuse_unet:
1827-
self.unet.fuse_lora()
1842+
self.unet.fuse_lora(lora_scale)
18281843

18291844
def fuse_text_encoder_lora(text_encoder):
18301845
for _, attn_module in text_encoder_attn_modules(text_encoder):
18311846
if isinstance(attn_module.q_proj, PatchedLoraProjection):
1832-
attn_module.q_proj._fuse_lora()
1833-
attn_module.k_proj._fuse_lora()
1834-
attn_module.v_proj._fuse_lora()
1835-
attn_module.out_proj._fuse_lora()
1847+
attn_module.q_proj._fuse_lora(lora_scale)
1848+
attn_module.k_proj._fuse_lora(lora_scale)
1849+
attn_module.v_proj._fuse_lora(lora_scale)
1850+
attn_module.out_proj._fuse_lora(lora_scale)
18361851

18371852
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
18381853
if isinstance(mlp_module.fc1, PatchedLoraProjection):
1839-
mlp_module.fc1._fuse_lora()
1840-
mlp_module.fc2._fuse_lora()
1854+
mlp_module.fc1._fuse_lora(lora_scale)
1855+
mlp_module.fc2._fuse_lora(lora_scale)
18411856

18421857
if fuse_text_encoder:
18431858
if hasattr(self, "text_encoder"):
@@ -1884,6 +1899,8 @@ def unfuse_text_encoder_lora(text_encoder):
18841899
if hasattr(self, "text_encoder_2"):
18851900
unfuse_text_encoder_lora(self.text_encoder_2)
18861901

1902+
self.num_fused_loras -= 1
1903+
18871904

18881905
class FromSingleFileMixin:
18891906
"""

src/diffusers/models/attention.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def forward(
177177
class_labels: Optional[torch.LongTensor] = None,
178178
):
179179
# Notice that normalization is always applied before the real computation in the following blocks.
180-
# 1. Self-Attention
180+
# 0. Self-Attention
181181
if self.use_ada_layer_norm:
182182
norm_hidden_states = self.norm1(hidden_states, timestep)
183183
elif self.use_ada_layer_norm_zero:
@@ -187,7 +187,10 @@ def forward(
187187
else:
188188
norm_hidden_states = self.norm1(hidden_states)
189189

190-
# 0. Prepare GLIGEN inputs
190+
# 1. Retrieve lora scale.
191+
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
192+
193+
# 2. Prepare GLIGEN inputs
191194
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
192195
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
193196

@@ -201,12 +204,12 @@ def forward(
201204
attn_output = gate_msa.unsqueeze(1) * attn_output
202205
hidden_states = attn_output + hidden_states
203206

204-
# 1.5 GLIGEN Control
207+
# 2.5 GLIGEN Control
205208
if gligen_kwargs is not None:
206209
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
207-
# 1.5 ends
210+
# 2.5 ends
208211

209-
# 2. Cross-Attention
212+
# 3. Cross-Attention
210213
if self.attn2 is not None:
211214
norm_hidden_states = (
212215
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
@@ -220,7 +223,7 @@ def forward(
220223
)
221224
hidden_states = attn_output + hidden_states
222225

223-
# 3. Feed-forward
226+
# 4. Feed-forward
224227
norm_hidden_states = self.norm3(hidden_states)
225228

226229
if self.use_ada_layer_norm_zero:
@@ -235,11 +238,14 @@ def forward(
235238

236239
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
237240
ff_output = torch.cat(
238-
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
241+
[
242+
self.ff(hid_slice, scale=lora_scale)
243+
for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
244+
],
239245
dim=self._chunk_dim,
240246
)
241247
else:
242-
ff_output = self.ff(norm_hidden_states)
248+
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
243249

244250
if self.use_ada_layer_norm_zero:
245251
ff_output = gate_mlp.unsqueeze(1) * ff_output
@@ -295,9 +301,12 @@ def __init__(
295301
if final_dropout:
296302
self.net.append(nn.Dropout(dropout))
297303

298-
def forward(self, hidden_states):
304+
def forward(self, hidden_states, scale: float = 1.0):
299305
for module in self.net:
300-
hidden_states = module(hidden_states)
306+
if isinstance(module, (LoRACompatibleLinear, GEGLU)):
307+
hidden_states = module(hidden_states, scale)
308+
else:
309+
hidden_states = module(hidden_states)
301310
return hidden_states
302311

303312

@@ -342,8 +351,8 @@ def gelu(self, gate):
342351
# mps: gelu is not implemented for float16
343352
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
344353

345-
def forward(self, hidden_states):
346-
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
354+
def forward(self, hidden_states, scale: float = 1.0):
355+
hidden_states, gate = self.proj(hidden_states, scale).chunk(2, dim=-1)
347356
return hidden_states * self.gelu(gate)
348357

349358

src/diffusers/models/attention_processor.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -570,15 +570,15 @@ def __call__(
570570
if attn.group_norm is not None:
571571
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
572572

573-
query = attn.to_q(hidden_states, lora_scale=scale)
573+
query = attn.to_q(hidden_states, scale=scale)
574574

575575
if encoder_hidden_states is None:
576576
encoder_hidden_states = hidden_states
577577
elif attn.norm_cross:
578578
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
579579

580-
key = attn.to_k(encoder_hidden_states, lora_scale=scale)
581-
value = attn.to_v(encoder_hidden_states, lora_scale=scale)
580+
key = attn.to_k(encoder_hidden_states, scale=scale)
581+
value = attn.to_v(encoder_hidden_states, scale=scale)
582582

583583
query = attn.head_to_batch_dim(query)
584584
key = attn.head_to_batch_dim(key)
@@ -589,7 +589,7 @@ def __call__(
589589
hidden_states = attn.batch_to_head_dim(hidden_states)
590590

591591
# linear proj
592-
hidden_states = attn.to_out[0](hidden_states, lora_scale=scale)
592+
hidden_states = attn.to_out[0](hidden_states, scale=scale)
593593
# dropout
594594
hidden_states = attn.to_out[1](hidden_states)
595595

@@ -722,17 +722,17 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
722722

723723
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
724724

725-
query = attn.to_q(hidden_states, lora_scale=scale)
725+
query = attn.to_q(hidden_states, scale=scale)
726726
query = attn.head_to_batch_dim(query)
727727

728-
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, lora_scale=scale)
729-
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, lora_scale=scale)
728+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, scale=scale)
729+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, scale=scale)
730730
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
731731
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
732732

733733
if not attn.only_cross_attention:
734-
key = attn.to_k(hidden_states, lora_scale=scale)
735-
value = attn.to_v(hidden_states, lora_scale=scale)
734+
key = attn.to_k(hidden_states, scale=scale)
735+
value = attn.to_v(hidden_states, scale=scale)
736736
key = attn.head_to_batch_dim(key)
737737
value = attn.head_to_batch_dim(value)
738738
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
@@ -746,7 +746,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
746746
hidden_states = attn.batch_to_head_dim(hidden_states)
747747

748748
# linear proj
749-
hidden_states = attn.to_out[0](hidden_states, lora_scale=scale)
749+
hidden_states = attn.to_out[0](hidden_states, scale=scale)
750750
# dropout
751751
hidden_states = attn.to_out[1](hidden_states)
752752

@@ -782,7 +782,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
782782

783783
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
784784

785-
query = attn.to_q(hidden_states, lora_scale=scale)
785+
query = attn.to_q(hidden_states, scale=scale)
786786
query = attn.head_to_batch_dim(query, out_dim=4)
787787

788788
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
@@ -791,8 +791,8 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
791791
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
792792

793793
if not attn.only_cross_attention:
794-
key = attn.to_k(hidden_states, lora_scale=scale)
795-
value = attn.to_v(hidden_states, lora_scale=scale)
794+
key = attn.to_k(hidden_states, scale=scale)
795+
value = attn.to_v(hidden_states, scale=scale)
796796
key = attn.head_to_batch_dim(key, out_dim=4)
797797
value = attn.head_to_batch_dim(value, out_dim=4)
798798
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
@@ -809,7 +809,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
809809
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
810810

811811
# linear proj
812-
hidden_states = attn.to_out[0](hidden_states, lora_scale=scale)
812+
hidden_states = attn.to_out[0](hidden_states, scale=scale)
813813
# dropout
814814
hidden_states = attn.to_out[1](hidden_states)
815815

@@ -937,15 +937,15 @@ def __call__(
937937
if attn.group_norm is not None:
938938
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
939939

940-
query = attn.to_q(hidden_states, lora_scale=scale)
940+
query = attn.to_q(hidden_states, scale=scale)
941941

942942
if encoder_hidden_states is None:
943943
encoder_hidden_states = hidden_states
944944
elif attn.norm_cross:
945945
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
946946

947-
key = attn.to_k(encoder_hidden_states, lora_scale=scale)
948-
value = attn.to_v(encoder_hidden_states, lora_scale=scale)
947+
key = attn.to_k(encoder_hidden_states, scale=scale)
948+
value = attn.to_v(encoder_hidden_states, scale=scale)
949949

950950
query = attn.head_to_batch_dim(query).contiguous()
951951
key = attn.head_to_batch_dim(key).contiguous()
@@ -958,7 +958,7 @@ def __call__(
958958
hidden_states = attn.batch_to_head_dim(hidden_states)
959959

960960
# linear proj
961-
hidden_states = attn.to_out[0](hidden_states, lora_scale=scale)
961+
hidden_states = attn.to_out[0](hidden_states, scale=scale)
962962
# dropout
963963
hidden_states = attn.to_out[1](hidden_states)
964964

@@ -1015,15 +1015,15 @@ def __call__(
10151015
if attn.group_norm is not None:
10161016
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
10171017

1018-
query = attn.to_q(hidden_states, lora_scale=scale)
1018+
query = attn.to_q(hidden_states, scale=scale)
10191019

10201020
if encoder_hidden_states is None:
10211021
encoder_hidden_states = hidden_states
10221022
elif attn.norm_cross:
10231023
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
10241024

1025-
key = attn.to_k(encoder_hidden_states, lora_scale=scale)
1026-
value = attn.to_v(encoder_hidden_states, lora_scale=scale)
1025+
key = attn.to_k(encoder_hidden_states, scale=scale)
1026+
value = attn.to_v(encoder_hidden_states, scale=scale)
10271027

10281028
inner_dim = key.shape[-1]
10291029
head_dim = inner_dim // attn.heads
@@ -1043,7 +1043,7 @@ def __call__(
10431043
hidden_states = hidden_states.to(query.dtype)
10441044

10451045
# linear proj
1046-
hidden_states = attn.to_out[0](hidden_states, lora_scale=scale)
1046+
hidden_states = attn.to_out[0](hidden_states, scale=scale)
10471047
# dropout
10481048
hidden_states = attn.to_out[1](hidden_states)
10491049

0 commit comments

Comments
 (0)