Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 98 additions & 76 deletions generative/networks/nets/autoencoderkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,37 +296,28 @@ class Encoder(nn.Module):
Args:
spatial_dims: number of spatial dimensions (1D, 2D, 3D).
in_channels: number of input channels.
num_channels: number of filters in the first downsampling.
num_channels: sequence of block output channels.
out_channels: number of channels in the bottom layer (latent space) of the autoencoder.
ch_mult: list of multipliers of num_channels in the initial layer and in each downsampling layer. Example: if
you want three downsamplings, you have to input a 4-element list. If you input [1, 1, 2, 2],
the first downsampling will leave num_channels to num_channels, the next will multiply num_channels by 2,
and the next will multiply num_channels*2 by 2 again, resulting in 8, 8, 16 and 32 channels.
num_res_blocks: number of residual blocks (see ResBlock) per level.
norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number.
norm_eps: epsilon for the normalization.
attention_levels: indicate which level from ch_mult contain an attention block.
attention_levels: indicate which level from num_channels contain an attention block.
with_nonlocal_attn: if True use non-local attention block.
"""

def __init__(
self,
spatial_dims: int,
in_channels: int,
num_channels: int,
num_channels: Sequence[int],
out_channels: int,
ch_mult: Sequence[int],
num_res_blocks: int,
norm_num_groups: int,
norm_eps: float,
attention_levels: Optional[Sequence[bool]] = None,
attention_levels: Sequence[bool],
with_nonlocal_attn: bool = True,
) -> None:
super().__init__()

if attention_levels is None:
attention_levels = (False,) * len(ch_mult)

self.spatial_dims = spatial_dims
self.in_channels = in_channels
self.num_channels = num_channels
Expand All @@ -336,15 +327,13 @@ def __init__(
self.norm_eps = norm_eps
self.attention_levels = attention_levels

in_ch_mult = (1,) + tuple(ch_mult)

blocks = []
# Initial convolution
blocks.append(
Convolution(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=num_channels,
out_channels=num_channels[0],
strides=1,
kernel_size=3,
padding=1,
Expand All @@ -353,52 +342,73 @@ def __init__(
)

# Residual and downsampling blocks
for i in range(len(ch_mult)):
block_in_ch = num_channels * in_ch_mult[i]
block_out_ch = num_channels * ch_mult[i]
output_channel = num_channels[0]
for i in range(len(num_channels)):
input_channel = output_channel
output_channel = num_channels[i]
is_final_block = i == len(num_channels) - 1

for _ in range(self.num_res_blocks):
blocks.append(
ResBlock(
spatial_dims=spatial_dims,
in_channels=block_in_ch,
in_channels=input_channel,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
out_channels=block_out_ch,
out_channels=output_channel,
)
)
block_in_ch = block_out_ch
input_channel = output_channel
if attention_levels[i]:
blocks.append(
AttentionBlock(
spatial_dims=spatial_dims,
num_channels=block_in_ch,
num_channels=input_channel,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
)
)

if i != len(ch_mult) - 1:
blocks.append(Downsample(spatial_dims, block_in_ch))
if not is_final_block:
blocks.append(Downsample(spatial_dims=spatial_dims, in_channels=input_channel))

# Non-local attention block
if with_nonlocal_attn is True:
blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_in_ch))
blocks.append(
AttentionBlock(
ResBlock(
spatial_dims=spatial_dims,
num_channels=block_in_ch,
in_channels=num_channels[-1],
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
out_channels=num_channels[-1],
)
)
blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_in_ch))

blocks.append(
AttentionBlock(
spatial_dims=spatial_dims,
num_channels=num_channels[-1],
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
)
)
blocks.append(
ResBlock(
spatial_dims=spatial_dims,
in_channels=num_channels[-1],
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
out_channels=num_channels[-1],
)
)
# Normalise and convert to latent size
blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True))
blocks.append(
nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels[-1], eps=norm_eps, affine=True)
)
blocks.append(
Convolution(
spatial_dims=self.spatial_dims,
in_channels=block_in_ch,
in_channels=num_channels[-1],
out_channels=out_channels,
strides=1,
kernel_size=3,
Expand All @@ -421,56 +431,47 @@ class Decoder(nn.Module):

Args:
spatial_dims: number of spatial dimensions (1D, 2D, 3D).
num_channels: number of filters in the last upsampling.
num_channels: sequence of block output channels.
in_channels: number of channels in the bottom layer (latent space) of the autoencoder.
out_channels: number of output channels.
ch_mult: list of multipliers of num_channels that make for all the upsampling layers before the last. In the
last layer, there will be a transition from num_channels to out_channels. In the layers before that,
channels will be the product of the previous number of channels by ch_mult.
num_res_blocks: number of residual blocks (see ResBlock) per level.
norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number.
norm_eps: epsilon for the normalization.
attention_levels: indicate which level from ch_mult contain an attention block.
attention_levels: indicate which level from num_channels contain an attention block.
with_nonlocal_attn: if True use non-local attention block.
"""

def __init__(
self,
spatial_dims: int,
num_channels: int,
num_channels: Sequence[int],
in_channels: int,
out_channels: int,
ch_mult: Sequence[int],
num_res_blocks: int,
norm_num_groups: int,
norm_eps: float,
attention_levels: Optional[Sequence[bool]] = None,
attention_levels: Sequence[bool],
with_nonlocal_attn: bool = True,
) -> None:
super().__init__()

