Skip to content

Commit c583f3b

Browse files
Fuse loras (#4473)
* Fuse loras * initial implementation. * add slow test one. * styling * add: test for checking efficiency * print * position * place model offload correctly * style * style. * unfuse test. * final checks * remove warning test * remove warnings altogether * debugging * tighten up tests. * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * denugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debuging * debugging * debugging * debugging * suit up the generator initialization a bit. * remove print * update assertion. * debugging * remove print. * fix: assertions. * style * can generator be a problem? * generator * correct tests. * support text encoder lora fusion. * tighten up tests. --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 12358b9 commit c583f3b

File tree

3 files changed

+456
-5
lines changed

3 files changed

+456
-5
lines changed

src/diffusers/loaders.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,49 @@ def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=
8585

8686
self.lora_scale = lora_scale
8787

88+
def _fuse_lora(self):
89+
if self.lora_linear_layer is None:
90+
return
91+
92+
dtype, device = self.regular_linear_layer.weight.data.dtype, self.regular_linear_layer.weight.data.device
93+
logger.info(f"Fusing LoRA weights for {self.__class__}")
94+
95+
w_orig = self.regular_linear_layer.weight.data.float()
96+
w_up = self.lora_linear_layer.up.weight.data.float()
97+
w_down = self.lora_linear_layer.down.weight.data.float()
98+
99+
if self.lora_linear_layer.network_alpha is not None:
100+
w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank
101+
102+
fused_weight = w_orig + torch.bmm(w_up[None, :], w_down[None, :])[0]
103+
self.regular_linear_layer.weight.data = fused_weight.to(device=device, dtype=dtype)
104+
105+
# we can drop the lora layer now
106+
self.lora_linear_layer = None
107+
108+
# offload the up and down matrices to CPU to not blow the memory
109+
self.w_up = w_up.cpu()
110+
self.w_down = w_down.cpu()
111+
112+
def _unfuse_lora(self):
113+
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
114+
return
115+
logger.info(f"Unfusing LoRA weights for {self.__class__}")
116+
117+
fused_weight = self.regular_linear_layer.weight.data
118+
dtype, device = fused_weight.dtype, fused_weight.device
119+
120+
self.w_up = self.w_up.to(device=device, dtype=dtype)
121+
self.w_down = self.w_down.to(device, dtype=dtype)
122+
unfused_weight = fused_weight - torch.bmm(self.w_up[None, :], self.w_down[None, :])[0]
123+
self.regular_linear_layer.weight.data = unfused_weight.to(device=device, dtype=dtype)
124+
125+
self.w_up = None
126+
self.w_down = None
127+
88128
def forward(self, input):
129+
if self.lora_linear_layer is None:
130+
return self.regular_linear_layer(input)
89131
return self.regular_linear_layer(input) + self.lora_scale * self.lora_linear_layer(input)
90132

91133

@@ -525,6 +567,20 @@ def save_function(weights, filename):
525567
save_function(state_dict, os.path.join(save_directory, weight_name))
526568
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
527569

570+
def fuse_lora(self):
571+
self.apply(self._fuse_lora_apply)
572+
573+
def _fuse_lora_apply(self, module):
574+
if hasattr(module, "_fuse_lora"):
575+
module._fuse_lora()
576+
577+
def unfuse_lora(self):
578+
self.apply(self._unfuse_lora_apply)
579+
580+
def _unfuse_lora_apply(self, module):
581+
if hasattr(module, "_unfuse_lora"):
582+
module._unfuse_lora()
583+
528584

529585
class TextualInversionLoaderMixin:
530586
r"""
@@ -1712,6 +1768,83 @@ def unload_lora_weights(self):
17121768
# Safe to call the following regardless of LoRA.
17131769
self._remove_text_encoder_monkey_patch()
17141770

1771+
def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True):
1772+
r"""
1773+
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
1774+
1775+
<Tip warning={true}>
1776+
1777+
This is an experimental API.
1778+
1779+
</Tip>
1780+
1781+
Args:
1782+
fuse_unet (`bool`, defaults to `True`): Whether to fuse the UNet LoRA parameters.
1783+
fuse_text_encoder (`bool`, defaults to `True`):
1784+
Whether to fuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
1785+
LoRA parameters then it won't have any effect.
1786+
"""
1787+
if fuse_unet:
1788+
self.unet.fuse_lora()
1789+
1790+
def fuse_text_encoder_lora(text_encoder):
1791+
for _, attn_module in text_encoder_attn_modules(text_encoder):
1792+
if isinstance(attn_module.q_proj, PatchedLoraProjection):
1793+
attn_module.q_proj._fuse_lora()
1794+
attn_module.k_proj._fuse_lora()
1795+
attn_module.v_proj._fuse_lora()
1796+
attn_module.out_proj._fuse_lora()
1797+
1798+
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
1799+
if isinstance(mlp_module.fc1, PatchedLoraProjection):
1800+
mlp_module.fc1._fuse_lora()
1801+
mlp_module.fc2._fuse_lora()
1802+
1803+
if fuse_text_encoder:
1804+
if hasattr(self, "text_encoder"):
1805+
fuse_text_encoder_lora(self.text_encoder)
1806+
if hasattr(self, "text_encoder_2"):
1807+
fuse_text_encoder_lora(self.text_encoder_2)
1808+
1809+
def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True):
1810+
r"""
1811+
Reverses the effect of
1812+
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora).
1813+
1814+
<Tip warning={true}>
1815+
1816+
This is an experimental API.
1817+
1818+
</Tip>
1819+
1820+
Args:
1821+
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
1822+
unfuse_text_encoder (`bool`, defaults to `True`):
1823+
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
1824+
LoRA parameters then it won't have any effect.
1825+
"""
1826+
if unfuse_unet:
1827+
self.unet.unfuse_lora()
1828+
1829+
def unfuse_text_encoder_lora(text_encoder):
1830+
for _, attn_module in text_encoder_attn_modules(text_encoder):
1831+
if isinstance(attn_module.q_proj, PatchedLoraProjection):
1832+
attn_module.q_proj._unfuse_lora()
1833+
attn_module.k_proj._unfuse_lora()
1834+
attn_module.v_proj._unfuse_lora()
1835+
attn_module.out_proj._unfuse_lora()
1836+
1837+
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
1838+
if isinstance(mlp_module.fc1, PatchedLoraProjection):
1839+
mlp_module.fc1._unfuse_lora()
1840+
mlp_module.fc2._unfuse_lora()
1841+
1842+
if unfuse_text_encoder:
1843+
if hasattr(self, "text_encoder"):
1844+
unfuse_text_encoder_lora(self.text_encoder)
1845+
if hasattr(self, "text_encoder_2"):
1846+
unfuse_text_encoder_lora(self.text_encoder_2)
1847+
17151848

17161849
class FromSingleFileMixin:
17171850
"""

