From bcc310848957197c0d664b399de26fccba256eae Mon Sep 17 00:00:00 2001 From: William Berman Date: Wed, 10 May 2023 14:08:24 -0700 Subject: [PATCH 1/2] Replace `AttentionBlock` with `Attention` --- src/diffusers/models/attention.py | 174 +----------------- src/diffusers/models/attention_processor.py | 124 ++++++++++++- src/diffusers/models/modeling_utils.py | 38 ++++ src/diffusers/models/unet_2d_blocks.py | 60 ++++-- .../pipeline_stable_diffusion_upscale.py | 12 +- tests/models/test_layers_utils.py | 55 +----- 6 files changed, 215 insertions(+), 248 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 134f84fc9d50..0b313b83d360 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -11,189 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math -from typing import Callable, Optional +from typing import Optional import torch import torch.nn.functional as F from torch import nn from ..utils import maybe_allow_in_graph -from ..utils.import_utils import is_xformers_available from .attention_processor import Attention from .embeddings import CombinedTimestepLabelEmbeddings -if is_xformers_available(): - import xformers - import xformers.ops -else: - xformers = None - - -class AttentionBlock(nn.Module): - """ - An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted - to the N-d case. - https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. - Uses three q, k, v linear layers to compute attention. - - Parameters: - channels (`int`): The number of channels in the input and output. - num_head_channels (`int`, *optional*): - The number of channels in each head. If None, then `num_heads` = 1. - norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm. - rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by. - eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. - """ - - # IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore - - def __init__( - self, - channels: int, - num_head_channels: Optional[int] = None, - norm_num_groups: int = 32, - rescale_output_factor: float = 1.0, - eps: float = 1e-5, - ): - super().__init__() - self.channels = channels - - self.num_heads = channels // num_head_channels if num_head_channels is not None else 1 - self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True) - - # define q,k,v as linear layers - self.query = nn.Linear(channels, channels) - self.key = nn.Linear(channels, channels) - self.value = nn.Linear(channels, channels) - - self.rescale_output_factor = rescale_output_factor - self.proj_attn = nn.Linear(channels, channels, bias=True) - - self._use_memory_efficient_attention_xformers = False - self._use_2_0_attn = True - self._attention_op = None - - def reshape_heads_to_batch_dim(self, tensor, merge_head_and_batch=True): - batch_size, seq_len, dim = tensor.shape - head_size = self.num_heads - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3) - if merge_head_and_batch: - tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) - return tensor - - def reshape_batch_dim_to_heads(self, tensor, unmerge_head_and_batch=True): - head_size = self.num_heads - - if unmerge_head_and_batch: - batch_head_size, seq_len, dim = tensor.shape - batch_size = batch_head_size // head_size - - tensor = tensor.reshape(batch_size, head_size, seq_len, dim) - else: - batch_size, _, seq_len, dim = tensor.shape - - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size, seq_len, dim * head_size) - return tensor - - def set_use_memory_efficient_attention_xformers( - self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None - ): - if use_memory_efficient_attention_xformers: - if not is_xformers_available(): - raise ModuleNotFoundError( - ( - "Refer to https://github.com/facebookresearch/xformers for more information on how to install" - " xformers" - ), - name="xformers", - ) - elif not torch.cuda.is_available(): - raise ValueError( - "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" - " only available for GPU " - ) - else: - try: - # Make sure we can run the memory efficient attention - _ = xformers.ops.memory_efficient_attention( - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - ) - except Exception as e: - raise e - self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers - self._attention_op = attention_op - - def forward(self, hidden_states): - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - - # proj to q, k, v - query_proj = self.query(hidden_states) - key_proj = self.key(hidden_states) - value_proj = self.value(hidden_states) - - scale = 1 / math.sqrt(self.channels / self.num_heads) - - _use_2_0_attn = self._use_2_0_attn and not self._use_memory_efficient_attention_xformers - use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention") and _use_2_0_attn - - query_proj = self.reshape_heads_to_batch_dim(query_proj, merge_head_and_batch=not use_torch_2_0_attn) - key_proj = self.reshape_heads_to_batch_dim(key_proj, merge_head_and_batch=not use_torch_2_0_attn) - value_proj = self.reshape_heads_to_batch_dim(value_proj, merge_head_and_batch=not use_torch_2_0_attn) - - if self._use_memory_efficient_attention_xformers: - # Memory efficient attention - hidden_states = xformers.ops.memory_efficient_attention( - query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op, scale=scale - ) - hidden_states = hidden_states.to(query_proj.dtype) - elif use_torch_2_0_attn: - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query_proj, key_proj, value_proj, dropout_p=0.0, is_causal=False - ) - hidden_states = hidden_states.to(query_proj.dtype) - else: - attention_scores = torch.baddbmm( - torch.empty( - query_proj.shape[0], - query_proj.shape[1], - key_proj.shape[1], - dtype=query_proj.dtype, - device=query_proj.device, - ), - query_proj, - key_proj.transpose(-1, -2), - beta=0, - alpha=scale, - ) - attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) - hidden_states = torch.bmm(attention_probs, value_proj) - - # reshape hidden_states - hidden_states = self.reshape_batch_dim_to_heads(hidden_states, unmerge_head_and_batch=not use_torch_2_0_attn) - - # compute next hidden_states - hidden_states = self.proj_attn(hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) - - # res connect and rescale - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states - - @maybe_allow_in_graph class BasicTransformerBlock(nn.Module): r""" diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index b727c76e2137..f3606261baf2 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -65,6 +65,9 @@ def __init__( out_bias: bool = True, scale_qk: bool = True, only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, processor: Optional["AttnProcessor"] = None, ): super().__init__() @@ -72,6 +75,8 @@ def __init__( cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim self.upcast_attention = upcast_attention self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection self.scale_qk = scale_qk self.scale = dim_head**-0.5 if self.scale_qk else 1.0 @@ -91,7 +96,7 @@ def __init__( ) if norm_num_groups is not None: - self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) + self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) else: self.group_norm = None @@ -407,10 +412,22 @@ def __call__( encoder_hidden_states=None, attention_mask=None, ): + residual = hidden_states + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + query = attn.to_q(hidden_states) if encoder_hidden_states is None: @@ -434,6 +451,14 @@ def __call__( # dropout hidden_states = attn.to_out[1](hidden_states) + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states @@ -474,11 +499,22 @@ def __init__(self, hidden_size, cross_attention_dim=None, rank=4): self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): + residual = hidden_states + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) query = attn.head_to_batch_dim(query) @@ -502,6 +538,14 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a # dropout hidden_states = attn.to_out[1](hidden_states) + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states @@ -762,12 +806,23 @@ def __init__(self, attention_op: Optional[Callable] = None): self.attention_op = attention_op def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + residual = hidden_states + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + query = attn.to_q(hidden_states) if encoder_hidden_states is None: @@ -792,6 +847,15 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states @@ -801,6 +865,14 @@ def __init__(self): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + residual = hidden_states + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) @@ -812,6 +884,9 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + query = attn.to_q(hidden_states) if encoder_hidden_states is None: @@ -840,6 +915,15 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states @@ -858,11 +942,22 @@ def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optio self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): + residual = hidden_states + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) query = attn.head_to_batch_dim(query).contiguous() @@ -887,6 +982,14 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a # dropout hidden_states = attn.to_out[1](hidden_states) + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states @@ -980,11 +1083,22 @@ def __init__(self, slice_size): self.slice_size = slice_size def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + residual = hidden_states + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + query = attn.to_q(hidden_states) dim = query.shape[-1] query = attn.head_to_batch_dim(query) @@ -1025,6 +1139,14 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a # dropout hidden_states = attn.to_out[1](hidden_states) + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index ef14ec3d09ef..4196d3cb01d8 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -583,6 +583,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if device_map is None: param_device = "cpu" state_dict = load_state_dict(model_file, variant=variant) + cls.convert_deprecated_attention_blocks(state_dict) # move the params from meta device to cpu missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) if len(missing_keys) > 0: @@ -625,6 +626,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model = cls.from_config(config, **unused_kwargs) state_dict = load_state_dict(model_file, variant=variant) + cls.convert_deprecated_attention_blocks(state_dict) model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( model, @@ -761,6 +763,42 @@ def _find_mismatched_keys( return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs + @classmethod + def convert_deprecated_attention_blocks(cls, state_dict): + # We check for the deprecated attention block via the `proj_attn` layer in the state dict + # The only other class with a layer called `proj_attn` is in `SelfAttention1d` which is used + # by only the top level model, `UNet1DModel`. Since, `UNet1DModel` wont have any of the deprecated + # attention blocks, we can just early return. + if cls.__name__ == "UNet1DModel": + return + + deprecated_attention_block_paths = [] + + for k in state_dict.keys(): + if "proj_attn.weight" in k: + index = k.index("proj_attn.weight") + path = k[: index - 1] + deprecated_attention_block_paths.append(path) + + for path in deprecated_attention_block_paths: + # group_norm path stays the same + + # query -> to_q + state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight") + state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias") + + # key -> to_k + state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight") + state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias") + + # value -> to_v + state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight") + state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias") + + # proj_attn -> to_out.0 + state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight") + state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias") + @property def device(self) -> device: """ diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 2f7b19b7328a..a034baf74676 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -18,7 +18,7 @@ import torch.nn.functional as F from torch import nn -from .attention import AdaGroupNorm, AttentionBlock +from .attention import AdaGroupNorm from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 from .dual_transformer_2d import DualTransformer2DModel from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D @@ -427,12 +427,16 @@ def __init__( for _ in range(num_layers): if self.add_attention: attentions.append( - AttentionBlock( + Attention( in_channels, - num_head_channels=attn_num_head_channels, + heads=in_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else in_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, ) ) else: @@ -711,12 +715,16 @@ def __init__( ) ) attentions.append( - AttentionBlock( + Attention( out_channels, - num_head_channels=attn_num_head_channels, + heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, ) ) @@ -1060,12 +1068,16 @@ def __init__( ) ) attentions.append( - AttentionBlock( + Attention( out_channels, - num_head_channels=attn_num_head_channels, + heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, ) ) @@ -1134,11 +1146,16 @@ def __init__( ) ) self.attentions.append( - AttentionBlock( + Attention( out_channels, - num_head_channels=attn_num_head_channels, + heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, + norm_num_groups=32, + residual_connection=True, + bias=True, + upcast_softmax=True, ) ) @@ -1703,12 +1720,16 @@ def __init__( ) ) attentions.append( - AttentionBlock( + Attention( out_channels, - num_head_channels=attn_num_head_channels, + heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, ) ) @@ -2037,12 +2058,16 @@ def __init__( ) ) attentions.append( - AttentionBlock( + Attention( out_channels, - num_head_channels=attn_num_head_channels, + heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, ) ) @@ -2109,11 +2134,16 @@ def __init__( ) self.attentions.append( - AttentionBlock( + Attention( out_channels, - num_head_channels=attn_num_head_channels, + heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, + norm_num_groups=32, + residual_connection=True, + bias=True, + upcast_softmax=True, ) ) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index b7530ac4ec5c..6bb463a6a65f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -19,11 +19,11 @@ import numpy as np import PIL import torch -import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.attention_processor import AttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -709,12 +709,14 @@ def __call__( # make sure the VAE is in float32 mode, as it overflows in float16 self.vae.to(dtype=torch.float32) - # TODO(Patrick, William) - clean up when attention is refactored - use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention") - use_xformers = self.vae.decoder.mid_block.attentions[0]._use_memory_efficient_attention_xformers + use_torch_2_0_or_xformers = self.vae.decoder.mid_block.attentions[0].processor in [ + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + ] # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory - if not use_torch_2_0_attn and not use_xformers: + if not use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(latents.dtype) self.vae.decoder.conv_in.to(latents.dtype) self.vae.decoder.mid_block.to(latents.dtype) diff --git a/tests/models/test_layers_utils.py b/tests/models/test_layers_utils.py index db0d6c78d902..98fa1afcbb9d 100644 --- a/tests/models/test_layers_utils.py +++ b/tests/models/test_layers_utils.py @@ -20,7 +20,7 @@ import torch from torch import nn -from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU, AttentionBlock +from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU from diffusers.models.embeddings import get_timestep_embedding from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D from diffusers.models.transformer_2d import Transformer2DModel @@ -314,59 +314,6 @@ def test_restnet_with_kernel_sde_vp(self): assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) -class AttentionBlockTests(unittest.TestCase): - @unittest.skipIf( - torch_device == "mps", "Matmul crashes on MPS, see https://github.com/pytorch/pytorch/issues/84039" - ) - def test_attention_block_default(self): - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - sample = torch.randn(1, 32, 64, 64).to(torch_device) - attentionBlock = AttentionBlock( - channels=32, - num_head_channels=1, - rescale_output_factor=1.0, - eps=1e-6, - norm_num_groups=32, - ).to(torch_device) - with torch.no_grad(): - attention_scores = attentionBlock(sample) - - assert attention_scores.shape == (1, 32, 64, 64) - output_slice = attention_scores[0, -1, -3:, -3:] - - expected_slice = torch.tensor( - [-1.4975, -0.0038, -0.7847, -1.4567, 1.1220, -0.8962, -1.7394, 1.1319, -0.5427], device=torch_device - ) - assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) - - def test_attention_block_sd(self): - # This version uses SD params and is compatible with mps - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - sample = torch.randn(1, 512, 64, 64).to(torch_device) - attentionBlock = AttentionBlock( - channels=512, - rescale_output_factor=1.0, - eps=1e-6, - norm_num_groups=32, - ).to(torch_device) - with torch.no_grad(): - attention_scores = attentionBlock(sample) - - assert attention_scores.shape == (1, 512, 64, 64) - output_slice = attention_scores[0, -1, -3:, -3:] - - expected_slice = torch.tensor( - [-0.6621, -0.0156, -3.2766, 0.8025, -0.8609, 0.2820, 0.0905, -1.1179, -3.2126], device=torch_device - ) - assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) - - class Transformer2DModelTests(unittest.TestCase): def test_spatial_transformer_default(self): torch.manual_seed(0) From 500ce022db209904ad00ac7b568bb631d505f86e Mon Sep 17 00:00:00 2001 From: William Berman Date: Thu, 11 May 2023 10:50:41 -0700 Subject: [PATCH 2/2] use _from_deprecated_attn_block check re: @patrickvonplaten --- src/diffusers/models/attention_processor.py | 5 ++ src/diffusers/models/modeling_utils.py | 84 +++++++++++---------- src/diffusers/models/unet_2d_blocks.py | 7 ++ 3 files changed, 58 insertions(+), 38 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index f3606261baf2..f88400da0333 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -68,6 +68,7 @@ def __init__( eps: float = 1e-5, rescale_output_factor: float = 1.0, residual_connection: bool = False, + _from_deprecated_attn_block=False, processor: Optional["AttnProcessor"] = None, ): super().__init__() @@ -78,6 +79,10 @@ def __init__( self.rescale_output_factor = rescale_output_factor self.residual_connection = residual_connection + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + self.scale_qk = scale_qk self.scale = dim_head**-0.5 if self.scale_qk else 1.0 diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 4196d3cb01d8..e7cfcd71062f 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -583,7 +583,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if device_map is None: param_device = "cpu" state_dict = load_state_dict(model_file, variant=variant) - cls.convert_deprecated_attention_blocks(state_dict) + model._convert_deprecated_attention_blocks(state_dict) # move the params from meta device to cpu missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) if len(missing_keys) > 0: @@ -626,7 +626,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model = cls.from_config(config, **unused_kwargs) state_dict = load_state_dict(model_file, variant=variant) - cls.convert_deprecated_attention_blocks(state_dict) + model._convert_deprecated_attention_blocks(state_dict) model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( model, @@ -763,42 +763,6 @@ def _find_mismatched_keys( return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs - @classmethod - def convert_deprecated_attention_blocks(cls, state_dict): - # We check for the deprecated attention block via the `proj_attn` layer in the state dict - # The only other class with a layer called `proj_attn` is in `SelfAttention1d` which is used - # by only the top level model, `UNet1DModel`. Since, `UNet1DModel` wont have any of the deprecated - # attention blocks, we can just early return. - if cls.__name__ == "UNet1DModel": - return - - deprecated_attention_block_paths = [] - - for k in state_dict.keys(): - if "proj_attn.weight" in k: - index = k.index("proj_attn.weight") - path = k[: index - 1] - deprecated_attention_block_paths.append(path) - - for path in deprecated_attention_block_paths: - # group_norm path stays the same - - # query -> to_q - state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight") - state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias") - - # key -> to_k - state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight") - state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias") - - # value -> to_v - state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight") - state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias") - - # proj_attn -> to_out.0 - state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight") - state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias") - @property def device(self) -> device: """ @@ -841,3 +805,47 @@ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable) else: return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) + + def _convert_deprecated_attention_blocks(self, state_dict): + deprecated_attention_block_paths = [] + + def recursive_find_attn_block(name, module): + if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: + deprecated_attention_block_paths.append(name) + + for sub_name, sub_module in module.named_children(): + sub_name = sub_name if name == "" else f"{name}.{sub_name}" + recursive_find_attn_block(sub_name, sub_module) + + recursive_find_attn_block("", self) + + # NOTE: we have to check if the deprecated parameters are in the state dict + # because it is possible we are loading from a state dict that was already + # converted + + for path in deprecated_attention_block_paths: + # group_norm path stays the same + + # query -> to_q + if f"{path}.query.weight" in state_dict: + state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight") + if f"{path}.query.bias" in state_dict: + state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias") + + # key -> to_k + if f"{path}.key.weight" in state_dict: + state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight") + if f"{path}.key.bias" in state_dict: + state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias") + + # value -> to_v + if f"{path}.value.weight" in state_dict: + state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight") + if f"{path}.value.bias" in state_dict: + state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias") + + # proj_attn -> to_out.0 + if f"{path}.proj_attn.weight" in state_dict: + state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight") + if f"{path}.proj_attn.bias" in state_dict: + state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias") diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index a034baf74676..0004f074c563 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -437,6 +437,7 @@ def __init__( residual_connection=True, bias=True, upcast_softmax=True, + _from_deprecated_attn_block=True, ) ) else: @@ -725,6 +726,7 @@ def __init__( residual_connection=True, bias=True, upcast_softmax=True, + _from_deprecated_attn_block=True, ) ) @@ -1078,6 +1080,7 @@ def __init__( residual_connection=True, bias=True, upcast_softmax=True, + _from_deprecated_attn_block=True, ) ) @@ -1156,6 +1159,7 @@ def __init__( residual_connection=True, bias=True, upcast_softmax=True, + _from_deprecated_attn_block=True, ) ) @@ -1730,6 +1734,7 @@ def __init__( residual_connection=True, bias=True, upcast_softmax=True, + _from_deprecated_attn_block=True, ) ) @@ -2068,6 +2073,7 @@ def __init__( residual_connection=True, bias=True, upcast_softmax=True, + _from_deprecated_attn_block=True, ) ) @@ -2144,6 +2150,7 @@ def __init__( residual_connection=True, bias=True, upcast_softmax=True, + _from_deprecated_attn_block=True, ) )