if attention_levels is None:
attention_levels = (False,) * len(ch_mult)

self.spatial_dims = spatial_dims
self.num_channels = num_channels
self.in_channels = in_channels
self.out_channels = out_channels
self.ch_mult = ch_mult
self.num_res_blocks = num_res_blocks
self.norm_num_groups = norm_num_groups
self.norm_eps = norm_eps
self.attention_levels = attention_levels

block_in_ch = num_channels * self.ch_mult[-1]
reversed_block_out_channels = list(reversed(num_channels))

blocks = []
# Initial convolution
blocks.append(
Convolution(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=block_in_ch,
out_channels=reversed_block_out_channels[0],
strides=1,
kernel_size=3,
padding=1,
Expand All @@ -480,25 +481,53 @@ def __init__(

# Non-local attention block
if with_nonlocal_attn is True:
blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_in_ch))
blocks.append(
ResBlock(
spatial_dims=spatial_dims,
in_channels=reversed_block_out_channels[0],
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
out_channels=reversed_block_out_channels[0],
)
)
blocks.append(
AttentionBlock(
spatial_dims=spatial_dims,
num_channels=block_in_ch,
num_channels=reversed_block_out_channels[0],
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
)
)
blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_in_ch))
blocks.append(
ResBlock(
spatial_dims=spatial_dims,
in_channels=reversed_block_out_channels[0],
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
out_channels=reversed_block_out_channels[0],
)
)

for i in reversed(range(len(ch_mult))):
block_out_ch = num_channels * self.ch_mult[i]
reversed_attention_levels = list(reversed(attention_levels))
block_out_ch = reversed_block_out_channels[0]
for i in range(len(reversed_block_out_channels)):
block_in_ch = block_out_ch
block_out_ch = reversed_block_out_channels[i]
is_final_block = i == len(num_channels) - 1

for _ in range(self.num_res_blocks):
blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_out_ch))
blocks.append(
ResBlock(
spatial_dims=spatial_dims,
in_channels=block_in_ch,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
out_channels=block_out_ch,
)
)
block_in_ch = block_out_ch

if attention_levels[i]:
if reversed_attention_levels[i]:
blocks.append(
AttentionBlock(
spatial_dims=spatial_dims,
Expand All @@ -508,8 +537,8 @@ def __init__(
)
)

if i != 0:
blocks.append(Upsample(spatial_dims, block_in_ch))
if not is_final_block:
blocks.append(Upsample(spatial_dims=spatial_dims, in_channels=block_in_ch))

blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True))
blocks.append(
Expand Down Expand Up @@ -542,50 +571,44 @@ class AutoencoderKL(nn.Module):
spatial_dims: number of spatial dimensions (1D, 2D, 3D).
in_channels: number of input channels.
out_channels: number of output channels.
num_channels: number of filters in the first downsampling / last upsampling.
latent_channels: latent embedding dimension.
ch_mult: multiplier of the number of channels in each downsampling layer (+ initial one). i.e.: If you want 3
downsamplings, it should be a 4-element list.
num_res_blocks: number of residual blocks (see ResBlock) per level.
num_channels: sequence of block output channels.
attention_levels: sequence of levels to add attention.
latent_channels: latent embedding dimension.
norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number.
norm_eps: epsilon for the normalization.
attention_levels: indicate which level from ch_mult contain an attention block.
with_encoder_nonlocal_attn: if True use non-local attention block in the encoder.
with_decoder_nonlocal_attn: if True use non-local attention block in the decoder.
"""

def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
num_channels: int,
latent_channels: int,
ch_mult: Sequence[int],
num_res_blocks: int,
in_channels: int = 1,
out_channels: int = 1,
num_res_blocks: int = 2,
num_channels: Sequence[int] = (32, 64, 64, 64),
attention_levels: Sequence[bool] = (False, False, True, True),
latent_channels: int = 3,
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
attention_levels: Optional[Sequence[bool]] = None,
with_encoder_nonlocal_attn: bool = True,
with_decoder_nonlocal_attn: bool = True,
) -> None:
super().__init__()
if attention_levels is None:
attention_levels = (False,) * len(ch_mult)

# The number of channels should be multiple of num_groups
if (num_channels % norm_num_groups) != 0:
raise ValueError("AutoencoderKL expects number of channels being multiple of number of groups")
# All number of channels should be multiple of num_groups
if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels):
raise ValueError("AutoencoderKL expects all num_channels being multiple of norm_num_groups")

if len(ch_mult) != len(attention_levels):
raise ValueError("AutoencoderKL expects ch_mult being same size of attention_levels")
if len(num_channels) != len(attention_levels):
raise ValueError("AutoencoderKL expects num_channels being same size of attention_levels")

self.encoder = Encoder(
spatial_dims=spatial_dims,
in_channels=in_channels,
num_channels=num_channels,
out_channels=latent_channels,
ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
Expand All @@ -597,7 +620,6 @@ def __init__(
num_channels=num_channels,
in_channels=latent_channels,
out_channels=out_channels,
ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
Expand Down
Loading