Skip to content

Conversation

@patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Aug 29, 2023

What does this PR do?

This PR fixes the unfusing issue as discovered by @apolinario here

attn_module.k_proj = attn_module.k_proj.regular_linear_layer
attn_module.v_proj = attn_module.v_proj.regular_linear_layer
attn_module.out_proj = attn_module.out_proj.regular_linear_layer
attn_module.q_proj.lora_linear_layer = None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we remove the whole patched layer when unloading LoRAs we cannot unfuse it anymore. Let's just set the LoRA linear layer to None.

self.w_up = self.w_up.to(device=device, dtype=dtype)
self.w_down = self.w_down.to(device, dtype=dtype)
unfused_weight = fused_weight - torch.bmm(self.w_up[None, :], self.w_down[None, :])[0]
w_up = self.w_up.to(device=device).float()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Improve precision to fp32

self.w_up = self.w_up.to(device=device, dtype=dtype)
self.w_down = self.w_down.to(device, dtype=dtype)
unfused_weight = fused_weight - torch.bmm(self.w_up[None, :], self.w_down[None, :])[0]
w_up = self.w_up.to(device=device).float()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Improve precision

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 29, 2023

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten patrickvonplaten mentioned this pull request Aug 29, 2023
3 tasks
Comment on lines +90 to +96
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
if self.lora_linear_layer is None:
return self.regular_linear_layer.state_dict(
*args, destination=destination, prefix=prefix, keep_vars=keep_vars
)

return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. TIL.

@sayakpaul
Copy link
Member

Thanks a lot.

@sayakpaul sayakpaul merged commit 9f1936d into main Aug 30, 2023
@sayakpaul sayakpaul deleted the fix_unfuse_lora branch August 30, 2023 04:02
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* Fix Unfuse Lora

* add tests

* Fix more

* Fix more

* Fix all

* make style

* make style
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* Fix Unfuse Lora

* add tests

* Fix more

* Fix more

* Fix all

* make style

* make style
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants