From 91b34abbd3c43cbd88e83140bddaa18c148fef1b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 10:14:02 +0530 Subject: [PATCH 01/78] throw warning when more than one lora is attempted to be fused. --- src/diffusers/loaders.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 52970e48147d..82d5c5cc4572 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -924,6 +924,7 @@ class LoraLoaderMixin: """ text_encoder_name = TEXT_ENCODER_NAME unet_name = UNET_NAME + num_fused_loras = 0 def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): """ @@ -1823,6 +1824,14 @@ def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True): Whether to fuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the LoRA parameters then it won't have any effect. """ + if fuse_unet or fuse_text_encoder: + self.num_fused_loras += 1 + if self.num_fused_loras > 1: + warnings.warn( + "The current API is supported for operating with a single LoRA file. You are trying to load and fuse more than one LoRA which is not well-supported.", + RuntimeWarning, + ) + if fuse_unet: self.unet.fuse_lora() From c9eeb788d2feb7382d7d39a474d5af479060f671 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 11:32:06 +0530 Subject: [PATCH 02/78] introduce support of lora scale during fusion. --- src/diffusers/loaders.py | 30 ++++++++++-------- src/diffusers/models/lora.py | 22 +++++++------ tests/models/test_lora_layers.py | 54 +++++++++++++++++++++++++++++++- 3 files changed, 83 insertions(+), 23 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 82d5c5cc4572..6e12a608a8cb 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -95,7 +95,7 @@ def state_dict(self, *args, destination=None, prefix="", keep_vars=False): return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars) - def _fuse_lora(self): + def _fuse_lora(self, lora_scale=1.0): if self.lora_linear_layer is None: return @@ -108,7 +108,7 @@ def _fuse_lora(self): if self.lora_linear_layer.network_alpha is not None: w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank - fused_weight = w_orig + torch.bmm(w_up[None, :], w_down[None, :])[0] + fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) self.regular_linear_layer.weight.data = fused_weight.to(device=device, dtype=dtype) # we can drop the lora layer now @@ -117,6 +117,7 @@ def _fuse_lora(self): # offload the up and down matrices to CPU to not blow the memory self.w_up = w_up.cpu() self.w_down = w_down.cpu() + self.lora_scale = lora_scale def _unfuse_lora(self): if not (hasattr(self, "w_up") and hasattr(self, "w_down")): @@ -128,7 +129,7 @@ def _unfuse_lora(self): w_up = self.w_up.to(device=device).float() w_down = self.w_down.to(device).float() - unfused_weight = fused_weight.float() - torch.bmm(w_up[None, :], w_down[None, :])[0] + unfused_weight = fused_weight.float() - (self.lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) self.regular_linear_layer.weight.data = unfused_weight.to(device=device, dtype=dtype) self.w_up = None @@ -576,12 +577,13 @@ def save_function(weights, filename): save_function(state_dict, os.path.join(save_directory, weight_name)) logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") - def fuse_lora(self): + def fuse_lora(self, lora_scale=1.0): + self.lora_scale = lora_scale self.apply(self._fuse_lora_apply) def _fuse_lora_apply(self, module): if hasattr(module, "_fuse_lora"): - module._fuse_lora() + module._fuse_lora(self.lora_scale) def unfuse_lora(self): self.apply(self._unfuse_lora_apply) @@ -1808,7 +1810,7 @@ def unload_lora_weights(self): # Safe to call the following regardless of LoRA. self._remove_text_encoder_monkey_patch() - def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True): + def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora_scale: float = 1.0): r""" Fuses the LoRA parameters into the original parameters of the corresponding blocks. @@ -1823,6 +1825,8 @@ def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True): fuse_text_encoder (`bool`, defaults to `True`): Whether to fuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the LoRA parameters then it won't have any effect. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. """ if fuse_unet or fuse_text_encoder: self.num_fused_loras += 1 @@ -1833,20 +1837,20 @@ def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True): ) if fuse_unet: - self.unet.fuse_lora() + self.unet.fuse_lora(lora_scale) def fuse_text_encoder_lora(text_encoder): for _, attn_module in text_encoder_attn_modules(text_encoder): if isinstance(attn_module.q_proj, PatchedLoraProjection): - attn_module.q_proj._fuse_lora() - attn_module.k_proj._fuse_lora() - attn_module.v_proj._fuse_lora() - attn_module.out_proj._fuse_lora() + attn_module.q_proj._fuse_lora(lora_scale) + attn_module.k_proj._fuse_lora(lora_scale) + attn_module.v_proj._fuse_lora(lora_scale) + attn_module.out_proj._fuse_lora(lora_scale) for _, mlp_module in text_encoder_mlp_modules(text_encoder): if isinstance(mlp_module.fc1, PatchedLoraProjection): - mlp_module.fc1._fuse_lora() - mlp_module.fc2._fuse_lora() + mlp_module.fc1._fuse_lora(lora_scale) + mlp_module.fc2._fuse_lora(lora_scale) if fuse_text_encoder: if hasattr(self, "text_encoder"): diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 671c93a3b2b2..b795f00ef720 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -97,7 +97,7 @@ def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]): self.lora_layer = lora_layer - def _fuse_lora(self): + def _fuse_lora(self, lora_scale=1.0): if self.lora_layer is None: return @@ -113,7 +113,7 @@ def _fuse_lora(self): fusion = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1)) fusion = fusion.reshape((w_orig.shape)) - fused_weight = w_orig + fusion + fused_weight = w_orig + (lora_scale * fusion) self.weight.data = fused_weight.to(device=device, dtype=dtype) # we can drop the lora layer now @@ -122,6 +122,7 @@ def _fuse_lora(self): # offload the up and down matrices to CPU to not blow the memory self.w_up = w_up.cpu() self.w_down = w_down.cpu() + self.lora_scale = lora_scale def _unfuse_lora(self): if not (hasattr(self, "w_up") and hasattr(self, "w_down")): @@ -136,19 +137,21 @@ def _unfuse_lora(self): fusion = torch.mm(self.w_up.flatten(start_dim=1), self.w_down.flatten(start_dim=1)) fusion = fusion.reshape((fused_weight.shape)) - unfused_weight = fused_weight - fusion + unfused_weight = fused_weight - (self.lora_scale * fusion) self.weight.data = unfused_weight.to(device=device, dtype=dtype) self.w_up = None self.w_down = None - def forward(self, x): + def forward(self, hidden_states, lora_scale=1.0): if self.lora_layer is None: # make sure to the functional Conv2D function as otherwise torch.compile's graph will break # see: https://github.com/huggingface/diffusers/pull/4315 - return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + return F.conv2d( + hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups + ) else: - return super().forward(x) + self.lora_layer(x) + return super().forward(hidden_states) + lora_scale * self.lora_layer(hidden_states) class LoRACompatibleLinear(nn.Linear): @@ -163,7 +166,7 @@ def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]): self.lora_layer = lora_layer - def _fuse_lora(self): + def _fuse_lora(self, lora_scale=1.0): if self.lora_layer is None: return @@ -176,7 +179,7 @@ def _fuse_lora(self): if self.lora_layer.network_alpha is not None: w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank - fused_weight = w_orig + torch.bmm(w_up[None, :], w_down[None, :])[0] + fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) self.weight.data = fused_weight.to(device=device, dtype=dtype) # we can drop the lora layer now @@ -185,6 +188,7 @@ def _fuse_lora(self): # offload the up and down matrices to CPU to not blow the memory self.w_up = w_up.cpu() self.w_down = w_down.cpu() + self.lora_scale = lora_scale def _unfuse_lora(self): if not (hasattr(self, "w_up") and hasattr(self, "w_down")): @@ -196,7 +200,7 @@ def _unfuse_lora(self): w_up = self.w_up.to(device=device).float() w_down = self.w_down.to(device).float() - unfused_weight = fused_weight.float() - torch.bmm(w_up[None, :], w_down[None, :])[0] + unfused_weight = fused_weight.float() - (self.lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) self.weight.data = unfused_weight.to(device=device, dtype=dtype) self.w_up = None diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 848f2f44adc9..85af4013b279 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -719,7 +719,7 @@ def test_text_encoder_lora_state_dict_unchanged(self): # default & unloaded LoRA weights should have identical state_dicts assert text_encoder_1_sd_keys == text_encoder_1_sd_keys_3 # default & loaded LoRA weights should NOT have identical state_dicts - assert text_encoder_1_sd_keys != text_encoder_1_sd_keys_2 # + assert text_encoder_1_sd_keys != text_encoder_1_sd_keys_2 # default & unloaded LoRA weights should have identical state_dicts assert text_encoder_2_sd_keys == text_encoder_2_sd_keys_3 @@ -863,6 +863,58 @@ def test_lora_fusion_is_not_affected_by_unloading(self): lora_image_slice, images_with_unloaded_lora_slice ), "`unload_lora_weights()` should have not effect on the semantics of the results as the LoRA parameters were fused." + def test_unfuse_lora(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) + + _ = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + + # Emulate training. + set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True) + set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True) + set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True) + + with tempfile.TemporaryDirectory() as tmpdirname: + StableDiffusionXLPipeline.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + safe_serialization=True, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + sd_pipe.fuse_lora(lora_scale=1.0) + lora_images_scale_one = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + lora_image_slice_scale_one = lora_images_scale_one[0, -3:, -3:, -1] + + # Reverse LoRA fusion. + sd_pipe.unfuse_lora() + + with tempfile.TemporaryDirectory() as tmpdirname: + StableDiffusionXLPipeline.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + safe_serialization=True, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + sd_pipe.fuse_lora(lora_scale=0.5) + lora_images_scale_0_5 = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + lora_image_slice_scale_0_5 = lora_images_scale_0_5[0, -3:, -3:, -1] + + assert not np.allclose( + lora_image_slice_scale_one, lora_image_slice_scale_0_5, atol=1e-03 + ), "Different LoRA scales should influence the outputs accordingly." + @slow @require_torch_gpu From 37692b12ddc93697426ca279f738b17687291c39 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 11:39:54 +0530 Subject: [PATCH 03/78] change test name --- tests/models/test_lora_layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 85af4013b279..3f9ca734f9e9 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -863,7 +863,7 @@ def test_lora_fusion_is_not_affected_by_unloading(self): lora_image_slice, images_with_unloaded_lora_slice ), "`unload_lora_weights()` should have not effect on the semantics of the results as the LoRA parameters were fused." - def test_unfuse_lora(self): + def test_fuse_lora_with_different_scales(self): pipeline_components, lora_components = self.get_dummy_components() sd_pipe = StableDiffusionXLPipeline(**pipeline_components) sd_pipe = sd_pipe.to(torch_device) From cfd19a57cfd3f5b700bb5496084013a330244b30 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 12:39:32 +0530 Subject: [PATCH 04/78] changes --- src/diffusers/models/lora.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index b795f00ef720..be0024266002 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -143,7 +143,7 @@ def _unfuse_lora(self): self.w_up = None self.w_down = None - def forward(self, hidden_states, lora_scale=1.0): + def forward(self, hidden_states, lora_scale: float = 1.0): if self.lora_layer is None: # make sure to the functional Conv2D function as otherwise torch.compile's graph will break # see: https://github.com/huggingface/diffusers/pull/4315 @@ -206,7 +206,7 @@ def _unfuse_lora(self): self.w_up = None self.w_down = None - def forward(self, hidden_states, lora_scale: int = 1): + def forward(self, hidden_states, lora_scale: float = 1.0): if self.lora_layer is None: return super().forward(hidden_states) else: From 8a9dad004dd78bd199448b454754703ad536200d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 15:44:13 +0530 Subject: [PATCH 05/78] change to _lora_scale --- src/diffusers/models/lora.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index be0024266002..1d671ed51c17 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -122,7 +122,7 @@ def _fuse_lora(self, lora_scale=1.0): # offload the up and down matrices to CPU to not blow the memory self.w_up = w_up.cpu() self.w_down = w_down.cpu() - self.lora_scale = lora_scale + self._lora_scale = lora_scale def _unfuse_lora(self): if not (hasattr(self, "w_up") and hasattr(self, "w_down")): @@ -137,7 +137,7 @@ def _unfuse_lora(self): fusion = torch.mm(self.w_up.flatten(start_dim=1), self.w_down.flatten(start_dim=1)) fusion = fusion.reshape((fused_weight.shape)) - unfused_weight = fused_weight - (self.lora_scale * fusion) + unfused_weight = fused_weight - (self._lora_scale * fusion) self.weight.data = unfused_weight.to(device=device, dtype=dtype) self.w_up = None @@ -188,7 +188,7 @@ def _fuse_lora(self, lora_scale=1.0): # offload the up and down matrices to CPU to not blow the memory self.w_up = w_up.cpu() self.w_down = w_down.cpu() - self.lora_scale = lora_scale + self._lora_scale = lora_scale def _unfuse_lora(self): if not (hasattr(self, "w_up") and hasattr(self, "w_down")): @@ -200,7 +200,7 @@ def _unfuse_lora(self): w_up = self.w_up.to(device=device).float() w_down = self.w_down.to(device).float() - unfused_weight = fused_weight.float() - (self.lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) + unfused_weight = fused_weight.float() - (self._lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) self.weight.data = unfused_weight.to(device=device, dtype=dtype) self.w_up = None From ed3b37ad8324b75e1e1c6258f4365fac38fa65d8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 16:03:16 +0530 Subject: [PATCH 06/78] lora_scale to call whenever applicable. --- src/diffusers/models/attention.py | 8 +++---- src/diffusers/models/resnet.py | 29 +++++++++++++++++--------- src/diffusers/models/transformer_2d.py | 11 ++++++---- 3 files changed, 30 insertions(+), 18 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index b017db158eda..043083915277 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -295,9 +295,9 @@ def __init__( if final_dropout: self.net.append(nn.Dropout(dropout)) - def forward(self, hidden_states): + def forward(self, hidden_states, lora_scale: float = 1.0): for module in self.net: - hidden_states = module(hidden_states) + hidden_states = module(hidden_states, lora_scale) return hidden_states @@ -342,8 +342,8 @@ def gelu(self, gate): # mps: gelu is not implemented for float16 return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) - def forward(self, hidden_states): - hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) + def forward(self, hidden_states, lora_scale: float = 1.0): + hidden_states, gate = self.proj(hidden_states, lora_scale=lora_scale).chunk(2, dim=-1) return hidden_states * self.gelu(gate) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 72aa17ed2c2d..6bebddc972e2 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -135,7 +135,7 @@ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_chann else: self.Conv2d_0 = conv - def forward(self, hidden_states, output_size=None): + def forward(self, hidden_states, output_size=None, lora_scale: float = 1.0): assert hidden_states.shape[1] == self.channels if self.use_conv_transpose: @@ -166,9 +166,15 @@ def forward(self, hidden_states, output_size=None): # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if self.use_conv: if self.name == "conv": - hidden_states = self.conv(hidden_states) + if isinstance(self.conv, LoRACompatibleConv): + hidden_states = self.conv(hidden_states, lora_scale) + else: + hidden_states = self.conv(hidden_states) else: - hidden_states = self.Conv2d_0(hidden_states) + if isinstance(self.Conv2d_0, LoRACompatibleConv): + hidden_states = self.Conv2d_0(hidden_states, lora_scale) + else: + hidden_states = self.Conv2d_0(hidden_states) return hidden_states @@ -211,14 +217,17 @@ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name= else: self.conv = conv - def forward(self, hidden_states): + def forward(self, hidden_states, lora_scale: float = 1.0): assert hidden_states.shape[1] == self.channels if self.use_conv and self.padding == 0: pad = (0, 1, 0, 1) hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) assert hidden_states.shape[1] == self.channels - hidden_states = self.conv(hidden_states) + if isinstance(self.conv, LoRACompatibleConv): + hidden_states = self.conv(hidden_states, lora_scale) + else: + hidden_states = self.conv(hidden_states) return hidden_states @@ -588,7 +597,7 @@ def __init__( in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias ) - def forward(self, input_tensor, temb): + def forward(self, input_tensor, temb, lora_scale: float = 1.0): hidden_states = input_tensor if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": @@ -609,12 +618,12 @@ def forward(self, input_tensor, temb): input_tensor = self.downsample(input_tensor) hidden_states = self.downsample(hidden_states) - hidden_states = self.conv1(hidden_states) + hidden_states = self.conv1(hidden_states, lora_scale) if self.time_emb_proj is not None: if not self.skip_time_act: temb = self.nonlinearity(temb) - temb = self.time_emb_proj(temb)[:, :, None, None] + temb = self.time_emb_proj(temb, lora_scale)[:, :, None, None] if temb is not None and self.time_embedding_norm == "default": hidden_states = hidden_states + temb @@ -631,10 +640,10 @@ def forward(self, input_tensor, temb): hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states) + hidden_states = self.conv2(hidden_states, lora_scale) if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor) + input_tensor = self.conv_shortcut(input_tensor, lora_scale) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index 225f20a1e397..36698b300df9 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -219,6 +219,7 @@ def forward( cross_attention_kwargs: Dict[str, Any] = None, attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, + lora_scale: float = 1.0, return_dict: bool = True, ): """ @@ -243,6 +244,8 @@ def forward( If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format above. This bias will be added to the cross-attention scores. + lora_scale: (`float`, *optional*, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. @@ -281,13 +284,13 @@ def forward( hidden_states = self.norm(hidden_states) if not self.use_linear_projection: - hidden_states = self.proj_in(hidden_states) + hidden_states = self.proj_in(hidden_states, lora_scale) inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) else: inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) - hidden_states = self.proj_in(hidden_states) + hidden_states = self.proj_in(hidden_states, lora_scale) elif self.is_input_vectorized: hidden_states = self.latent_image_embedding(hidden_states) elif self.is_input_patches: @@ -322,9 +325,9 @@ def forward( if self.is_input_continuous: if not self.use_linear_projection: hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() - hidden_states = self.proj_out(hidden_states) + hidden_states = self.proj_out(hidden_states, lora_scale) else: - hidden_states = self.proj_out(hidden_states) + hidden_states = self.proj_out(hidden_states, lora_scale) hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() output = hidden_states + residual From b86a8f6da946d1f9d32d2a7bec127232cd5c4de0 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 16:21:00 +0530 Subject: [PATCH 07/78] debugging --- src/diffusers/models/attention.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 043083915277..b00728913ffa 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -297,7 +297,10 @@ def __init__( def forward(self, hidden_states, lora_scale: float = 1.0): for module in self.net: - hidden_states = module(hidden_states, lora_scale) + if isinstance(module, LoRACompatibleLinear): + hidden_states = module(hidden_states, lora_scale) + else: + hidden_states = module(hidden_states) return hidden_states From 80839d63adfd8de5431a23ce2c50e0b90574618e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 16:24:20 +0530 Subject: [PATCH 08/78] lora_scale additional. --- src/diffusers/models/unet_2d_condition.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 1994649f4c59..f7e63ca45c9b 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -914,6 +914,9 @@ def forward( gligen_args = cross_attention_kwargs.pop("gligen") cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} + if cross_attention_kwargs is not None and cross_attention_kwargs.get("scale", None) is not None: + cross_attention_kwargs["lora_scale"] = cross_attention_kwargs.get("scale") + # 3. down is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None From 2ed1f2a4c39a11ad866e34af9705392c3b8fb271 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 16:26:31 +0530 Subject: [PATCH 09/78] cross_attention_kwargs --- .../pipelines/alt_diffusion/pipeline_alt_diffusion.py | 6 ++---- .../alt_diffusion/pipeline_alt_diffusion_img2img.py | 6 ++---- .../pipeline_onnx_stable_diffusion_img2img.py | 6 ++---- .../pipelines/versatile_diffusion/modeling_text_unet.py | 3 +++ 4 files changed, 9 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 8cf308588422..00e688907889 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -453,10 +453,8 @@ def run_safety_checker(self, image, device, dtype): def decode_latents(self, latents): warnings.warn( - ( - "The decode_latents method is deprecated and will be removed in a future version. Please" - " use VaeImageProcessor instead" - ), + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 89c8279fb3c0..5f6dc07487bd 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -454,10 +454,8 @@ def run_safety_checker(self, image, device, dtype): def decode_latents(self, latents): warnings.warn( - ( - "The decode_latents method is deprecated and will be removed in a future version. Please" - " use VaeImageProcessor instead" - ), + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py index 508085094b16..d418662a4b44 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py @@ -35,10 +35,8 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess with 8->64 def preprocess(image): warnings.warn( - ( - "The preprocess method is deprecated and will be removed in a future version. Please" - " use VaeImageProcessor.preprocess instead" - ), + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", FutureWarning, ) if isinstance(image, torch.Tensor): diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 3fd9695c2d43..b62eb323a3f3 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1086,6 +1086,9 @@ def forward( gligen_args = cross_attention_kwargs.pop("gligen") cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} + if cross_attention_kwargs is not None and cross_attention_kwargs.get("scale", None) is not None: + cross_attention_kwargs["lora_scale"] = cross_attention_kwargs.get("scale") + # 3. down is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None From 3967da86f07434740fdd28a1f173bc84a0e3e10e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 16:41:35 +0530 Subject: [PATCH 10/78] lora_scale -> scale. --- src/diffusers/models/attention.py | 8 ++++---- src/diffusers/models/lora.py | 8 ++++---- src/diffusers/models/resnet.py | 20 +++++++++---------- src/diffusers/models/transformer_2d.py | 12 +++++------ src/diffusers/models/unet_2d_condition.py | 3 --- .../versatile_diffusion/modeling_text_unet.py | 3 --- 6 files changed, 24 insertions(+), 30 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index b00728913ffa..0e2b334f2990 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -295,10 +295,10 @@ def __init__( if final_dropout: self.net.append(nn.Dropout(dropout)) - def forward(self, hidden_states, lora_scale: float = 1.0): + def forward(self, hidden_states, scale: float = 1.0): for module in self.net: if isinstance(module, LoRACompatibleLinear): - hidden_states = module(hidden_states, lora_scale) + hidden_states = module(hidden_states, scale) else: hidden_states = module(hidden_states) return hidden_states @@ -345,8 +345,8 @@ def gelu(self, gate): # mps: gelu is not implemented for float16 return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) - def forward(self, hidden_states, lora_scale: float = 1.0): - hidden_states, gate = self.proj(hidden_states, lora_scale=lora_scale).chunk(2, dim=-1) + def forward(self, hidden_states, scale: float = 1.0): + hidden_states, gate = self.proj(hidden_states, scale).chunk(2, dim=-1) return hidden_states * self.gelu(gate) diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 1d671ed51c17..1557ebcd1b3d 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -143,7 +143,7 @@ def _unfuse_lora(self): self.w_up = None self.w_down = None - def forward(self, hidden_states, lora_scale: float = 1.0): + def forward(self, hidden_states, scale: float = 1.0): if self.lora_layer is None: # make sure to the functional Conv2D function as otherwise torch.compile's graph will break # see: https://github.com/huggingface/diffusers/pull/4315 @@ -151,7 +151,7 @@ def forward(self, hidden_states, lora_scale: float = 1.0): hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups ) else: - return super().forward(hidden_states) + lora_scale * self.lora_layer(hidden_states) + return super().forward(hidden_states) + scale * self.lora_layer(hidden_states) class LoRACompatibleLinear(nn.Linear): @@ -206,8 +206,8 @@ def _unfuse_lora(self): self.w_up = None self.w_down = None - def forward(self, hidden_states, lora_scale: float = 1.0): + def forward(self, hidden_states, scale: float = 1.0): if self.lora_layer is None: return super().forward(hidden_states) else: - return super().forward(hidden_states) + lora_scale * self.lora_layer(hidden_states) + return super().forward(hidden_states) + scale * self.lora_layer(hidden_states) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 6bebddc972e2..8acc0c826896 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -135,7 +135,7 @@ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_chann else: self.Conv2d_0 = conv - def forward(self, hidden_states, output_size=None, lora_scale: float = 1.0): + def forward(self, hidden_states, output_size=None, scale: float = 1.0): assert hidden_states.shape[1] == self.channels if self.use_conv_transpose: @@ -167,12 +167,12 @@ def forward(self, hidden_states, output_size=None, lora_scale: float = 1.0): if self.use_conv: if self.name == "conv": if isinstance(self.conv, LoRACompatibleConv): - hidden_states = self.conv(hidden_states, lora_scale) + hidden_states = self.conv(hidden_states, scale) else: hidden_states = self.conv(hidden_states) else: if isinstance(self.Conv2d_0, LoRACompatibleConv): - hidden_states = self.Conv2d_0(hidden_states, lora_scale) + hidden_states = self.Conv2d_0(hidden_states, scale) else: hidden_states = self.Conv2d_0(hidden_states) @@ -217,7 +217,7 @@ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name= else: self.conv = conv - def forward(self, hidden_states, lora_scale: float = 1.0): + def forward(self, hidden_states, scale: float = 1.0): assert hidden_states.shape[1] == self.channels if self.use_conv and self.padding == 0: pad = (0, 1, 0, 1) @@ -225,7 +225,7 @@ def forward(self, hidden_states, lora_scale: float = 1.0): assert hidden_states.shape[1] == self.channels if isinstance(self.conv, LoRACompatibleConv): - hidden_states = self.conv(hidden_states, lora_scale) + hidden_states = self.conv(hidden_states, scale) else: hidden_states = self.conv(hidden_states) @@ -597,7 +597,7 @@ def __init__( in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias ) - def forward(self, input_tensor, temb, lora_scale: float = 1.0): + def forward(self, input_tensor, temb, scale: float = 1.0): hidden_states = input_tensor if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": @@ -618,12 +618,12 @@ def forward(self, input_tensor, temb, lora_scale: float = 1.0): input_tensor = self.downsample(input_tensor) hidden_states = self.downsample(hidden_states) - hidden_states = self.conv1(hidden_states, lora_scale) + hidden_states = self.conv1(hidden_states, scale) if self.time_emb_proj is not None: if not self.skip_time_act: temb = self.nonlinearity(temb) - temb = self.time_emb_proj(temb, lora_scale)[:, :, None, None] + temb = self.time_emb_proj(temb, scale)[:, :, None, None] if temb is not None and self.time_embedding_norm == "default": hidden_states = hidden_states + temb @@ -640,10 +640,10 @@ def forward(self, input_tensor, temb, lora_scale: float = 1.0): hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states, lora_scale) + hidden_states = self.conv2(hidden_states, scale) if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor, lora_scale) + input_tensor = self.conv_shortcut(input_tensor, scale) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index 36698b300df9..ceea754dfd9b 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -219,7 +219,7 @@ def forward( cross_attention_kwargs: Dict[str, Any] = None, attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, - lora_scale: float = 1.0, + scale: float = 1.0, return_dict: bool = True, ): """ @@ -244,7 +244,7 @@ def forward( If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format above. This bias will be added to the cross-attention scores. - lora_scale: (`float`, *optional*, defaults to 1.0): + scale: (`float`, *optional*, defaults to 1.0): Controls how much to influence the outputs with the LoRA parameters. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain @@ -284,13 +284,13 @@ def forward( hidden_states = self.norm(hidden_states) if not self.use_linear_projection: - hidden_states = self.proj_in(hidden_states, lora_scale) + hidden_states = self.proj_in(hidden_states, scale) inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) else: inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) - hidden_states = self.proj_in(hidden_states, lora_scale) + hidden_states = self.proj_in(hidden_states, scale) elif self.is_input_vectorized: hidden_states = self.latent_image_embedding(hidden_states) elif self.is_input_patches: @@ -325,9 +325,9 @@ def forward( if self.is_input_continuous: if not self.use_linear_projection: hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() - hidden_states = self.proj_out(hidden_states, lora_scale) + hidden_states = self.proj_out(hidden_states, scale) else: - hidden_states = self.proj_out(hidden_states, lora_scale) + hidden_states = self.proj_out(hidden_states, scale) hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() output = hidden_states + residual diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index f7e63ca45c9b..1994649f4c59 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -914,9 +914,6 @@ def forward( gligen_args = cross_attention_kwargs.pop("gligen") cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} - if cross_attention_kwargs is not None and cross_attention_kwargs.get("scale", None) is not None: - cross_attention_kwargs["lora_scale"] = cross_attention_kwargs.get("scale") - # 3. down is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index b62eb323a3f3..3fd9695c2d43 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1086,9 +1086,6 @@ def forward( gligen_args = cross_attention_kwargs.pop("gligen") cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} - if cross_attention_kwargs is not None and cross_attention_kwargs.get("scale", None) is not None: - cross_attention_kwargs["lora_scale"] = cross_attention_kwargs.get("scale") - # 3. down is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None From e24fd70aab34f80adce7c247300aba01d8651e38 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 16:44:54 +0530 Subject: [PATCH 11/78] lora_scale fix --- src/diffusers/models/attention_processor.py | 44 ++++++++++----------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 9d3c576107d4..3150f13466d5 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -570,15 +570,15 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states, lora_scale=scale) + query = attn.to_q(hidden_states, scale=scale) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states, lora_scale=scale) - value = attn.to_v(encoder_hidden_states, lora_scale=scale) + key = attn.to_k(encoder_hidden_states, scale=scale) + value = attn.to_v(encoder_hidden_states, scale=scale) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) @@ -589,7 +589,7 @@ def __call__( hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj - hidden_states = attn.to_out[0](hidden_states, lora_scale=scale) + hidden_states = attn.to_out[0](hidden_states, scale=scale) # dropout hidden_states = attn.to_out[1](hidden_states) @@ -722,17 +722,17 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states, lora_scale=scale) + query = attn.to_q(hidden_states, scale=scale) query = attn.head_to_batch_dim(query) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, lora_scale=scale) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, lora_scale=scale) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, scale=scale) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, scale=scale) encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) if not attn.only_cross_attention: - key = attn.to_k(hidden_states, lora_scale=scale) - value = attn.to_v(hidden_states, lora_scale=scale) + key = attn.to_k(hidden_states, scale=scale) + value = attn.to_v(hidden_states, scale=scale) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) @@ -746,7 +746,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj - hidden_states = attn.to_out[0](hidden_states, lora_scale=scale) + hidden_states = attn.to_out[0](hidden_states, scale=scale) # dropout hidden_states = attn.to_out[1](hidden_states) @@ -782,7 +782,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states, lora_scale=scale) + query = attn.to_q(hidden_states, scale=scale) query = attn.head_to_batch_dim(query, out_dim=4) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) @@ -791,8 +791,8 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4) if not attn.only_cross_attention: - key = attn.to_k(hidden_states, lora_scale=scale) - value = attn.to_v(hidden_states, lora_scale=scale) + key = attn.to_k(hidden_states, scale=scale) + value = attn.to_v(hidden_states, scale=scale) key = attn.head_to_batch_dim(key, out_dim=4) value = attn.head_to_batch_dim(value, out_dim=4) key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) @@ -809,7 +809,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1]) # linear proj - hidden_states = attn.to_out[0](hidden_states, lora_scale=scale) + hidden_states = attn.to_out[0](hidden_states, scale=scale) # dropout hidden_states = attn.to_out[1](hidden_states) @@ -937,15 +937,15 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states, lora_scale=scale) + query = attn.to_q(hidden_states, scale=scale) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states, lora_scale=scale) - value = attn.to_v(encoder_hidden_states, lora_scale=scale) + key = attn.to_k(encoder_hidden_states, scale=scale) + value = attn.to_v(encoder_hidden_states, scale=scale) query = attn.head_to_batch_dim(query).contiguous() key = attn.head_to_batch_dim(key).contiguous() @@ -958,7 +958,7 @@ def __call__( hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj - hidden_states = attn.to_out[0](hidden_states, lora_scale=scale) + hidden_states = attn.to_out[0](hidden_states, scale=scale) # dropout hidden_states = attn.to_out[1](hidden_states) @@ -1015,15 +1015,15 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states, lora_scale=scale) + query = attn.to_q(hidden_states, scale=scale) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states, lora_scale=scale) - value = attn.to_v(encoder_hidden_states, lora_scale=scale) + key = attn.to_k(encoder_hidden_states, scale=scale) + value = attn.to_v(encoder_hidden_states, scale=scale) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads @@ -1043,7 +1043,7 @@ def __call__( hidden_states = hidden_states.to(query.dtype) # linear proj - hidden_states = attn.to_out[0](hidden_states, lora_scale=scale) + hidden_states = attn.to_out[0](hidden_states, scale=scale) # dropout hidden_states = attn.to_out[1](hidden_states) From 21e765b274aaac29b05f981a7171fe416134063f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 17:27:09 +0530 Subject: [PATCH 12/78] lora_scale in patched projection. --- src/diffusers/loaders.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 6e12a608a8cb..c0e5b1480b4b 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -135,10 +135,10 @@ def _unfuse_lora(self): self.w_up = None self.w_down = None - def forward(self, input): + def forward(self, input, lora_scale: float = 1.0): if self.lora_linear_layer is None: return self.regular_linear_layer(input) - return self.regular_linear_layer(input) + self.lora_scale * self.lora_linear_layer(input) + return self.regular_linear_layer(input) + lora_scale * self.lora_linear_layer(input) def text_encoder_attn_modules(text_encoder): From 9678ed296da2db2fdf84dd98b9e73dca40889fbe Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 17:29:05 +0530 Subject: [PATCH 13/78] debugging --- src/diffusers/loaders.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index c0e5b1480b4b..b79c12177b5f 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -137,7 +137,8 @@ def _unfuse_lora(self): def forward(self, input, lora_scale: float = 1.0): if self.lora_linear_layer is None: - return self.regular_linear_layer(input) + return self.regular_linear_layer(input) + print(f"lora scale from {self.__class__}: {lora_scale}") return self.regular_linear_layer(input) + lora_scale * self.lora_linear_layer(input) From acbbb4d885cbf2a2ccd8f8d78b186b2cbea35231 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 17:31:01 +0530 Subject: [PATCH 14/78] debugging --- src/diffusers/models/lora.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 1557ebcd1b3d..1aa7ede66f54 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -151,6 +151,7 @@ def forward(self, hidden_states, scale: float = 1.0): hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups ) else: + print(f"From {self.__class__.__name__}: scale {scale}") return super().forward(hidden_states) + scale * self.lora_layer(hidden_states) From 0c2dad46220175ea9f83d9fb445ab4381e60b852 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 17:45:19 +0530 Subject: [PATCH 15/78] debugging --- src/diffusers/models/unet_2d_condition.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 1994649f4c59..2a400cb511d5 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -752,6 +752,8 @@ def forward( # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: + print(f"From the UNet: {cross_attention_kwargs['scale']}") default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` From 6269412098578e1489b20e16329a42d224749753 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 18:29:32 +0530 Subject: [PATCH 16/78] debugging --- src/diffusers/loaders.py | 2 +- src/diffusers/models/unet_2d_blocks.py | 115 +++++++++++++++---------- 2 files changed, 69 insertions(+), 48 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index b79c12177b5f..3d92d18888db 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -137,7 +137,7 @@ def _unfuse_lora(self): def forward(self, input, lora_scale: float = 1.0): if self.lora_linear_layer is None: - return self.regular_linear_layer(input) + return self.regular_linear_layer(input) print(f"lora scale from {self.__class__}: {lora_scale}") return self.regular_linear_layer(input) + lora_scale * self.lora_linear_layer(input) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 05b360def7c2..b7b7a08b0275 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -1126,7 +1126,7 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states, temb=None): + def forward(self, hidden_states, temb=None, scale: float = 1.0): output_states = () for resnet in self.resnets: @@ -1147,13 +1147,13 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) + hidden_states = downsampler(hidden_states, scale) output_states = output_states + (hidden_states,) @@ -1209,13 +1209,13 @@ def __init__( else: self.downsamplers = None - def forward(self, hidden_states): + def forward(self, hidden_states, scale: float = 1.0): for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb=None) + hidden_states = resnet(hidden_states, temb=None, scale=scale) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) + hidden_states = downsampler(hidden_states, scale) return hidden_states @@ -1292,14 +1292,14 @@ def __init__( else: self.downsamplers = None - def forward(self, hidden_states): + def forward(self, hidden_states, scale: float = 1.0): for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb=None) + hidden_states = resnet(hidden_states, temb=None, scale=scale) hidden_states = attn(hidden_states) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) + hidden_states = downsampler(hidden_states, scale) return hidden_states @@ -1385,16 +1385,16 @@ def __init__( self.downsamplers = None self.skip_conv = None - def forward(self, hidden_states, temb=None, skip_sample=None): + def forward(self, hidden_states, temb=None, skip_sample=None, scale: float = 1.0): output_states = () for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale=scale) hidden_states = attn(hidden_states) output_states += (hidden_states,) if self.downsamplers is not None: - hidden_states = self.resnet_down(hidden_states, temb) + hidden_states = self.resnet_down(hidden_states, temb, scale=scale) for downsampler in self.downsamplers: skip_sample = downsampler(skip_sample) @@ -1465,15 +1465,15 @@ def __init__( self.downsamplers = None self.skip_conv = None - def forward(self, hidden_states, temb=None, skip_sample=None): + def forward(self, hidden_states, temb=None, skip_sample=None, scale: float = 1.0): output_states = () for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale) output_states += (hidden_states,) if self.downsamplers is not None: - hidden_states = self.resnet_down(hidden_states, temb) + hidden_states = self.resnet_down(hidden_states, temb, scale) for downsampler in self.downsamplers: skip_sample = downsampler(skip_sample) @@ -1548,7 +1548,7 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states, temb=None): + def forward(self, hidden_states, temb=None, scale: float = 1.0): output_states = () for resnet in self.resnets: @@ -1569,13 +1569,13 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, temb) + hidden_states = downsampler(hidden_states, temb, scale) output_states = output_states + (hidden_states,) @@ -1720,7 +1720,10 @@ def custom_forward(*inputs): **cross_attention_kwargs, ) else: - hidden_states = resnet(hidden_states, temb) + if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: + hidden_states = resnet(hidden_states, temb, scale=cross_attention_kwargs["scale"]) + else: + hidden_states = resnet(hidden_states, temb, scale=1.0) hidden_states = attn( hidden_states, @@ -1733,7 +1736,10 @@ def custom_forward(*inputs): if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, temb) + if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: + hidden_states = downsampler(hidden_states, temb, scale=cross_attention_kwargs["scale"]) + else: + hidden_states = downsampler(hidden_states, temb, scale=1.0) output_states = output_states + (hidden_states,) @@ -1786,7 +1792,7 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states, temb=None): + def forward(self, hidden_states, temb=None, scale: float = 1.0): output_states = () for resnet in self.resnets: @@ -1807,7 +1813,7 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale) output_states += (hidden_states,) @@ -1922,7 +1928,10 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, ) else: - hidden_states = resnet(hidden_states, temb) + if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: + hidden_states = resnet(hidden_states, temb, scale=cross_attention_kwargs["scale"]) + else: + hidden_states = resnet(hidden_states, temb, scale=1.0) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -2033,20 +2042,20 @@ def __init__( else: self.upsamplers = None - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0): for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale=scale) hidden_states = attn(hidden_states) if self.upsamplers is not None: for upsampler in self.upsamplers: if self.upsample_type == "resnet": - hidden_states = upsampler(hidden_states, temb=temb) + hidden_states = upsampler(hidden_states, temb=temb, scale=scale) else: hidden_states = upsampler(hidden_states) @@ -2183,7 +2192,10 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - hidden_states = resnet(hidden_states, temb) + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: + hidden_states = resnet(hidden_states, temb, scale=cross_attention_kwargs["scale"]) + else: + hidden_states = resnet(hidden_states, temb, scale=1.0) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -2248,7 +2260,7 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0): for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -2272,7 +2284,7 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale=scale) if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -2325,9 +2337,9 @@ def __init__( else: self.upsamplers = None - def forward(self, hidden_states, temb=None): + def forward(self, hidden_states, temb=None, scale: float = 1.0): for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb=temb) + hidden_states = resnet(hidden_states, temb=temb, scale=scale) if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -2404,9 +2416,9 @@ def __init__( else: self.upsamplers = None - def forward(self, hidden_states, temb=None): + def forward(self, hidden_states, temb=None, scale: float = 1.0): for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb=temb) + hidden_states = resnet(hidden_states, temb=temb, scale=scale) hidden_states = attn(hidden_states, temb=temb) if self.upsamplers is not None: @@ -2507,14 +2519,14 @@ def __init__( self.skip_norm = None self.act = None - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None, scale: float = 1.0): for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale=scale) hidden_states = self.attentions[0](hidden_states) @@ -2530,7 +2542,7 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample skip_sample = skip_sample + skip_sample_states - hidden_states = self.resnet_up(hidden_states, temb) + hidden_states = self.resnet_up(hidden_states, temb, scale=scale) return hidden_states, skip_sample @@ -2604,14 +2616,14 @@ def __init__( self.skip_norm = None self.act = None - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None, scale: float = 1.0): for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale=scale) if skip_sample is not None: skip_sample = self.upsampler(skip_sample) @@ -2625,7 +2637,7 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample skip_sample = skip_sample + skip_sample_states - hidden_states = self.resnet_up(hidden_states, temb) + hidden_states = self.resnet_up(hidden_states, temb, scale=scale) return hidden_states, skip_sample @@ -2697,7 +2709,7 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0): for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -2721,11 +2733,11 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale=scale) if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, temb) + hidden_states = upsampler(hidden_states, temb, scale=scale) return hidden_states @@ -2877,7 +2889,10 @@ def custom_forward(*inputs): **cross_attention_kwargs, ) else: - hidden_states = resnet(hidden_states, temb) + if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: + hidden_states = resnet(hidden_states, temb, scale=cross_attention_kwargs["scale"]) + else: + hidden_states = resnet(hidden_states, temb, scale=1.0) hidden_states = attn( hidden_states, @@ -2888,7 +2903,10 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, temb) + if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: + hidden_states = upsampler(hidden_states, temb, scale=cross_attention_kwargs["scale"]) + else: + hidden_states = upsampler(hidden_states, temb, scale=1.0) return hidden_states @@ -2941,7 +2959,7 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0): res_hidden_states_tuple = res_hidden_states_tuple[-1] if res_hidden_states_tuple is not None: hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) @@ -2964,7 +2982,7 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale=scale) if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -3100,7 +3118,10 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, ) else: - hidden_states = resnet(hidden_states, temb) + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: + hidden_states = resnet(hidden_states, temb, scale=cross_attention_kwargs["scale"]) + else: + hidden_states = resnet(hidden_states, temb, scale=1.0) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, From cc0c7ec0144e6e17e4f5dca94b69d62b646503a1 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 18:40:43 +0530 Subject: [PATCH 17/78] debugging --- src/diffusers/models/transformer_2d.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index ceea754dfd9b..b7c13c930971 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -219,7 +219,6 @@ def forward( cross_attention_kwargs: Dict[str, Any] = None, attention_mask: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, - scale: float = 1.0, return_dict: bool = True, ): """ @@ -244,8 +243,6 @@ def forward( If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format above. This bias will be added to the cross-attention scores. - scale: (`float`, *optional*, defaults to 1.0): - Controls how much to influence the outputs with the LoRA parameters. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. @@ -284,13 +281,19 @@ def forward( hidden_states = self.norm(hidden_states) if not self.use_linear_projection: - hidden_states = self.proj_in(hidden_states, scale) + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: + hidden_states = self.proj_in(hidden_states, scale=cross_attention_kwargs["scale"]) + else: + hidden_states = self.proj_in(hidden_states, scale=1.0) inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) else: inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) - hidden_states = self.proj_in(hidden_states, scale) + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: + hidden_states = self.proj_in(hidden_states, scale=cross_attention_kwargs["scale"]) + else: + hidden_states = self.proj_in(hidden_states, scale=1.0) elif self.is_input_vectorized: hidden_states = self.latent_image_embedding(hidden_states) elif self.is_input_patches: @@ -325,9 +328,15 @@ def forward( if self.is_input_continuous: if not self.use_linear_projection: hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() - hidden_states = self.proj_out(hidden_states, scale) + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: + hidden_states = self.proj_out(hidden_states, scale=cross_attention_kwargs["scale"]) + else: + hidden_states = self.proj_out(hidden_states, scale=1.0) else: - hidden_states = self.proj_out(hidden_states, scale) + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: + hidden_states = self.proj_out(hidden_states, scale=cross_attention_kwargs["scale"]) + else: + hidden_states = self.proj_out(hidden_states, scale=1.0) hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() output = hidden_states + residual From b357ffcfc1bc0732d4577bb3033bd8d8fe24a41c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 18:48:19 +0530 Subject: [PATCH 18/78] debugging --- src/diffusers/models/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 1aa7ede66f54..14e36cae97fe 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -151,7 +151,7 @@ def forward(self, hidden_states, scale: float = 1.0): hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups ) else: - print(f"From {self.__class__.__name__}: scale {scale}") + print(f"From {self.__class__.__name__}: scale {scale}, parent: {self.__bases__}") return super().forward(hidden_states) + scale * self.lora_layer(hidden_states) From 8495b4342b8d431ccdf539381cc56aa0785c31b8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 18:53:56 +0530 Subject: [PATCH 19/78] debugging --- src/diffusers/models/lora.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 14e36cae97fe..f9cdb259c5ea 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -151,7 +151,8 @@ def forward(self, hidden_states, scale: float = 1.0): hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups ) else: - print(f"From {self.__class__.__name__}: scale {scale}, parent: {self.__bases__}") + print(f"Parents: {[cls.__name__ for cls in self.__class__.__bases__]}") + print(f"From {self.__class__.__name__}: scale {scale}") return super().forward(hidden_states) + scale * self.lora_layer(hidden_states) From 96fc1afb3fa41ea3836ad7f773291001feba4f20 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 18:55:37 +0530 Subject: [PATCH 20/78] debugging --- src/diffusers/models/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index f9cdb259c5ea..d358d1119804 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -151,7 +151,7 @@ def forward(self, hidden_states, scale: float = 1.0): hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups ) else: - print(f"Parents: {[cls.__name__ for cls in self.__class__.__bases__]}") + print(f"Parents: {[cls.__module__ for cls in self.__class__.__bases__]}") print(f"From {self.__class__.__name__}: scale {scale}") return super().forward(hidden_states) + scale * self.lora_layer(hidden_states) From 0c501d3df097924435ad2930f6a86cace1cbec60 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 18:58:44 +0530 Subject: [PATCH 21/78] debugging --- src/diffusers/models/lora.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index d358d1119804..993b91ded0fd 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -17,7 +17,7 @@ import torch import torch.nn.functional as F from torch import nn - +import inspect from ..utils import logging @@ -151,7 +151,9 @@ def forward(self, hidden_states, scale: float = 1.0): hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups ) else: - print(f"Parents: {[cls.__module__ for cls in self.__class__.__bases__]}") + caller_frame = inspect.currentframe().f_back + caller_class_name = caller_frame.f_locals.get('self').__class__.__name__ + print("Caller class:", caller_class_name) print(f"From {self.__class__.__name__}: scale {scale}") return super().forward(hidden_states) + scale * self.lora_layer(hidden_states) From 016d3e95a5152e3f87377c8d05ecf244447a10bb Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 19:09:54 +0530 Subject: [PATCH 22/78] debugging --- src/diffusers/models/resnet.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 8acc0c826896..2ff000440525 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -612,11 +612,11 @@ def forward(self, input_tensor, temb, scale: float = 1.0): if hidden_states.shape[0] >= 64: input_tensor = input_tensor.contiguous() hidden_states = hidden_states.contiguous() - input_tensor = self.upsample(input_tensor) - hidden_states = self.upsample(hidden_states) + input_tensor = self.upsample(input_tensor, scale=scale) if isinstance(self.upsample, Upsample2D) else self.upsample(input_tensor) + hidden_states = self.upsample(hidden_states, scale=scale) if isinstance(self.upsample, Upsample2D) else self.upsample(hidden_states) elif self.downsample is not None: - input_tensor = self.downsample(input_tensor) - hidden_states = self.downsample(hidden_states) + input_tensor = self.downsample(input_tensor, scale=scale) if isinstance(self.downsample, Downsample2D) else self.downsample(input_tensor) + hidden_states = self.downsample(hidden_states, scale=scale) if isinstance(self.downsample, Downsample2D) else self.downsample(hidden_states) hidden_states = self.conv1(hidden_states, scale) From 910e96b42cfd3e46176d5ab45c43eb65261e6003 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 19:11:20 +0530 Subject: [PATCH 23/78] debugging --- src/diffusers/models/lora.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 993b91ded0fd..1520ae9de7ed 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -151,9 +151,6 @@ def forward(self, hidden_states, scale: float = 1.0): hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups ) else: - caller_frame = inspect.currentframe().f_back - caller_class_name = caller_frame.f_locals.get('self').__class__.__name__ - print("Caller class:", caller_class_name) print(f"From {self.__class__.__name__}: scale {scale}") return super().forward(hidden_states) + scale * self.lora_layer(hidden_states) From de159daddf91a3aa6548b27b41bf626324345d5a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 19:28:44 +0530 Subject: [PATCH 24/78] debugging --- src/diffusers/models/attention.py | 9 +++++++-- src/diffusers/models/unet_2d_blocks.py | 28 ++++++++++++++++++++------ 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 0e2b334f2990..91b0fc05f0b8 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -233,13 +233,18 @@ def forward( f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." ) + if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: + scale = cross_attention_kwargs["scale"] + else: + scale = 1.0 + num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size ff_output = torch.cat( - [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], + [self.ff(hid_slice, scale=scale) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], dim=self._chunk_dim, ) else: - ff_output = self.ff(norm_hidden_states) + ff_output = self.ff(norm_hidden_states, scale=scale) if self.use_ada_layer_norm_zero: ff_output = gate_mlp.unsqueeze(1) * ff_output diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index b7b7a08b0275..cfcb43c21da8 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -639,8 +639,12 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - ) -> torch.FloatTensor: - hidden_states = self.resnets[0](hidden_states, temb) + ) -> torch.FloatTensor: + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: + scale = cross_attention_kwargs["scale"] + else: + scale = 1.0 + hidden_states = self.resnets[0](hidden_states, temb, scale=scale) for attn, resnet in zip(self.attentions, self.resnets[1:]): if self.training and self.gradient_checkpointing: @@ -677,7 +681,10 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] - hidden_states = resnet(hidden_states, temb) + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: + hidden_states = resnet(hidden_states, temb, scale=cross_attention_kwargs["scale"]) + else: + hidden_states = resnet(hidden_states, temb, scale=1.0) return hidden_states @@ -1049,7 +1056,10 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - hidden_states = resnet(hidden_states, temb) + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: + hidden_states = resnet(hidden_states, temb, scale=cross_attention_kwargs["scale"]) + else: + hidden_states = resnet(hidden_states, temb, scale=1.0) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -1067,7 +1077,10 @@ def custom_forward(*inputs): if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: + hidden_states = downsampler(hidden_states, scale=cross_attention_kwargs["scale"]) + else: + hidden_states = downsampler(hidden_states, scale=1.0) output_states = output_states + (hidden_states,) @@ -2207,7 +2220,10 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: + hidden_states = upsampler(hidden_states, upsample_size, scale=cross_attention_kwargs["scale"]) + else: + hidden_states = upsampler(hidden_states, upsample_size, scale=1.0) return hidden_states From 4ee8dbfe9a4df1459c0394f887da7f359acbc57b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 19:32:09 +0530 Subject: [PATCH 25/78] debugging --- src/diffusers/models/attention.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 91b0fc05f0b8..e5d2d4b602d5 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -244,6 +244,10 @@ def forward( dim=self._chunk_dim, ) else: + if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: + scale = cross_attention_kwargs["scale"] + else: + scale = 1.0 ff_output = self.ff(norm_hidden_states, scale=scale) if self.use_ada_layer_norm_zero: From cd9ac470d5327466403b0e3d94c2b70f7af8a7b2 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 19:46:28 +0530 Subject: [PATCH 26/78] styling. --- src/diffusers/models/attention.py | 5 ++++- src/diffusers/models/lora.py | 2 +- src/diffusers/models/resnet.py | 24 ++++++++++++++++++++---- src/diffusers/models/transformer_2d.py | 2 +- src/diffusers/models/unet_2d_blocks.py | 2 +- 5 files changed, 27 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index e5d2d4b602d5..71af64a32069 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -240,7 +240,10 @@ def forward( num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size ff_output = torch.cat( - [self.ff(hid_slice, scale=scale) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], + [ + self.ff(hid_slice, scale=scale) + for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim) + ], dim=self._chunk_dim, ) else: diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 1520ae9de7ed..1aa7ede66f54 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -17,7 +17,7 @@ import torch import torch.nn.functional as F from torch import nn -import inspect + from ..utils import logging diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 2ff000440525..ac66e2271c61 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -612,11 +612,27 @@ def forward(self, input_tensor, temb, scale: float = 1.0): if hidden_states.shape[0] >= 64: input_tensor = input_tensor.contiguous() hidden_states = hidden_states.contiguous() - input_tensor = self.upsample(input_tensor, scale=scale) if isinstance(self.upsample, Upsample2D) else self.upsample(input_tensor) - hidden_states = self.upsample(hidden_states, scale=scale) if isinstance(self.upsample, Upsample2D) else self.upsample(hidden_states) + input_tensor = ( + self.upsample(input_tensor, scale=scale) + if isinstance(self.upsample, Upsample2D) + else self.upsample(input_tensor) + ) + hidden_states = ( + self.upsample(hidden_states, scale=scale) + if isinstance(self.upsample, Upsample2D) + else self.upsample(hidden_states) + ) elif self.downsample is not None: - input_tensor = self.downsample(input_tensor, scale=scale) if isinstance(self.downsample, Downsample2D) else self.downsample(input_tensor) - hidden_states = self.downsample(hidden_states, scale=scale) if isinstance(self.downsample, Downsample2D) else self.downsample(hidden_states) + input_tensor = ( + self.downsample(input_tensor, scale=scale) + if isinstance(self.downsample, Downsample2D) + else self.downsample(input_tensor) + ) + hidden_states = ( + self.downsample(hidden_states, scale=scale) + if isinstance(self.downsample, Downsample2D) + else self.downsample(hidden_states) + ) hidden_states = self.conv1(hidden_states, scale) diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index b7c13c930971..de390285f390 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -283,7 +283,7 @@ def forward( if not self.use_linear_projection: if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: hidden_states = self.proj_in(hidden_states, scale=cross_attention_kwargs["scale"]) - else: + else: hidden_states = self.proj_in(hidden_states, scale=1.0) inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index cfcb43c21da8..7e8a55049a40 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -639,7 +639,7 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - ) -> torch.FloatTensor: + ) -> torch.FloatTensor: if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: scale = cross_attention_kwargs["scale"] else: From 1cd983fbf1aa905d9ded41c3bc17a1983f5c13fb Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 20:00:45 +0530 Subject: [PATCH 27/78] debugging --- src/diffusers/models/unet_2d_blocks.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 7e8a55049a40..6281633f8665 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -807,7 +807,10 @@ def forward( ) # resnet - hidden_states = resnet(hidden_states, temb) + if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: + hidden_states = resnet(hidden_states, temb, scale=cross_attention_kwargs["scale"]) + else: + hidden_states = resnet(hidden_states, temb, scale=1.0) return hidden_states @@ -904,20 +907,28 @@ def __init__( else: self.downsamplers = None - def forward(self, hidden_states, temb=None, upsample_size=None): + def forward(self, hidden_states, temb=None, upsample_size=None, cross_attention_kwargs=None): + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} output_states = () for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) + if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: + hidden_states = resnet(hidden_states, temb, scale=cross_attention_kwargs["scale"]) + else: + hidden_states = resnet(hidden_states, temb, scale=1.0) + hidden_states = attn(hidden_states, cross_attention_kwargs=cross_attention_kwargs) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: + if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: + scale = cross_attention_kwargs["scale"] + else: + scale = 1.0 if self.downsample_type == "resnet": - hidden_states = downsampler(hidden_states, temb=temb) + hidden_states = downsampler(hidden_states, temb=temb, scale=scale) else: - hidden_states = downsampler(hidden_states) + hidden_states = downsampler(hidden_states, scale=scale) output_states += (hidden_states,) From 860a374d1a4d5bc8b1b406e15707ad676eabb8bc Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 20:21:57 +0530 Subject: [PATCH 28/78] debugging --- src/diffusers/models/transformer_2d.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index de390285f390..c339f4d5cc84 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -251,6 +251,9 @@ def forward( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: + scale = cross_attention_kwargs["scale"] + print(f"From {self.__class__.__name__}: scale - {scale}") # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. From 1b2346caea14d88b90cdfa1c88149902f2616d12 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 20:26:01 +0530 Subject: [PATCH 29/78] debugging --- src/diffusers/models/resnet.py | 3 +++ src/diffusers/models/transformer_2d.py | 3 --- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index ac66e2271c61..d9ce7b96e049 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -136,6 +136,7 @@ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_chann self.Conv2d_0 = conv def forward(self, hidden_states, output_size=None, scale: float = 1.0): + print(f"From {self.__class__.__name__} scale: {scale}") assert hidden_states.shape[1] == self.channels if self.use_conv_transpose: @@ -218,6 +219,7 @@ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name= self.conv = conv def forward(self, hidden_states, scale: float = 1.0): + print(f"From {self.__class__.__name__} scale: {scale}") assert hidden_states.shape[1] == self.channels if self.use_conv and self.padding == 0: pad = (0, 1, 0, 1) @@ -598,6 +600,7 @@ def __init__( ) def forward(self, input_tensor, temb, scale: float = 1.0): + print(f"From {self.__class__.__name__} scale: {scale}") hidden_states = input_tensor if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index c339f4d5cc84..de390285f390 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -251,9 +251,6 @@ def forward( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - scale = cross_attention_kwargs["scale"] - print(f"From {self.__class__.__name__}: scale - {scale}") # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. From 583da5f01c064a51fa114f7ec404319037040cbb Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 20:34:58 +0530 Subject: [PATCH 30/78] debugging --- src/diffusers/models/unet_2d_blocks.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 6281633f8665..35c522910e1e 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -796,7 +796,11 @@ def forward( # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask mask = attention_mask - hidden_states = self.resnets[0](hidden_states, temb) + if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: + scale = cross_attention_kwargs["scale"] + else: + scale = 1.0 + hidden_states = self.resnets[0](hidden_states, temb, scale=scale) for attn, resnet in zip(self.attentions, self.resnets[1:]): # attn hidden_states = attn( From 77f64591c0efa1e87796854b6e21cbdc20233dc1 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 20:53:14 +0530 Subject: [PATCH 31/78] debugging --- src/diffusers/models/unet_2d_condition.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 2a400cb511d5..5d1d85f7acb4 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -939,7 +939,11 @@ def forward( **additional_residuals, ) else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: + scale = cross_attention_kwargs["scale"] + else: + scale = 1.0 + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=scale) if is_adapter and len(down_block_additional_residuals) > 0: sample += down_block_additional_residuals.pop(0) @@ -1002,8 +1006,12 @@ def forward( encoder_attention_mask=encoder_attention_mask, ) else: + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: + scale = cross_attention_kwargs["scale"] + else: + scale = 1.0 sample = upsample_block( - hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, scale=scale ) # 6. post-process From ec67361c22eb4a770f6dcf348942bcf93df433e5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 20:59:43 +0530 Subject: [PATCH 32/78] debugging --- src/diffusers/models/unet_2d_condition.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 5d1d85f7acb4..558bf024d094 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -984,6 +984,7 @@ def forward( # 5. up for i, upsample_block in enumerate(self.up_blocks): + print(f"Upblock: {upsample_block.__class__.__name__}") is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] From 6c9c5dc5a5b5a4fb02577e12ed62699014ee55ec Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 21:05:45 +0530 Subject: [PATCH 33/78] debugging --- .../stable_diffusion_xl/pipeline_stable_diffusion_xl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 11e575d68269..57bcaa78dc97 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -877,6 +877,7 @@ def __call__( if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + print("Coming after VAE?") else: image = latents return StableDiffusionXLPipelineOutput(images=image) From d7b35d46484f0778c4a4446ebfce51b1eab6691b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 21:11:56 +0530 Subject: [PATCH 34/78] debugging --- src/diffusers/models/unet_2d_condition.py | 1 + .../stable_diffusion_xl/pipeline_stable_diffusion_xl.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 558bf024d094..32f591c959f1 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -996,6 +996,7 @@ def forward( upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + print(f"From cross_attention: {upsample_block.__class__.__name__}") sample = upsample_block( hidden_states=sample, temb=emb, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 57bcaa78dc97..11e575d68269 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -877,7 +877,6 @@ def __call__( if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - print("Coming after VAE?") else: image = latents return StableDiffusionXLPipelineOutput(images=image) From e601d2b991e2a25528e1b58a935750a067525253 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 21:17:06 +0530 Subject: [PATCH 35/78] debugging --- src/diffusers/models/unet_2d_blocks.py | 2 +- src/diffusers/models/unet_2d_condition.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 35c522910e1e..004e921c642a 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -2319,7 +2319,7 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) + hidden_states = upsampler(hidden_states, upsample_size, scale=scale) return hidden_states diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 32f591c959f1..558bf024d094 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -996,7 +996,6 @@ def forward( upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: - print(f"From cross_attention: {upsample_block.__class__.__name__}") sample = upsample_block( hidden_states=sample, temb=emb, From 35148d072e15153162d60279e39e5c24093cf339 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 21:20:21 +0530 Subject: [PATCH 36/78] debugging --- src/diffusers/models/unet_2d_condition.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 558bf024d094..65f86cc1fcdc 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -1009,6 +1009,7 @@ def forward( else: if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: scale = cross_attention_kwargs["scale"] + print(f"cross_attention_kwargs: {scale}") else: scale = 1.0 sample = upsample_block( From 55efe9cde5893b55cdb207619ad8b95212957853 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 21:24:29 +0530 Subject: [PATCH 37/78] debugging --- src/diffusers/models/unet_2d_blocks.py | 1 + src/diffusers/models/unet_2d_condition.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 004e921c642a..f652d952cf3e 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -2292,6 +2292,7 @@ def __init__( self.gradient_checkpointing = False def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0): + print(f"{self.__class__.__name__} scale: {scale}") for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 65f86cc1fcdc..558bf024d094 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -1009,7 +1009,6 @@ def forward( else: if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: scale = cross_attention_kwargs["scale"] - print(f"cross_attention_kwargs: {scale}") else: scale = 1.0 sample = upsample_block( From 0d7b3df0481ecdbe27c9ab8668ee829271573da4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 22:35:41 +0530 Subject: [PATCH 38/78] debugging --- src/diffusers/models/lora.py | 2 +- src/diffusers/models/unet_2d_blocks.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 1aa7ede66f54..04937e0c7510 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -151,7 +151,6 @@ def forward(self, hidden_states, scale: float = 1.0): hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups ) else: - print(f"From {self.__class__.__name__}: scale {scale}") return super().forward(hidden_states) + scale * self.lora_layer(hidden_states) @@ -211,4 +210,5 @@ def forward(self, hidden_states, scale: float = 1.0): if self.lora_layer is None: return super().forward(hidden_states) else: + print(f"From {self.__class__.__name__}: scale {scale}") return super().forward(hidden_states) + scale * self.lora_layer(hidden_states) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index f652d952cf3e..8d53e16c7d6c 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -1175,13 +1175,13 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb, scale) + hidden_states = resnet(hidden_states, temb, scale=scale) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, scale) + hidden_states = downsampler(hidden_states, scale=scale) output_states = output_states + (hidden_states,) @@ -2292,7 +2292,6 @@ def __init__( self.gradient_checkpointing = False def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0): - print(f"{self.__class__.__name__} scale: {scale}") for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] From cdc79631513382ff2e04a7a3647060d0573bc495 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 22:42:05 +0530 Subject: [PATCH 39/78] remove unneeded prints. --- src/diffusers/models/resnet.py | 3 --- src/diffusers/models/unet_2d_condition.py | 3 --- 2 files changed, 6 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index d9ce7b96e049..ac66e2271c61 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -136,7 +136,6 @@ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_chann self.Conv2d_0 = conv def forward(self, hidden_states, output_size=None, scale: float = 1.0): - print(f"From {self.__class__.__name__} scale: {scale}") assert hidden_states.shape[1] == self.channels if self.use_conv_transpose: @@ -219,7 +218,6 @@ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name= self.conv = conv def forward(self, hidden_states, scale: float = 1.0): - print(f"From {self.__class__.__name__} scale: {scale}") assert hidden_states.shape[1] == self.channels if self.use_conv and self.padding == 0: pad = (0, 1, 0, 1) @@ -600,7 +598,6 @@ def __init__( ) def forward(self, input_tensor, temb, scale: float = 1.0): - print(f"From {self.__class__.__name__} scale: {scale}") hidden_states = input_tensor if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 558bf024d094..168376085f93 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -752,8 +752,6 @@ def forward( # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - print(f"From the UNet: {cross_attention_kwargs['scale']}") default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` @@ -984,7 +982,6 @@ def forward( # 5. up for i, upsample_block in enumerate(self.up_blocks): - print(f"Upblock: {upsample_block.__class__.__name__}") is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] From 2a3e358bc9de2b61e90b65d9c092539e8154f15c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 22:52:43 +0530 Subject: [PATCH 40/78] remove unneeded prints. --- src/diffusers/models/attention_processor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 3150f13466d5..5cc75a6357bb 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -991,6 +991,7 @@ def __call__( temb=None, scale: float = 1.0, ): + print(f"{self.__class__.__name__} yields a scale of {scale}.") residual = hidden_states if attn.spatial_norm is not None: From 42c2c0ae31e9da40f13f1a4caa5ceadd3b18c5ee Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 23:08:01 +0530 Subject: [PATCH 41/78] assign cross_attention_kwargs. --- src/diffusers/models/unet_2d_blocks.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 8d53e16c7d6c..64977ab59c5e 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -1323,7 +1323,8 @@ def __init__( def forward(self, hidden_states, scale: float = 1.0): for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb=None, scale=scale) - hidden_states = attn(hidden_states) + cross_attention_kwargs = {"scale": scale} + hidden_states = attn(hidden_states, cross_attention_kwargs=cross_attention_kwargs) if self.downsamplers is not None: for downsampler in self.downsamplers: @@ -1418,7 +1419,8 @@ def forward(self, hidden_states, temb=None, skip_sample=None, scale: float = 1.0 for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb, scale=scale) - hidden_states = attn(hidden_states) + cross_attention_kwargs = {"scale": scale} + hidden_states = attn(hidden_states, cross_attention_kwargs=cross_attention_kwargs) output_states += (hidden_states,) if self.downsamplers is not None: @@ -2078,14 +2080,15 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_si hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb, scale=scale) - hidden_states = attn(hidden_states) + cross_attention_kwargs = {"scale": scale} + hidden_states = attn(hidden_states, cross_attention_kwargs=cross_attention_kwargs) if self.upsamplers is not None: for upsampler in self.upsamplers: if self.upsample_type == "resnet": hidden_states = upsampler(hidden_states, temb=temb, scale=scale) else: - hidden_states = upsampler(hidden_states) + hidden_states = upsampler(hidden_states, scale=scale) return hidden_states @@ -2450,11 +2453,12 @@ def __init__( def forward(self, hidden_states, temb=None, scale: float = 1.0): for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb=temb, scale=scale) - hidden_states = attn(hidden_states, temb=temb) + cross_attention_kwargs = {"scale": scale} + hidden_states = attn(hidden_states, temb=temb, cross_attention_kwargs=cross_attention_kwargs) if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) + hidden_states = upsampler(hidden_states, scale=scale) return hidden_states @@ -2559,7 +2563,8 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample hidden_states = resnet(hidden_states, temb, scale=scale) - hidden_states = self.attentions[0](hidden_states) + cross_attention_kwargs = {"scale": scale} + hidden_states = self.attentions[0](hidden_states, cross_attention_kwargs=cross_attention_kwargs) if skip_sample is not None: skip_sample = self.upsampler(skip_sample) From 98e6eca909581ccf53ed0a7ab56a47c9b6b36737 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 23:28:55 +0530 Subject: [PATCH 42/78] debugging --- src/diffusers/loaders.py | 1 - src/diffusers/models/unet_2d_blocks.py | 6 ++++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 3d92d18888db..c0e5b1480b4b 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -138,7 +138,6 @@ def _unfuse_lora(self): def forward(self, input, lora_scale: float = 1.0): if self.lora_linear_layer is None: return self.regular_linear_layer(input) - print(f"lora scale from {self.__class__}: {lora_scale}") return self.regular_linear_layer(input) + lora_scale * self.lora_linear_layer(input) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 64977ab59c5e..2ced37c9d965 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -917,9 +917,11 @@ def forward(self, hidden_states, temb=None, upsample_size=None, cross_attention_ for resnet, attn in zip(self.resnets, self.attentions): if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: - hidden_states = resnet(hidden_states, temb, scale=cross_attention_kwargs["scale"]) + scale = cross_attention_kwargs["scale"] else: - hidden_states = resnet(hidden_states, temb, scale=1.0) + scale = 1.0 + cross_attention_kwargs.update({"scale": scale}) + hidden_states = resnet(hidden_states, temb, scale=scale) hidden_states = attn(hidden_states, cross_attention_kwargs=cross_attention_kwargs) output_states = output_states + (hidden_states,) From 03abb4c0796d3b0ea45b8d1a915877ccc513818f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 23:31:02 +0530 Subject: [PATCH 43/78] debugging --- src/diffusers/models/unet_2d_condition.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 168376085f93..2d1c6186b4bd 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -1008,6 +1008,7 @@ def forward( scale = cross_attention_kwargs["scale"] else: scale = 1.0 + print("Last block.") sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, scale=scale ) From 32a175fbdd8ec1e78c4bb2963b688dcf8b182db4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 23:32:28 +0530 Subject: [PATCH 44/78] debugging --- src/diffusers/models/unet_2d_condition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 2d1c6186b4bd..eccc765e84f1 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -982,6 +982,7 @@ def forward( # 5. up for i, upsample_block in enumerate(self.up_blocks): + print(f"upsample_block: {upsample_block.__class__.__name__}") is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] @@ -1008,7 +1009,6 @@ def forward( scale = cross_attention_kwargs["scale"] else: scale = 1.0 - print("Last block.") sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, scale=scale ) From ef1ad841e9145f65cbbb177da183e40fc5a485ba Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 23:46:08 +0530 Subject: [PATCH 45/78] debugging --- src/diffusers/models/transformer_2d.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index de390285f390..88dc62af501c 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -261,6 +261,8 @@ def forward( # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs["scale"]: + print(f"{self.__class__.name} scale: {cross_attention_kwargs['scale']}") if attention_mask is not None and attention_mask.ndim == 2: # assume that mask is expressed as: # (1 = keep, 0 = discard) From 9a759b967c0e85ee84282e51a985c419b22cd5b6 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 23:46:51 +0530 Subject: [PATCH 46/78] debugging --- src/diffusers/models/transformer_2d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index 88dc62af501c..98bff42fe8a9 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -261,7 +261,7 @@ def forward( # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs["scale"]: + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: print(f"{self.__class__.name} scale: {cross_attention_kwargs['scale']}") if attention_mask is not None and attention_mask.ndim == 2: # assume that mask is expressed as: From 369a53f04b8e6706978e8adde07b922849386277 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 23:49:40 +0530 Subject: [PATCH 47/78] debugging --- src/diffusers/models/transformer_2d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index 98bff42fe8a9..9863d0f3e2d5 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -262,7 +262,7 @@ def forward( # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - print(f"{self.__class__.name} scale: {cross_attention_kwargs['scale']}") + print(f"{self.__class__.__name__} scale: {cross_attention_kwargs['scale']}") if attention_mask is not None and attention_mask.ndim == 2: # assume that mask is expressed as: # (1 = keep, 0 = discard) From 833fd358bb3a8f47a51cb3cd62a65ea1caa496da Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 30 Aug 2023 23:51:19 +0530 Subject: [PATCH 48/78] debugging --- src/diffusers/models/dual_transformer_2d.py | 3 ++- src/diffusers/models/transformer_2d.py | 2 -- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/dual_transformer_2d.py b/src/diffusers/models/dual_transformer_2d.py index 3db7e73ca6af..a21d9468fd47 100644 --- a/src/diffusers/models/dual_transformer_2d.py +++ b/src/diffusers/models/dual_transformer_2d.py @@ -124,7 +124,8 @@ def forward( returning a tuple, the first element is the sample tensor. """ input_states = hidden_states - + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: + print(f"{self.__class__.__name__} scale: {cross_attention_kwargs['scale']}") encoded_states = [] tokens_start = 0 # attention_mask is not used yet diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index 9863d0f3e2d5..de390285f390 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -261,8 +261,6 @@ def forward( # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - print(f"{self.__class__.__name__} scale: {cross_attention_kwargs['scale']}") if attention_mask is not None and attention_mask.ndim == 2: # assume that mask is expressed as: # (1 = keep, 0 = discard) From d8371ab1c57ed9398df9c326ea45d7122373156d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 31 Aug 2023 00:19:37 +0530 Subject: [PATCH 49/78] debugging --- src/diffusers/models/dual_transformer_2d.py | 3 +-- src/diffusers/models/unet_2d_condition.py | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/dual_transformer_2d.py b/src/diffusers/models/dual_transformer_2d.py index a21d9468fd47..d205bbece9a5 100644 --- a/src/diffusers/models/dual_transformer_2d.py +++ b/src/diffusers/models/dual_transformer_2d.py @@ -124,8 +124,7 @@ def forward( returning a tuple, the first element is the sample tensor. """ input_states = hidden_states - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - print(f"{self.__class__.__name__} scale: {cross_attention_kwargs['scale']}") + encoded_states = [] tokens_start = 0 # attention_mask is not used yet diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index eccc765e84f1..fff29cb84696 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -921,6 +921,7 @@ def forward( down_block_res_samples = (sample,) for downsample_block in self.down_blocks: + print(f"downsample_block: {downsample_block.__class__.__name__}") if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: # For t2i-adapter CrossAttnDownBlock2D additional_residuals = {} @@ -961,6 +962,7 @@ def forward( # 4. mid if self.mid_block is not None: + print(f"mid_block: {self.mid_block.__class__.__name__}") sample = self.mid_block( sample, emb, @@ -982,7 +984,6 @@ def forward( # 5. up for i, upsample_block in enumerate(self.up_blocks): - print(f"upsample_block: {upsample_block.__class__.__name__}") is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] From a5925ab6b3ba7f76feadc43d785a2a0ff643b563 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 31 Aug 2023 00:27:03 +0530 Subject: [PATCH 50/78] debugging --- src/diffusers/models/unet_2d_blocks.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 2ced37c9d965..1309eecaab57 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -1077,6 +1077,7 @@ def custom_forward(*inputs): hidden_states = resnet(hidden_states, temb, scale=cross_attention_kwargs["scale"]) else: hidden_states = resnet(hidden_states, temb, scale=1.0) + print(f"From {self.__class__.__name__}: {cross_attention_kwargs}") hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, From d3d6ab11b27087a822c011b8833abbf9fa2611cd Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 31 Aug 2023 00:33:38 +0530 Subject: [PATCH 51/78] debugging --- src/diffusers/models/transformer_2d.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index de390285f390..5399c1769e94 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -251,6 +251,7 @@ def forward( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ + print(f"{self.__class__.__name__} cross_attention_kwargs: {cross_attention_kwargs}") # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. From 8c0b5846d9ba2427c646b502c8d4c6e583e14c6e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 31 Aug 2023 00:37:05 +0530 Subject: [PATCH 52/78] debugging --- src/diffusers/models/attention.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 71af64a32069..300366ece440 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -178,6 +178,7 @@ def forward( ): # Notice that normalization is always applied before the real computation in the following blocks. # 1. Self-Attention + print(f"From {self.__class__.__name__}: {cross_attention_kwargs}") if self.use_ada_layer_norm: norm_hidden_states = self.norm1(hidden_states, timestep) elif self.use_ada_layer_norm_zero: From 43d6c8d6b42b9b91d2f394947ac913dd78566482 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 31 Aug 2023 00:41:54 +0530 Subject: [PATCH 53/78] debugging --- src/diffusers/models/attention_processor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 5cc75a6357bb..9b1d60ff8bf6 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -416,6 +416,7 @@ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None # The `Attention` class can call different attention processors / attention functions # here we simply pass along all tensors to the selected processor class # For standard processors that are defined here, `**cross_attention_kwargs` is empty + print(f"{self.__class__.__name__} {cross_attention_kwargs}") return self.processor( self, hidden_states, From caa86251b6c664f390c39f8adc5dc4f3da2a9ce5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 31 Aug 2023 00:46:59 +0530 Subject: [PATCH 54/78] debugging --- src/diffusers/models/resnet.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index ac66e2271c61..1825dffb695d 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -598,6 +598,7 @@ def __init__( ) def forward(self, input_tensor, temb, scale: float = 1.0): + print(f"{self.__class__.__name__} scale {scale}") hidden_states = input_tensor if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": From 38cbe4616f5699e313557940ea4e97377807f13d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 31 Aug 2023 00:50:00 +0530 Subject: [PATCH 55/78] debugging --- src/diffusers/models/attention.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 300366ece440..001533c7f6f5 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -311,6 +311,7 @@ def __init__( def forward(self, hidden_states, scale: float = 1.0): for module in self.net: if isinstance(module, LoRACompatibleLinear): + print(f"{self.__class__.__name__} scale: {scale}") hidden_states = module(hidden_states, scale) else: hidden_states = module(hidden_states) @@ -359,6 +360,7 @@ def gelu(self, gate): return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) def forward(self, hidden_states, scale: float = 1.0): + print(f"{self.__class__.__name__} {scale}") hidden_states, gate = self.proj(hidden_states, scale).chunk(2, dim=-1) return hidden_states * self.gelu(gate) From b29e025a7b33d0e03954e85b21012b95fe91c0eb Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 31 Aug 2023 00:54:29 +0530 Subject: [PATCH 56/78] debugging --- src/diffusers/models/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 001533c7f6f5..bdcd4c12f2f0 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -310,7 +310,7 @@ def __init__( def forward(self, hidden_states, scale: float = 1.0): for module in self.net: - if isinstance(module, LoRACompatibleLinear): + if isinstance(module, (LoRACompatibleLinear, GEGLU)): print(f"{self.__class__.__name__} scale: {scale}") hidden_states = module(hidden_states, scale) else: From b275947053459b264980c732c534b2e9ebc7a70a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 31 Aug 2023 00:56:11 +0530 Subject: [PATCH 57/78] debugging --- src/diffusers/models/attention.py | 1 - src/diffusers/models/attention_processor.py | 1 - src/diffusers/models/resnet.py | 1 - 3 files changed, 3 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index bdcd4c12f2f0..95706ac0a819 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -178,7 +178,6 @@ def forward( ): # Notice that normalization is always applied before the real computation in the following blocks. # 1. Self-Attention - print(f"From {self.__class__.__name__}: {cross_attention_kwargs}") if self.use_ada_layer_norm: norm_hidden_states = self.norm1(hidden_states, timestep) elif self.use_ada_layer_norm_zero: diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 9b1d60ff8bf6..5cc75a6357bb 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -416,7 +416,6 @@ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None # The `Attention` class can call different attention processors / attention functions # here we simply pass along all tensors to the selected processor class # For standard processors that are defined here, `**cross_attention_kwargs` is empty - print(f"{self.__class__.__name__} {cross_attention_kwargs}") return self.processor( self, hidden_states, diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 1825dffb695d..ac66e2271c61 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -598,7 +598,6 @@ def __init__( ) def forward(self, input_tensor, temb, scale: float = 1.0): - print(f"{self.__class__.__name__} scale {scale}") hidden_states = input_tensor if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": From d8b4bf71eb6a5fbc946b456e30602f08cec13818 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 31 Aug 2023 00:58:04 +0530 Subject: [PATCH 58/78] debugging --- src/diffusers/models/unet_2d_blocks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 1309eecaab57..2ced37c9d965 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -1077,7 +1077,6 @@ def custom_forward(*inputs): hidden_states = resnet(hidden_states, temb, scale=cross_attention_kwargs["scale"]) else: hidden_states = resnet(hidden_states, temb, scale=1.0) - print(f"From {self.__class__.__name__}: {cross_attention_kwargs}") hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, From a3df6cd8f69193b071f75ffc99daf75ecf52bf77 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 31 Aug 2023 00:59:26 +0530 Subject: [PATCH 59/78] debugging --- src/diffusers/models/attention_processor.py | 1 - src/diffusers/models/unet_2d_condition.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 5cc75a6357bb..3150f13466d5 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -991,7 +991,6 @@ def __call__( temb=None, scale: float = 1.0, ): - print(f"{self.__class__.__name__} yields a scale of {scale}.") residual = hidden_states if attn.spatial_norm is not None: diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index fff29cb84696..168376085f93 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -921,7 +921,6 @@ def forward( down_block_res_samples = (sample,) for downsample_block in self.down_blocks: - print(f"downsample_block: {downsample_block.__class__.__name__}") if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: # For t2i-adapter CrossAttnDownBlock2D additional_residuals = {} @@ -962,7 +961,6 @@ def forward( # 4. mid if self.mid_block is not None: - print(f"mid_block: {self.mid_block.__class__.__name__}") sample = self.mid_block( sample, emb, From 265d5f4ccdefc02635fb7d46801d50e8f17e8088 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 31 Aug 2023 01:01:10 +0530 Subject: [PATCH 60/78] debugging --- src/diffusers/models/transformer_2d.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index 5399c1769e94..de390285f390 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -251,7 +251,6 @@ def forward( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ - print(f"{self.__class__.__name__} cross_attention_kwargs: {cross_attention_kwargs}") # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. From 7d348840ae489fc1d8a58dc736b351103a81ea55 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 31 Aug 2023 01:06:14 +0530 Subject: [PATCH 61/78] clean up. --- src/diffusers/models/attention.py | 2 - src/diffusers/models/dual_transformer_2d.py | 2 +- src/diffusers/models/lora.py | 1 - src/diffusers/models/unet_2d_condition.py | 6 +- .../versatile_diffusion/modeling_text_unet.py | 70 ++++++++++++++----- 5 files changed, 60 insertions(+), 21 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 95706ac0a819..eef70adb0906 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -310,7 +310,6 @@ def __init__( def forward(self, hidden_states, scale: float = 1.0): for module in self.net: if isinstance(module, (LoRACompatibleLinear, GEGLU)): - print(f"{self.__class__.__name__} scale: {scale}") hidden_states = module(hidden_states, scale) else: hidden_states = module(hidden_states) @@ -359,7 +358,6 @@ def gelu(self, gate): return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) def forward(self, hidden_states, scale: float = 1.0): - print(f"{self.__class__.__name__} {scale}") hidden_states, gate = self.proj(hidden_states, scale).chunk(2, dim=-1) return hidden_states * self.gelu(gate) diff --git a/src/diffusers/models/dual_transformer_2d.py b/src/diffusers/models/dual_transformer_2d.py index d205bbece9a5..3db7e73ca6af 100644 --- a/src/diffusers/models/dual_transformer_2d.py +++ b/src/diffusers/models/dual_transformer_2d.py @@ -124,7 +124,7 @@ def forward( returning a tuple, the first element is the sample tensor. """ input_states = hidden_states - + encoded_states = [] tokens_start = 0 # attention_mask is not used yet diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 04937e0c7510..1557ebcd1b3d 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -210,5 +210,4 @@ def forward(self, hidden_states, scale: float = 1.0): if self.lora_layer is None: return super().forward(hidden_states) else: - print(f"From {self.__class__.__name__}: scale {scale}") return super().forward(hidden_states) + scale * self.lora_layer(hidden_states) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 168376085f93..902745b854ba 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -1009,7 +1009,11 @@ def forward( else: scale = 1.0 sample = upsample_block( - hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, scale=scale + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + scale=scale, ) # 6. post-process diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 3fd9695c2d43..4aee8f066604 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1109,7 +1109,11 @@ def forward( **additional_residuals, ) else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: + scale = cross_attention_kwargs["scale"] + else: + scale = 1.0 + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=scale) if is_adapter and len(down_block_additional_residuals) > 0: sample += down_block_additional_residuals.pop(0) @@ -1172,8 +1176,16 @@ def forward( encoder_attention_mask=encoder_attention_mask, ) else: + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: + scale = cross_attention_kwargs["scale"] + else: + scale = 1.0 sample = upsample_block( - hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + scale=scale, ) # 6. post-process @@ -1354,7 +1366,7 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states, temb=None): + def forward(self, hidden_states, temb=None, scale: float = 1.0): output_states = () for resnet in self.resnets: @@ -1375,13 +1387,13 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale=scale) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) + hidden_states = downsampler(hidden_states, scale=scale) output_states = output_states + (hidden_states,) @@ -1521,7 +1533,10 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - hidden_states = resnet(hidden_states, temb) + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: + hidden_states = resnet(hidden_states, temb, scale=cross_attention_kwargs["scale"]) + else: + hidden_states = resnet(hidden_states, temb, scale=1.0) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -1539,7 +1554,10 @@ def custom_forward(*inputs): if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: + hidden_states = downsampler(hidden_states, scale=cross_attention_kwargs["scale"]) + else: + hidden_states = downsampler(hidden_states, scale=1.0) output_states = output_states + (hidden_states,) @@ -1595,7 +1613,7 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0): for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -1619,11 +1637,11 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale=scale) if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) + hidden_states = upsampler(hidden_states, upsample_size, scale=scale) return hidden_states @@ -1759,7 +1777,10 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - hidden_states = resnet(hidden_states, temb) + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: + hidden_states = resnet(hidden_states, temb, scale=cross_attention_kwargs["scale"]) + else: + hidden_states = resnet(hidden_states, temb, scale=1.0) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -1771,7 +1792,10 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: + hidden_states = upsampler(hidden_states, upsample_size, scale=cross_attention_kwargs["scale"]) + else: + hidden_states = upsampler(hidden_states, upsample_size, scale=1.0) return hidden_states @@ -1876,7 +1900,11 @@ def forward( cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: - hidden_states = self.resnets[0](hidden_states, temb) + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: + scale = cross_attention_kwargs["scale"] + else: + scale = 1.0 + hidden_states = self.resnets[0](hidden_states, temb, scale=scale) for attn, resnet in zip(self.attentions, self.resnets[1:]): if self.training and self.gradient_checkpointing: @@ -1913,7 +1941,10 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] - hidden_states = resnet(hidden_states, temb) + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: + hidden_states = resnet(hidden_states, temb, scale=cross_attention_kwargs["scale"]) + else: + hidden_states = resnet(hidden_states, temb, scale=1.0) return hidden_states @@ -2026,7 +2057,11 @@ def forward( # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask mask = attention_mask - hidden_states = self.resnets[0](hidden_states, temb) + if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: + scale = cross_attention_kwargs["scale"] + else: + scale = 1.0 + hidden_states = self.resnets[0](hidden_states, temb, scale=scale) for attn, resnet in zip(self.attentions, self.resnets[1:]): # attn hidden_states = attn( @@ -2037,6 +2072,9 @@ def forward( ) # resnet - hidden_states = resnet(hidden_states, temb) + if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: + hidden_states = resnet(hidden_states, temb, scale=cross_attention_kwargs["scale"]) + else: + hidden_states = resnet(hidden_states, temb, scale=1.0) return hidden_states From 9dee7d4f3960d7703da1ec2eb16cb2a2934e8848 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 31 Aug 2023 11:55:56 +0530 Subject: [PATCH 62/78] refactor scale retrieval logic a bit. --- src/diffusers/models/attention.py | 31 ++--- src/diffusers/models/transformer_2d.py | 27 ++-- src/diffusers/models/unet_2d_blocks.py | 127 ++++++++---------- src/diffusers/models/unet_2d_condition.py | 16 +-- .../alt_diffusion/pipeline_alt_diffusion.py | 6 +- .../pipeline_alt_diffusion_img2img.py | 6 +- .../pipeline_onnx_stable_diffusion_img2img.py | 6 +- .../versatile_diffusion/modeling_text_unet.py | 72 +++++----- 8 files changed, 132 insertions(+), 159 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index eef70adb0906..2e5319a2f184 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -177,7 +177,7 @@ def forward( class_labels: Optional[torch.LongTensor] = None, ): # Notice that normalization is always applied before the real computation in the following blocks. - # 1. Self-Attention + # 0. Self-Attention if self.use_ada_layer_norm: norm_hidden_states = self.norm1(hidden_states, timestep) elif self.use_ada_layer_norm_zero: @@ -187,7 +187,13 @@ def forward( else: norm_hidden_states = self.norm1(hidden_states) - # 0. Prepare GLIGEN inputs + # 1. Retrieve lora scale. + if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: + lora_scale = cross_attention_kwargs["scale"] + else: + lora_scale = 1.0 + + # 2. Prepare GLIGEN inputs cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} gligen_kwargs = cross_attention_kwargs.pop("gligen", None) @@ -201,12 +207,12 @@ def forward( attn_output = gate_msa.unsqueeze(1) * attn_output hidden_states = attn_output + hidden_states - # 1.5 GLIGEN Control + # 2.5 GLIGEN Control if gligen_kwargs is not None: hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) - # 1.5 ends + # 2.5 ends - # 2. Cross-Attention + # 3. Cross-Attention if self.attn2 is not None: norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) @@ -220,7 +226,7 @@ def forward( ) hidden_states = attn_output + hidden_states - # 3. Feed-forward + # 4. Feed-forward norm_hidden_states = self.norm3(hidden_states) if self.use_ada_layer_norm_zero: @@ -233,25 +239,16 @@ def forward( f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." ) - if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: - scale = cross_attention_kwargs["scale"] - else: - scale = 1.0 - num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size ff_output = torch.cat( [ - self.ff(hid_slice, scale=scale) + self.ff(hid_slice, scale=lora_scale) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim) ], dim=self._chunk_dim, ) else: - if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: - scale = cross_attention_kwargs["scale"] - else: - scale = 1.0 - ff_output = self.ff(norm_hidden_states, scale=scale) + ff_output = self.ff(norm_hidden_states, scale=lora_scale) if self.use_ada_layer_norm_zero: ff_output = gate_mlp.unsqueeze(1) * ff_output diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index de390285f390..677af6075616 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -274,6 +274,12 @@ def forward( encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + # Retrieve lora scale. + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: + lora_scale = cross_attention_kwargs["scale"] + else: + lora_scale = 1.0 + # 1. Input if self.is_input_continuous: batch, _, height, width = hidden_states.shape @@ -281,19 +287,14 @@ def forward( hidden_states = self.norm(hidden_states) if not self.use_linear_projection: - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - hidden_states = self.proj_in(hidden_states, scale=cross_attention_kwargs["scale"]) - else: - hidden_states = self.proj_in(hidden_states, scale=1.0) + hidden_states = self.proj_in(hidden_states, lora_scale) inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) else: inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - hidden_states = self.proj_in(hidden_states, scale=cross_attention_kwargs["scale"]) - else: - hidden_states = self.proj_in(hidden_states, scale=1.0) + hidden_states = self.proj_in(hidden_states, scale=lora_scale) + elif self.is_input_vectorized: hidden_states = self.latent_image_embedding(hidden_states) elif self.is_input_patches: @@ -328,15 +329,9 @@ def forward( if self.is_input_continuous: if not self.use_linear_projection: hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - hidden_states = self.proj_out(hidden_states, scale=cross_attention_kwargs["scale"]) - else: - hidden_states = self.proj_out(hidden_states, scale=1.0) + hidden_states = self.proj_out(hidden_states, scale=lora_scale) else: - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - hidden_states = self.proj_out(hidden_states, scale=cross_attention_kwargs["scale"]) - else: - hidden_states = self.proj_out(hidden_states, scale=1.0) + hidden_states = self.proj_out(hidden_states, scale=lora_scale) hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() output = hidden_states + residual diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 2ced37c9d965..0cbbc154993f 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -641,10 +641,10 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - scale = cross_attention_kwargs["scale"] + lora_scale = cross_attention_kwargs["scale"] else: - scale = 1.0 - hidden_states = self.resnets[0](hidden_states, temb, scale=scale) + lora_scale = 1.0 + hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) for attn, resnet in zip(self.attentions, self.resnets[1:]): if self.training and self.gradient_checkpointing: @@ -681,10 +681,7 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - hidden_states = resnet(hidden_states, temb, scale=cross_attention_kwargs["scale"]) - else: - hidden_states = resnet(hidden_states, temb, scale=1.0) + hidden_states = resnet(hidden_states, temb, scale=lora_scale) return hidden_states @@ -784,6 +781,10 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: + lora_scale = cross_attention_kwargs["scale"] + else: + lora_scale = 1.0 if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. @@ -796,11 +797,7 @@ def forward( # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask mask = attention_mask - if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: - scale = cross_attention_kwargs["scale"] - else: - scale = 1.0 - hidden_states = self.resnets[0](hidden_states, temb, scale=scale) + hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) for attn, resnet in zip(self.attentions, self.resnets[1:]): # attn hidden_states = attn( @@ -811,10 +808,7 @@ def forward( ) # resnet - if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: - hidden_states = resnet(hidden_states, temb, scale=cross_attention_kwargs["scale"]) - else: - hidden_states = resnet(hidden_states, temb, scale=1.0) + hidden_states = resnet(hidden_states, temb, scale=lora_scale) return hidden_states @@ -913,28 +907,26 @@ def __init__( def forward(self, hidden_states, temb=None, upsample_size=None, cross_attention_kwargs=None): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: + lora_scale = cross_attention_kwargs["scale"] + else: + lora_scale = 1.0 + output_states = () for resnet, attn in zip(self.resnets, self.attentions): - if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: - scale = cross_attention_kwargs["scale"] - else: - scale = 1.0 - cross_attention_kwargs.update({"scale": scale}) - hidden_states = resnet(hidden_states, temb, scale=scale) + cross_attention_kwargs.update({"scale": lora_scale}) + hidden_states = resnet(hidden_states, temb, scale=lora_scale) hidden_states = attn(hidden_states, cross_attention_kwargs=cross_attention_kwargs) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: - if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: - scale = cross_attention_kwargs["scale"] - else: - scale = 1.0 if self.downsample_type == "resnet": - hidden_states = downsampler(hidden_states, temb=temb, scale=scale) + hidden_states = downsampler(hidden_states, temb=temb, scale=lora_scale) else: - hidden_states = downsampler(hidden_states, scale=scale) + hidden_states = downsampler(hidden_states, scale=lora_scale) output_states += (hidden_states,) @@ -1043,6 +1035,11 @@ def forward( ): output_states = () + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: + lora_scale = cross_attention_kwargs["lora_scale"] + else: + lora_scale = 1.0 + blocks = list(zip(self.resnets, self.attentions)) for i, (resnet, attn) in enumerate(blocks): @@ -1073,10 +1070,7 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - hidden_states = resnet(hidden_states, temb, scale=cross_attention_kwargs["scale"]) - else: - hidden_states = resnet(hidden_states, temb, scale=1.0) + hidden_states = resnet(hidden_states, temb, scale=lora_scale) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -1094,10 +1088,7 @@ def custom_forward(*inputs): if self.downsamplers is not None: for downsampler in self.downsamplers: - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - hidden_states = downsampler(hidden_states, scale=cross_attention_kwargs["scale"]) - else: - hidden_states = downsampler(hidden_states, scale=1.0) + hidden_states = downsampler(hidden_states, scale=lora_scale) output_states = output_states + (hidden_states,) @@ -1721,6 +1712,11 @@ def forward( output_states = () cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: + lora_scale = cross_attention_kwargs["scale"] + else: + lora_scale = 1.0 + if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. mask = None if encoder_hidden_states is None else encoder_attention_mask @@ -1752,10 +1748,7 @@ def custom_forward(*inputs): **cross_attention_kwargs, ) else: - if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: - hidden_states = resnet(hidden_states, temb, scale=cross_attention_kwargs["scale"]) - else: - hidden_states = resnet(hidden_states, temb, scale=1.0) + hidden_states = resnet(hidden_states, temb, scale=lora_scale) hidden_states = attn( hidden_states, @@ -1768,10 +1761,7 @@ def custom_forward(*inputs): if self.downsamplers is not None: for downsampler in self.downsamplers: - if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: - hidden_states = downsampler(hidden_states, temb, scale=cross_attention_kwargs["scale"]) - else: - hidden_states = downsampler(hidden_states, temb, scale=1.0) + hidden_states = downsampler(hidden_states, temb, scale=lora_scale) output_states = output_states + (hidden_states,) @@ -1931,6 +1921,10 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, ): output_states = () + if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: + lora_scale = cross_attention_kwargs["scale"] + else: + lora_scale = 1.0 for resnet, attn in zip(self.resnets, self.attentions): if self.training and self.gradient_checkpointing: @@ -1960,10 +1954,7 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, ) else: - if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: - hidden_states = resnet(hidden_states, temb, scale=cross_attention_kwargs["scale"]) - else: - hidden_states = resnet(hidden_states, temb, scale=1.0) + hidden_states = resnet(hidden_states, temb, scale=lora_scale) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -2192,6 +2183,11 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: + lora_scale = cross_attention_kwargs["scale"] + else: + lora_scale = 1.0 + for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -2225,10 +2221,7 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - hidden_states = resnet(hidden_states, temb, scale=cross_attention_kwargs["scale"]) - else: - hidden_states = resnet(hidden_states, temb, scale=1.0) + hidden_states = resnet(hidden_states, temb, scale=lora_scale) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -2240,10 +2233,7 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - hidden_states = upsampler(hidden_states, upsample_size, scale=cross_attention_kwargs["scale"]) - else: - hidden_states = upsampler(hidden_states, upsample_size, scale=1.0) + hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale) return hidden_states @@ -2890,6 +2880,11 @@ def forward( ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: + lora_scale = cross_attention_kwargs["scale"] + else: + lora_scale = 1.0 + if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. mask = None if encoder_hidden_states is None else encoder_attention_mask @@ -2927,10 +2922,7 @@ def custom_forward(*inputs): **cross_attention_kwargs, ) else: - if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: - hidden_states = resnet(hidden_states, temb, scale=cross_attention_kwargs["scale"]) - else: - hidden_states = resnet(hidden_states, temb, scale=1.0) + hidden_states = resnet(hidden_states, temb, scale=lora_scale) hidden_states = attn( hidden_states, @@ -2941,10 +2933,7 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: - hidden_states = upsampler(hidden_states, temb, scale=cross_attention_kwargs["scale"]) - else: - hidden_states = upsampler(hidden_states, temb, scale=1.0) + hidden_states = upsampler(hidden_states, temb, scale=lora_scale) return hidden_states @@ -3128,6 +3117,11 @@ def forward( if res_hidden_states_tuple is not None: hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: + lora_scale = cross_attention_kwargs["scale"] + else: + lora_scale = 1.0 + for resnet, attn in zip(self.resnets, self.attentions): if self.training and self.gradient_checkpointing: @@ -3156,10 +3150,7 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, ) else: - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - hidden_states = resnet(hidden_states, temb, scale=cross_attention_kwargs["scale"]) - else: - hidden_states = resnet(hidden_states, temb, scale=1.0) + hidden_states = resnet(hidden_states, temb, scale=lora_scale) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 902745b854ba..5893ab66ad2e 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -915,6 +915,10 @@ def forward( cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} # 3. down + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: + lora_scale = cross_attention_kwargs["scale"] + else: + lora_scale = 1.0 is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None @@ -937,11 +941,7 @@ def forward( **additional_residuals, ) else: - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - scale = cross_attention_kwargs["scale"] - else: - scale = 1.0 - sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=scale) + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) if is_adapter and len(down_block_additional_residuals) > 0: sample += down_block_additional_residuals.pop(0) @@ -1004,16 +1004,12 @@ def forward( encoder_attention_mask=encoder_attention_mask, ) else: - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - scale = cross_attention_kwargs["scale"] - else: - scale = 1.0 sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, - scale=scale, + scale=lora_scale, ) # 6. post-process diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 00e688907889..8cf308588422 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -453,8 +453,10 @@ def run_safety_checker(self, image, device, dtype): def decode_latents(self, latents): warnings.warn( - "The decode_latents method is deprecated and will be removed in a future version. Please" - " use VaeImageProcessor instead", + ( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead" + ), FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 5f6dc07487bd..89c8279fb3c0 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -454,8 +454,10 @@ def run_safety_checker(self, image, device, dtype): def decode_latents(self, latents): warnings.warn( - "The decode_latents method is deprecated and will be removed in a future version. Please" - " use VaeImageProcessor instead", + ( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead" + ), FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py index d418662a4b44..508085094b16 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py @@ -35,8 +35,10 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess with 8->64 def preprocess(image): warnings.warn( - "The preprocess method is deprecated and will be removed in a future version. Please" - " use VaeImageProcessor.preprocess instead", + ( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead" + ), FutureWarning, ) if isinstance(image, torch.Tensor): diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 4aee8f066604..07cd9310624c 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1087,6 +1087,10 @@ def forward( cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} # 3. down + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: + lora_scale = cross_attention_kwargs["scale"] + else: + lora_scale = 1.0 is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None @@ -1109,11 +1113,7 @@ def forward( **additional_residuals, ) else: - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - scale = cross_attention_kwargs["scale"] - else: - scale = 1.0 - sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=scale) + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) if is_adapter and len(down_block_additional_residuals) > 0: sample += down_block_additional_residuals.pop(0) @@ -1176,16 +1176,12 @@ def forward( encoder_attention_mask=encoder_attention_mask, ) else: - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - scale = cross_attention_kwargs["scale"] - else: - scale = 1.0 sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, - scale=scale, + scale=lora_scale, ) # 6. post-process @@ -1503,6 +1499,11 @@ def forward( ): output_states = () + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: + lora_scale = cross_attention_kwargs["lora_scale"] + else: + lora_scale = 1.0 + blocks = list(zip(self.resnets, self.attentions)) for i, (resnet, attn) in enumerate(blocks): @@ -1533,10 +1534,7 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - hidden_states = resnet(hidden_states, temb, scale=cross_attention_kwargs["scale"]) - else: - hidden_states = resnet(hidden_states, temb, scale=1.0) + hidden_states = resnet(hidden_states, temb, scale=lora_scale) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -1554,10 +1552,7 @@ def custom_forward(*inputs): if self.downsamplers is not None: for downsampler in self.downsamplers: - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - hidden_states = downsampler(hidden_states, scale=cross_attention_kwargs["scale"]) - else: - hidden_states = downsampler(hidden_states, scale=1.0) + hidden_states = downsampler(hidden_states, scale=lora_scale) output_states = output_states + (hidden_states,) @@ -1744,6 +1739,11 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: + lora_scale = cross_attention_kwargs["scale"] + else: + lora_scale = 1.0 + for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -1777,10 +1777,7 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - hidden_states = resnet(hidden_states, temb, scale=cross_attention_kwargs["scale"]) - else: - hidden_states = resnet(hidden_states, temb, scale=1.0) + hidden_states = resnet(hidden_states, temb, scale=lora_scale) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -1792,10 +1789,7 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - hidden_states = upsampler(hidden_states, upsample_size, scale=cross_attention_kwargs["scale"]) - else: - hidden_states = upsampler(hidden_states, upsample_size, scale=1.0) + hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale) return hidden_states @@ -1901,10 +1895,10 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - scale = cross_attention_kwargs["scale"] + lora_scale = cross_attention_kwargs["scale"] else: - scale = 1.0 - hidden_states = self.resnets[0](hidden_states, temb, scale=scale) + lora_scale = 1.0 + hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) for attn, resnet in zip(self.attentions, self.resnets[1:]): if self.training and self.gradient_checkpointing: @@ -1941,10 +1935,7 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - hidden_states = resnet(hidden_states, temb, scale=cross_attention_kwargs["scale"]) - else: - hidden_states = resnet(hidden_states, temb, scale=1.0) + hidden_states = resnet(hidden_states, temb, scale=lora_scale) return hidden_states @@ -2045,6 +2036,10 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: + lora_scale = cross_attention_kwargs["scale"] + else: + lora_scale = 1.0 if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. @@ -2057,11 +2052,7 @@ def forward( # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask mask = attention_mask - if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: - scale = cross_attention_kwargs["scale"] - else: - scale = 1.0 - hidden_states = self.resnets[0](hidden_states, temb, scale=scale) + hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) for attn, resnet in zip(self.attentions, self.resnets[1:]): # attn hidden_states = attn( @@ -2072,9 +2063,6 @@ def forward( ) # resnet - if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: - hidden_states = resnet(hidden_states, temb, scale=cross_attention_kwargs["scale"]) - else: - hidden_states = resnet(hidden_states, temb, scale=1.0) + hidden_states = resnet(hidden_states, temb, scale=lora_scale) return hidden_states From f81f77d5e41c199707af0dc40718bcb5b1c176fe Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 31 Aug 2023 12:04:19 +0530 Subject: [PATCH 63/78] fix nonetypw --- src/diffusers/models/attention.py | 2 +- src/diffusers/models/unet_2d_blocks.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 2e5319a2f184..9e5b8ef0cbce 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -188,7 +188,7 @@ def forward( norm_hidden_states = self.norm1(hidden_states) # 1. Retrieve lora scale. - if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: lora_scale = cross_attention_kwargs["scale"] else: lora_scale = 1.0 diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 0cbbc154993f..f445ff6ec585 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -1921,7 +1921,7 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, ): output_states = () - if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: + if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: lora_scale = cross_attention_kwargs["scale"] else: lora_scale = 1.0 From 92e1194fbdf7ce2fe748efb9931255a03b832df9 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 31 Aug 2023 12:37:23 +0530 Subject: [PATCH 64/78] fix: tests --- src/diffusers/models/unet_2d_blocks.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index f445ff6ec585..f6ca0dce7f96 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -918,7 +918,7 @@ def forward(self, hidden_states, temb=None, upsample_size=None, cross_attention_ for resnet, attn in zip(self.resnets, self.attentions): cross_attention_kwargs.update({"scale": lora_scale}) hidden_states = resnet(hidden_states, temb, scale=lora_scale) - hidden_states = attn(hidden_states, cross_attention_kwargs=cross_attention_kwargs) + hidden_states = attn(hidden_states, **cross_attention_kwargs) output_states = output_states + (hidden_states,) if self.downsamplers is not None: @@ -1317,7 +1317,7 @@ def forward(self, hidden_states, scale: float = 1.0): for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb=None, scale=scale) cross_attention_kwargs = {"scale": scale} - hidden_states = attn(hidden_states, cross_attention_kwargs=cross_attention_kwargs) + hidden_states = attn(hidden_states, **cross_attention_kwargs) if self.downsamplers is not None: for downsampler in self.downsamplers: @@ -1413,7 +1413,7 @@ def forward(self, hidden_states, temb=None, skip_sample=None, scale: float = 1.0 for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb, scale=scale) cross_attention_kwargs = {"scale": scale} - hidden_states = attn(hidden_states, cross_attention_kwargs=cross_attention_kwargs) + hidden_states = attn(hidden_states, **cross_attention_kwargs) output_states += (hidden_states,) if self.downsamplers is not None: @@ -2074,7 +2074,7 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_si hidden_states = resnet(hidden_states, temb, scale=scale) cross_attention_kwargs = {"scale": scale} - hidden_states = attn(hidden_states, cross_attention_kwargs=cross_attention_kwargs) + hidden_states = attn(hidden_states, **cross_attention_kwargs) if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -2556,7 +2556,7 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample hidden_states = resnet(hidden_states, temb, scale=scale) cross_attention_kwargs = {"scale": scale} - hidden_states = self.attentions[0](hidden_states, cross_attention_kwargs=cross_attention_kwargs) + hidden_states = self.attentions[0](hidden_states, **cross_attention_kwargs) if skip_sample is not None: skip_sample = self.upsampler(skip_sample) From 4511f48ef4396162b35c98a55f10b24fbed64bce Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 31 Aug 2023 13:22:30 +0530 Subject: [PATCH 65/78] add more tests --- src/diffusers/models/unet_2d_blocks.py | 2 +- .../versatile_diffusion/modeling_text_unet.py | 2 +- tests/models/test_lora_layers.py | 47 +++++++++++++++++++ 3 files changed, 49 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index f6ca0dce7f96..624e1feaa55c 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -1036,7 +1036,7 @@ def forward( output_states = () if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - lora_scale = cross_attention_kwargs["lora_scale"] + lora_scale = cross_attention_kwargs["scale"] else: lora_scale = 1.0 diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 07cd9310624c..8820b4359d84 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1500,7 +1500,7 @@ def forward( output_states = () if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - lora_scale = cross_attention_kwargs["lora_scale"] + lora_scale = cross_attention_kwargs["scale"] else: lora_scale = 1.0 diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 3f9ca734f9e9..d323d3988564 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -915,6 +915,53 @@ def test_fuse_lora_with_different_scales(self): lora_image_slice_scale_one, lora_image_slice_scale_0_5, atol=1e-03 ), "Different LoRA scales should influence the outputs accordingly." + def test_with_different_scales(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) + original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + original_imagee_slice = original_images[0, -3:, -3:, -1] + + # Emulate training. + set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True) + set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True) + set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True) + + with tempfile.TemporaryDirectory() as tmpdirname: + StableDiffusionXLPipeline.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + safe_serialization=True, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + lora_images_scale_one = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + lora_image_slice_scale_one = lora_images_scale_one[0, -3:, -3:, -1] + + lora_images_scale_0_5 = sd_pipe( + **pipeline_inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5} + ).images + lora_image_slice_scale_0_5 = lora_images_scale_0_5[0, -3:, -3:, -1] + + lora_images_scale_0_0 = sd_pipe( + **pipeline_inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0} + ).images + lora_image_slice_scale_0_0 = lora_images_scale_0_0[0, -3:, -3:, -1] + + assert not np.allclose( + lora_image_slice_scale_one, lora_image_slice_scale_0_5, atol=1e-03 + ), "Different LoRA scales should influence the outputs accordingly." + + assert np.allclose( + original_imagee_slice, lora_image_slice_scale_0_0, atol=1e-03 + ), "LoRA scale of 0.0 shouldn't be different from the results without LoRA." + @slow @require_torch_gpu From 6667e6869cc549a916375e2e112e3b6774c48da9 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 31 Aug 2023 13:40:22 +0530 Subject: [PATCH 66/78] more fixes. --- src/diffusers/models/unet_2d_blocks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 624e1feaa55c..0d697f0cbf25 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -2446,7 +2446,7 @@ def forward(self, hidden_states, temb=None, scale: float = 1.0): for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb=temb, scale=scale) cross_attention_kwargs = {"scale": scale} - hidden_states = attn(hidden_states, temb=temb, cross_attention_kwargs=cross_attention_kwargs) + hidden_states = attn(hidden_states, temb=temb, **cross_attention_kwargs) if self.upsamplers is not None: for upsampler in self.upsamplers: From b941b88daa954fe7f3a7497193f48f8c0f7479dd Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 31 Aug 2023 14:35:52 +0530 Subject: [PATCH 67/78] figure out a way to pass lora_scale. --- src/diffusers/loaders.py | 6 +++-- .../pipeline_stable_diffusion_xl.py | 27 ++++++++++++++++++- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index c0e5b1480b4b..2ba06f59e70e 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -135,10 +135,12 @@ def _unfuse_lora(self): self.w_up = None self.w_down = None - def forward(self, input, lora_scale: float = 1.0): + def forward(self, input): + if self.lora_scale is None: + self.lora_scale = 1.0 if self.lora_linear_layer is None: return self.regular_linear_layer(input) - return self.regular_linear_layer(input) + lora_scale * self.lora_linear_layer(input) + return self.regular_linear_layer(input) + self.lora_scale * self.lora_linear_layer(input) def text_encoder_attn_modules(text_encoder): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 11e575d68269..310695c03f61 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -20,7 +20,14 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from ...image_processor import VaeImageProcessor -from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import ( + FromSingleFileMixin, + LoraLoaderMixin, + PatchedLoraProjection, + TextualInversionLoaderMixin, + text_encoder_attn_modules, + text_encoder_mlp_modules, +) from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, @@ -559,6 +566,21 @@ def upcast_vae(self): self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) + def _adjust_lora_scale_text_encoder( + self, text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection], lora_scale + ): + for _, attn_module in text_encoder_attn_modules(text_encoder): + if isinstance(attn_module.q_proj, PatchedLoraProjection): + attn_module.q_proj.lora_scale = lora_scale + attn_module.k_proj.lora_scale = lora_scale + attn_module.v_proj.lora_scale = lora_scale + attn_module.out_proj.lora_scale = lora_scale + + for _, mlp_module in text_encoder_mlp_modules(text_encoder): + if isinstance(mlp_module.fc1, PatchedLoraProjection): + mlp_module.fc1.lora_scale = lora_scale + mlp_module.fc2.lora_scale = lora_scale + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -755,6 +777,9 @@ def __call__( text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) + # 3.1 Dynamically adjust the LoRA scale + self._adjust_lora_scale_text_encoder(self.text_encoder, text_encoder_lora_scale) + self._adjust_lora_scale_text_encoder(self.text_encoder_2, text_encoder_lora_scale) ( prompt_embeds, negative_prompt_embeds, From 9705cc28485d4898c51d1971d3cac12d2de5d651 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 4 Sep 2023 09:13:23 +0530 Subject: [PATCH 68/78] Apply suggestions from code review Co-authored-by: Patrick von Platen --- src/diffusers/loaders.py | 3 +- src/diffusers/models/attention.py | 5 +- src/diffusers/models/transformer_2d.py | 5 +- src/diffusers/models/unet_2d_blocks.py | 47 ++++--------------- src/diffusers/models/unet_2d_condition.py | 5 +- .../versatile_diffusion/modeling_text_unet.py | 21 ++------- 6 files changed, 17 insertions(+), 69 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 2ba06f59e70e..44fb67a57986 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1833,9 +1833,8 @@ def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora if fuse_unet or fuse_text_encoder: self.num_fused_loras += 1 if self.num_fused_loras > 1: - warnings.warn( + logger.warn( "The current API is supported for operating with a single LoRA file. You are trying to load and fuse more than one LoRA which is not well-supported.", - RuntimeWarning, ) if fuse_unet: diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 9e5b8ef0cbce..f76650406a65 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -188,10 +188,7 @@ def forward( norm_hidden_states = self.norm1(hidden_states) # 1. Retrieve lora scale. - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - lora_scale = cross_attention_kwargs["scale"] - else: - lora_scale = 1.0 + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 # 2. Prepare GLIGEN inputs cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index 677af6075616..4819d3be48e1 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -275,10 +275,7 @@ def forward( encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # Retrieve lora scale. - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - lora_scale = cross_attention_kwargs["scale"] - else: - lora_scale = 1.0 + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 # 1. Input if self.is_input_continuous: diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 0d697f0cbf25..1697765626bd 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -640,10 +640,7 @@ def forward( cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - lora_scale = cross_attention_kwargs["scale"] - else: - lora_scale = 1.0 + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) for attn, resnet in zip(self.attentions, self.resnets[1:]): if self.training and self.gradient_checkpointing: @@ -781,10 +778,7 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: - lora_scale = cross_attention_kwargs["scale"] - else: - lora_scale = 1.0 + lora_scale = cross_attention_kwargs.get("scale", 1.0) if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. @@ -908,10 +902,7 @@ def __init__( def forward(self, hidden_states, temb=None, upsample_size=None, cross_attention_kwargs=None): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: - lora_scale = cross_attention_kwargs["scale"] - else: - lora_scale = 1.0 + lora_scale = cross_attention_kwargs.get("scale", 1.0) output_states = () @@ -1035,10 +1026,7 @@ def forward( ): output_states = () - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - lora_scale = cross_attention_kwargs["scale"] - else: - lora_scale = 1.0 + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 blocks = list(zip(self.resnets, self.attentions)) @@ -1712,10 +1700,7 @@ def forward( output_states = () cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: - lora_scale = cross_attention_kwargs["scale"] - else: - lora_scale = 1.0 + lora_scale = cross_attention_kwargs.get("scale", 1.0) if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. @@ -1921,10 +1906,7 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, ): output_states = () - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - lora_scale = cross_attention_kwargs["scale"] - else: - lora_scale = 1.0 + lora_scale = cross_attention_kwargs.get("scale", 1.0) for resnet, attn in zip(self.resnets, self.attentions): if self.training and self.gradient_checkpointing: @@ -2183,10 +2165,7 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - lora_scale = cross_attention_kwargs["scale"] - else: - lora_scale = 1.0 + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states @@ -2880,11 +2859,7 @@ def forward( ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: - lora_scale = cross_attention_kwargs["scale"] - else: - lora_scale = 1.0 - + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. mask = None if encoder_hidden_states is None else encoder_attention_mask @@ -3117,11 +3092,7 @@ def forward( if res_hidden_states_tuple is not None: hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - lora_scale = cross_attention_kwargs["scale"] - else: - lora_scale = 1.0 - + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 for resnet, attn in zip(self.resnets, self.attentions): if self.training and self.gradient_checkpointing: diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 5893ab66ad2e..b7e02ad84336 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -915,10 +915,7 @@ def forward( cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} # 3. down - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - lora_scale = cross_attention_kwargs["scale"] - else: - lora_scale = 1.0 + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 8820b4359d84..970fb2577d1b 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1087,10 +1087,7 @@ def forward( cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} # 3. down - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - lora_scale = cross_attention_kwargs["scale"] - else: - lora_scale = 1.0 + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None @@ -1499,11 +1496,7 @@ def forward( ): output_states = () - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - lora_scale = cross_attention_kwargs["scale"] - else: - lora_scale = 1.0 - + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 blocks = list(zip(self.resnets, self.attentions)) for i, (resnet, attn) in enumerate(blocks): @@ -1739,10 +1732,7 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - lora_scale = cross_attention_kwargs["scale"] - else: - lora_scale = 1.0 + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states @@ -1894,10 +1884,7 @@ def forward( cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: - if cross_attention_kwargs is not None and "scale" in cross_attention_kwargs: - lora_scale = cross_attention_kwargs["scale"] - else: - lora_scale = 1.0 + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) for attn, resnet in zip(self.attentions, self.resnets[1:]): if self.training and self.gradient_checkpointing: From bebab12926b94691c9a66cc5f190afb2c76600cf Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 4 Sep 2023 09:17:54 +0530 Subject: [PATCH 69/78] unify the retrieval logic of lora_scale. --- src/diffusers/models/unet_2d_blocks.py | 6 +++--- .../pipelines/versatile_diffusion/modeling_text_unet.py | 6 ++---- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 1697765626bd..3751806c180d 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -1906,7 +1906,7 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, ): output_states = () - lora_scale = cross_attention_kwargs.get("scale", 1.0) + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 for resnet, attn in zip(self.resnets, self.attentions): if self.training and self.gradient_checkpointing: @@ -2165,7 +2165,7 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states @@ -2859,7 +2859,7 @@ def forward( ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + lora_scale = cross_attention_kwargs.get("scale", 1.0) if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. mask = None if encoder_hidden_states is None else encoder_attention_mask diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 970fb2577d1b..4825195d8d17 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1497,6 +1497,7 @@ def forward( output_states = () lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + blocks = list(zip(self.resnets, self.attentions)) for i, (resnet, attn) in enumerate(blocks): @@ -2023,10 +2024,7 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - if len(cross_attention_kwargs) >= 1 and "scale" in cross_attention_kwargs: - lora_scale = cross_attention_kwargs["scale"] - else: - lora_scale = 1.0 + lora_scale = cross_attention_kwargs.get("scale", 1.0) if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. From 81f7ddf978827edb7becdad0fcfb3204c96d2a3f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 4 Sep 2023 09:31:06 +0530 Subject: [PATCH 70/78] move adjust_lora_scale_text_encoder to lora.py. --- src/diffusers/models/lora.py | 15 +++++++++++ .../controlnet/pipeline_controlnet_sd_xl.py | 5 ++++ .../pipeline_controlnet_sd_xl_img2img.py | 5 ++++ .../pipeline_stable_diffusion_xl.py | 26 ++++--------------- .../pipeline_stable_diffusion_xl_img2img.py | 5 ++++ .../pipeline_stable_diffusion_xl_inpaint.py | 5 ++++ .../pipeline_stable_diffusion_xl_adapter.py | 5 ++++ 7 files changed, 45 insertions(+), 21 deletions(-) diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 1557ebcd1b3d..fbb69b787a84 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -18,12 +18,27 @@ import torch.nn.functional as F from torch import nn +from ..loaders import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules from ..utils import logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name +def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0): + for _, attn_module in text_encoder_attn_modules(text_encoder): + if isinstance(attn_module.q_proj, PatchedLoraProjection): + attn_module.q_proj.lora_scale = lora_scale + attn_module.k_proj.lora_scale = lora_scale + attn_module.v_proj.lora_scale = lora_scale + attn_module.out_proj.lora_scale = lora_scale + + for _, mlp_module in text_encoder_mlp_modules(text_encoder): + if isinstance(mlp_module.fc1, PatchedLoraProjection): + mlp_module.fc1.lora_scale = lora_scale + mlp_module.fc2.lora_scale = lora_scale + + class LoRALinearLayer(nn.Module): def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None): super().__init__() diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index c99f007f9c82..388a70d911f1 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -34,6 +34,7 @@ LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( is_accelerate_available, @@ -321,6 +322,10 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 6e66bdcc4472..a7d9328b3e88 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -33,6 +33,7 @@ LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( is_accelerate_available, @@ -352,6 +353,10 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 310695c03f61..a599f5040ca9 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -23,10 +23,7 @@ from ...loaders import ( FromSingleFileMixin, LoraLoaderMixin, - PatchedLoraProjection, TextualInversionLoaderMixin, - text_encoder_attn_modules, - text_encoder_mlp_modules, ) from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import ( @@ -35,6 +32,7 @@ LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( is_accelerate_available, @@ -291,6 +289,10 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -566,21 +568,6 @@ def upcast_vae(self): self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) - def _adjust_lora_scale_text_encoder( - self, text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection], lora_scale - ): - for _, attn_module in text_encoder_attn_modules(text_encoder): - if isinstance(attn_module.q_proj, PatchedLoraProjection): - attn_module.q_proj.lora_scale = lora_scale - attn_module.k_proj.lora_scale = lora_scale - attn_module.v_proj.lora_scale = lora_scale - attn_module.out_proj.lora_scale = lora_scale - - for _, mlp_module in text_encoder_mlp_modules(text_encoder): - if isinstance(mlp_module.fc1, PatchedLoraProjection): - mlp_module.fc1.lora_scale = lora_scale - mlp_module.fc2.lora_scale = lora_scale - @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -777,9 +764,6 @@ def __call__( text_encoder_lora_scale = ( cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) - # 3.1 Dynamically adjust the LoRA scale - self._adjust_lora_scale_text_encoder(self.text_encoder, text_encoder_lora_scale) - self._adjust_lora_scale_text_encoder(self.text_encoder_2, text_encoder_lora_scale) ( prompt_embeds, negative_prompt_embeds, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index ff51f8765e4a..be558b70e382 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -29,6 +29,7 @@ LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( is_accelerate_available, @@ -294,6 +295,10 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index eecbdc7e669e..47119df63cc8 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -30,6 +30,7 @@ LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( deprecate, @@ -444,6 +445,10 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index 6311c02be475..07947b7c8e88 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -31,6 +31,7 @@ LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( PIL_INTERPOLATION, @@ -312,6 +313,10 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): From e2c835c4737afd191df265928e4376b9878a0d1f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 4 Sep 2023 09:43:51 +0530 Subject: [PATCH 71/78] introduce dynamic adjustment lora scale support to sd --- .../pipelines/alt_diffusion/pipeline_alt_diffusion.py | 4 ++++ .../alt_diffusion/pipeline_alt_diffusion_img2img.py | 4 ++++ src/diffusers/pipelines/controlnet/pipeline_controlnet.py | 4 ++++ .../pipelines/controlnet/pipeline_controlnet_img2img.py | 5 ++++- .../pipelines/controlnet/pipeline_controlnet_inpaint.py | 4 ++++ .../pipelines/stable_diffusion/pipeline_cycle_diffusion.py | 4 ++++ .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 4 ++++ .../pipeline_stable_diffusion_attend_and_excite.py | 4 ++++ .../stable_diffusion/pipeline_stable_diffusion_depth2img.py | 4 ++++ .../stable_diffusion/pipeline_stable_diffusion_diffedit.py | 4 ++++ .../stable_diffusion/pipeline_stable_diffusion_gligen.py | 4 ++++ .../stable_diffusion/pipeline_stable_diffusion_img2img.py | 4 ++++ .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 4 ++++ .../pipeline_stable_diffusion_inpaint_legacy.py | 4 ++++ .../pipeline_stable_diffusion_k_diffusion.py | 4 ++++ .../stable_diffusion/pipeline_stable_diffusion_ldm3d.py | 4 ++++ .../pipeline_stable_diffusion_model_editing.py | 4 ++++ .../stable_diffusion/pipeline_stable_diffusion_panorama.py | 4 ++++ .../stable_diffusion/pipeline_stable_diffusion_paradigms.py | 4 ++++ .../pipeline_stable_diffusion_pix2pix_zero.py | 4 ++++ .../stable_diffusion/pipeline_stable_diffusion_sag.py | 4 ++++ .../stable_diffusion/pipeline_stable_diffusion_upscale.py | 4 ++++ .../pipelines/stable_diffusion/pipeline_stable_unclip.py | 4 ++++ .../stable_diffusion/pipeline_stable_unclip_img2img.py | 4 ++++ .../t2i_adapter/pipeline_stable_diffusion_adapter.py | 4 ++++ .../text_to_video_synthesis/pipeline_text_to_video_synth.py | 4 ++++ .../pipeline_text_to_video_synth_img2img.py | 4 ++++ 27 files changed, 108 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 8cf308588422..5d362b0f218a 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -26,6 +26,7 @@ from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline @@ -323,6 +324,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 89c8279fb3c0..6855dceb63ea 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -28,6 +28,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline @@ -324,6 +325,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 84e89b5049cd..77cb83de942c 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -26,6 +26,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( deprecate, @@ -314,6 +315,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index 7c2db0d8c70f..7387f42f00a9 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import inspect import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -26,6 +25,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( deprecate, @@ -339,6 +339,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 271f996e8286..e6cede499e2a 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -27,6 +27,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( deprecate, @@ -465,6 +466,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py index ceeaeb64949e..9123b5e28146 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py @@ -28,6 +28,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import DDIMScheduler from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -333,6 +334,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index fd655822376f..1cc55860e1ff 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -24,6 +24,7 @@ from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( deprecate, @@ -322,6 +323,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py index 0dd91327251b..43c3fb167ca3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py @@ -26,6 +26,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import Attention +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline @@ -322,6 +323,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index f4d352b8956e..f620d607c4a3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -27,6 +27,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -207,6 +208,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py index c26c96710041..44e67120ac60 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -27,6 +27,7 @@ from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import DDIMInverseScheduler, KarrasDiffusionSchedulers from ...utils import ( PIL_INTERPOLATION, @@ -508,6 +509,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py index 8293d4c59d57..78d0e852a632 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py @@ -24,6 +24,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention import GatedSelfAttentionDense +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( deprecate, @@ -298,6 +299,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 69d040ad5957..f36d892f3b70 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -26,6 +26,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( PIL_INTERPOLATION, @@ -327,6 +328,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index cef711461137..bf8965d3e9c8 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -25,6 +25,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AsymmetricAutoencoderKL, AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -393,6 +394,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index b0688d50a045..6c4c2e95fda3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -26,6 +26,7 @@ from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( PIL_INTERPOLATION, @@ -323,6 +324,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py index 650fb647defd..312e5ac15899 100755 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -23,6 +23,7 @@ from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import LMSDiscreteScheduler from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -230,6 +231,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py index 326830f0f9c6..13ccb226b0d7 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py @@ -24,6 +24,7 @@ from ...image_processor import VaeImageProcessorLDM3D from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( BaseOutput, @@ -293,6 +294,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py index 1e9a6a9dfa14..e07baa62f4cb 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py @@ -22,6 +22,7 @@ from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import PNDMScheduler from ...schedulers.scheduling_utils import SchedulerMixin from ...utils import deprecate, logging, randn_tensor @@ -234,6 +235,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py index 84e3e209ae1b..ec629983f6e9 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py @@ -22,6 +22,7 @@ from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import DDIMScheduler from ...utils import deprecate, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline @@ -211,6 +212,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py index cbbc6c197788..7ce3dfc35908 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py @@ -21,6 +21,7 @@ from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( deprecate, @@ -277,6 +278,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index b7f1eb19a7e3..a1744419c8d4 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -33,6 +33,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import Attention +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler from ...schedulers.scheduling_ddim_inverse import DDIMInverseScheduler from ...utils import ( @@ -466,6 +467,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py index af828991307e..bceff47a501b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py @@ -23,6 +23,7 @@ from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline @@ -234,6 +235,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index a0cadaaa52d1..5fad253c890b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -30,6 +30,7 @@ LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -255,6 +256,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py index c7b916c3878f..5a841ebfbc60 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -24,6 +24,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel from ...models.embeddings import get_timestep_embedding +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( deprecate, @@ -365,6 +366,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py index 99a279d80955..6ab3ba9ffdf2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -26,6 +26,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.embeddings import get_timestep_embedding +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_version, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -314,6 +315,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py index ccb9215b8e30..e1e0a747ed6e 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py @@ -25,6 +25,7 @@ from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, MultiAdapter, T2IAdapter, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( PIL_INTERPOLATION, @@ -319,6 +320,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index 5315329e692e..72063769c868 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -21,6 +21,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet3DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( deprecate, @@ -246,6 +247,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py index 5122f114f342..cb0c24c474a4 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py @@ -22,6 +22,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet3DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( deprecate, @@ -308,6 +309,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): From f2026acbc1dd6bf5e1290d00abb9dc97fb8a84c4 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 4 Sep 2023 09:54:00 +0530 Subject: [PATCH 72/78] fix up copies --- .../controlnet/pipeline_controlnet_inpaint_sd_xl.py | 5 +++++ .../pipeline_stable_diffusion_gligen_text_image.py | 4 ++++ 2 files changed, 9 insertions(+) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index 030ebf51f143..5a8e68e1aab7 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -33,6 +33,7 @@ LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( is_accelerate_available, @@ -342,6 +343,10 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py index 74061245c73e..0940b830065c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py @@ -30,6 +30,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention import GatedSelfAttentionDense +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( is_accelerate_available, @@ -331,6 +332,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): From 744489664b6afab3ae4f1ca1701d2ea5aea90b7e Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 4 Sep 2023 10:42:17 +0530 Subject: [PATCH 73/78] Empty-Commit From e60f45094112b55fd261afb373e2845fd5877788 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 4 Sep 2023 14:43:46 +0530 Subject: [PATCH 74/78] add: test to check fusion equivalence on different scales. --- src/diffusers/loaders.py | 6 ++++- src/diffusers/models/lora.py | 20 ++++++++++++---- tests/models/test_lora_layers.py | 40 ++++++++++++++++++++++++++++++++ 3 files changed, 60 insertions(+), 6 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 44fb67a57986..a103fabb876b 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -99,6 +99,8 @@ def _fuse_lora(self, lora_scale=1.0): if self.lora_linear_layer is None: return + print(f"From _fuse_lora of {self.__class__.__name__} {lora_scale}") + logger.info(f"Fusing LoRA weights for {self.__class__}") dtype, device = self.regular_linear_layer.weight.data.dtype, self.regular_linear_layer.weight.data.device w_orig = self.regular_linear_layer.weight.data.float() @@ -123,6 +125,7 @@ def _unfuse_lora(self): if not (hasattr(self, "w_up") and hasattr(self, "w_down")): return + logger.info(f"Unfusing LoRA weights for {self.__class__}") fused_weight = self.regular_linear_layer.weight.data dtype, device = fused_weight.dtype, fused_weight.device @@ -136,11 +139,12 @@ def _unfuse_lora(self): self.w_down = None def forward(self, input): + print(f"{self.__class__.__name__} has a lora_scale of {self.lora_scale}") if self.lora_scale is None: self.lora_scale = 1.0 if self.lora_linear_layer is None: return self.regular_linear_layer(input) - return self.regular_linear_layer(input) + self.lora_scale * self.lora_linear_layer(input) + return self.regular_linear_layer(input) + (self.lora_scale * self.lora_linear_layer(input)) def text_encoder_attn_modules(text_encoder): diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index fbb69b787a84..6e570dcc55cd 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -116,6 +116,7 @@ def _fuse_lora(self, lora_scale=1.0): if self.lora_layer is None: return + print(f"From _fuse_lora of {self.__class__.__name__} {lora_scale}") dtype, device = self.weight.data.dtype, self.weight.data.device logger.info(f"Fusing LoRA weights for {self.__class__}") @@ -147,12 +148,12 @@ def _unfuse_lora(self): fused_weight = self.weight.data dtype, device = fused_weight.data.dtype, fused_weight.data.device - self.w_up = self.w_up.to(device=device, dtype=dtype) - self.w_down = self.w_down.to(device, dtype=dtype) + self.w_up = self.w_up.to(device=device).float() + self.w_down = self.w_down.to(device).float() fusion = torch.mm(self.w_up.flatten(start_dim=1), self.w_down.flatten(start_dim=1)) fusion = fusion.reshape((fused_weight.shape)) - unfused_weight = fused_weight - (self._lora_scale * fusion) + unfused_weight = fused_weight.float() - (self._lora_scale * fusion) self.weight.data = unfused_weight.to(device=device, dtype=dtype) self.w_up = None @@ -160,13 +161,16 @@ def _unfuse_lora(self): def forward(self, hidden_states, scale: float = 1.0): if self.lora_layer is None: + if hasattr(self, "_lora_scale"): + print(f"{self.__class__.__name__} has a lora_scale of {self._lora_scale}") # make sure to the functional Conv2D function as otherwise torch.compile's graph will break # see: https://github.com/huggingface/diffusers/pull/4315 return F.conv2d( hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups ) else: - return super().forward(hidden_states) + scale * self.lora_layer(hidden_states) + print(f"{self.__class__.__name__} has a scale of {scale}") + return super().forward(hidden_states) + (scale * self.lora_layer(hidden_states)) class LoRACompatibleLinear(nn.Linear): @@ -185,6 +189,8 @@ def _fuse_lora(self, lora_scale=1.0): if self.lora_layer is None: return + print(f"From _fuse_lora of {self.__class__.__name__} {lora_scale}") + logger.info(f"Fusing LoRA weights for {self.__class__}") dtype, device = self.weight.data.dtype, self.weight.data.device w_orig = self.weight.data.float() @@ -208,6 +214,7 @@ def _fuse_lora(self, lora_scale=1.0): def _unfuse_lora(self): if not (hasattr(self, "w_up") and hasattr(self, "w_down")): return + logger.info(f"Unfusing LoRA weights for {self.__class__}") fused_weight = self.weight.data dtype, device = fused_weight.dtype, fused_weight.device @@ -223,6 +230,9 @@ def _unfuse_lora(self): def forward(self, hidden_states, scale: float = 1.0): if self.lora_layer is None: + if hasattr(self, "_lora_scale"): + print(f"{self.__class__.__name__} has a lora_scale of {self._lora_scale}") return super().forward(hidden_states) else: - return super().forward(hidden_states) + scale * self.lora_layer(hidden_states) + print(f"{self.__class__.__name__} has a scale of {scale}") + return super().forward(hidden_states) + (scale * self.lora_layer(hidden_states)) diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index d323d3988564..5f87c8e52411 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -962,6 +962,46 @@ def test_with_different_scales(self): original_imagee_slice, lora_image_slice_scale_0_0, atol=1e-03 ), "LoRA scale of 0.0 shouldn't be different from the results without LoRA." + def test_with_different_scales_fusion_equivalence(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) + + # Emulate training. + set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True) + set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True) + set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True) + + with tempfile.TemporaryDirectory() as tmpdirname: + StableDiffusionXLPipeline.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + safe_serialization=True, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + print("load_lora_weights().") + lora_images_scale_0_5 = sd_pipe( + **pipeline_inputs, + generator=torch.manual_seed(0), # cross_attention_kwargs={"scale": 0.5} + ).images + lora_image_slice_scale_0_5 = lora_images_scale_0_5[0, -3:, -3:, -1] + + print("LoRA fusion.") + sd_pipe.fuse_lora() + lora_images_scale_0_5_fusion = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + lora_image_slice_scale_0_5_fusion = lora_images_scale_0_5_fusion[0, -3:, -3:, -1] + + assert np.allclose( + lora_image_slice_scale_0_5, lora_image_slice_scale_0_5_fusion, atol=1e-03 + ), "Fusion shouldn't affect the results when calling the pipeline with a non-default LoRA scale." + @slow @require_torch_gpu From bf1052b7f93d05c874ac65ea0e9efb3d15b90227 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 4 Sep 2023 17:42:13 +0530 Subject: [PATCH 75/78] handle lora fusion warning. --- src/diffusers/loaders.py | 2 ++ src/diffusers/models/lora.py | 14 ++++++++++++-- tests/models/test_lora_layers.py | 8 ++++---- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index a103fabb876b..278c51b1c9a9 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1901,6 +1901,8 @@ def unfuse_text_encoder_lora(text_encoder): unfuse_text_encoder_lora(self.text_encoder) if hasattr(self, "text_encoder_2"): unfuse_text_encoder_lora(self.text_encoder_2) + + self.num_fused_loras -= 1 class FromSingleFileMixin: diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 6e570dcc55cd..cfd482e9b1cf 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -232,7 +232,17 @@ def forward(self, hidden_states, scale: float = 1.0): if self.lora_layer is None: if hasattr(self, "_lora_scale"): print(f"{self.__class__.__name__} has a lora_scale of {self._lora_scale}") - return super().forward(hidden_states) + out = super().forward(hidden_states) + # if out.ndim == 2: + # print(out[0, :3]) + # else: + # print(out[0, :3, :3]) + return out else: print(f"{self.__class__.__name__} has a scale of {scale}") - return super().forward(hidden_states) + (scale * self.lora_layer(hidden_states)) + out = super().forward(hidden_states) + (scale * self.lora_layer(hidden_states)) + # if out.ndim == 2: + # print(out[0, :3]) + # else: + # print(out[0, :3, :3]) + return out diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 5f87c8e52411..1a8c42f84a2f 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -979,21 +979,21 @@ def test_with_different_scales_fusion_equivalence(self): StableDiffusionXLPipeline.save_lora_weights( save_directory=tmpdirname, unet_lora_layers=lora_components["unet_lora_layers"], - text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], - text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + # text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], + # text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], safe_serialization=True, ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - print("load_lora_weights().") + print("Before fusion.") lora_images_scale_0_5 = sd_pipe( **pipeline_inputs, generator=torch.manual_seed(0), # cross_attention_kwargs={"scale": 0.5} ).images lora_image_slice_scale_0_5 = lora_images_scale_0_5[0, -3:, -3:, -1] - print("LoRA fusion.") + print("After fusion.") sd_pipe.fuse_lora() lora_images_scale_0_5_fusion = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images lora_image_slice_scale_0_5_fusion = lora_images_scale_0_5_fusion[0, -3:, -3:, -1] From 47333846d53bf5b7c61a2163dedf888322ed0a9b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 4 Sep 2023 21:08:46 +0000 Subject: [PATCH 76/78] make lora smaller --- src/diffusers/loaders.py | 6 ++--- src/diffusers/models/lora.py | 17 -------------- tests/models/test_lora_layers.py | 38 +++++++++++++++++++++++--------- 3 files changed, 30 insertions(+), 31 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 278c51b1c9a9..e693596a1b8a 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -99,7 +99,7 @@ def _fuse_lora(self, lora_scale=1.0): if self.lora_linear_layer is None: return - print(f"From _fuse_lora of {self.__class__.__name__} {lora_scale}") + # print(f"From _fuse_lora of {self.__class__.__name__} {lora_scale}") logger.info(f"Fusing LoRA weights for {self.__class__}") dtype, device = self.regular_linear_layer.weight.data.dtype, self.regular_linear_layer.weight.data.device @@ -139,7 +139,7 @@ def _unfuse_lora(self): self.w_down = None def forward(self, input): - print(f"{self.__class__.__name__} has a lora_scale of {self.lora_scale}") + # print(f"{self.__class__.__name__} has a lora_scale of {self.lora_scale}") if self.lora_scale is None: self.lora_scale = 1.0 if self.lora_linear_layer is None: @@ -1901,7 +1901,7 @@ def unfuse_text_encoder_lora(text_encoder): unfuse_text_encoder_lora(self.text_encoder) if hasattr(self, "text_encoder_2"): unfuse_text_encoder_lora(self.text_encoder_2) - + self.num_fused_loras -= 1 diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index cfd482e9b1cf..5173029a2e56 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -116,9 +116,7 @@ def _fuse_lora(self, lora_scale=1.0): if self.lora_layer is None: return - print(f"From _fuse_lora of {self.__class__.__name__} {lora_scale}") dtype, device = self.weight.data.dtype, self.weight.data.device - logger.info(f"Fusing LoRA weights for {self.__class__}") w_orig = self.weight.data.float() w_up = self.lora_layer.up.weight.data.float() @@ -143,7 +141,6 @@ def _fuse_lora(self, lora_scale=1.0): def _unfuse_lora(self): if not (hasattr(self, "w_up") and hasattr(self, "w_down")): return - logger.info(f"Unfusing LoRA weights for {self.__class__}") fused_weight = self.weight.data dtype, device = fused_weight.data.dtype, fused_weight.data.device @@ -189,8 +186,6 @@ def _fuse_lora(self, lora_scale=1.0): if self.lora_layer is None: return - print(f"From _fuse_lora of {self.__class__.__name__} {lora_scale}") - logger.info(f"Fusing LoRA weights for {self.__class__}") dtype, device = self.weight.data.dtype, self.weight.data.device w_orig = self.weight.data.float() @@ -214,7 +209,6 @@ def _fuse_lora(self, lora_scale=1.0): def _unfuse_lora(self): if not (hasattr(self, "w_up") and hasattr(self, "w_down")): return - logger.info(f"Unfusing LoRA weights for {self.__class__}") fused_weight = self.weight.data dtype, device = fused_weight.dtype, fused_weight.device @@ -230,19 +224,8 @@ def _unfuse_lora(self): def forward(self, hidden_states, scale: float = 1.0): if self.lora_layer is None: - if hasattr(self, "_lora_scale"): - print(f"{self.__class__.__name__} has a lora_scale of {self._lora_scale}") out = super().forward(hidden_states) - # if out.ndim == 2: - # print(out[0, :3]) - # else: - # print(out[0, :3, :3]) return out else: - print(f"{self.__class__.__name__} has a scale of {scale}") out = super().forward(hidden_states) + (scale * self.lora_layer(hidden_states)) - # if out.ndim == 2: - # print(out[0, :3]) - # else: - # print(out[0, :3, :3]) return out diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 1a8c42f84a2f..253427a87b1f 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -92,11 +92,11 @@ def create_text_encoder_lora_layers(text_encoder: nn.Module): return text_encoder_lora_layers -def set_lora_weights(lora_attn_parameters, randn_weight=False): +def set_lora_weights(lora_attn_parameters, randn_weight=False, var=1.0): with torch.no_grad(): for parameter in lora_attn_parameters: if randn_weight: - parameter[:] = torch.randn_like(parameter) + parameter[:] = torch.randn_like(parameter) * var else: torch.zero_(parameter) @@ -966,35 +966,41 @@ def test_with_different_scales_fusion_equivalence(self): pipeline_components, lora_components = self.get_dummy_components() sd_pipe = StableDiffusionXLPipeline(**pipeline_components) sd_pipe = sd_pipe.to(torch_device) + # sd_pipe.unet.set_default_attn_processor() sd_pipe.set_progress_bar_config(disable=None) _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) + images = sd_pipe( + **pipeline_inputs, + generator=torch.manual_seed(0), + ).images + images_slice = images[0, -3:, -3:, -1] + # Emulate training. - set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True) - set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True) - set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True) + set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True, var=0.1) + set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True, var=0.1) + set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True, var=0.1) with tempfile.TemporaryDirectory() as tmpdirname: StableDiffusionXLPipeline.save_lora_weights( save_directory=tmpdirname, unet_lora_layers=lora_components["unet_lora_layers"], - # text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], - # text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], safe_serialization=True, ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - print("Before fusion.") lora_images_scale_0_5 = sd_pipe( **pipeline_inputs, - generator=torch.manual_seed(0), # cross_attention_kwargs={"scale": 0.5} + generator=torch.manual_seed(0), + cross_attention_kwargs={"scale": 0.5}, ).images lora_image_slice_scale_0_5 = lora_images_scale_0_5[0, -3:, -3:, -1] - print("After fusion.") - sd_pipe.fuse_lora() + sd_pipe.fuse_lora(lora_scale=0.5) lora_images_scale_0_5_fusion = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images lora_image_slice_scale_0_5_fusion = lora_images_scale_0_5_fusion[0, -3:, -3:, -1] @@ -1002,6 +1008,16 @@ def test_with_different_scales_fusion_equivalence(self): lora_image_slice_scale_0_5, lora_image_slice_scale_0_5_fusion, atol=1e-03 ), "Fusion shouldn't affect the results when calling the pipeline with a non-default LoRA scale." + sd_pipe.unfuse_lora() + images_unfused = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + images_slice_unfused = images_unfused[0, -3:, -3:, -1] + + assert np.allclose(images_slice, images_slice_unfused, atol=1e-03), "Unfused should match no LoRA" + + assert not np.allclose( + images_slice, lora_image_slice_scale_0_5, atol=1e-03 + ), "0.5 scale and no scale shouldn't match" + @slow @require_torch_gpu From dabdd58cb02c25a5d5f89fc5c2caa52110b2d071 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 4 Sep 2023 21:09:56 +0000 Subject: [PATCH 77/78] make lora smaller --- src/diffusers/models/lora.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 5173029a2e56..834a7051b06d 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -158,15 +158,12 @@ def _unfuse_lora(self): def forward(self, hidden_states, scale: float = 1.0): if self.lora_layer is None: - if hasattr(self, "_lora_scale"): - print(f"{self.__class__.__name__} has a lora_scale of {self._lora_scale}") # make sure to the functional Conv2D function as otherwise torch.compile's graph will break # see: https://github.com/huggingface/diffusers/pull/4315 return F.conv2d( hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups ) else: - print(f"{self.__class__.__name__} has a scale of {scale}") return super().forward(hidden_states) + (scale * self.lora_layer(hidden_states)) From 51824c74293fb5a133afd3cbc8f6499b1a00611f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 4 Sep 2023 21:10:41 +0000 Subject: [PATCH 78/78] make lora smaller --- src/diffusers/loaders.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index e693596a1b8a..e1ac6a498a12 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -99,8 +99,6 @@ def _fuse_lora(self, lora_scale=1.0): if self.lora_linear_layer is None: return - # print(f"From _fuse_lora of {self.__class__.__name__} {lora_scale}") - logger.info(f"Fusing LoRA weights for {self.__class__}") dtype, device = self.regular_linear_layer.weight.data.dtype, self.regular_linear_layer.weight.data.device w_orig = self.regular_linear_layer.weight.data.float() @@ -125,7 +123,6 @@ def _unfuse_lora(self): if not (hasattr(self, "w_up") and hasattr(self, "w_down")): return - logger.info(f"Unfusing LoRA weights for {self.__class__}") fused_weight = self.regular_linear_layer.weight.data dtype, device = fused_weight.dtype, fused_weight.device