diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 52970e48147d..e1ac6a498a12 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -95,7 +95,7 @@ def state_dict(self, *args, destination=None, prefix="", keep_vars=False): return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars) - def _fuse_lora(self): + def _fuse_lora(self, lora_scale=1.0): if self.lora_linear_layer is None: return @@ -108,7 +108,7 @@ def _fuse_lora(self): if self.lora_linear_layer.network_alpha is not None: w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank - fused_weight = w_orig + torch.bmm(w_up[None, :], w_down[None, :])[0] + fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) self.regular_linear_layer.weight.data = fused_weight.to(device=device, dtype=dtype) # we can drop the lora layer now @@ -117,6 +117,7 @@ def _fuse_lora(self): # offload the up and down matrices to CPU to not blow the memory self.w_up = w_up.cpu() self.w_down = w_down.cpu() + self.lora_scale = lora_scale def _unfuse_lora(self): if not (hasattr(self, "w_up") and hasattr(self, "w_down")): @@ -128,16 +129,19 @@ def _unfuse_lora(self): w_up = self.w_up.to(device=device).float() w_down = self.w_down.to(device).float() - unfused_weight = fused_weight.float() - torch.bmm(w_up[None, :], w_down[None, :])[0] + unfused_weight = fused_weight.float() - (self.lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) self.regular_linear_layer.weight.data = unfused_weight.to(device=device, dtype=dtype) self.w_up = None self.w_down = None def forward(self, input): + # print(f"{self.__class__.__name__} has a lora_scale of {self.lora_scale}") + if self.lora_scale is None: + self.lora_scale = 1.0 if self.lora_linear_layer is None: return self.regular_linear_layer(input) - return self.regular_linear_layer(input) + self.lora_scale * self.lora_linear_layer(input) + return self.regular_linear_layer(input) + (self.lora_scale * self.lora_linear_layer(input)) def text_encoder_attn_modules(text_encoder): @@ -576,12 +580,13 @@ def save_function(weights, filename): save_function(state_dict, os.path.join(save_directory, weight_name)) logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") - def fuse_lora(self): + def fuse_lora(self, lora_scale=1.0): + self.lora_scale = lora_scale self.apply(self._fuse_lora_apply) def _fuse_lora_apply(self, module): if hasattr(module, "_fuse_lora"): - module._fuse_lora() + module._fuse_lora(self.lora_scale) def unfuse_lora(self): self.apply(self._unfuse_lora_apply) @@ -924,6 +929,7 @@ class LoraLoaderMixin: """ text_encoder_name = TEXT_ENCODER_NAME unet_name = UNET_NAME + num_fused_loras = 0 def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): """ @@ -1807,7 +1813,7 @@ def unload_lora_weights(self): # Safe to call the following regardless of LoRA. self._remove_text_encoder_monkey_patch() - def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True): + def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora_scale: float = 1.0): r""" Fuses the LoRA parameters into the original parameters of the corresponding blocks. @@ -1822,22 +1828,31 @@ def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True): fuse_text_encoder (`bool`, defaults to `True`): Whether to fuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the LoRA parameters then it won't have any effect. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. """ + if fuse_unet or fuse_text_encoder: + self.num_fused_loras += 1 + if self.num_fused_loras > 1: + logger.warn( + "The current API is supported for operating with a single LoRA file. You are trying to load and fuse more than one LoRA which is not well-supported.", + ) + if fuse_unet: - self.unet.fuse_lora() + self.unet.fuse_lora(lora_scale) def fuse_text_encoder_lora(text_encoder): for _, attn_module in text_encoder_attn_modules(text_encoder): if isinstance(attn_module.q_proj, PatchedLoraProjection): - attn_module.q_proj._fuse_lora() - attn_module.k_proj._fuse_lora() - attn_module.v_proj._fuse_lora() - attn_module.out_proj._fuse_lora() + attn_module.q_proj._fuse_lora(lora_scale) + attn_module.k_proj._fuse_lora(lora_scale) + attn_module.v_proj._fuse_lora(lora_scale) + attn_module.out_proj._fuse_lora(lora_scale) for _, mlp_module in text_encoder_mlp_modules(text_encoder): if isinstance(mlp_module.fc1, PatchedLoraProjection): - mlp_module.fc1._fuse_lora() - mlp_module.fc2._fuse_lora() + mlp_module.fc1._fuse_lora(lora_scale) + mlp_module.fc2._fuse_lora(lora_scale) if fuse_text_encoder: if hasattr(self, "text_encoder"): @@ -1884,6 +1899,8 @@ def unfuse_text_encoder_lora(text_encoder): if hasattr(self, "text_encoder_2"): unfuse_text_encoder_lora(self.text_encoder_2) + self.num_fused_loras -= 1 + class FromSingleFileMixin: """ diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index bca272101d95..185b87f2046a 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -177,7 +177,7 @@ def forward( class_labels: Optional[torch.LongTensor] = None, ): # Notice that normalization is always applied before the real computation in the following blocks. - # 1. Self-Attention + # 0. Self-Attention if self.use_ada_layer_norm: norm_hidden_states = self.norm1(hidden_states, timestep) elif self.use_ada_layer_norm_zero: @@ -187,7 +187,10 @@ def forward( else: norm_hidden_states = self.norm1(hidden_states) - # 0. Prepare GLIGEN inputs + # 1. Retrieve lora scale. + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + # 2. Prepare GLIGEN inputs cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} gligen_kwargs = cross_attention_kwargs.pop("gligen", None) @@ -201,12 +204,12 @@ def forward( attn_output = gate_msa.unsqueeze(1) * attn_output hidden_states = attn_output + hidden_states - # 1.5 GLIGEN Control + # 2.5 GLIGEN Control if gligen_kwargs is not None: hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) - # 1.5 ends + # 2.5 ends - # 2. Cross-Attention + # 3. Cross-Attention if self.attn2 is not None: norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) @@ -220,7 +223,7 @@ def forward( ) hidden_states = attn_output + hidden_states - # 3. Feed-forward + # 4. Feed-forward norm_hidden_states = self.norm3(hidden_states) if self.use_ada_layer_norm_zero: @@ -235,11 +238,14 @@ def forward( num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size ff_output = torch.cat( - [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], + [ + self.ff(hid_slice, scale=lora_scale) + for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim) + ], dim=self._chunk_dim, ) else: - ff_output = self.ff(norm_hidden_states) + ff_output = self.ff(norm_hidden_states, scale=lora_scale) if self.use_ada_layer_norm_zero: ff_output = gate_mlp.unsqueeze(1) * ff_output @@ -295,9 +301,12 @@ def __init__( if final_dropout: self.net.append(nn.Dropout(dropout)) - def forward(self, hidden_states): + def forward(self, hidden_states, scale: float = 1.0): for module in self.net: - hidden_states = module(hidden_states) + if isinstance(module, (LoRACompatibleLinear, GEGLU)): + hidden_states = module(hidden_states, scale) + else: + hidden_states = module(hidden_states) return hidden_states @@ -342,8 +351,8 @@ def gelu(self, gate): # mps: gelu is not implemented for float16 return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) - def forward(self, hidden_states): - hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) + def forward(self, hidden_states, scale: float = 1.0): + hidden_states, gate = self.proj(hidden_states, scale).chunk(2, dim=-1) return hidden_states * self.gelu(gate) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 7eb6b495ab6a..1ec89a5ff68e 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -570,15 +570,15 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states, lora_scale=scale) + query = attn.to_q(hidden_states, scale=scale) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states, lora_scale=scale) - value = attn.to_v(encoder_hidden_states, lora_scale=scale) + key = attn.to_k(encoder_hidden_states, scale=scale) + value = attn.to_v(encoder_hidden_states, scale=scale) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) @@ -589,7 +589,7 @@ def __call__( hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj - hidden_states = attn.to_out[0](hidden_states, lora_scale=scale) + hidden_states = attn.to_out[0](hidden_states, scale=scale) # dropout hidden_states = attn.to_out[1](hidden_states) @@ -722,17 +722,17 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states, lora_scale=scale) + query = attn.to_q(hidden_states, scale=scale) query = attn.head_to_batch_dim(query) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, lora_scale=scale) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, lora_scale=scale) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, scale=scale) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, scale=scale) encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) if not attn.only_cross_attention: - key = attn.to_k(hidden_states, lora_scale=scale) - value = attn.to_v(hidden_states, lora_scale=scale) + key = attn.to_k(hidden_states, scale=scale) + value = attn.to_v(hidden_states, scale=scale) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) @@ -746,7 +746,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj - hidden_states = attn.to_out[0](hidden_states, lora_scale=scale) + hidden_states = attn.to_out[0](hidden_states, scale=scale) # dropout hidden_states = attn.to_out[1](hidden_states) @@ -782,7 +782,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states, lora_scale=scale) + query = attn.to_q(hidden_states, scale=scale) query = attn.head_to_batch_dim(query, out_dim=4) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) @@ -791,8 +791,8 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4) if not attn.only_cross_attention: - key = attn.to_k(hidden_states, lora_scale=scale) - value = attn.to_v(hidden_states, lora_scale=scale) + key = attn.to_k(hidden_states, scale=scale) + value = attn.to_v(hidden_states, scale=scale) key = attn.head_to_batch_dim(key, out_dim=4) value = attn.head_to_batch_dim(value, out_dim=4) key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) @@ -809,7 +809,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1]) # linear proj - hidden_states = attn.to_out[0](hidden_states, lora_scale=scale) + hidden_states = attn.to_out[0](hidden_states, scale=scale) # dropout hidden_states = attn.to_out[1](hidden_states) @@ -937,15 +937,15 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states, lora_scale=scale) + query = attn.to_q(hidden_states, scale=scale) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states, lora_scale=scale) - value = attn.to_v(encoder_hidden_states, lora_scale=scale) + key = attn.to_k(encoder_hidden_states, scale=scale) + value = attn.to_v(encoder_hidden_states, scale=scale) query = attn.head_to_batch_dim(query).contiguous() key = attn.head_to_batch_dim(key).contiguous() @@ -958,7 +958,7 @@ def __call__( hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj - hidden_states = attn.to_out[0](hidden_states, lora_scale=scale) + hidden_states = attn.to_out[0](hidden_states, scale=scale) # dropout hidden_states = attn.to_out[1](hidden_states) @@ -1015,15 +1015,15 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states, lora_scale=scale) + query = attn.to_q(hidden_states, scale=scale) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - key = attn.to_k(encoder_hidden_states, lora_scale=scale) - value = attn.to_v(encoder_hidden_states, lora_scale=scale) + key = attn.to_k(encoder_hidden_states, scale=scale) + value = attn.to_v(encoder_hidden_states, scale=scale) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads @@ -1043,7 +1043,7 @@ def __call__( hidden_states = hidden_states.to(query.dtype) # linear proj - hidden_states = attn.to_out[0](hidden_states, lora_scale=scale) + hidden_states = attn.to_out[0](hidden_states, scale=scale) # dropout hidden_states = attn.to_out[1](hidden_states) diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 671c93a3b2b2..834a7051b06d 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -18,12 +18,27 @@ import torch.nn.functional as F from torch import nn +from ..loaders import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules from ..utils import logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name +def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0): + for _, attn_module in text_encoder_attn_modules(text_encoder): + if isinstance(attn_module.q_proj, PatchedLoraProjection): + attn_module.q_proj.lora_scale = lora_scale + attn_module.k_proj.lora_scale = lora_scale + attn_module.v_proj.lora_scale = lora_scale + attn_module.out_proj.lora_scale = lora_scale + + for _, mlp_module in text_encoder_mlp_modules(text_encoder): + if isinstance(mlp_module.fc1, PatchedLoraProjection): + mlp_module.fc1.lora_scale = lora_scale + mlp_module.fc2.lora_scale = lora_scale + + class LoRALinearLayer(nn.Module): def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None): super().__init__() @@ -97,12 +112,11 @@ def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]): self.lora_layer = lora_layer - def _fuse_lora(self): + def _fuse_lora(self, lora_scale=1.0): if self.lora_layer is None: return dtype, device = self.weight.data.dtype, self.weight.data.device - logger.info(f"Fusing LoRA weights for {self.__class__}") w_orig = self.weight.data.float() w_up = self.lora_layer.up.weight.data.float() @@ -113,7 +127,7 @@ def _fuse_lora(self): fusion = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1)) fusion = fusion.reshape((w_orig.shape)) - fused_weight = w_orig + fusion + fused_weight = w_orig + (lora_scale * fusion) self.weight.data = fused_weight.to(device=device, dtype=dtype) # we can drop the lora layer now @@ -122,33 +136,35 @@ def _fuse_lora(self): # offload the up and down matrices to CPU to not blow the memory self.w_up = w_up.cpu() self.w_down = w_down.cpu() + self._lora_scale = lora_scale def _unfuse_lora(self): if not (hasattr(self, "w_up") and hasattr(self, "w_down")): return - logger.info(f"Unfusing LoRA weights for {self.__class__}") fused_weight = self.weight.data dtype, device = fused_weight.data.dtype, fused_weight.data.device - self.w_up = self.w_up.to(device=device, dtype=dtype) - self.w_down = self.w_down.to(device, dtype=dtype) + self.w_up = self.w_up.to(device=device).float() + self.w_down = self.w_down.to(device).float() fusion = torch.mm(self.w_up.flatten(start_dim=1), self.w_down.flatten(start_dim=1)) fusion = fusion.reshape((fused_weight.shape)) - unfused_weight = fused_weight - fusion + unfused_weight = fused_weight.float() - (self._lora_scale * fusion) self.weight.data = unfused_weight.to(device=device, dtype=dtype) self.w_up = None self.w_down = None - def forward(self, x): + def forward(self, hidden_states, scale: float = 1.0): if self.lora_layer is None: # make sure to the functional Conv2D function as otherwise torch.compile's graph will break # see: https://github.com/huggingface/diffusers/pull/4315 - return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + return F.conv2d( + hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups + ) else: - return super().forward(x) + self.lora_layer(x) + return super().forward(hidden_states) + (scale * self.lora_layer(hidden_states)) class LoRACompatibleLinear(nn.Linear): @@ -163,7 +179,7 @@ def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]): self.lora_layer = lora_layer - def _fuse_lora(self): + def _fuse_lora(self, lora_scale=1.0): if self.lora_layer is None: return @@ -176,7 +192,7 @@ def _fuse_lora(self): if self.lora_layer.network_alpha is not None: w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank - fused_weight = w_orig + torch.bmm(w_up[None, :], w_down[None, :])[0] + fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) self.weight.data = fused_weight.to(device=device, dtype=dtype) # we can drop the lora layer now @@ -185,6 +201,7 @@ def _fuse_lora(self): # offload the up and down matrices to CPU to not blow the memory self.w_up = w_up.cpu() self.w_down = w_down.cpu() + self._lora_scale = lora_scale def _unfuse_lora(self): if not (hasattr(self, "w_up") and hasattr(self, "w_down")): @@ -196,14 +213,16 @@ def _unfuse_lora(self): w_up = self.w_up.to(device=device).float() w_down = self.w_down.to(device).float() - unfused_weight = fused_weight.float() - torch.bmm(w_up[None, :], w_down[None, :])[0] + unfused_weight = fused_weight.float() - (self._lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) self.weight.data = unfused_weight.to(device=device, dtype=dtype) self.w_up = None self.w_down = None - def forward(self, hidden_states, lora_scale: int = 1): + def forward(self, hidden_states, scale: float = 1.0): if self.lora_layer is None: - return super().forward(hidden_states) + out = super().forward(hidden_states) + return out else: - return super().forward(hidden_states) + lora_scale * self.lora_layer(hidden_states) + out = super().forward(hidden_states) + (scale * self.lora_layer(hidden_states)) + return out diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 72aa17ed2c2d..ac66e2271c61 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -135,7 +135,7 @@ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_chann else: self.Conv2d_0 = conv - def forward(self, hidden_states, output_size=None): + def forward(self, hidden_states, output_size=None, scale: float = 1.0): assert hidden_states.shape[1] == self.channels if self.use_conv_transpose: @@ -166,9 +166,15 @@ def forward(self, hidden_states, output_size=None): # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if self.use_conv: if self.name == "conv": - hidden_states = self.conv(hidden_states) + if isinstance(self.conv, LoRACompatibleConv): + hidden_states = self.conv(hidden_states, scale) + else: + hidden_states = self.conv(hidden_states) else: - hidden_states = self.Conv2d_0(hidden_states) + if isinstance(self.Conv2d_0, LoRACompatibleConv): + hidden_states = self.Conv2d_0(hidden_states, scale) + else: + hidden_states = self.Conv2d_0(hidden_states) return hidden_states @@ -211,14 +217,17 @@ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name= else: self.conv = conv - def forward(self, hidden_states): + def forward(self, hidden_states, scale: float = 1.0): assert hidden_states.shape[1] == self.channels if self.use_conv and self.padding == 0: pad = (0, 1, 0, 1) hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) assert hidden_states.shape[1] == self.channels - hidden_states = self.conv(hidden_states) + if isinstance(self.conv, LoRACompatibleConv): + hidden_states = self.conv(hidden_states, scale) + else: + hidden_states = self.conv(hidden_states) return hidden_states @@ -588,7 +597,7 @@ def __init__( in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias ) - def forward(self, input_tensor, temb): + def forward(self, input_tensor, temb, scale: float = 1.0): hidden_states = input_tensor if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": @@ -603,18 +612,34 @@ def forward(self, input_tensor, temb): if hidden_states.shape[0] >= 64: input_tensor = input_tensor.contiguous() hidden_states = hidden_states.contiguous() - input_tensor = self.upsample(input_tensor) - hidden_states = self.upsample(hidden_states) + input_tensor = ( + self.upsample(input_tensor, scale=scale) + if isinstance(self.upsample, Upsample2D) + else self.upsample(input_tensor) + ) + hidden_states = ( + self.upsample(hidden_states, scale=scale) + if isinstance(self.upsample, Upsample2D) + else self.upsample(hidden_states) + ) elif self.downsample is not None: - input_tensor = self.downsample(input_tensor) - hidden_states = self.downsample(hidden_states) + input_tensor = ( + self.downsample(input_tensor, scale=scale) + if isinstance(self.downsample, Downsample2D) + else self.downsample(input_tensor) + ) + hidden_states = ( + self.downsample(hidden_states, scale=scale) + if isinstance(self.downsample, Downsample2D) + else self.downsample(hidden_states) + ) - hidden_states = self.conv1(hidden_states) + hidden_states = self.conv1(hidden_states, scale) if self.time_emb_proj is not None: if not self.skip_time_act: temb = self.nonlinearity(temb) - temb = self.time_emb_proj(temb)[:, :, None, None] + temb = self.time_emb_proj(temb, scale)[:, :, None, None] if temb is not None and self.time_embedding_norm == "default": hidden_states = hidden_states + temb @@ -631,10 +656,10 @@ def forward(self, input_tensor, temb): hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states) + hidden_states = self.conv2(hidden_states, scale) if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor) + input_tensor = self.conv_shortcut(input_tensor, scale) output_tensor = (input_tensor + hidden_states) / self.output_scale_factor diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index 225f20a1e397..4819d3be48e1 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -274,6 +274,9 @@ def forward( encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + # Retrieve lora scale. + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + # 1. Input if self.is_input_continuous: batch, _, height, width = hidden_states.shape @@ -281,13 +284,14 @@ def forward( hidden_states = self.norm(hidden_states) if not self.use_linear_projection: - hidden_states = self.proj_in(hidden_states) + hidden_states = self.proj_in(hidden_states, lora_scale) inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) else: inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) - hidden_states = self.proj_in(hidden_states) + hidden_states = self.proj_in(hidden_states, scale=lora_scale) + elif self.is_input_vectorized: hidden_states = self.latent_image_embedding(hidden_states) elif self.is_input_patches: @@ -322,9 +326,9 @@ def forward( if self.is_input_continuous: if not self.use_linear_projection: hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() - hidden_states = self.proj_out(hidden_states) + hidden_states = self.proj_out(hidden_states, scale=lora_scale) else: - hidden_states = self.proj_out(hidden_states) + hidden_states = self.proj_out(hidden_states, scale=lora_scale) hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() output = hidden_states + residual diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 05b360def7c2..3751806c180d 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -640,7 +640,8 @@ def forward( cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: - hidden_states = self.resnets[0](hidden_states, temb) + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) for attn, resnet in zip(self.attentions, self.resnets[1:]): if self.training and self.gradient_checkpointing: @@ -677,7 +678,7 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale=lora_scale) return hidden_states @@ -777,6 +778,7 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + lora_scale = cross_attention_kwargs.get("scale", 1.0) if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. @@ -789,7 +791,7 @@ def forward( # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask mask = attention_mask - hidden_states = self.resnets[0](hidden_states, temb) + hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) for attn, resnet in zip(self.attentions, self.resnets[1:]): # attn hidden_states = attn( @@ -800,7 +802,7 @@ def forward( ) # resnet - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale=lora_scale) return hidden_states @@ -897,20 +899,25 @@ def __init__( else: self.downsamplers = None - def forward(self, hidden_states, temb=None, upsample_size=None): + def forward(self, hidden_states, temb=None, upsample_size=None, cross_attention_kwargs=None): + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + lora_scale = cross_attention_kwargs.get("scale", 1.0) + output_states = () for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) + cross_attention_kwargs.update({"scale": lora_scale}) + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = attn(hidden_states, **cross_attention_kwargs) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: if self.downsample_type == "resnet": - hidden_states = downsampler(hidden_states, temb=temb) + hidden_states = downsampler(hidden_states, temb=temb, scale=lora_scale) else: - hidden_states = downsampler(hidden_states) + hidden_states = downsampler(hidden_states, scale=lora_scale) output_states += (hidden_states,) @@ -1019,6 +1026,8 @@ def forward( ): output_states = () + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + blocks = list(zip(self.resnets, self.attentions)) for i, (resnet, attn) in enumerate(blocks): @@ -1049,7 +1058,7 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale=lora_scale) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -1067,7 +1076,7 @@ def custom_forward(*inputs): if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) + hidden_states = downsampler(hidden_states, scale=lora_scale) output_states = output_states + (hidden_states,) @@ -1126,7 +1135,7 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states, temb=None): + def forward(self, hidden_states, temb=None, scale: float = 1.0): output_states = () for resnet in self.resnets: @@ -1147,13 +1156,13 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale=scale) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) + hidden_states = downsampler(hidden_states, scale=scale) output_states = output_states + (hidden_states,) @@ -1209,13 +1218,13 @@ def __init__( else: self.downsamplers = None - def forward(self, hidden_states): + def forward(self, hidden_states, scale: float = 1.0): for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb=None) + hidden_states = resnet(hidden_states, temb=None, scale=scale) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) + hidden_states = downsampler(hidden_states, scale) return hidden_states @@ -1292,14 +1301,15 @@ def __init__( else: self.downsamplers = None - def forward(self, hidden_states): + def forward(self, hidden_states, scale: float = 1.0): for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb=None) - hidden_states = attn(hidden_states) + hidden_states = resnet(hidden_states, temb=None, scale=scale) + cross_attention_kwargs = {"scale": scale} + hidden_states = attn(hidden_states, **cross_attention_kwargs) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) + hidden_states = downsampler(hidden_states, scale) return hidden_states @@ -1385,16 +1395,17 @@ def __init__( self.downsamplers = None self.skip_conv = None - def forward(self, hidden_states, temb=None, skip_sample=None): + def forward(self, hidden_states, temb=None, skip_sample=None, scale: float = 1.0): output_states = () for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) + hidden_states = resnet(hidden_states, temb, scale=scale) + cross_attention_kwargs = {"scale": scale} + hidden_states = attn(hidden_states, **cross_attention_kwargs) output_states += (hidden_states,) if self.downsamplers is not None: - hidden_states = self.resnet_down(hidden_states, temb) + hidden_states = self.resnet_down(hidden_states, temb, scale=scale) for downsampler in self.downsamplers: skip_sample = downsampler(skip_sample) @@ -1465,15 +1476,15 @@ def __init__( self.downsamplers = None self.skip_conv = None - def forward(self, hidden_states, temb=None, skip_sample=None): + def forward(self, hidden_states, temb=None, skip_sample=None, scale: float = 1.0): output_states = () for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale) output_states += (hidden_states,) if self.downsamplers is not None: - hidden_states = self.resnet_down(hidden_states, temb) + hidden_states = self.resnet_down(hidden_states, temb, scale) for downsampler in self.downsamplers: skip_sample = downsampler(skip_sample) @@ -1548,7 +1559,7 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states, temb=None): + def forward(self, hidden_states, temb=None, scale: float = 1.0): output_states = () for resnet in self.resnets: @@ -1569,13 +1580,13 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, temb) + hidden_states = downsampler(hidden_states, temb, scale) output_states = output_states + (hidden_states,) @@ -1689,6 +1700,8 @@ def forward( output_states = () cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + lora_scale = cross_attention_kwargs.get("scale", 1.0) + if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. mask = None if encoder_hidden_states is None else encoder_attention_mask @@ -1720,7 +1733,7 @@ def custom_forward(*inputs): **cross_attention_kwargs, ) else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale=lora_scale) hidden_states = attn( hidden_states, @@ -1733,7 +1746,7 @@ def custom_forward(*inputs): if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, temb) + hidden_states = downsampler(hidden_states, temb, scale=lora_scale) output_states = output_states + (hidden_states,) @@ -1786,7 +1799,7 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states, temb=None): + def forward(self, hidden_states, temb=None, scale: float = 1.0): output_states = () for resnet in self.resnets: @@ -1807,7 +1820,7 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale) output_states += (hidden_states,) @@ -1893,6 +1906,7 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, ): output_states = () + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 for resnet, attn in zip(self.resnets, self.attentions): if self.training and self.gradient_checkpointing: @@ -1922,7 +1936,7 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, ) else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale=lora_scale) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -2033,22 +2047,23 @@ def __init__( else: self.upsamplers = None - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0): for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states) + hidden_states = resnet(hidden_states, temb, scale=scale) + cross_attention_kwargs = {"scale": scale} + hidden_states = attn(hidden_states, **cross_attention_kwargs) if self.upsamplers is not None: for upsampler in self.upsamplers: if self.upsample_type == "resnet": - hidden_states = upsampler(hidden_states, temb=temb) + hidden_states = upsampler(hidden_states, temb=temb, scale=scale) else: - hidden_states = upsampler(hidden_states) + hidden_states = upsampler(hidden_states, scale=scale) return hidden_states @@ -2150,6 +2165,8 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -2183,7 +2200,7 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale=lora_scale) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -2195,7 +2212,7 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) + hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale) return hidden_states @@ -2248,7 +2265,7 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0): for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -2272,11 +2289,11 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale=scale) if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) + hidden_states = upsampler(hidden_states, upsample_size, scale=scale) return hidden_states @@ -2325,9 +2342,9 @@ def __init__( else: self.upsamplers = None - def forward(self, hidden_states, temb=None): + def forward(self, hidden_states, temb=None, scale: float = 1.0): for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb=temb) + hidden_states = resnet(hidden_states, temb=temb, scale=scale) if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -2404,14 +2421,15 @@ def __init__( else: self.upsamplers = None - def forward(self, hidden_states, temb=None): + def forward(self, hidden_states, temb=None, scale: float = 1.0): for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb=temb) - hidden_states = attn(hidden_states, temb=temb) + hidden_states = resnet(hidden_states, temb=temb, scale=scale) + cross_attention_kwargs = {"scale": scale} + hidden_states = attn(hidden_states, temb=temb, **cross_attention_kwargs) if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) + hidden_states = upsampler(hidden_states, scale=scale) return hidden_states @@ -2507,16 +2525,17 @@ def __init__( self.skip_norm = None self.act = None - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None, scale: float = 1.0): for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale=scale) - hidden_states = self.attentions[0](hidden_states) + cross_attention_kwargs = {"scale": scale} + hidden_states = self.attentions[0](hidden_states, **cross_attention_kwargs) if skip_sample is not None: skip_sample = self.upsampler(skip_sample) @@ -2530,7 +2549,7 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample skip_sample = skip_sample + skip_sample_states - hidden_states = self.resnet_up(hidden_states, temb) + hidden_states = self.resnet_up(hidden_states, temb, scale=scale) return hidden_states, skip_sample @@ -2604,14 +2623,14 @@ def __init__( self.skip_norm = None self.act = None - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None, scale: float = 1.0): for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale=scale) if skip_sample is not None: skip_sample = self.upsampler(skip_sample) @@ -2625,7 +2644,7 @@ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample skip_sample = skip_sample + skip_sample_states - hidden_states = self.resnet_up(hidden_states, temb) + hidden_states = self.resnet_up(hidden_states, temb, scale=scale) return hidden_states, skip_sample @@ -2697,7 +2716,7 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0): for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -2721,11 +2740,11 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale=scale) if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, temb) + hidden_states = upsampler(hidden_states, temb, scale=scale) return hidden_states @@ -2840,6 +2859,7 @@ def forward( ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + lora_scale = cross_attention_kwargs.get("scale", 1.0) if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. mask = None if encoder_hidden_states is None else encoder_attention_mask @@ -2877,7 +2897,7 @@ def custom_forward(*inputs): **cross_attention_kwargs, ) else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale=lora_scale) hidden_states = attn( hidden_states, @@ -2888,7 +2908,7 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, temb) + hidden_states = upsampler(hidden_states, temb, scale=lora_scale) return hidden_states @@ -2941,7 +2961,7 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0): res_hidden_states_tuple = res_hidden_states_tuple[-1] if res_hidden_states_tuple is not None: hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) @@ -2964,7 +2984,7 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale=scale) if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -3072,6 +3092,7 @@ def forward( if res_hidden_states_tuple is not None: hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 for resnet, attn in zip(self.resnets, self.attentions): if self.training and self.gradient_checkpointing: @@ -3100,7 +3121,7 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, ) else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale=lora_scale) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index f793be49c32b..8c56cccdbddf 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -934,6 +934,7 @@ def forward( cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} # 3. down + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None @@ -956,7 +957,7 @@ def forward( **additional_residuals, ) else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) if is_adapter and len(down_block_additional_residuals) > 0: sample += down_block_additional_residuals.pop(0) @@ -1020,7 +1021,11 @@ def forward( ) else: sample = upsample_block( - hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + scale=lora_scale, ) # 6. post-process diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 3110e0b0c942..78e46990b50c 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -25,6 +25,7 @@ from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline @@ -322,6 +323,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index d148882a3ae6..5713395639cc 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -27,6 +27,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline @@ -320,6 +321,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index f3a7a5112752..82e3851377d9 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -25,6 +25,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( deprecate, @@ -312,6 +313,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index fbc52b7d31cd..88410ad0d7c3 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -25,6 +24,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( deprecate, @@ -337,6 +337,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 4d8abb9c3afc..f98e4bb20c3c 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -26,6 +26,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( deprecate, @@ -463,6 +464,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index 48c0f289d067..b20d1f0c636e 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -33,6 +33,7 @@ LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( is_accelerate_available, @@ -342,6 +343,10 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 997994afb4f7..f54adc1a1aba 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -34,6 +34,7 @@ LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( is_accelerate_available, @@ -315,6 +316,10 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 43ee40857e6c..2b33a631bfcd 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -33,6 +33,7 @@ LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( is_accelerate_available, @@ -352,6 +353,10 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py index b33b1285cd0f..9a3b828828e3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py @@ -27,6 +27,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import DDIMScheduler from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -329,6 +330,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 26dace9f460a..6faec1f9a140 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -23,6 +23,7 @@ from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( deprecate, @@ -321,6 +322,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py index 044c8aaa903f..b5f94add9f18 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py @@ -25,6 +25,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import Attention +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline @@ -321,6 +322,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index 0ad72f12aa0e..0ab0b85a46c2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -26,6 +26,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -203,6 +204,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py index e258909d8c5d..261dabe46754 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -26,6 +26,7 @@ from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import DDIMInverseScheduler, KarrasDiffusionSchedulers from ...utils import ( PIL_INTERPOLATION, @@ -502,6 +503,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py index 8293d4c59d57..78d0e852a632 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py @@ -24,6 +24,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention import GatedSelfAttentionDense +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( deprecate, @@ -298,6 +299,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py index 74061245c73e..0940b830065c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py @@ -30,6 +30,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention import GatedSelfAttentionDense +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( is_accelerate_available, @@ -331,6 +332,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 68c4664ed312..5e7f5f01cb28 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -25,6 +25,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( PIL_INTERPOLATION, @@ -323,6 +324,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index fadf5fdca007..c1fb5831a305 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -25,6 +25,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AsymmetricAutoencoderKL, AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -393,6 +394,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index c8554a644ff9..f5b60d95e543 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -25,6 +25,7 @@ from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( PIL_INTERPOLATION, @@ -322,6 +323,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py index 4b3a303ca463..92e481e707c3 100755 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -22,6 +22,7 @@ from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import LMSDiscreteScheduler from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -229,6 +230,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py index 326830f0f9c6..13ccb226b0d7 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py @@ -24,6 +24,7 @@ from ...image_processor import VaeImageProcessorLDM3D from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( BaseOutput, @@ -293,6 +294,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py index bd182cc711b1..0b96a2cc8195 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py @@ -21,6 +21,7 @@ from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import PNDMScheduler from ...schedulers.scheduling_utils import SchedulerMixin from ...utils import deprecate, logging, randn_tensor @@ -233,6 +234,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py index f3f2738811f0..84bd9f7e8815 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py @@ -21,6 +21,7 @@ from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import DDIMScheduler from ...utils import deprecate, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline @@ -210,6 +211,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py index cbbc6c197788..7ce3dfc35908 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py @@ -21,6 +21,7 @@ from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( deprecate, @@ -277,6 +278,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 4a30ba33630c..f2b281e8c6c7 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -32,6 +32,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import Attention +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler from ...schedulers.scheduling_ddim_inverse import DDIMInverseScheduler from ...utils import ( @@ -462,6 +463,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py index 93861ed6f877..539696e9d5b6 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py @@ -22,6 +22,7 @@ from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline @@ -233,6 +234,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 05943ddf48f7..eaa76cab6812 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -30,6 +30,7 @@ LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -255,6 +256,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py index e32c92cdb4f4..10207d0ba32d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -23,6 +23,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel from ...models.embeddings import get_timestep_embedding +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( deprecate, @@ -364,6 +365,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py index 650b43a92f7c..1a7427c21bc5 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -25,6 +25,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.embeddings import get_timestep_embedding +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_version, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -313,6 +314,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 094fe026c2a6..2ab8003b95f0 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -20,7 +20,11 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from ...image_processor import VaeImageProcessor -from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import ( + FromSingleFileMixin, + LoraLoaderMixin, + TextualInversionLoaderMixin, +) from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, @@ -28,6 +32,7 @@ LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( is_accelerate_available, @@ -284,6 +289,10 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 9e3a34385d87..bfb3861c9965 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -29,6 +29,7 @@ LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( is_accelerate_available, @@ -294,6 +295,10 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index ba8ddabc4da3..ab8722b0bfa0 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -30,6 +30,7 @@ LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( deprecate, @@ -444,6 +445,10 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py index 2b097d091450..93b5f3b25d8b 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py @@ -24,6 +24,7 @@ from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, MultiAdapter, T2IAdapter, UNet2DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( PIL_INTERPOLATION, @@ -318,6 +319,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index 6311c02be475..07947b7c8e88 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -31,6 +31,7 @@ LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( PIL_INTERPOLATION, @@ -312,6 +313,10 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index 5315329e692e..72063769c868 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -21,6 +21,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet3DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( deprecate, @@ -246,6 +247,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py index 5122f114f342..cb0c24c474a4 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py @@ -22,6 +22,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet3DConditionModel +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( deprecate, @@ -308,6 +309,9 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 8af417241b1a..b653dd80d9b1 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1143,6 +1143,7 @@ def forward( cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} # 3. down + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None @@ -1165,7 +1166,7 @@ def forward( **additional_residuals, ) else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) if is_adapter and len(down_block_additional_residuals) > 0: sample += down_block_additional_residuals.pop(0) @@ -1229,7 +1230,11 @@ def forward( ) else: sample = upsample_block( - hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + scale=lora_scale, ) # 6. post-process @@ -1410,7 +1415,7 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states, temb=None): + def forward(self, hidden_states, temb=None, scale: float = 1.0): output_states = () for resnet in self.resnets: @@ -1431,13 +1436,13 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale=scale) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) + hidden_states = downsampler(hidden_states, scale=scale) output_states = output_states + (hidden_states,) @@ -1547,6 +1552,8 @@ def forward( ): output_states = () + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + blocks = list(zip(self.resnets, self.attentions)) for i, (resnet, attn) in enumerate(blocks): @@ -1577,7 +1584,7 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale=lora_scale) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -1595,7 +1602,7 @@ def custom_forward(*inputs): if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) + hidden_states = downsampler(hidden_states, scale=lora_scale) output_states = output_states + (hidden_states,) @@ -1651,7 +1658,7 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0): for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -1675,11 +1682,11 @@ def custom_forward(*inputs): create_custom_forward(resnet), hidden_states, temb ) else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale=scale) if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) + hidden_states = upsampler(hidden_states, upsample_size, scale=scale) return hidden_states @@ -1782,6 +1789,8 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -1815,7 +1824,7 @@ def custom_forward(*inputs): return_dict=False, )[0] else: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale=lora_scale) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -1827,7 +1836,7 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) + hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale) return hidden_states @@ -1932,7 +1941,8 @@ def forward( cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: - hidden_states = self.resnets[0](hidden_states, temb) + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) for attn, resnet in zip(self.attentions, self.resnets[1:]): if self.training and self.gradient_checkpointing: @@ -1969,7 +1979,7 @@ def custom_forward(*inputs): encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale=lora_scale) return hidden_states @@ -2070,6 +2080,7 @@ def forward( encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + lora_scale = cross_attention_kwargs.get("scale", 1.0) if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. @@ -2082,7 +2093,7 @@ def forward( # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask mask = attention_mask - hidden_states = self.resnets[0](hidden_states, temb) + hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) for attn, resnet in zip(self.attentions, self.resnets[1:]): # attn hidden_states = attn( @@ -2093,6 +2104,6 @@ def forward( ) # resnet - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, scale=lora_scale) return hidden_states diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 848f2f44adc9..253427a87b1f 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -92,11 +92,11 @@ def create_text_encoder_lora_layers(text_encoder: nn.Module): return text_encoder_lora_layers -def set_lora_weights(lora_attn_parameters, randn_weight=False): +def set_lora_weights(lora_attn_parameters, randn_weight=False, var=1.0): with torch.no_grad(): for parameter in lora_attn_parameters: if randn_weight: - parameter[:] = torch.randn_like(parameter) + parameter[:] = torch.randn_like(parameter) * var else: torch.zero_(parameter) @@ -719,7 +719,7 @@ def test_text_encoder_lora_state_dict_unchanged(self): # default & unloaded LoRA weights should have identical state_dicts assert text_encoder_1_sd_keys == text_encoder_1_sd_keys_3 # default & loaded LoRA weights should NOT have identical state_dicts - assert text_encoder_1_sd_keys != text_encoder_1_sd_keys_2 # + assert text_encoder_1_sd_keys != text_encoder_1_sd_keys_2 # default & unloaded LoRA weights should have identical state_dicts assert text_encoder_2_sd_keys == text_encoder_2_sd_keys_3 @@ -863,6 +863,161 @@ def test_lora_fusion_is_not_affected_by_unloading(self): lora_image_slice, images_with_unloaded_lora_slice ), "`unload_lora_weights()` should have not effect on the semantics of the results as the LoRA parameters were fused." + def test_fuse_lora_with_different_scales(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) + + _ = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + + # Emulate training. + set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True) + set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True) + set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True) + + with tempfile.TemporaryDirectory() as tmpdirname: + StableDiffusionXLPipeline.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + safe_serialization=True, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + sd_pipe.fuse_lora(lora_scale=1.0) + lora_images_scale_one = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + lora_image_slice_scale_one = lora_images_scale_one[0, -3:, -3:, -1] + + # Reverse LoRA fusion. + sd_pipe.unfuse_lora() + + with tempfile.TemporaryDirectory() as tmpdirname: + StableDiffusionXLPipeline.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + safe_serialization=True, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + sd_pipe.fuse_lora(lora_scale=0.5) + lora_images_scale_0_5 = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + lora_image_slice_scale_0_5 = lora_images_scale_0_5[0, -3:, -3:, -1] + + assert not np.allclose( + lora_image_slice_scale_one, lora_image_slice_scale_0_5, atol=1e-03 + ), "Different LoRA scales should influence the outputs accordingly." + + def test_with_different_scales(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) + original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + original_imagee_slice = original_images[0, -3:, -3:, -1] + + # Emulate training. + set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True) + set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True) + set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True) + + with tempfile.TemporaryDirectory() as tmpdirname: + StableDiffusionXLPipeline.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + safe_serialization=True, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + lora_images_scale_one = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + lora_image_slice_scale_one = lora_images_scale_one[0, -3:, -3:, -1] + + lora_images_scale_0_5 = sd_pipe( + **pipeline_inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5} + ).images + lora_image_slice_scale_0_5 = lora_images_scale_0_5[0, -3:, -3:, -1] + + lora_images_scale_0_0 = sd_pipe( + **pipeline_inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0} + ).images + lora_image_slice_scale_0_0 = lora_images_scale_0_0[0, -3:, -3:, -1] + + assert not np.allclose( + lora_image_slice_scale_one, lora_image_slice_scale_0_5, atol=1e-03 + ), "Different LoRA scales should influence the outputs accordingly." + + assert np.allclose( + original_imagee_slice, lora_image_slice_scale_0_0, atol=1e-03 + ), "LoRA scale of 0.0 shouldn't be different from the results without LoRA." + + def test_with_different_scales_fusion_equivalence(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + # sd_pipe.unet.set_default_attn_processor() + sd_pipe.set_progress_bar_config(disable=None) + + _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) + + images = sd_pipe( + **pipeline_inputs, + generator=torch.manual_seed(0), + ).images + images_slice = images[0, -3:, -3:, -1] + + # Emulate training. + set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True, var=0.1) + set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True, var=0.1) + set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True, var=0.1) + + with tempfile.TemporaryDirectory() as tmpdirname: + StableDiffusionXLPipeline.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + safe_serialization=True, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + lora_images_scale_0_5 = sd_pipe( + **pipeline_inputs, + generator=torch.manual_seed(0), + cross_attention_kwargs={"scale": 0.5}, + ).images + lora_image_slice_scale_0_5 = lora_images_scale_0_5[0, -3:, -3:, -1] + + sd_pipe.fuse_lora(lora_scale=0.5) + lora_images_scale_0_5_fusion = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + lora_image_slice_scale_0_5_fusion = lora_images_scale_0_5_fusion[0, -3:, -3:, -1] + + assert np.allclose( + lora_image_slice_scale_0_5, lora_image_slice_scale_0_5_fusion, atol=1e-03 + ), "Fusion shouldn't affect the results when calling the pipeline with a non-default LoRA scale." + + sd_pipe.unfuse_lora() + images_unfused = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + images_slice_unfused = images_unfused[0, -3:, -3:, -1] + + assert np.allclose(images_slice, images_slice_unfused, atol=1e-03), "Unfused should match no LoRA" + + assert not np.allclose( + images_slice, lora_image_slice_scale_0_5, atol=1e-03 + ), "0.5 scale and no scale shouldn't match" + @slow @require_torch_gpu