@@ -85,7 +85,49 @@ def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=
8585
8686 self .lora_scale = lora_scale
8787
88+ def _fuse_lora (self ):
89+ if self .lora_linear_layer is None :
90+ return
91+
92+ dtype , device = self .regular_linear_layer .weight .data .dtype , self .regular_linear_layer .weight .data .device
93+ logger .info (f"Fusing LoRA weights for { self .__class__ } " )
94+
95+ w_orig = self .regular_linear_layer .weight .data .float ()
96+ w_up = self .lora_linear_layer .up .weight .data .float ()
97+ w_down = self .lora_linear_layer .down .weight .data .float ()
98+
99+ if self .lora_linear_layer .network_alpha is not None :
100+ w_up = w_up * self .lora_linear_layer .network_alpha / self .lora_linear_layer .rank
101+
102+ fused_weight = w_orig + torch .bmm (w_up [None , :], w_down [None , :])[0 ]
103+ self .regular_linear_layer .weight .data = fused_weight .to (device = device , dtype = dtype )
104+
105+ # we can drop the lora layer now
106+ self .lora_linear_layer = None
107+
108+ # offload the up and down matrices to CPU to not blow the memory
109+ self .w_up = w_up .cpu ()
110+ self .w_down = w_down .cpu ()
111+
112+ def _unfuse_lora (self ):
113+ if not (hasattr (self , "w_up" ) and hasattr (self , "w_down" )):
114+ return
115+ logger .info (f"Unfusing LoRA weights for { self .__class__ } " )
116+
117+ fused_weight = self .regular_linear_layer .weight .data
118+ dtype , device = fused_weight .dtype , fused_weight .device
119+
120+ self .w_up = self .w_up .to (device = device , dtype = dtype )
121+ self .w_down = self .w_down .to (device , dtype = dtype )
122+ unfused_weight = fused_weight - torch .bmm (self .w_up [None , :], self .w_down [None , :])[0 ]
123+ self .regular_linear_layer .weight .data = unfused_weight .to (device = device , dtype = dtype )
124+
125+ self .w_up = None
126+ self .w_down = None
127+
88128 def forward (self , input ):
129+ if self .lora_linear_layer is None :
130+ return self .regular_linear_layer (input )
89131 return self .regular_linear_layer (input ) + self .lora_scale * self .lora_linear_layer (input )
90132
91133
@@ -525,6 +567,20 @@ def save_function(weights, filename):
525567 save_function (state_dict , os .path .join (save_directory , weight_name ))
526568 logger .info (f"Model weights saved in { os .path .join (save_directory , weight_name )} " )
527569
570+ def fuse_lora (self ):
571+ self .apply (self ._fuse_lora_apply )
572+
573+ def _fuse_lora_apply (self , module ):
574+ if hasattr (module , "_fuse_lora" ):
575+ module ._fuse_lora ()
576+
577+ def unfuse_lora (self ):
578+ self .apply (self ._unfuse_lora_apply )
579+
580+ def _unfuse_lora_apply (self , module ):
581+ if hasattr (module , "_unfuse_lora" ):
582+ module ._unfuse_lora ()
583+
528584
529585class TextualInversionLoaderMixin :
530586 r"""
@@ -1712,6 +1768,83 @@ def unload_lora_weights(self):
17121768 # Safe to call the following regardless of LoRA.
17131769 self ._remove_text_encoder_monkey_patch ()
17141770
1771+ def fuse_lora (self , fuse_unet : bool = True , fuse_text_encoder : bool = True ):
1772+ r"""
1773+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
1774+
1775+ <Tip warning={true}>
1776+
1777+ This is an experimental API.
1778+
1779+ </Tip>
1780+
1781+ Args:
1782+ fuse_unet (`bool`, defaults to `True`): Whether to fuse the UNet LoRA parameters.
1783+ fuse_text_encoder (`bool`, defaults to `True`):
1784+ Whether to fuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
1785+ LoRA parameters then it won't have any effect.
1786+ """
1787+ if fuse_unet :
1788+ self .unet .fuse_lora ()
1789+
1790+ def fuse_text_encoder_lora (text_encoder ):
1791+ for _ , attn_module in text_encoder_attn_modules (text_encoder ):
1792+ if isinstance (attn_module .q_proj , PatchedLoraProjection ):
1793+ attn_module .q_proj ._fuse_lora ()
1794+ attn_module .k_proj ._fuse_lora ()
1795+ attn_module .v_proj ._fuse_lora ()
1796+ attn_module .out_proj ._fuse_lora ()
1797+
1798+ for _ , mlp_module in text_encoder_mlp_modules (text_encoder ):
1799+ if isinstance (mlp_module .fc1 , PatchedLoraProjection ):
1800+ mlp_module .fc1 ._fuse_lora ()
1801+ mlp_module .fc2 ._fuse_lora ()
1802+
1803+ if fuse_text_encoder :
1804+ if hasattr (self , "text_encoder" ):
1805+ fuse_text_encoder_lora (self .text_encoder )
1806+ if hasattr (self , "text_encoder_2" ):
1807+ fuse_text_encoder_lora (self .text_encoder_2 )
1808+
1809+ def unfuse_lora (self , unfuse_unet : bool = True , unfuse_text_encoder : bool = True ):
1810+ r"""
1811+ Reverses the effect of
1812+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora).
1813+
1814+ <Tip warning={true}>
1815+
1816+ This is an experimental API.
1817+
1818+ </Tip>
1819+
1820+ Args:
1821+ unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
1822+ unfuse_text_encoder (`bool`, defaults to `True`):
1823+ Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
1824+ LoRA parameters then it won't have any effect.
1825+ """
1826+ if unfuse_unet :
1827+ self .unet .unfuse_lora ()
1828+
1829+ def unfuse_text_encoder_lora (text_encoder ):
1830+ for _ , attn_module in text_encoder_attn_modules (text_encoder ):
1831+ if isinstance (attn_module .q_proj , PatchedLoraProjection ):
1832+ attn_module .q_proj ._unfuse_lora ()
1833+ attn_module .k_proj ._unfuse_lora ()
1834+ attn_module .v_proj ._unfuse_lora ()
1835+ attn_module .out_proj ._unfuse_lora ()
1836+
1837+ for _ , mlp_module in text_encoder_mlp_modules (text_encoder ):
1838+ if isinstance (mlp_module .fc1 , PatchedLoraProjection ):
1839+ mlp_module .fc1 ._unfuse_lora ()
1840+ mlp_module .fc2 ._unfuse_lora ()
1841+
1842+ if unfuse_text_encoder :
1843+ if hasattr (self , "text_encoder" ):
1844+ unfuse_text_encoder_lora (self .text_encoder )
1845+ if hasattr (self , "text_encoder_2" ):
1846+ unfuse_text_encoder_lora (self .text_encoder_2 )
1847+
17151848
17161849class FromSingleFileMixin :
17171850 """
0 commit comments