diff --git a/generative/networks/nets/autoencoderkl.py b/generative/networks/nets/autoencoderkl.py index d1eac073..73ea36ea 100644 --- a/generative/networks/nets/autoencoderkl.py +++ b/generative/networks/nets/autoencoderkl.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib.util +import math from typing import Optional, Sequence, Tuple import torch @@ -16,6 +18,20 @@ import torch.nn.functional as F from monai.networks.blocks import Convolution +# To install xformers, use pip install xformers==0.0.16rc401 +if importlib.util.find_spec("xformers") is not None: + import xformers + import xformers.ops + + has_xformers = True +else: + xformers = None + has_xformers = False + +# TODO: Use MONAI's optional_import +# from monai.utils import optional_import +# xformers, has_xformers = optional_import("xformers.ops", name="xformers") + __all__ = ["AutoencoderKL"] @@ -154,106 +170,120 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x + h -class AttnBlock(nn.Module): +class AttentionBlock(nn.Module): """ Attention block. Args: spatial_dims: number of spatial dimensions (1D, 2D, 3D). - in_channels: number of input channels. + num_channels: number of input channels. + num_head_channels: number of channels in each attention head. norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of channels is divisible by this number. - norm_eps: epsilon for the normalisation. + norm_eps: epsilon value to use for the normalisation. """ def __init__( self, spatial_dims: int, - in_channels: int, - norm_num_groups: int, - norm_eps: float, + num_channels: int, + num_head_channels: Optional[int] = None, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, ) -> None: super().__init__() self.spatial_dims = spatial_dims - self.in_channels = in_channels + self.num_channels = num_channels - self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) - self.q = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=in_channels, - strides=1, - kernel_size=1, - padding=0, - conv_only=True, - ) - self.k = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=in_channels, - strides=1, - kernel_size=1, - padding=0, - conv_only=True, - ) - self.v = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=in_channels, - strides=1, - kernel_size=1, - padding=0, - conv_only=True, - ) - self.proj_out = Convolution( - spatial_dims=spatial_dims, - in_channels=in_channels, - out_channels=in_channels, - strides=1, - kernel_size=1, - padding=0, - conv_only=True, + self.num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 + self.scale = 1 / math.sqrt(num_channels / self.num_heads) + + self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True) + + self.to_q = nn.Linear(num_channels, num_channels) + self.to_k = nn.Linear(num_channels, num_channels) + self.to_v = nn.Linear(num_channels, num_channels) + + self.proj_attn = nn.Linear(num_channels, num_channels) + + def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: + """ + Divide hidden state dimension to the multiple attention heads and reshape their input as instances in the batch. + """ + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) + x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) + return x + + def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: + """Combine the output of the attention heads back into the hidden state dimension.""" + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) + x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) + return x + + def _memory_efficient_attention_xformers( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + x = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) + return x + + def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, ) + attention_probs = attention_scores.softmax(dim=-1) + x = torch.bmm(attention_probs, value) + return x def forward(self, x: torch.Tensor) -> torch.Tensor: - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # Compute attention - b = q.shape[0] - c = q.shape[1] - h = q.shape[2] - w = q.shape[3] - # in order to Torchscript work, we initialise d = 1 - d = 1 + residual = x + batch = channel = height = width = depth = -1 + if self.spatial_dims == 2: + batch, channel, height, width = x.shape if self.spatial_dims == 3: - d = q.shape[4] - n_spatial_elements = h * w * d - - q = q.reshape(b, c, n_spatial_elements) - q = q.permute(0, 2, 1) - k = k.reshape(b, c, n_spatial_elements) - w_ = torch.bmm(q, k) - w_ = w_ * (int(c) ** (-0.5)) - w_ = F.softmax(w_, dim=2) + batch, channel, height, width, depth = x.shape - # Attend to values - v = v.reshape(b, c, n_spatial_elements) - w_ = w_.permute(0, 2, 1) - h_ = torch.bmm(v, w_) + # norm + x = self.norm(x) if self.spatial_dims == 2: - h_ = h_.reshape(b, c, h, w) + x = x.view(batch, channel, height * width).transpose(1, 2) if self.spatial_dims == 3: - h_ = h_.reshape(b, c, h, w, d) + x = x.view(batch, channel, height * width * depth).transpose(1, 2) + + # proj to q, k, v + query = self.to_q(x) + key = self.to_k(x) + value = self.to_v(x) + + # Multi-Head Attention + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if has_xformers: + x = self._memory_efficient_attention_xformers(query, key, value) + else: + x = self._attention(query, key, value) - h_ = self.proj_out(h_) + x = self.reshape_batch_dim_to_heads(x) + x = x.to(query.dtype) - return x + h_ + if self.spatial_dims == 2: + x = x.transpose(-1, -2).reshape(batch, channel, height, width) + if self.spatial_dims == 3: + x = x.transpose(-1, -2).reshape(batch, channel, height, width, depth) + + return x + residual class Encoder(nn.Module): @@ -335,7 +365,14 @@ def __init__( ) block_in_ch = block_out_ch if attention_levels[i]: - blocks.append(AttnBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps)) + blocks.append( + AttentionBlock( + spatial_dims=spatial_dims, + num_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) if i != len(ch_mult) - 1: blocks.append(Downsample(spatial_dims, block_in_ch)) @@ -343,7 +380,14 @@ def __init__( # Non-local attention block if with_nonlocal_attn is True: blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_in_ch)) - blocks.append(AttnBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps)) + blocks.append( + AttentionBlock( + spatial_dims=spatial_dims, + num_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_in_ch)) # Normalise and convert to latent size @@ -434,7 +478,14 @@ def __init__( # Non-local attention block if with_nonlocal_attn is True: blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_in_ch)) - blocks.append(AttnBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps)) + blocks.append( + AttentionBlock( + spatial_dims=spatial_dims, + num_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_in_ch)) for i in reversed(range(len(ch_mult))): @@ -445,7 +496,14 @@ def __init__( block_in_ch = block_out_ch if attention_levels[i]: - blocks.append(AttnBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps)) + blocks.append( + AttentionBlock( + spatial_dims=spatial_dims, + num_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) if i != 0: blocks.append(Upsample(spatial_dims, block_in_ch)) diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index 3812a7f9..271f85e4 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -29,6 +29,7 @@ # limitations under the License. # ========================================================================= +import importlib.util import math from typing import List, Optional, Sequence, Tuple, Union @@ -38,6 +39,21 @@ from monai.networks.layers.factories import Pool from torch import nn +# To install xformers, use pip install xformers==0.0.16rc401 +if importlib.util.find_spec("xformers") is not None: + import xformers + import xformers.ops + + has_xformers = True +else: + xformers = None + has_xformers = False + + +# TODO: Use MONAI's optional_import +# from monai.utils import optional_import +# xformers, has_xformers = optional_import("xformers.ops", name="xformers") + __all__ = ["DiffusionModelUNet"] @@ -114,8 +130,8 @@ def __init__( inner_dim = num_head_channels * num_attention_heads cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim - self.scale = num_head_channels**-0.5 - self.heads = num_attention_heads + self.scale = 1 / math.sqrt(num_head_channels) + self.num_heads = num_attention_heads self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False) @@ -124,28 +140,41 @@ def __init__( self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: + """ + Divide hidden state dimension to the multiple attention heads and reshape their input as instances in the batch. + """ batch_size, seq_len, dim = x.shape - head_size = self.heads - x = x.reshape(batch_size, seq_len, head_size, dim // head_size) - x = x.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) + x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) return x def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: + """Combine the output of the attention heads back into the hidden state dimension.""" batch_size, seq_len, dim = x.shape - head_size = self.heads - x = x.reshape(batch_size // head_size, head_size, seq_len, dim) - x = x.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) + x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) + return x + + def _memory_efficient_attention_xformers( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + x = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) return x def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: - attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) attention_probs = attention_scores.softmax(dim=-1) - # compute attention output - hidden_states = torch.matmul(attention_probs, value) - - # reshape hidden_states - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) - return hidden_states + x = torch.bmm(attention_probs, value) + return x def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor: query = self.to_q(x) @@ -153,11 +182,18 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> to key = self.to_k(context) value = self.to_v(context) + # Multi-Head Attention query = self.reshape_heads_to_batch_dim(query) key = self.reshape_heads_to_batch_dim(key) value = self.reshape_heads_to_batch_dim(value) - x = self._attention(query, key, value) + if has_xformers: + x = self._memory_efficient_attention_xformers(query, key, value) + else: + x = self._attention(query, key, value) + + x = self.reshape_batch_dim_to_heads(x) + x = x.to(query.dtype) return self.to_out(x) @@ -322,10 +358,11 @@ class AttentionBlock(nn.Module): Args: spatial_dims: number of spatial dimensions. - num_channels: number of channels in the input and output. + num_channels: number of input channels. num_head_channels: number of channels in each attention head. - norm_num_groups: number of groups to use for group norm. - norm_eps: epsilon value to use for group norm. + norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of + channels is divisible by this number. + norm_eps: epsilon value to use for the normalisation. """ def __init__( @@ -341,21 +378,48 @@ def __init__( self.num_channels = num_channels self.num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 - self.num_head_size = num_head_channels + self.scale = 1 / math.sqrt(num_channels / self.num_heads) + self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True) - # define q,k,v as linear layers - self.query = nn.Linear(num_channels, num_channels) - self.key = nn.Linear(num_channels, num_channels) - self.value = nn.Linear(num_channels, num_channels) + self.to_q = nn.Linear(num_channels, num_channels) + self.to_k = nn.Linear(num_channels, num_channels) + self.to_v = nn.Linear(num_channels, num_channels) self.proj_attn = nn.Linear(num_channels, num_channels) - def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor: - new_projection_shape = projection.size()[:-1] + (self.num_heads, -1) - # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) - new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) - return new_projection + def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) + x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) + return x + + def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) + x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) + return x + + def _memory_efficient_attention_xformers( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + x = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) + return x + + def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + attention_probs = attention_scores.softmax(dim=-1) + x = torch.bmm(attention_probs, value) + return x def forward(self, x: torch.Tensor) -> torch.Tensor: residual = x @@ -375,29 +439,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.view(batch, channel, height * width * depth).transpose(1, 2) # proj to q, k, v - query_proj = self.query(x) - key_proj = self.key(x) - value_proj = self.value(x) - - # transpose - query_states = self.transpose_for_scores(query_proj) - key_states = self.transpose_for_scores(key_proj) - value_states = self.transpose_for_scores(value_proj) - - # get scores - scale = 1 / math.sqrt(math.sqrt(self.num_channels / self.num_heads)) - attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) - attention_probs = torch.softmax(attention_scores.float(), dim=-1) + query = self.to_q(x) + key = self.to_k(x) + value = self.to_v(x) - # compute attention output - x = torch.matmul(attention_probs, value_states) + # Multi-Head Attention + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) - x = x.permute(0, 2, 1, 3).contiguous() - new_x_shape = x.size()[:-2] + (self.num_channels,) - x = x.view(new_x_shape) + if has_xformers: + x = self._memory_efficient_attention_xformers(query, key, value) + else: + x = self._attention(query, key, value) - # compute next hidden states - x = self.proj_attn(x) + x = self.reshape_batch_dim_to_heads(x) + x = x.to(query.dtype) if self.spatial_dims == 2: x = x.transpose(-1, -2).reshape(batch, channel, height, width)