Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
212 changes: 135 additions & 77 deletions generative/networks/nets/autoencoderkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib.util
import math
from typing import Optional, Sequence, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from monai.networks.blocks import Convolution

# To install xformers, use pip install xformers==0.0.16rc401
if importlib.util.find_spec("xformers") is not None:
import xformers
import xformers.ops

has_xformers = True
else:
xformers = None
has_xformers = False

# TODO: Use MONAI's optional_import
# from monai.utils import optional_import
# xformers, has_xformers = optional_import("xformers.ops", name="xformers")

__all__ = ["AutoencoderKL"]


Expand Down Expand Up @@ -154,106 +170,120 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + h


class AttnBlock(nn.Module):
class AttentionBlock(nn.Module):
"""
Attention block.

Args:
spatial_dims: number of spatial dimensions (1D, 2D, 3D).
in_channels: number of input channels.
num_channels: number of input channels.
num_head_channels: number of channels in each attention head.
norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of
channels is divisible by this number.
norm_eps: epsilon for the normalisation.
norm_eps: epsilon value to use for the normalisation.
"""

def __init__(
self,
spatial_dims: int,
in_channels: int,
norm_num_groups: int,
norm_eps: float,
num_channels: int,
num_head_channels: Optional[int] = None,
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
) -> None:
super().__init__()
self.spatial_dims = spatial_dims
self.in_channels = in_channels
self.num_channels = num_channels

self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True)
self.q = Convolution(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=in_channels,
strides=1,
kernel_size=1,
padding=0,
conv_only=True,
)
self.k = Convolution(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=in_channels,
strides=1,
kernel_size=1,
padding=0,
conv_only=True,
)
self.v = Convolution(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=in_channels,
strides=1,
kernel_size=1,
padding=0,
conv_only=True,
)
self.proj_out = Convolution(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=in_channels,
strides=1,
kernel_size=1,
padding=0,
conv_only=True,
self.num_heads = num_channels // num_head_channels if num_head_channels is not None else 1
self.scale = 1 / math.sqrt(num_channels / self.num_heads)

self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True)

self.to_q = nn.Linear(num_channels, num_channels)
self.to_k = nn.Linear(num_channels, num_channels)
self.to_v = nn.Linear(num_channels, num_channels)

self.proj_attn = nn.Linear(num_channels, num_channels)

def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor:
"""
Divide hidden state dimension to the multiple attention heads and reshape their input as instances in the batch.
"""
batch_size, seq_len, dim = x.shape
x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads)
x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads)
return x

def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor:
"""Combine the output of the attention heads back into the hidden state dimension."""
batch_size, seq_len, dim = x.shape
x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim)
x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads)
return x

def _memory_efficient_attention_xformers(
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> torch.Tensor:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
x = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None)
return x

def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
attention_scores = torch.baddbmm(
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
query,
key.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
attention_probs = attention_scores.softmax(dim=-1)
x = torch.bmm(attention_probs, value)
return x

def forward(self, x: torch.Tensor) -> torch.Tensor:
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)

# Compute attention
b = q.shape[0]
c = q.shape[1]
h = q.shape[2]
w = q.shape[3]
# in order to Torchscript work, we initialise d = 1
d = 1
residual = x

batch = channel = height = width = depth = -1
if self.spatial_dims == 2:
batch, channel, height, width = x.shape
if self.spatial_dims == 3:
d = q.shape[4]
n_spatial_elements = h * w * d

q = q.reshape(b, c, n_spatial_elements)
q = q.permute(0, 2, 1)
k = k.reshape(b, c, n_spatial_elements)
w_ = torch.bmm(q, k)
w_ = w_ * (int(c) ** (-0.5))
w_ = F.softmax(w_, dim=2)
batch, channel, height, width, depth = x.shape

# Attend to values
v = v.reshape(b, c, n_spatial_elements)
w_ = w_.permute(0, 2, 1)
h_ = torch.bmm(v, w_)
# norm
x = self.norm(x)

if self.spatial_dims == 2:
h_ = h_.reshape(b, c, h, w)
x = x.view(batch, channel, height * width).transpose(1, 2)
if self.spatial_dims == 3:
h_ = h_.reshape(b, c, h, w, d)
x = x.view(batch, channel, height * width * depth).transpose(1, 2)

# proj to q, k, v
query = self.to_q(x)
key = self.to_k(x)
value = self.to_v(x)

# Multi-Head Attention
query = self.reshape_heads_to_batch_dim(query)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)

if has_xformers:
x = self._memory_efficient_attention_xformers(query, key, value)
else:
x = self._attention(query, key, value)

h_ = self.proj_out(h_)
x = self.reshape_batch_dim_to_heads(x)
x = x.to(query.dtype)

return x + h_
if self.spatial_dims == 2:
x = x.transpose(-1, -2).reshape(batch, channel, height, width)
if self.spatial_dims == 3:
x = x.transpose(-1, -2).reshape(batch, channel, height, width, depth)

return x + residual


class Encoder(nn.Module):
Expand Down Expand Up @@ -335,15 +365,29 @@ def __init__(
)
block_in_ch = block_out_ch
if attention_levels[i]:
blocks.append(AttnBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps))
blocks.append(
AttentionBlock(
spatial_dims=spatial_dims,
num_channels=block_in_ch,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
)
)

if i != len(ch_mult) - 1:
blocks.append(Downsample(spatial_dims, block_in_ch))

# Non-local attention block
if with_nonlocal_attn is True:
blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_in_ch))
blocks.append(AttnBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps))
blocks.append(
AttentionBlock(
spatial_dims=spatial_dims,
num_channels=block_in_ch,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
)
)
blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_in_ch))

# Normalise and convert to latent size
Expand Down Expand Up @@ -434,7 +478,14 @@ def __init__(
# Non-local attention block
if with_nonlocal_attn is True:
blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_in_ch))
blocks.append(AttnBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps))
blocks.append(
AttentionBlock(
spatial_dims=spatial_dims,
num_channels=block_in_ch,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
)
)
blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_in_ch))

for i in reversed(range(len(ch_mult))):
Expand All @@ -445,7 +496,14 @@ def __init__(
block_in_ch = block_out_ch

if attention_levels[i]:
blocks.append(AttnBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps))
blocks.append(
AttentionBlock(
spatial_dims=spatial_dims,
num_channels=block_in_ch,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
)
)

if i != 0:
blocks.append(Upsample(spatial_dims, block_in_ch))
Expand Down
Loading