diff --git a/generative/networks/nets/autoencoderkl.py b/generative/networks/nets/autoencoderkl.py index d081cad5..d13b26ad 100644 --- a/generative/networks/nets/autoencoderkl.py +++ b/generative/networks/nets/autoencoderkl.py @@ -296,16 +296,12 @@ class Encoder(nn.Module): Args: spatial_dims: number of spatial dimensions (1D, 2D, 3D). in_channels: number of input channels. - num_channels: number of filters in the first downsampling. + num_channels: sequence of block output channels. out_channels: number of channels in the bottom layer (latent space) of the autoencoder. - ch_mult: list of multipliers of num_channels in the initial layer and in each downsampling layer. Example: if - you want three downsamplings, you have to input a 4-element list. If you input [1, 1, 2, 2], - the first downsampling will leave num_channels to num_channels, the next will multiply num_channels by 2, - and the next will multiply num_channels*2 by 2 again, resulting in 8, 8, 16 and 32 channels. num_res_blocks: number of residual blocks (see ResBlock) per level. norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. norm_eps: epsilon for the normalization. - attention_levels: indicate which level from ch_mult contain an attention block. + attention_levels: indicate which level from num_channels contain an attention block. with_nonlocal_attn: if True use non-local attention block. """ @@ -313,20 +309,15 @@ def __init__( self, spatial_dims: int, in_channels: int, - num_channels: int, + num_channels: Sequence[int], out_channels: int, - ch_mult: Sequence[int], num_res_blocks: int, norm_num_groups: int, norm_eps: float, - attention_levels: Optional[Sequence[bool]] = None, + attention_levels: Sequence[bool], with_nonlocal_attn: bool = True, ) -> None: super().__init__() - - if attention_levels is None: - attention_levels = (False,) * len(ch_mult) - self.spatial_dims = spatial_dims self.in_channels = in_channels self.num_channels = num_channels @@ -336,15 +327,13 @@ def __init__( self.norm_eps = norm_eps self.attention_levels = attention_levels - in_ch_mult = (1,) + tuple(ch_mult) - blocks = [] # Initial convolution blocks.append( Convolution( spatial_dims=spatial_dims, in_channels=in_channels, - out_channels=num_channels, + out_channels=num_channels[0], strides=1, kernel_size=3, padding=1, @@ -353,52 +342,73 @@ def __init__( ) # Residual and downsampling blocks - for i in range(len(ch_mult)): - block_in_ch = num_channels * in_ch_mult[i] - block_out_ch = num_channels * ch_mult[i] + output_channel = num_channels[0] + for i in range(len(num_channels)): + input_channel = output_channel + output_channel = num_channels[i] + is_final_block = i == len(num_channels) - 1 + for _ in range(self.num_res_blocks): blocks.append( ResBlock( spatial_dims=spatial_dims, - in_channels=block_in_ch, + in_channels=input_channel, norm_num_groups=norm_num_groups, norm_eps=norm_eps, - out_channels=block_out_ch, + out_channels=output_channel, ) ) - block_in_ch = block_out_ch + input_channel = output_channel if attention_levels[i]: blocks.append( AttentionBlock( spatial_dims=spatial_dims, - num_channels=block_in_ch, + num_channels=input_channel, norm_num_groups=norm_num_groups, norm_eps=norm_eps, ) ) - if i != len(ch_mult) - 1: - blocks.append(Downsample(spatial_dims, block_in_ch)) + if not is_final_block: + blocks.append(Downsample(spatial_dims=spatial_dims, in_channels=input_channel)) # Non-local attention block if with_nonlocal_attn is True: - blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_in_ch)) blocks.append( - AttentionBlock( + ResBlock( spatial_dims=spatial_dims, - num_channels=block_in_ch, + in_channels=num_channels[-1], norm_num_groups=norm_num_groups, norm_eps=norm_eps, + out_channels=num_channels[-1], ) ) - blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_in_ch)) + blocks.append( + AttentionBlock( + spatial_dims=spatial_dims, + num_channels=num_channels[-1], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + blocks.append( + ResBlock( + spatial_dims=spatial_dims, + in_channels=num_channels[-1], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=num_channels[-1], + ) + ) # Normalise and convert to latent size - blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True)) + blocks.append( + nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels[-1], eps=norm_eps, affine=True) + ) blocks.append( Convolution( spatial_dims=self.spatial_dims, - in_channels=block_in_ch, + in_channels=num_channels[-1], out_channels=out_channels, strides=1, kernel_size=3, @@ -421,48 +431,39 @@ class Decoder(nn.Module): Args: spatial_dims: number of spatial dimensions (1D, 2D, 3D). - num_channels: number of filters in the last upsampling. + num_channels: sequence of block output channels. in_channels: number of channels in the bottom layer (latent space) of the autoencoder. out_channels: number of output channels. - ch_mult: list of multipliers of num_channels that make for all the upsampling layers before the last. In the - last layer, there will be a transition from num_channels to out_channels. In the layers before that, - channels will be the product of the previous number of channels by ch_mult. num_res_blocks: number of residual blocks (see ResBlock) per level. norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. norm_eps: epsilon for the normalization. - attention_levels: indicate which level from ch_mult contain an attention block. + attention_levels: indicate which level from num_channels contain an attention block. with_nonlocal_attn: if True use non-local attention block. """ def __init__( self, spatial_dims: int, - num_channels: int, + num_channels: Sequence[int], in_channels: int, out_channels: int, - ch_mult: Sequence[int], num_res_blocks: int, norm_num_groups: int, norm_eps: float, - attention_levels: Optional[Sequence[bool]] = None, + attention_levels: Sequence[bool], with_nonlocal_attn: bool = True, ) -> None: super().__init__() - - if attention_levels is None: - attention_levels = (False,) * len(ch_mult) - self.spatial_dims = spatial_dims self.num_channels = num_channels self.in_channels = in_channels self.out_channels = out_channels - self.ch_mult = ch_mult self.num_res_blocks = num_res_blocks self.norm_num_groups = norm_num_groups self.norm_eps = norm_eps self.attention_levels = attention_levels - block_in_ch = num_channels * self.ch_mult[-1] + reversed_block_out_channels = list(reversed(num_channels)) blocks = [] # Initial convolution @@ -470,7 +471,7 @@ def __init__( Convolution( spatial_dims=spatial_dims, in_channels=in_channels, - out_channels=block_in_ch, + out_channels=reversed_block_out_channels[0], strides=1, kernel_size=3, padding=1, @@ -480,25 +481,53 @@ def __init__( # Non-local attention block if with_nonlocal_attn is True: - blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_in_ch)) + blocks.append( + ResBlock( + spatial_dims=spatial_dims, + in_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=reversed_block_out_channels[0], + ) + ) blocks.append( AttentionBlock( spatial_dims=spatial_dims, - num_channels=block_in_ch, + num_channels=reversed_block_out_channels[0], norm_num_groups=norm_num_groups, norm_eps=norm_eps, ) ) - blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_in_ch)) + blocks.append( + ResBlock( + spatial_dims=spatial_dims, + in_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=reversed_block_out_channels[0], + ) + ) - for i in reversed(range(len(ch_mult))): - block_out_ch = num_channels * self.ch_mult[i] + reversed_attention_levels = list(reversed(attention_levels)) + block_out_ch = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + block_in_ch = block_out_ch + block_out_ch = reversed_block_out_channels[i] + is_final_block = i == len(num_channels) - 1 for _ in range(self.num_res_blocks): - blocks.append(ResBlock(spatial_dims, block_in_ch, norm_num_groups, norm_eps, block_out_ch)) + blocks.append( + ResBlock( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=block_out_ch, + ) + ) block_in_ch = block_out_ch - if attention_levels[i]: + if reversed_attention_levels[i]: blocks.append( AttentionBlock( spatial_dims=spatial_dims, @@ -508,8 +537,8 @@ def __init__( ) ) - if i != 0: - blocks.append(Upsample(spatial_dims, block_in_ch)) + if not is_final_block: + blocks.append(Upsample(spatial_dims=spatial_dims, in_channels=block_in_ch)) blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True)) blocks.append( @@ -542,14 +571,12 @@ class AutoencoderKL(nn.Module): spatial_dims: number of spatial dimensions (1D, 2D, 3D). in_channels: number of input channels. out_channels: number of output channels. - num_channels: number of filters in the first downsampling / last upsampling. - latent_channels: latent embedding dimension. - ch_mult: multiplier of the number of channels in each downsampling layer (+ initial one). i.e.: If you want 3 - downsamplings, it should be a 4-element list. num_res_blocks: number of residual blocks (see ResBlock) per level. + num_channels: sequence of block output channels. + attention_levels: sequence of levels to add attention. + latent_channels: latent embedding dimension. norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. norm_eps: epsilon for the normalization. - attention_levels: indicate which level from ch_mult contain an attention block. with_encoder_nonlocal_attn: if True use non-local attention block in the encoder. with_decoder_nonlocal_attn: if True use non-local attention block in the decoder. """ @@ -557,35 +584,31 @@ class AutoencoderKL(nn.Module): def __init__( self, spatial_dims: int, - in_channels: int, - out_channels: int, - num_channels: int, - latent_channels: int, - ch_mult: Sequence[int], - num_res_blocks: int, + in_channels: int = 1, + out_channels: int = 1, + num_res_blocks: int = 2, + num_channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + latent_channels: int = 3, norm_num_groups: int = 32, norm_eps: float = 1e-6, - attention_levels: Optional[Sequence[bool]] = None, with_encoder_nonlocal_attn: bool = True, with_decoder_nonlocal_attn: bool = True, ) -> None: super().__init__() - if attention_levels is None: - attention_levels = (False,) * len(ch_mult) - # The number of channels should be multiple of num_groups - if (num_channels % norm_num_groups) != 0: - raise ValueError("AutoencoderKL expects number of channels being multiple of number of groups") + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): + raise ValueError("AutoencoderKL expects all num_channels being multiple of norm_num_groups") - if len(ch_mult) != len(attention_levels): - raise ValueError("AutoencoderKL expects ch_mult being same size of attention_levels") + if len(num_channels) != len(attention_levels): + raise ValueError("AutoencoderKL expects num_channels being same size of attention_levels") self.encoder = Encoder( spatial_dims=spatial_dims, in_channels=in_channels, num_channels=num_channels, out_channels=latent_channels, - ch_mult=ch_mult, num_res_blocks=num_res_blocks, norm_num_groups=norm_num_groups, norm_eps=norm_eps, @@ -597,7 +620,6 @@ def __init__( num_channels=num_channels, in_channels=latent_channels, out_channels=out_channels, - ch_mult=ch_mult, num_res_blocks=num_res_blocks, norm_num_groups=norm_num_groups, norm_eps=norm_eps, diff --git a/tests/test_autoencoderkl.py b/tests/test_autoencoderkl.py index 11b1ec68..bb6af8f8 100644 --- a/tests/test_autoencoderkl.py +++ b/tests/test_autoencoderkl.py @@ -26,10 +26,9 @@ "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_channels": 4, + "num_channels": (4, 4, 4), "latent_channels": 4, - "ch_mult": [1, 1, 1], - "attention_levels": None, + "attention_levels": (False, False, False), "num_res_blocks": 1, "norm_num_groups": 4, }, @@ -42,9 +41,8 @@ "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_channels": 4, + "num_channels": (4, 4, 4), "latent_channels": 4, - "ch_mult": [1, 1, 1], "attention_levels": (False, False, False), "num_res_blocks": 1, "norm_num_groups": 4, @@ -58,9 +56,8 @@ "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_channels": 4, + "num_channels": (4, 4, 4), "latent_channels": 4, - "ch_mult": [1, 1, 1], "attention_levels": (False, False, True), "num_res_blocks": 1, "norm_num_groups": 4, @@ -74,9 +71,8 @@ "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_channels": 4, + "num_channels": (4, 4, 4), "latent_channels": 4, - "ch_mult": [1, 1, 1], "attention_levels": (False, False, False), "num_res_blocks": 1, "norm_num_groups": 4, @@ -91,9 +87,8 @@ "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_channels": 4, + "num_channels": (4, 4, 4), "latent_channels": 4, - "ch_mult": [1, 1, 1], "attention_levels": (False, False, False), "num_res_blocks": 1, "norm_num_groups": 4, @@ -109,9 +104,8 @@ "spatial_dims": 3, "in_channels": 1, "out_channels": 1, - "num_channels": 4, + "num_channels": (4, 4, 4), "latent_channels": 4, - "ch_mult": [1, 1, 1], "attention_levels": (False, False, True), "num_res_blocks": 1, "norm_num_groups": 4, @@ -145,25 +139,24 @@ def test_model_channels_not_multiple_of_norm_num_group(self): spatial_dims=2, in_channels=1, out_channels=1, - num_channels=24, + num_channels=(24, 24, 24), + attention_levels=(False, False, False), latent_channels=8, - ch_mult=[1, 1, 1], num_res_blocks=1, norm_num_groups=16, ) - def test_model_ch_mult_not_same_size_of_attention_levels(self): + def test_model_num_channels_not_same_size_of_attention_levels(self): with self.assertRaises(ValueError): AutoencoderKL( spatial_dims=2, in_channels=1, out_channels=1, - num_channels=24, + num_channels=(24, 24, 24), + attention_levels=(False, False), latent_channels=8, - ch_mult=[1, 1, 1], num_res_blocks=1, norm_num_groups=16, - attention_levels=(True,), ) def test_shape_reconstruction(self): diff --git a/tests/test_latent_diffusion_inferer.py b/tests/test_latent_diffusion_inferer.py index e258f2e9..58394754 100644 --- a/tests/test_latent_diffusion_inferer.py +++ b/tests/test_latent_diffusion_inferer.py @@ -25,9 +25,8 @@ "spatial_dims": 2, "in_channels": 1, "out_channels": 1, - "num_channels": 8, + "num_channels": (8, 8, 8), "latent_channels": 3, - "ch_mult": [1, 1, 1], "attention_levels": [False, False, False], "num_res_blocks": 1, "with_encoder_nonlocal_attn": False, diff --git a/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.ipynb b/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.ipynb index 5a0d0ed0..3a8c8019 100644 --- a/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.ipynb +++ b/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.ipynb @@ -598,9 +598,8 @@ " spatial_dims=2,\n", " in_channels=1,\n", " out_channels=1,\n", - " num_channels=128,\n", + " num_channels=(128, 256, 384),\n", " latent_channels=8,\n", - " ch_mult=(1, 2, 3),\n", " num_res_blocks=1,\n", " norm_num_groups=32,\n", " attention_levels=(False, False, True),\n", diff --git a/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.py b/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.py index 9081de10..aab95a25 100644 --- a/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.py +++ b/tutorials/generative/2d_autoencoderkl/2d_autoencoderkl_tutorial.py @@ -121,9 +121,8 @@ spatial_dims=2, in_channels=1, out_channels=1, - num_channels=128, + num_channels=(128, 256, 384), latent_channels=8, - ch_mult=(1, 2, 3), num_res_blocks=1, norm_num_groups=32, attention_levels=(False, False, True), diff --git a/tutorials/generative/3d_autoencoderkl/3d_autoencoderkl_tutorial.ipynb b/tutorials/generative/3d_autoencoderkl/3d_autoencoderkl_tutorial.ipynb index 8e164b82..fd2eb249 100644 --- a/tutorials/generative/3d_autoencoderkl/3d_autoencoderkl_tutorial.ipynb +++ b/tutorials/generative/3d_autoencoderkl/3d_autoencoderkl_tutorial.ipynb @@ -501,9 +501,8 @@ " spatial_dims=3,\n", " in_channels=1,\n", " out_channels=1,\n", - " num_channels=32,\n", + " num_channels=(32, 64, 64),\n", " latent_channels=3,\n", - " ch_mult=(1, 2, 2),\n", " num_res_blocks=1,\n", " norm_num_groups=32,\n", " attention_levels=(False, False, True),\n", diff --git a/tutorials/generative/3d_autoencoderkl/3d_autoencoderkl_tutorial.py b/tutorials/generative/3d_autoencoderkl/3d_autoencoderkl_tutorial.py index 97979189..cfce9e80 100644 --- a/tutorials/generative/3d_autoencoderkl/3d_autoencoderkl_tutorial.py +++ b/tutorials/generative/3d_autoencoderkl/3d_autoencoderkl_tutorial.py @@ -175,9 +175,8 @@ spatial_dims=3, in_channels=1, out_channels=1, - num_channels=32, + num_channels=(32, 64, 64), latent_channels=3, - ch_mult=(1, 2, 2), num_res_blocks=1, norm_num_groups=32, attention_levels=(False, False, True),