src/diffusers/models/lora.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,15 @@
1414

1515
from typing import Optional
1616

17+
import torch
1718
import torch.nn.functional as F
1819
from torch import nn
1920

21+
from ..utils import logging
22+
23+
24+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25+
2026

2127
class LoRALinearLayer(nn.Module):
2228
def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
@@ -91,6 +97,51 @@ def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs
9197
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
9298
self.lora_layer = lora_layer
9399

100+
def _fuse_lora(self):
101+
if self.lora_layer is None:
102+
return
103+
104+
dtype, device = self.weight.data.dtype, self.weight.data.device
105+
logger.info(f"Fusing LoRA weights for {self.__class__}")
106+
107+
w_orig = self.weight.data.float()
108+
w_up = self.lora_layer.up.weight.data.float()
109+
w_down = self.lora_layer.down.weight.data.float()
110+
111+
if self.lora_layer.network_alpha is not None:
112+
w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank
113+
114+
fusion = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1))
115+
fusion = fusion.reshape((w_orig.shape))
116+
fused_weight = w_orig + fusion
117+
self.weight.data = fused_weight.to(device=device, dtype=dtype)
118+
119+
# we can drop the lora layer now
120+
self.lora_layer = None
121+
122+
# offload the up and down matrices to CPU to not blow the memory
123+
self.w_up = w_up.cpu()
124+
self.w_down = w_down.cpu()
125+
126+
def _unfuse_lora(self):
127+
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
128+
return
129+
logger.info(f"Unfusing LoRA weights for {self.__class__}")
130+
131+
fused_weight = self.weight.data
132+
dtype, device = fused_weight.data.dtype, fused_weight.data.device
133+
134+
self.w_up = self.w_up.to(device=device, dtype=dtype)
135+
self.w_down = self.w_down.to(device, dtype=dtype)
136+
137+
fusion = torch.mm(self.w_up.flatten(start_dim=1), self.w_down.flatten(start_dim=1))
138+
fusion = fusion.reshape((fused_weight.shape))
139+
unfused_weight = fused_weight - fusion
140+
self.weight.data = unfused_weight.to(device=device, dtype=dtype)
141+
142+
self.w_up = None
143+
self.w_down = None
144+
94145
def forward(self, x):
95146
if self.lora_layer is None:
96147
# make sure to the functional Conv2D function as otherwise torch.compile's graph will break
@@ -109,9 +160,49 @@ def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs
109160
super().__init__(*args, **kwargs)
110161
self.lora_layer = lora_layer
111162

112-
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
163+
def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
113164
self.lora_layer = lora_layer
114165

166+
def _fuse_lora(self):
167+
if self.lora_layer is None:
168+
return
169+
170+
dtype, device = self.weight.data.dtype, self.weight.data.device
171+
logger.info(f"Fusing LoRA weights for {self.__class__}")
172+
173+
w_orig = self.weight.data.float()
174+
w_up = self.lora_layer.up.weight.data.float()
175+
w_down = self.lora_layer.down.weight.data.float()
176+
177+
if self.lora_layer.network_alpha is not None:
178+
w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank
179+
180+
fused_weight = w_orig + torch.bmm(w_up[None, :], w_down[None, :])[0]
181+
self.weight.data = fused_weight.to(device=device, dtype=dtype)
182+
183+
# we can drop the lora layer now
184+
self.lora_layer = None
185+
186+
# offload the up and down matrices to CPU to not blow the memory
187+
self.w_up = w_up.cpu()
188+
self.w_down = w_down.cpu()
189+
190+
def _unfuse_lora(self):
191+
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
192+
return
193+
logger.info(f"Unfusing LoRA weights for {self.__class__}")
194+
195+
fused_weight = self.weight.data
196+
dtype, device = fused_weight.dtype, fused_weight.device
197+
198+
self.w_up = self.w_up.to(device=device, dtype=dtype)
199+
self.w_down = self.w_down.to(device, dtype=dtype)
200+
unfused_weight = fused_weight - torch.bmm(self.w_up[None, :], self.w_down[None, :])[0]
201+
self.weight.data = unfused_weight.to(device=device, dtype=dtype)
202+
203+
self.w_up = None
204+
self.w_down = None
205+
115206
def forward(self, hidden_states, lora_scale: int = 1):
116207
if self.lora_layer is None:
117208
return super().forward(hidden_states)

0 commit comments

Comments
 (0)