diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index d29aa2d6..3812a7f9 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -472,7 +472,8 @@ def __init__( assert self.num_channels == self.out_channels self.op = Pool[Pool.AVG, spatial_dims](kernel_size=2, stride=2) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, emb: Optional[torch.Tensor] = None) -> torch.Tensor: + del emb assert x.shape[1] == self.num_channels return self.op(x) @@ -512,8 +513,11 @@ def __init__( padding=padding, conv_only=True, ) + else: + self.conv = None - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, emb: Optional[torch.Tensor] = None) -> torch.Tensor: + del emb assert x.shape[1] == self.num_channels x = F.interpolate(x, scale_factor=2.0, mode="nearest") if self.use_conv: @@ -645,6 +649,7 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-6, add_downsample: bool = True, + resblock_updown: bool = False, downsample_padding: int = 1, ) -> None: """ @@ -659,9 +664,12 @@ def __init__( norm_num_groups: number of groups for the group normalization. norm_eps: epsilon for the group normalization. add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. downsample_padding: padding used in the downsampling block. """ super().__init__() + self.resblock_updown = resblock_updown + resnets = [] for i in range(num_res_blocks): @@ -680,13 +688,24 @@ def __init__( self.resnets = nn.ModuleList(resnets) if add_downsample: - self.downsampler = Downsample( - spatial_dims=spatial_dims, - num_channels=out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - ) + if resblock_updown: + self.downsampler = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + down=True, + ) + else: + self.downsampler = Downsample( + spatial_dims=spatial_dims, + num_channels=out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + ) else: self.downsampler = None @@ -701,7 +720,7 @@ def forward( output_states.append(hidden_states) if self.downsampler is not None: - hidden_states = self.downsampler(hidden_states) + hidden_states = self.downsampler(hidden_states, temb) output_states.append(hidden_states) return hidden_states, output_states @@ -718,6 +737,7 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-6, add_downsample: bool = True, + resblock_updown: bool = False, downsample_padding: int = 1, num_head_channels: int = 1, ) -> None: @@ -733,10 +753,13 @@ def __init__( norm_num_groups: number of groups for the group normalization. norm_eps: epsilon for the group normalization. add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. downsample_padding: padding used in the downsampling block. num_head_channels: number of channels in each attention head. """ super().__init__() + self.resblock_updown = resblock_updown + resnets = [] attentions = [] @@ -766,13 +789,24 @@ def __init__( self.resnets = nn.ModuleList(resnets) if add_downsample: - self.downsampler = Downsample( - spatial_dims=spatial_dims, - num_channels=out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - ) + if resblock_updown: + self.downsampler = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + down=True, + ) + else: + self.downsampler = Downsample( + spatial_dims=spatial_dims, + num_channels=out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + ) else: self.downsampler = None @@ -788,7 +822,7 @@ def forward( output_states.append(hidden_states) if self.downsampler is not None: - hidden_states = self.downsampler(hidden_states) + hidden_states = self.downsampler(hidden_states, temb) output_states.append(hidden_states) return hidden_states, output_states @@ -805,6 +839,7 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-6, add_downsample: bool = True, + resblock_updown: bool = False, downsample_padding: int = 1, num_head_channels: int = 1, transformer_num_layers: int = 1, @@ -822,12 +857,15 @@ def __init__( norm_num_groups: number of groups for the group normalization. norm_eps: epsilon for the group normalization. add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. downsample_padding: padding used in the downsampling block. num_head_channels: number of channels in each attention head. transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. """ super().__init__() + self.resblock_updown = resblock_updown + resnets = [] attentions = [] @@ -861,13 +899,24 @@ def __init__( self.resnets = nn.ModuleList(resnets) if add_downsample: - self.downsampler = Downsample( - spatial_dims=spatial_dims, - num_channels=out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - ) + if resblock_updown: + self.downsampler = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + down=True, + ) + else: + self.downsampler = Downsample( + spatial_dims=spatial_dims, + num_channels=out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + ) else: self.downsampler = None @@ -882,7 +931,7 @@ def forward( output_states.append(hidden_states) if self.downsampler is not None: - hidden_states = self.downsampler(hidden_states) + hidden_states = self.downsampler(hidden_states, temb) output_states.append(hidden_states) return hidden_states, output_states @@ -1025,6 +1074,7 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-6, add_upsample: bool = True, + resblock_updown: bool = False, ) -> None: """ Unet's up block containing resnet and upsamplers blocks. @@ -1039,8 +1089,10 @@ def __init__( norm_num_groups: number of groups for the group normalization. norm_eps: epsilon for the group normalization. add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. """ super().__init__() + self.resblock_updown = resblock_updown resnets = [] for i in range(num_res_blocks): @@ -1061,9 +1113,20 @@ def __init__( self.resnets = nn.ModuleList(resnets) if add_upsample: - self.upsampler = Upsample( - spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels - ) + if resblock_updown: + self.upsampler = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) else: self.upsampler = None @@ -1084,7 +1147,7 @@ def forward( hidden_states = resnet(hidden_states, temb) if self.upsampler is not None: - hidden_states = self.upsampler(hidden_states) + hidden_states = self.upsampler(hidden_states, temb) return hidden_states @@ -1101,6 +1164,7 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-6, add_upsample: bool = True, + resblock_updown: bool = False, num_head_channels: int = 1, ) -> None: """ @@ -1116,9 +1180,12 @@ def __init__( norm_num_groups: number of groups for the group normalization. norm_eps: epsilon for the group normalization. add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. num_head_channels: number of channels in each attention head. """ super().__init__() + self.resblock_updown = resblock_updown + resnets = [] attentions = [] @@ -1150,9 +1217,20 @@ def __init__( self.attentions = nn.ModuleList(attentions) if add_upsample: - self.upsampler = Upsample( - spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels - ) + if resblock_updown: + self.upsampler = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) else: self.upsampler = None @@ -1174,7 +1252,7 @@ def forward( hidden_states = attn(hidden_states) if self.upsampler is not None: - hidden_states = self.upsampler(hidden_states) + hidden_states = self.upsampler(hidden_states, temb) return hidden_states @@ -1191,6 +1269,7 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-6, add_upsample: bool = True, + resblock_updown: bool = False, num_head_channels: int = 1, transformer_num_layers: int = 1, cross_attention_dim: Optional[int] = None, @@ -1208,11 +1287,14 @@ def __init__( norm_num_groups: number of groups for the group normalization. norm_eps: epsilon for the group normalization. add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. num_head_channels: number of channels in each attention head. transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. """ super().__init__() + self.resblock_updown = resblock_updown + resnets = [] attentions = [] @@ -1247,9 +1329,20 @@ def __init__( self.resnets = nn.ModuleList(resnets) if add_upsample: - self.upsampler = Upsample( - spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels - ) + if resblock_updown: + self.upsampler = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) else: self.upsampler = None @@ -1270,7 +1363,7 @@ def forward( hidden_states = attn(hidden_states, context=context) if self.upsampler is not None: - hidden_states = self.upsampler(hidden_states) + hidden_states = self.upsampler(hidden_states, temb) return hidden_states @@ -1284,6 +1377,7 @@ def get_down_block( norm_num_groups: int, norm_eps: float, add_downsample: bool, + resblock_updown: bool, with_attn: bool, with_cross_attn: bool, num_head_channels: int, @@ -1300,6 +1394,7 @@ def get_down_block( norm_num_groups=norm_num_groups, norm_eps=norm_eps, add_downsample=add_downsample, + resblock_updown=resblock_updown, num_head_channels=num_head_channels, ) elif with_cross_attn: @@ -1312,6 +1407,7 @@ def get_down_block( norm_num_groups=norm_num_groups, norm_eps=norm_eps, add_downsample=add_downsample, + resblock_updown=resblock_updown, num_head_channels=num_head_channels, transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, @@ -1326,6 +1422,7 @@ def get_down_block( norm_num_groups=norm_num_groups, norm_eps=norm_eps, add_downsample=add_downsample, + resblock_updown=resblock_updown, ) @@ -1372,6 +1469,7 @@ def get_up_block( norm_num_groups: int, norm_eps: float, add_upsample: bool, + resblock_updown: bool, with_attn: bool, with_cross_attn: bool, num_head_channels: int, @@ -1389,6 +1487,7 @@ def get_up_block( norm_num_groups=norm_num_groups, norm_eps=norm_eps, add_upsample=add_upsample, + resblock_updown=resblock_updown, num_head_channels=num_head_channels, ) elif with_cross_attn: @@ -1402,6 +1501,7 @@ def get_up_block( norm_num_groups=norm_num_groups, norm_eps=norm_eps, add_upsample=add_upsample, + resblock_updown=resblock_updown, num_head_channels=num_head_channels, transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, @@ -1417,6 +1517,7 @@ def get_up_block( norm_num_groups=norm_num_groups, norm_eps=norm_eps, add_upsample=add_upsample, + resblock_updown=resblock_updown, ) @@ -1435,6 +1536,7 @@ class DiffusionModelUNet(nn.Module): attention_levels: list of levels to add attention. norm_num_groups: number of groups for the normalization. norm_eps: epsilon for the normalization. + resblock_updown: if True use residual blocks for up/downsampling. num_head_channels: number of channels in each attention head. with_conditioning: if True add spatial transformers to perform conditioning. transformer_num_layers: number of layers of Transformer blocks to use. @@ -1453,6 +1555,7 @@ def __init__( attention_levels: Sequence[bool] = (False, False, True, True), norm_num_groups: int = 32, norm_eps: float = 1e-6, + resblock_updown: bool = False, num_head_channels: Union[int, Sequence[int]] = 8, with_conditioning: bool = False, transformer_num_layers: int = 1, @@ -1534,6 +1637,7 @@ def __init__( norm_num_groups=norm_num_groups, norm_eps=norm_eps, add_downsample=not is_final_block, + resblock_updown=resblock_updown, with_attn=(attention_levels[i] and not with_conditioning), with_cross_attn=(attention_levels[i] and with_conditioning), num_head_channels=num_head_channels[i], @@ -1579,6 +1683,7 @@ def __init__( norm_num_groups=norm_num_groups, norm_eps=norm_eps, add_upsample=not is_final_block, + resblock_updown=resblock_updown, with_attn=(reversed_attention_levels[i] and not with_conditioning), with_cross_attn=(reversed_attention_levels[i] and with_conditioning), num_head_channels=reversed_num_head_channels[i], diff --git a/tests/test_diffusion_model_unet.py b/tests/test_diffusion_model_unet.py index 80382b38..255e2fc9 100644 --- a/tests/test_diffusion_model_unet.py +++ b/tests/test_diffusion_model_unet.py @@ -30,6 +30,30 @@ "norm_num_groups": 8, }, ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "resblock_updown": True, + }, + ], + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + }, + ], [ { "spatial_dims": 2, @@ -40,6 +64,7 @@ "attention_levels": (False, False, True), "num_head_channels": 8, "norm_num_groups": 8, + "resblock_updown": True, }, ], [ @@ -80,6 +105,30 @@ "norm_num_groups": 8, }, ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, False), + "norm_num_groups": 8, + "resblock_updown": True, + }, + ], + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "num_res_blocks": 1, + "num_channels": (8, 8, 8), + "attention_levels": (False, False, True), + "num_head_channels": 8, + "norm_num_groups": 8, + }, + ], [ { "spatial_dims": 3, @@ -90,6 +139,7 @@ "attention_levels": (False, False, True), "num_head_channels": 8, "norm_num_groups": 8, + "resblock_updown": True, }, ], [ @@ -290,6 +340,22 @@ def test_script_conditioned_2d_models(self): ) test_script_save(net, torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 1, 3))) + def test_script_conditioned_2d_models_with_resblock_updown(self): + net = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_res_blocks=1, + num_channels=(8, 8, 8), + attention_levels=(False, False, True), + norm_num_groups=8, + with_conditioning=True, + transformer_num_layers=1, + cross_attention_dim=3, + resblock_updown=True, + ) + test_script_save(net, torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 1, 3))) + class TestDiffusionModelUNet3D(unittest.TestCase): @parameterized.expand(UNCOND_CASES_3D)