Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 56 additions & 17 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,21 @@ def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=

self.lora_scale = lora_scale

# overwrite PyTorch's `state_dict` to be sure that only the 'regular_linear_layer' weights are saved
# when saving the whole text encoder model and when LoRA is unloaded or fused
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
if self.lora_linear_layer is None:
return self.regular_linear_layer.state_dict(
*args, destination=destination, prefix=prefix, keep_vars=keep_vars
)

return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)
Comment on lines +90 to +96
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. TIL.


def _fuse_lora(self):
if self.lora_linear_layer is None:
return

dtype, device = self.regular_linear_layer.weight.data.dtype, self.regular_linear_layer.weight.data.device
logger.info(f"Fusing LoRA weights for {self.__class__}")

w_orig = self.regular_linear_layer.weight.data.float()
w_up = self.lora_linear_layer.up.weight.data.float()
Expand All @@ -112,14 +121,14 @@ def _fuse_lora(self):
def _unfuse_lora(self):
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
return
logger.info(f"Unfusing LoRA weights for {self.__class__}")

fused_weight = self.regular_linear_layer.weight.data
dtype, device = fused_weight.dtype, fused_weight.device

self.w_up = self.w_up.to(device=device, dtype=dtype)
self.w_down = self.w_down.to(device, dtype=dtype)
unfused_weight = fused_weight - torch.bmm(self.w_up[None, :], self.w_down[None, :])[0]
w_up = self.w_up.to(device=device).float()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Improve precision to fp32

w_down = self.w_down.to(device).float()

unfused_weight = fused_weight.float() - torch.bmm(w_up[None, :], w_down[None, :])[0]
self.regular_linear_layer.weight.data = unfused_weight.to(device=device, dtype=dtype)

self.w_up = None
Expand Down Expand Up @@ -1405,15 +1414,15 @@ def _remove_text_encoder_monkey_patch(self):
def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj = attn_module.q_proj.regular_linear_layer
attn_module.k_proj = attn_module.k_proj.regular_linear_layer
attn_module.v_proj = attn_module.v_proj.regular_linear_layer
attn_module.out_proj = attn_module.out_proj.regular_linear_layer
attn_module.q_proj.lora_linear_layer = None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we remove the whole patched layer when unloading LoRAs we cannot unfuse it anymore. Let's just set the LoRA linear layer to None.

attn_module.k_proj.lora_linear_layer = None
attn_module.v_proj.lora_linear_layer = None
attn_module.out_proj.lora_linear_layer = None

for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1 = mlp_module.fc1.regular_linear_layer
mlp_module.fc2 = mlp_module.fc2.regular_linear_layer
mlp_module.fc1.lora_linear_layer = None
mlp_module.fc2.lora_linear_layer = None

@classmethod
def _modify_text_encoder(
Expand Down Expand Up @@ -1447,23 +1456,43 @@ def _modify_text_encoder(
else:
current_rank = rank

q_linear_layer = (
attn_module.q_proj.regular_linear_layer
if isinstance(attn_module.q_proj, PatchedLoraProjection)
else attn_module.q_proj
)
attn_module.q_proj = PatchedLoraProjection(
attn_module.q_proj, lora_scale, network_alpha=query_alpha, rank=current_rank, dtype=dtype
q_linear_layer, lora_scale, network_alpha=query_alpha, rank=current_rank, dtype=dtype
)
lora_parameters.extend(attn_module.q_proj.lora_linear_layer.parameters())

k_linear_layer = (
attn_module.k_proj.regular_linear_layer
if isinstance(attn_module.k_proj, PatchedLoraProjection)
else attn_module.k_proj
)
attn_module.k_proj = PatchedLoraProjection(
attn_module.k_proj, lora_scale, network_alpha=key_alpha, rank=current_rank, dtype=dtype
k_linear_layer, lora_scale, network_alpha=key_alpha, rank=current_rank, dtype=dtype
)
lora_parameters.extend(attn_module.k_proj.lora_linear_layer.parameters())

v_linear_layer = (
attn_module.v_proj.regular_linear_layer
if isinstance(attn_module.v_proj, PatchedLoraProjection)
else attn_module.v_proj
)
attn_module.v_proj = PatchedLoraProjection(
attn_module.v_proj, lora_scale, network_alpha=value_alpha, rank=current_rank, dtype=dtype
v_linear_layer, lora_scale, network_alpha=value_alpha, rank=current_rank, dtype=dtype
)
lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters())

out_linear_layer = (
attn_module.out_proj.regular_linear_layer
if isinstance(attn_module.out_proj, PatchedLoraProjection)
else attn_module.out_proj
)
attn_module.out_proj = PatchedLoraProjection(
attn_module.out_proj, lora_scale, network_alpha=out_alpha, rank=current_rank, dtype=dtype
out_linear_layer, lora_scale, network_alpha=out_alpha, rank=current_rank, dtype=dtype
)
lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters())

Expand All @@ -1475,13 +1504,23 @@ def _modify_text_encoder(
current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight")
current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight")

fc1_linear_layer = (
mlp_module.fc1.regular_linear_layer
if isinstance(mlp_module.fc1, PatchedLoraProjection)
else mlp_module.fc1
)
mlp_module.fc1 = PatchedLoraProjection(
mlp_module.fc1, lora_scale, network_alpha=fc1_alpha, rank=current_rank_fc1, dtype=dtype
fc1_linear_layer, lora_scale, network_alpha=fc1_alpha, rank=current_rank_fc1, dtype=dtype
)
lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters())

