From 497d304c88bd48d0bcd353675f89b6694775de1b Mon Sep 17 00:00:00 2001 From: Akash Pannu Date: Sun, 9 Oct 2022 20:27:03 +0000 Subject: [PATCH 1/5] pass norm_num_groups param and add tests --- src/diffusers/models/vae_flax.py | 58 ++++++++++++++++++++++-------- tests/test_modeling_common_flax.py | 38 ++++++++++++++++++++ tests/test_models_vae_flax.py | 33 +++++++++++++++++ 3 files changed, 115 insertions(+), 14 deletions(-) create mode 100644 tests/test_modeling_common_flax.py create mode 100644 tests/test_models_vae_flax.py diff --git a/src/diffusers/models/vae_flax.py b/src/diffusers/models/vae_flax.py index b3261b11cf6c..02cf6ad980e8 100644 --- a/src/diffusers/models/vae_flax.py +++ b/src/diffusers/models/vae_flax.py @@ -119,6 +119,8 @@ class FlaxResnetBlock2D(nn.Module): Output channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate + groups (:obj:`int`, *optional*, defaults to 32): + The number of groups to use for group norm. use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`): Whether to use `nin_shortcut`. This activates a new layer inside ResNet block dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): @@ -128,13 +130,14 @@ class FlaxResnetBlock2D(nn.Module): in_channels: int out_channels: int = None dropout: float = 0.0 + groups: int = 32 use_nin_shortcut: bool = None dtype: jnp.dtype = jnp.float32 def setup(self): out_channels = self.in_channels if self.out_channels is None else self.out_channels - self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-6) + self.norm1 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6) self.conv1 = nn.Conv( out_channels, kernel_size=(3, 3), @@ -143,7 +146,7 @@ def setup(self): dtype=self.dtype, ) - self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6) + self.norm2 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6) self.dropout_layer = nn.Dropout(self.dropout) self.conv2 = nn.Conv( out_channels, @@ -191,12 +194,15 @@ class FlaxAttentionBlock(nn.Module): Input channels num_head_channels (:obj:`int`, *optional*, defaults to `None`): Number of attention heads + num_groups (:obj:`int`, *optional*, defaults to 32): + The number of groups to use for group norm dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ channels: int num_head_channels: int = None + num_groups: int = 32 dtype: jnp.dtype = jnp.float32 def setup(self): @@ -204,7 +210,7 @@ def setup(self): dense = partial(nn.Dense, self.channels, dtype=self.dtype) - self.group_norm = nn.GroupNorm(num_groups=32, epsilon=1e-6) + self.group_norm = nn.GroupNorm(num_groups=self.num_groups, epsilon=1e-6) self.query, self.key, self.value = dense(), dense(), dense() self.proj_attn = dense() @@ -264,6 +270,8 @@ class FlaxDownEncoderBlock2D(nn.Module): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of Resnet layer block + resnet_groups (:obj:`int`, *optional*, defaults to 32): + The number of groups to use for the Resnet block group norm add_downsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add downsample layer dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): @@ -273,6 +281,7 @@ class FlaxDownEncoderBlock2D(nn.Module): out_channels: int dropout: float = 0.0 num_layers: int = 1 + resnet_groups: int = 32 add_downsample: bool = True dtype: jnp.dtype = jnp.float32 @@ -285,6 +294,7 @@ def setup(self): in_channels=in_channels, out_channels=self.out_channels, dropout=self.dropout, + groups=self.resnet_groups, dtype=self.dtype, ) resnets.append(res_block) @@ -303,9 +313,9 @@ def __call__(self, hidden_states, deterministic=True): return hidden_states -class FlaxUpEncoderBlock2D(nn.Module): +class FlaxUpDecoderBlock2D(nn.Module): r""" - Flax Resnet blocks-based Encoder block for diffusion-based VAE. + Flax Resnet blocks-based Decoder block for diffusion-based VAE. Parameters: in_channels (:obj:`int`): @@ -316,8 +326,10 @@ class FlaxUpEncoderBlock2D(nn.Module): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of Resnet layer block - add_downsample (:obj:`bool`, *optional*, defaults to `True`): - Whether to add downsample layer + resnet_groups (:obj:`int`, *optional*, defaults to 32): + The number of groups to use for the Resnet block group norm + add_upsample (:obj:`bool`, *optional*, defaults to `True`): + Whether to add upsample layer dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ @@ -325,6 +337,7 @@ class FlaxUpEncoderBlock2D(nn.Module): out_channels: int dropout: float = 0.0 num_layers: int = 1 + resnet_groups: int = 32 add_upsample: bool = True dtype: jnp.dtype = jnp.float32 @@ -336,6 +349,7 @@ def setup(self): in_channels=in_channels, out_channels=self.out_channels, dropout=self.dropout, + groups=self.resnet_groups, dtype=self.dtype, ) resnets.append(res_block) @@ -366,6 +380,8 @@ class FlaxUNetMidBlock2D(nn.Module): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of Resnet layer block + resnet_groups (:obj:`int`, *optional*, defaults to 32): + The number of groups to use for the Resnet and Attention block group norm attn_num_head_channels (:obj:`int`, *optional*, defaults to `1`): Number of attention heads for each attention block dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): @@ -374,6 +390,7 @@ class FlaxUNetMidBlock2D(nn.Module): in_channels: int dropout: float = 0.0 num_layers: int = 1 + resnet_groups: int = 32 attn_num_head_channels: int = 1 dtype: jnp.dtype = jnp.float32 @@ -384,6 +401,7 @@ def setup(self): in_channels=self.in_channels, out_channels=self.in_channels, dropout=self.dropout, + groups=self.resnet_groups, dtype=self.dtype, ) ] @@ -392,7 +410,10 @@ def setup(self): for _ in range(self.num_layers): attn_block = FlaxAttentionBlock( - channels=self.in_channels, num_head_channels=self.attn_num_head_channels, dtype=self.dtype + channels=self.in_channels, + num_head_channels=self.attn_num_head_channels, + num_groups=self.resnet_groups, + dtype=self.dtype, ) attentions.append(attn_block) @@ -400,6 +421,7 @@ def setup(self): in_channels=self.in_channels, out_channels=self.in_channels, dropout=self.dropout, + groups=self.resnet_groups, dtype=self.dtype, ) resnets.append(res_block) @@ -441,7 +463,7 @@ class FlaxEncoder(nn.Module): Tuple containing the number of output channels for each block layers_per_block (:obj:`int`, *optional*, defaults to `2`): Number of Resnet layer for each block - norm_num_groups (:obj:`int`, *optional*, defaults to `2`): + norm_num_groups (:obj:`int`, *optional*, defaults to `32`): norm num group act_fn (:obj:`str`, *optional*, defaults to `silu`): Activation function @@ -483,6 +505,7 @@ def setup(self): in_channels=input_channel, out_channels=output_channel, num_layers=self.layers_per_block, + resnet_groups=self.norm_num_groups, add_downsample=not is_final_block, dtype=self.dtype, ) @@ -491,12 +514,15 @@ def setup(self): # middle self.mid_block = FlaxUNetMidBlock2D( - in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype + in_channels=block_out_channels[-1], + resnet_groups=self.norm_num_groups, + attn_num_head_channels=None, + dtype=self.dtype, ) # end conv_out_channels = 2 * self.out_channels if self.double_z else self.out_channels - self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6) + self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6) self.conv_out = nn.Conv( conv_out_channels, kernel_size=(3, 3), @@ -581,7 +607,10 @@ def setup(self): # middle self.mid_block = FlaxUNetMidBlock2D( - in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype + in_channels=block_out_channels[-1], + resnet_groups=self.norm_num_groups, + attn_num_head_channels=None, + dtype=self.dtype, ) # upsampling @@ -594,10 +623,11 @@ def setup(self): is_final_block = i == len(block_out_channels) - 1 - up_block = FlaxUpEncoderBlock2D( + up_block = FlaxUpDecoderBlock2D( in_channels=prev_output_channel, out_channels=output_channel, num_layers=self.layers_per_block + 1, + resnet_groups=self.norm_num_groups, add_upsample=not is_final_block, dtype=self.dtype, ) @@ -607,7 +637,7 @@ def setup(self): self.up_blocks = up_blocks # end - self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6) + self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6) self.conv_out = nn.Conv( self.out_channels, kernel_size=(3, 3), diff --git a/tests/test_modeling_common_flax.py b/tests/test_modeling_common_flax.py new file mode 100644 index 000000000000..3873da135e27 --- /dev/null +++ b/tests/test_modeling_common_flax.py @@ -0,0 +1,38 @@ +import jax + + +class FlaxModelTesterMixin: + def test_output(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"]) + jax.lax.stop_gradient(variables) + + output = model.apply(variables, inputs_dict["sample"]) + + if isinstance(output, dict): + output = output.sample + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + def test_forward_with_norm_groups(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["norm_num_groups"] = 16 + init_dict["block_out_channels"] = (16, 32) + + model = self.model_class(**init_dict) + variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"]) + jax.lax.stop_gradient(variables) + + output = model.apply(variables, inputs_dict["sample"]) + + if isinstance(output, dict): + output = output.sample + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") diff --git a/tests/test_models_vae_flax.py b/tests/test_models_vae_flax.py new file mode 100644 index 000000000000..c6f36b4c04d2 --- /dev/null +++ b/tests/test_models_vae_flax.py @@ -0,0 +1,33 @@ +import unittest + +import jax +from diffusers import FlaxAutoencoderKL + +from .test_modeling_common_flax import FlaxModelTesterMixin + + +class FlaxAutoencoderKLTests(FlaxModelTesterMixin, unittest.TestCase): + model_class = FlaxAutoencoderKL + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 3 + sizes = (32, 32) + + prng_key = jax.random.PRNGKey(0) + image = jax.random.uniform(prng_key, ((batch_size, num_channels) + sizes)) + + return {"sample": image, "prng_key": prng_key} + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "block_out_channels": [32, 64], + "in_channels": 3, + "out_channels": 3, + "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], + "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], + "latent_channels": 4, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict From 673f15fe18e9cf2bcfae3065886d0b75f36b1532 Mon Sep 17 00:00:00 2001 From: Akash Pannu Date: Sun, 9 Oct 2022 20:59:10 +0000 Subject: [PATCH 2/5] set resnet_groups for FlaxUNetMidBlock2D --- src/diffusers/models/vae_flax.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/vae_flax.py b/src/diffusers/models/vae_flax.py index 02cf6ad980e8..4c023a48df4a 100644 --- a/src/diffusers/models/vae_flax.py +++ b/src/diffusers/models/vae_flax.py @@ -395,13 +395,15 @@ class FlaxUNetMidBlock2D(nn.Module): dtype: jnp.dtype = jnp.float32 def setup(self): + resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(in_channels // 4, 32) + # there is always at least one resnet resnets = [ FlaxResnetBlock2D( in_channels=self.in_channels, out_channels=self.in_channels, dropout=self.dropout, - groups=self.resnet_groups, + groups=resnet_groups, dtype=self.dtype, ) ] @@ -412,7 +414,7 @@ def setup(self): attn_block = FlaxAttentionBlock( channels=self.in_channels, num_head_channels=self.attn_num_head_channels, - num_groups=self.resnet_groups, + num_groups=resnet_groups, dtype=self.dtype, ) attentions.append(attn_block) @@ -421,7 +423,7 @@ def setup(self): in_channels=self.in_channels, out_channels=self.in_channels, dropout=self.dropout, - groups=self.resnet_groups, + groups=resnet_groups, dtype=self.dtype, ) resnets.append(res_block) From 6a765339a7ecae9902288103bb4d08f654ff095e Mon Sep 17 00:00:00 2001 From: Akash Pannu Date: Sun, 9 Oct 2022 21:11:24 +0000 Subject: [PATCH 3/5] fixed docstrings --- src/diffusers/models/vae_flax.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/vae_flax.py b/src/diffusers/models/vae_flax.py index 4c023a48df4a..ce1f0a9520f9 100644 --- a/src/diffusers/models/vae_flax.py +++ b/src/diffusers/models/vae_flax.py @@ -119,7 +119,7 @@ class FlaxResnetBlock2D(nn.Module): Output channels dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate - groups (:obj:`int`, *optional*, defaults to 32): + groups (:obj:`int`, *optional*, defaults to `32`): The number of groups to use for group norm. use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`): Whether to use `nin_shortcut`. This activates a new layer inside ResNet block @@ -194,7 +194,7 @@ class FlaxAttentionBlock(nn.Module): Input channels num_head_channels (:obj:`int`, *optional*, defaults to `None`): Number of attention heads - num_groups (:obj:`int`, *optional*, defaults to 32): + num_groups (:obj:`int`, *optional*, defaults to `32`): The number of groups to use for group norm dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` @@ -270,7 +270,7 @@ class FlaxDownEncoderBlock2D(nn.Module): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of Resnet layer block - resnet_groups (:obj:`int`, *optional*, defaults to 32): + resnet_groups (:obj:`int`, *optional*, defaults to `32`): The number of groups to use for the Resnet block group norm add_downsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add downsample layer @@ -326,7 +326,7 @@ class FlaxUpDecoderBlock2D(nn.Module): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of Resnet layer block - resnet_groups (:obj:`int`, *optional*, defaults to 32): + resnet_groups (:obj:`int`, *optional*, defaults to `32`): The number of groups to use for the Resnet block group norm add_upsample (:obj:`bool`, *optional*, defaults to `True`): Whether to add upsample layer @@ -380,7 +380,7 @@ class FlaxUNetMidBlock2D(nn.Module): Dropout rate num_layers (:obj:`int`, *optional*, defaults to 1): Number of Resnet layer block - resnet_groups (:obj:`int`, *optional*, defaults to 32): + resnet_groups (:obj:`int`, *optional*, defaults to `32`): The number of groups to use for the Resnet and Attention block group norm attn_num_head_channels (:obj:`int`, *optional*, defaults to `1`): Number of attention heads for each attention block @@ -396,7 +396,7 @@ class FlaxUNetMidBlock2D(nn.Module): def setup(self): resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(in_channels // 4, 32) - + # there is always at least one resnet resnets = [ FlaxResnetBlock2D( From dbacb475362aba0f37b06d06038e6c317810f1e1 Mon Sep 17 00:00:00 2001 From: Akash Pannu Date: Sun, 9 Oct 2022 22:58:53 +0000 Subject: [PATCH 4/5] fixed typo --- src/diffusers/models/vae_flax.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/vae_flax.py b/src/diffusers/models/vae_flax.py index ce1f0a9520f9..074133a05c4a 100644 --- a/src/diffusers/models/vae_flax.py +++ b/src/diffusers/models/vae_flax.py @@ -395,8 +395,8 @@ class FlaxUNetMidBlock2D(nn.Module): dtype: jnp.dtype = jnp.float32 def setup(self): - resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(in_channels // 4, 32) - + resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(self.in_channels // 4, 32) + # there is always at least one resnet resnets = [ FlaxResnetBlock2D( From 2c6481038d65c19602f19dea202af8a87ef920ff Mon Sep 17 00:00:00 2001 From: Akash Pannu Date: Mon, 10 Oct 2022 05:06:28 +0000 Subject: [PATCH 5/5] using is_flax_available util and created require_flax decorator --- src/diffusers/utils/testing_utils.py | 9 +++++++++ tests/test_modeling_common_flax.py | 8 +++++++- tests/test_models_vae_flax.py | 8 +++++++- 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 0177d30abac9..f44b9cd394c9 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -14,6 +14,8 @@ import requests from packaging import version +from .import_utils import is_flax_available + global_rng = random.Random() torch_device = "cuda" if torch.cuda.is_available() else "cpu" @@ -89,6 +91,13 @@ def slow(test_case): return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) +def require_flax(test_case): + """ + Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed + """ + return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case) + + def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: """ Args: diff --git a/tests/test_modeling_common_flax.py b/tests/test_modeling_common_flax.py index 3873da135e27..61849b22318f 100644 --- a/tests/test_modeling_common_flax.py +++ b/tests/test_modeling_common_flax.py @@ -1,6 +1,12 @@ -import jax +from diffusers.utils import is_flax_available +from diffusers.utils.testing_utils import require_flax +if is_flax_available(): + import jax + + +@require_flax class FlaxModelTesterMixin: def test_output(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() diff --git a/tests/test_models_vae_flax.py b/tests/test_models_vae_flax.py index c6f36b4c04d2..e5c56b61a5a4 100644 --- a/tests/test_models_vae_flax.py +++ b/tests/test_models_vae_flax.py @@ -1,11 +1,17 @@ import unittest -import jax from diffusers import FlaxAutoencoderKL +from diffusers.utils import is_flax_available +from diffusers.utils.testing_utils import require_flax from .test_modeling_common_flax import FlaxModelTesterMixin +if is_flax_available(): + import jax + + +@require_flax class FlaxAutoencoderKLTests(FlaxModelTesterMixin, unittest.TestCase): model_class = FlaxAutoencoderKL