From 34ae6437c99427675f75ae7dbc155fc45e307e56 Mon Sep 17 00:00:00 2001 From: Ilmari Heikkinen Date: Fri, 2 Dec 2022 01:07:21 +0800 Subject: [PATCH 1/3] Add xformers attention to VAE --- src/diffusers/models/attention.py | 62 +++++++++++++++++++------- src/diffusers/models/unet_2d_blocks.py | 4 ++ src/diffusers/models/vae.py | 16 +++++++ 3 files changed, 67 insertions(+), 15 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 0547bb4a0e47..c6ea2fa37603 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -290,6 +290,33 @@ def __init__( self.rescale_output_factor = rescale_output_factor self.proj_attn = nn.Linear(channels, channels, 1) + self._use_memory_efficient_attention_xformers = False + + def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + if not is_xformers_available(): + print("Here is how to install it") + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" + " available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + def reshape_heads_to_batch_dim(self, tensor): batch_size, seq_len, dim = tensor.shape head_size = self.num_heads @@ -324,21 +351,26 @@ def forward(self, hidden_states): key_proj = self.reshape_heads_to_batch_dim(key_proj) value_proj = self.reshape_heads_to_batch_dim(value_proj) - attention_scores = torch.baddbmm( - torch.empty( - query_proj.shape[0], - query_proj.shape[1], - key_proj.shape[1], - dtype=query_proj.dtype, - device=query_proj.device, - ), - query_proj, - key_proj.transpose(-1, -2), - beta=0, - alpha=scale, - ) - attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) - hidden_states = torch.bmm(attention_probs, value_proj) + if self._use_memory_efficient_attention_xformers: + # Memory efficient attention + hidden_states = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None) + hidden_states = hidden_states.to(query_proj.dtype) + else: + attention_scores = torch.baddbmm( + torch.empty( + query_proj.shape[0], + query_proj.shape[1], + key_proj.shape[1], + dtype=query_proj.dtype, + device=query_proj.device, + ), + query_proj, + key_proj.transpose(-1, -2), + beta=0, + alpha=scale, + ) + attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) + hidden_states = torch.bmm(attention_probs, value_proj) # reshape hidden_states hidden_states = self.reshape_batch_dim_to_heads(hidden_states) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index cce7e7fd5a90..8b579acad353 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -305,6 +305,10 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for attn in self.attentions: + attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + def forward(self, hidden_states, temb=None, encoder_states=None): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index e29f4e8afa2f..e4ad9795c0f1 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -125,6 +125,12 @@ def __init__( conv_out_channels = 2 * out_channels if double_z else out_channels self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + for block in self.down_blocks: + if hasattr(block, "attentions") and block.attentions is not None: + block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + def forward(self, x): sample = x sample = self.conv_in(sample) @@ -205,6 +211,12 @@ def __init__( self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + for block in self.up_blocks: + if hasattr(block, "attentions") and block.attentions is not None: + block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + def forward(self, z): sample = z sample = self.conv_in(sample) @@ -567,6 +579,10 @@ def __init__( self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) self.use_slicing = False + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + self.decoder.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + self.encoder.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: h = self.encoder(x) moments = self.quant_conv(h) From 2fee0060dcc03c94a8fa66beeb3e077f68bee8b2 Mon Sep 17 00:00:00 2001 From: Ilmari Heikkinen Date: Sat, 3 Dec 2022 12:57:04 +0800 Subject: [PATCH 2/3] Simplify VAE xformers code --- src/diffusers/models/attention.py | 2 +- src/diffusers/models/unet_2d_blocks.py | 4 ---- src/diffusers/models/vae.py | 16 ---------------- 3 files changed, 1 insertion(+), 21 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 1380bb2c7444..7e7bb455eed7 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -288,7 +288,7 @@ def __init__( self._use_memory_efficient_attention_xformers = False - def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): if not is_xformers_available(): print("Here is how to install it") raise ModuleNotFoundError( diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index c934d87ef70d..d78804b18e75 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -305,10 +305,6 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): - for attn in self.attentions: - attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) - def forward(self, hidden_states, temb=None, encoder_states=None): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index e4ad9795c0f1..e29f4e8afa2f 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -125,12 +125,6 @@ def __init__( conv_out_channels = 2 * out_channels if double_z else out_channels self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) - def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): - self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) - for block in self.down_blocks: - if hasattr(block, "attentions") and block.attentions is not None: - block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) - def forward(self, x): sample = x sample = self.conv_in(sample) @@ -211,12 +205,6 @@ def __init__( self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) - def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): - self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) - for block in self.up_blocks: - if hasattr(block, "attentions") and block.attentions is not None: - block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) - def forward(self, z): sample = z sample = self.conv_in(sample) @@ -579,10 +567,6 @@ def __init__( self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) self.use_slicing = False - def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): - self.decoder.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) - self.encoder.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) - def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: h = self.encoder(x) moments = self.quant_conv(h) From c86d6286f00ef4c525a37f595dce8a825a9d68a9 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Sat, 3 Dec 2022 14:56:13 +0100 Subject: [PATCH 3/3] Update src/diffusers/models/attention.py --- src/diffusers/models/attention.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 7e7bb455eed7..31d1f1ff418a 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -290,7 +290,6 @@ def __init__( def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): if not is_xformers_available(): - print("Here is how to install it") raise ModuleNotFoundError( "Refer to https://github.com/facebookresearch/xformers for more information on how to install" " xformers",