fc2_linear_layer = (
mlp_module.fc2.regular_linear_layer
if isinstance(mlp_module.fc2, PatchedLoraProjection)
else mlp_module.fc2
)
mlp_module.fc2 = PatchedLoraProjection(
mlp_module.fc2, lora_scale, network_alpha=fc2_alpha, rank=current_rank_fc2, dtype=dtype
fc2_linear_layer, lora_scale, network_alpha=fc2_alpha, rank=current_rank_fc2, dtype=dtype
)
lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters())

Expand Down
9 changes: 4 additions & 5 deletions src/diffusers/models/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ def _fuse_lora(self):
return

dtype, device = self.weight.data.dtype, self.weight.data.device
logger.info(f"Fusing LoRA weights for {self.__class__}")

w_orig = self.weight.data.float()
w_up = self.lora_layer.up.weight.data.float()
Expand All @@ -190,14 +189,14 @@ def _fuse_lora(self):
def _unfuse_lora(self):
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
return
logger.info(f"Unfusing LoRA weights for {self.__class__}")

fused_weight = self.weight.data
dtype, device = fused_weight.dtype, fused_weight.device

self.w_up = self.w_up.to(device=device, dtype=dtype)
self.w_down = self.w_down.to(device, dtype=dtype)
unfused_weight = fused_weight - torch.bmm(self.w_up[None, :], self.w_down[None, :])[0]
w_up = self.w_up.to(device=device).float()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Improve precision

w_down = self.w_down.to(device).float()

unfused_weight = fused_weight.float() - torch.bmm(w_up[None, :], w_down[None, :])[0]
self.weight.data = unfused_weight.to(device=device, dtype=dtype)

self.w_up = None
Expand Down
71 changes: 71 additions & 0 deletions tests/models/test_lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import os
import tempfile
import time
Expand Down Expand Up @@ -100,6 +101,18 @@ def set_lora_weights(lora_attn_parameters, randn_weight=False):
torch.zero_(parameter)


def state_dicts_almost_equal(sd1, sd2):
sd1 = dict(sorted(sd1.items()))
sd2 = dict(sorted(sd2.items()))

models_are_equal = True
for ten1, ten2 in zip(sd1.values(), sd2.values()):
if (ten1 - ten2).abs().sum() > 1e-3:
models_are_equal = False

return models_are_equal


class LoraLoaderMixinTests(unittest.TestCase):
def get_dummy_components(self):
torch.manual_seed(0)
Expand Down Expand Up @@ -674,6 +687,45 @@ def test_load_lora_locally(self):

sd_pipe.unload_lora_weights()

def test_text_encoder_lora_state_dict_unchanged(self):
pipeline_components, lora_components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)

text_encoder_1_sd_keys = sorted(sd_pipe.text_encoder.state_dict().keys())
text_encoder_2_sd_keys = sorted(sd_pipe.text_encoder_2.state_dict().keys())

sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)

with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionXLPipeline.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
safe_serialization=False,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))

text_encoder_1_sd_keys_2 = sorted(sd_pipe.text_encoder.state_dict().keys())
text_encoder_2_sd_keys_2 = sorted(sd_pipe.text_encoder_2.state_dict().keys())

sd_pipe.unload_lora_weights()

text_encoder_1_sd_keys_3 = sorted(sd_pipe.text_encoder.state_dict().keys())
text_encoder_2_sd_keys_3 = sorted(sd_pipe.text_encoder_2.state_dict().keys())

# default & unloaded LoRA weights should have identical state_dicts
assert text_encoder_1_sd_keys == text_encoder_1_sd_keys_3
# default & loaded LoRA weights should NOT have identical state_dicts
assert text_encoder_1_sd_keys != text_encoder_1_sd_keys_2 #

# default & unloaded LoRA weights should have identical state_dicts
assert text_encoder_2_sd_keys == text_encoder_2_sd_keys_3
# default & loaded LoRA weights should NOT have identical state_dicts
assert text_encoder_2_sd_keys != text_encoder_2_sd_keys_2

def test_load_lora_locally_safetensors(self):
pipeline_components, lora_components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
Expand Down Expand Up @@ -1187,3 +1239,22 @@ def test_sdxl_1_0_last_ben(self):
expected = np.array([0.5244, 0.4347, 0.4312, 0.4246, 0.4398, 0.4409, 0.4884, 0.4938, 0.4094])

self.assertTrue(np.allclose(images, expected, atol=1e-3))

def test_sdxl_1_0_fuse_unfuse_all(self):
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
text_encoder_1_sd = copy.deepcopy(pipe.text_encoder.state_dict())
text_encoder_2_sd = copy.deepcopy(pipe.text_encoder_2.state_dict())
unet_sd = copy.deepcopy(pipe.unet.state_dict())

pipe.load_lora_weights("davizca87/sun-flower", weight_name="snfw3rXL-000004.safetensors")
pipe.fuse_lora()
pipe.unload_lora_weights()
pipe.unfuse_lora()

new_text_encoder_1_sd = copy.deepcopy(pipe.text_encoder.state_dict())
new_text_encoder_2_sd = copy.deepcopy(pipe.text_encoder_2.state_dict())
new_unet_sd = copy.deepcopy(pipe.unet.state_dict())

assert state_dicts_almost_equal(text_encoder_1_sd, new_text_encoder_1_sd)
assert state_dicts_almost_equal(text_encoder_2_sd, new_text_encoder_2_sd)
assert state_dicts_almost_equal(unet_sd, new_unet_sd)