-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Fix Unfuse Lora #4833
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Unfuse Lora #4833
Changes from all commits
27d1ad4
aec0239
682f664
e796874
3bccb98
8878915
0e63378
17474c4
3c2889f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -85,12 +85,21 @@ def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank= | |
|
|
||
| self.lora_scale = lora_scale | ||
|
|
||
| # overwrite PyTorch's `state_dict` to be sure that only the 'regular_linear_layer' weights are saved | ||
| # when saving the whole text encoder model and when LoRA is unloaded or fused | ||
| def state_dict(self, *args, destination=None, prefix="", keep_vars=False): | ||
| if self.lora_linear_layer is None: | ||
| return self.regular_linear_layer.state_dict( | ||
| *args, destination=destination, prefix=prefix, keep_vars=keep_vars | ||
| ) | ||
|
|
||
| return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars) | ||
|
|
||
| def _fuse_lora(self): | ||
| if self.lora_linear_layer is None: | ||
| return | ||
|
|
||
| dtype, device = self.regular_linear_layer.weight.data.dtype, self.regular_linear_layer.weight.data.device | ||
| logger.info(f"Fusing LoRA weights for {self.__class__}") | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| w_orig = self.regular_linear_layer.weight.data.float() | ||
| w_up = self.lora_linear_layer.up.weight.data.float() | ||
|
|
@@ -112,14 +121,14 @@ def _fuse_lora(self): | |
| def _unfuse_lora(self): | ||
| if not (hasattr(self, "w_up") and hasattr(self, "w_down")): | ||
| return | ||
| logger.info(f"Unfusing LoRA weights for {self.__class__}") | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| fused_weight = self.regular_linear_layer.weight.data | ||
| dtype, device = fused_weight.dtype, fused_weight.device | ||
|
|
||
| self.w_up = self.w_up.to(device=device, dtype=dtype) | ||
| self.w_down = self.w_down.to(device, dtype=dtype) | ||
| unfused_weight = fused_weight - torch.bmm(self.w_up[None, :], self.w_down[None, :])[0] | ||
| w_up = self.w_up.to(device=device).float() | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Improve precision to fp32 |
||
| w_down = self.w_down.to(device).float() | ||
|
|
||
| unfused_weight = fused_weight.float() - torch.bmm(w_up[None, :], w_down[None, :])[0] | ||
| self.regular_linear_layer.weight.data = unfused_weight.to(device=device, dtype=dtype) | ||
|
|
||
| self.w_up = None | ||
|
|
@@ -1405,15 +1414,15 @@ def _remove_text_encoder_monkey_patch(self): | |
| def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder): | ||
| for _, attn_module in text_encoder_attn_modules(text_encoder): | ||
| if isinstance(attn_module.q_proj, PatchedLoraProjection): | ||
| attn_module.q_proj = attn_module.q_proj.regular_linear_layer | ||
| attn_module.k_proj = attn_module.k_proj.regular_linear_layer | ||
| attn_module.v_proj = attn_module.v_proj.regular_linear_layer | ||
| attn_module.out_proj = attn_module.out_proj.regular_linear_layer | ||
| attn_module.q_proj.lora_linear_layer = None | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we remove the whole patched layer when unloading LoRAs we cannot unfuse it anymore. Let's just set the LoRA linear layer to |
||
| attn_module.k_proj.lora_linear_layer = None | ||
| attn_module.v_proj.lora_linear_layer = None | ||
| attn_module.out_proj.lora_linear_layer = None | ||
|
|
||
| for _, mlp_module in text_encoder_mlp_modules(text_encoder): | ||
| if isinstance(mlp_module.fc1, PatchedLoraProjection): | ||
| mlp_module.fc1 = mlp_module.fc1.regular_linear_layer | ||
| mlp_module.fc2 = mlp_module.fc2.regular_linear_layer | ||
| mlp_module.fc1.lora_linear_layer = None | ||
| mlp_module.fc2.lora_linear_layer = None | ||
|
|
||
| @classmethod | ||
| def _modify_text_encoder( | ||
|
|
@@ -1447,23 +1456,43 @@ def _modify_text_encoder( | |
| else: | ||
| current_rank = rank | ||
|
|
||
| q_linear_layer = ( | ||
| attn_module.q_proj.regular_linear_layer | ||
| if isinstance(attn_module.q_proj, PatchedLoraProjection) | ||
| else attn_module.q_proj | ||
| ) | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| attn_module.q_proj = PatchedLoraProjection( | ||
| attn_module.q_proj, lora_scale, network_alpha=query_alpha, rank=current_rank, dtype=dtype | ||
| q_linear_layer, lora_scale, network_alpha=query_alpha, rank=current_rank, dtype=dtype | ||
| ) | ||
| lora_parameters.extend(attn_module.q_proj.lora_linear_layer.parameters()) | ||
|
|
||
| k_linear_layer = ( | ||
| attn_module.k_proj.regular_linear_layer | ||
| if isinstance(attn_module.k_proj, PatchedLoraProjection) | ||
| else attn_module.k_proj | ||
| ) | ||
| attn_module.k_proj = PatchedLoraProjection( | ||
| attn_module.k_proj, lora_scale, network_alpha=key_alpha, rank=current_rank, dtype=dtype | ||
| k_linear_layer, lora_scale, network_alpha=key_alpha, rank=current_rank, dtype=dtype | ||
| ) | ||
| lora_parameters.extend(attn_module.k_proj.lora_linear_layer.parameters()) | ||
|
|
||
| v_linear_layer = ( | ||
| attn_module.v_proj.regular_linear_layer | ||
| if isinstance(attn_module.v_proj, PatchedLoraProjection) | ||
| else attn_module.v_proj | ||
| ) | ||
| attn_module.v_proj = PatchedLoraProjection( | ||
| attn_module.v_proj, lora_scale, network_alpha=value_alpha, rank=current_rank, dtype=dtype | ||
| v_linear_layer, lora_scale, network_alpha=value_alpha, rank=current_rank, dtype=dtype | ||
| ) | ||
| lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters()) | ||
|
|
||
| out_linear_layer = ( | ||
| attn_module.out_proj.regular_linear_layer | ||
| if isinstance(attn_module.out_proj, PatchedLoraProjection) | ||
| else attn_module.out_proj | ||
| ) | ||
| attn_module.out_proj = PatchedLoraProjection( | ||
| attn_module.out_proj, lora_scale, network_alpha=out_alpha, rank=current_rank, dtype=dtype | ||
| out_linear_layer, lora_scale, network_alpha=out_alpha, rank=current_rank, dtype=dtype | ||
| ) | ||
| lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters()) | ||
|
|
||
|
|
@@ -1475,13 +1504,23 @@ def _modify_text_encoder( | |
| current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight") | ||
| current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight") | ||
|
|
||
| fc1_linear_layer = ( | ||
| mlp_module.fc1.regular_linear_layer | ||
| if isinstance(mlp_module.fc1, PatchedLoraProjection) | ||
| else mlp_module.fc1 | ||
| ) | ||
| mlp_module.fc1 = PatchedLoraProjection( | ||
| mlp_module.fc1, lora_scale, network_alpha=fc1_alpha, rank=current_rank_fc1, dtype=dtype | ||
| fc1_linear_layer, lora_scale, network_alpha=fc1_alpha, rank=current_rank_fc1, dtype=dtype | ||
| ) | ||
| lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters()) | ||
|
|
||
| fc2_linear_layer = ( | ||
| mlp_module.fc2.regular_linear_layer | ||
| if isinstance(mlp_module.fc2, PatchedLoraProjection) | ||
| else mlp_module.fc2 | ||
| ) | ||
| mlp_module.fc2 = PatchedLoraProjection( | ||
| mlp_module.fc2, lora_scale, network_alpha=fc2_alpha, rank=current_rank_fc2, dtype=dtype | ||
| fc2_linear_layer, lora_scale, network_alpha=fc2_alpha, rank=current_rank_fc2, dtype=dtype | ||
| ) | ||
| lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters()) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -168,7 +168,6 @@ def _fuse_lora(self): | |
| return | ||
|
|
||
| dtype, device = self.weight.data.dtype, self.weight.data.device | ||
| logger.info(f"Fusing LoRA weights for {self.__class__}") | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| w_orig = self.weight.data.float() | ||
| w_up = self.lora_layer.up.weight.data.float() | ||
|
|
@@ -190,14 +189,14 @@ def _fuse_lora(self): | |
| def _unfuse_lora(self): | ||
| if not (hasattr(self, "w_up") and hasattr(self, "w_down")): | ||
| return | ||
| logger.info(f"Unfusing LoRA weights for {self.__class__}") | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| fused_weight = self.weight.data | ||
| dtype, device = fused_weight.dtype, fused_weight.device | ||
|
|
||
| self.w_up = self.w_up.to(device=device, dtype=dtype) | ||
| self.w_down = self.w_down.to(device, dtype=dtype) | ||
| unfused_weight = fused_weight - torch.bmm(self.w_up[None, :], self.w_down[None, :])[0] | ||
| w_up = self.w_up.to(device=device).float() | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Improve precision |
||
| w_down = self.w_down.to(device).float() | ||
|
|
||
| unfused_weight = fused_weight.float() - torch.bmm(w_up[None, :], w_down[None, :])[0] | ||
| self.weight.data = unfused_weight.to(device=device, dtype=dtype) | ||
|
|
||
| self.w_up = None | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice. TIL.