From 76bb311c9d34be0ac6cf2ebc10e322e9bd3636c1 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 25 Nov 2022 17:00:36 +0100 Subject: [PATCH 01/15] Flax: start adapting to Stable Diffusion 2 --- src/diffusers/models/attention_flax.py | 78 ++++++++++++++++++-------- 1 file changed, 54 insertions(+), 24 deletions(-) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 1b8609474750..40a29f91c43c 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -15,7 +15,6 @@ import flax.linen as nn import jax.numpy as jnp - class FlaxAttentionBlock(nn.Module): r""" A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762 @@ -104,6 +103,10 @@ class FlaxBasicTransformerBlock(nn.Module): Hidden states dimension inside each head dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate + cross_attention_dim (`int`, *optional*): + The size of the context vector for cross attention. + only_cross_attention (`bool`, defaults to `False`): + Whether to only apply cross attention. dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ @@ -111,10 +114,11 @@ class FlaxBasicTransformerBlock(nn.Module): n_heads: int d_head: int dropout: float = 0.0 + only_cross_attention: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): - # self attention + # self attention (or cross_attention if only_cross_attention is True) self.attn1 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) # cross attention self.attn2 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) @@ -126,7 +130,10 @@ def setup(self): def __call__(self, hidden_states, context, deterministic=True): # self attention residual = hidden_states - hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic) + if self.only_cross_attention: + hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic) + else: + hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic) hidden_states = hidden_states + residual # cross attention @@ -159,6 +166,8 @@ class FlaxTransformer2DModel(nn.Module): Number of transformers block dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate + use_linear_projection (`bool`, defaults to `False`): tbd + only_cross_attention (`bool`, defaults to `False`): tbd dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ @@ -167,49 +176,70 @@ class FlaxTransformer2DModel(nn.Module): d_head: int depth: int = 1 dropout: float = 0.0 + use_linear_projection: bool = False + only_cross_attention: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5) inner_dim = self.n_heads * self.d_head - self.proj_in = nn.Conv( - inner_dim, - kernel_size=(1, 1), - strides=(1, 1), - padding="VALID", - dtype=self.dtype, - ) + if self.use_linear_projection: + self.proj_in = nn.Dense(inner_dim, dtype=self.dtype) + else: + self.proj_in = nn.Conv( + inner_dim, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) self.transformer_blocks = [ - FlaxBasicTransformerBlock(inner_dim, self.n_heads, self.d_head, dropout=self.dropout, dtype=self.dtype) + FlaxBasicTransformerBlock( + inner_dim, + self.n_heads, + self.d_head, + dropout=self.dropout, + only_cross_attention=self.only_cross_attention, + dtype=self.dtype, + ) for _ in range(self.depth) ] - self.proj_out = nn.Conv( - inner_dim, - kernel_size=(1, 1), - strides=(1, 1), - padding="VALID", - dtype=self.dtype, - ) + if self.use_linear_projection: + self.proj_out = nn.Dense(inner_dim, dtype=self.dtype) + else: + self.proj_out = nn.Conv( + inner_dim, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) def __call__(self, hidden_states, context, deterministic=True): batch, height, width, channels = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) - hidden_states = self.proj_in(hidden_states) - - hidden_states = hidden_states.reshape(batch, height * width, channels) + if self.use_linear_projection: + hidden_states = hidden_states.reshape(batch, height * width, channels) + hidden_states = self.proj_in(hidden_states) + else: + hidden_states = self.proj_in(hidden_states) + hidden_states = hidden_states.reshape(batch, height * width, channels) for transformer_block in self.transformer_blocks: hidden_states = transformer_block(hidden_states, context, deterministic=deterministic) - hidden_states = hidden_states.reshape(batch, height, width, channels) + if self.use_linear_projection: + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, width, channels) + else: + hidden_states = hidden_states.reshape(batch, height, width, channels) + hidden_states = self.proj_out(hidden_states) - hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states + residual - return hidden_states From a87c3d98c51e87b15e003f1a48d56260dae26970 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 25 Nov 2022 17:18:12 +0100 Subject: [PATCH 02/15] More changes. --- src/diffusers/models/unet_2d_blocks_flax.py | 10 ++++++++++ src/diffusers/models/unet_2d_condition_flax.py | 12 ++++++++++++ 2 files changed, 22 insertions(+) diff --git a/src/diffusers/models/unet_2d_blocks_flax.py b/src/diffusers/models/unet_2d_blocks_flax.py index 5798385b9d28..96e76cb06a59 100644 --- a/src/diffusers/models/unet_2d_blocks_flax.py +++ b/src/diffusers/models/unet_2d_blocks_flax.py @@ -46,6 +46,8 @@ class FlaxCrossAttnDownBlock2D(nn.Module): num_layers: int = 1 attn_num_head_channels: int = 1 add_downsample: bool = True + use_linear_projection: bool = False + only_cross_attention: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): @@ -68,6 +70,8 @@ def setup(self): n_heads=self.attn_num_head_channels, d_head=self.out_channels // self.attn_num_head_channels, depth=1, + use_linear_projection=self.use_linear_projection, + only_cross_attention=self.only_cross_attention, dtype=self.dtype, ) attentions.append(attn_block) @@ -178,6 +182,8 @@ class FlaxCrossAttnUpBlock2D(nn.Module): num_layers: int = 1 attn_num_head_channels: int = 1 add_upsample: bool = True + use_linear_projection: bool = False + only_cross_attention: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): @@ -201,6 +207,8 @@ def setup(self): n_heads=self.attn_num_head_channels, d_head=self.out_channels // self.attn_num_head_channels, depth=1, + use_linear_projection=self.use_linear_projection, + only_cross_attention=self.only_cross_attention, dtype=self.dtype, ) attentions.append(attn_block) @@ -310,6 +318,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): dropout: float = 0.0 num_layers: int = 1 attn_num_head_channels: int = 1 + use_linear_projection: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): @@ -331,6 +340,7 @@ def setup(self): n_heads=self.attn_num_head_channels, d_head=self.in_channels // self.attn_num_head_channels, depth=1, + use_linear_projection=self.use_linear_projection, dtype=self.dtype, ) attentions.append(attn_block) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 7ca9c191b448..ed71231cc905 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -97,11 +97,13 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): "DownBlock2D", ) up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D") + only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280) layers_per_block: int = 2 attention_head_dim: int = 8 cross_attention_dim: int = 1280 dropout: float = 0.0 + use_linear_projection: bool = False dtype: jnp.dtype = jnp.float32 freq_shift: int = 0 @@ -134,6 +136,10 @@ def setup(self): self.time_proj = FlaxTimesteps(block_out_channels[0], freq_shift=self.config.freq_shift) self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) + only_cross_attention = self.only_cross_attention + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(self.down_block_types) + # down down_blocks = [] output_channel = block_out_channels[0] @@ -150,6 +156,8 @@ def setup(self): num_layers=self.layers_per_block, attn_num_head_channels=self.attention_head_dim, add_downsample=not is_final_block, + use_linear_projection=self.use_linear_projection, + only_cross_attention=only_cross_attention[i], dtype=self.dtype, ) else: @@ -170,12 +178,14 @@ def setup(self): in_channels=block_out_channels[-1], dropout=self.dropout, attn_num_head_channels=self.attention_head_dim, + use_linear_projection=self.use_linear_projection, dtype=self.dtype, ) # up up_blocks = [] reversed_block_out_channels = list(reversed(block_out_channels)) + only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(self.up_block_types): prev_output_channel = output_channel @@ -193,6 +203,8 @@ def setup(self): attn_num_head_channels=self.attention_head_dim, add_upsample=not is_final_block, dropout=self.dropout, + use_linear_projection=self.use_linear_projection, + only_cross_attention=only_cross_attention[i], dtype=self.dtype, ) else: From 992427704b08a06ee44b9747cad11e9a83539c9b Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Sat, 26 Nov 2022 01:20:11 +0100 Subject: [PATCH 03/15] attention_head_dim can be a tuple. --- src/diffusers/models/unet_2d_condition_flax.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index ed71231cc905..20b47259a9b0 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -79,7 +79,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. - attention_head_dim (`int`, *optional*, defaults to 8): + attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8): The dimension of the attention heads. cross_attention_dim (`int`, *optional*, defaults to 768): The dimension of the cross attention features. @@ -100,7 +100,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280) layers_per_block: int = 2 - attention_head_dim: int = 8 + attention_head_dim: Union[int, Tuple[int]] = 8, cross_attention_dim: int = 1280 dropout: float = 0.0 use_linear_projection: bool = False @@ -140,6 +140,10 @@ def setup(self): if isinstance(only_cross_attention, bool): only_cross_attention = [only_cross_attention] * len(self.down_block_types) + attention_head_dim = self.attention_head_dim + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(self.down_block_types) + # down down_blocks = [] output_channel = block_out_channels[0] @@ -154,7 +158,7 @@ def setup(self): out_channels=output_channel, dropout=self.dropout, num_layers=self.layers_per_block, - attn_num_head_channels=self.attention_head_dim, + attn_num_head_channels=self.attention_head_dim[i], add_downsample=not is_final_block, use_linear_projection=self.use_linear_projection, only_cross_attention=only_cross_attention[i], @@ -177,7 +181,7 @@ def setup(self): self.mid_block = FlaxUNetMidBlock2DCrossAttn( in_channels=block_out_channels[-1], dropout=self.dropout, - attn_num_head_channels=self.attention_head_dim, + attn_num_head_channels=self.attention_head_dim[-1], use_linear_projection=self.use_linear_projection, dtype=self.dtype, ) @@ -185,6 +189,7 @@ def setup(self): # up up_blocks = [] reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(self.up_block_types): @@ -200,7 +205,7 @@ def setup(self): out_channels=output_channel, prev_output_channel=prev_output_channel, num_layers=self.layers_per_block + 1, - attn_num_head_channels=self.attention_head_dim, + attn_num_head_channels=self.attention_head_dim[i], add_upsample=not is_final_block, dropout=self.dropout, use_linear_projection=self.use_linear_projection, From b684d8c86b924660443d31816690705baeba2194 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Sat, 26 Nov 2022 01:58:05 +0100 Subject: [PATCH 04/15] Fix typos --- src/diffusers/models/unet_2d_condition_flax.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 20b47259a9b0..4339837aff3f 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -97,7 +97,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): "DownBlock2D", ) up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D") - only_cross_attention: Union[bool, Tuple[bool]] = False, + only_cross_attention: Union[bool, Tuple[bool]] = False block_out_channels: Tuple[int] = (320, 640, 1280, 1280) layers_per_block: int = 2 attention_head_dim: Union[int, Tuple[int]] = 8, @@ -138,7 +138,7 @@ def setup(self): only_cross_attention = self.only_cross_attention if isinstance(only_cross_attention, bool): - only_cross_attention = [only_cross_attention] * len(self.down_block_types) + only_cross_attention = (only_cross_attention,) * len(self.down_block_types) attention_head_dim = self.attention_head_dim if isinstance(attention_head_dim, int): @@ -158,7 +158,7 @@ def setup(self): out_channels=output_channel, dropout=self.dropout, num_layers=self.layers_per_block, - attn_num_head_channels=self.attention_head_dim[i], + attn_num_head_channels=attention_head_dim[i], add_downsample=not is_final_block, use_linear_projection=self.use_linear_projection, only_cross_attention=only_cross_attention[i], @@ -181,7 +181,7 @@ def setup(self): self.mid_block = FlaxUNetMidBlock2DCrossAttn( in_channels=block_out_channels[-1], dropout=self.dropout, - attn_num_head_channels=self.attention_head_dim[-1], + attn_num_head_channels=attention_head_dim[-1], use_linear_projection=self.use_linear_projection, dtype=self.dtype, ) @@ -205,7 +205,7 @@ def setup(self): out_channels=output_channel, prev_output_channel=prev_output_channel, num_layers=self.layers_per_block + 1, - attn_num_head_channels=self.attention_head_dim[i], + attn_num_head_channels=reversed_attention_head_dim[i], add_upsample=not is_final_block, dropout=self.dropout, use_linear_projection=self.use_linear_projection, From 7782f50b27fe99e4b7f0dfd648a88e5f8296feae Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 28 Nov 2022 17:28:37 +0100 Subject: [PATCH 05/15] Add simple SD 2 integration test. Slice values taken from my Ampere GPU. --- tests/models/test_models_unet_2d.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/models/test_models_unet_2d.py b/tests/models/test_models_unet_2d.py index 02c6d314bfff..8c1db0831a30 100644 --- a/tests/models/test_models_unet_2d.py +++ b/tests/models/test_models_unet_2d.py @@ -639,3 +639,29 @@ def test_compvis_sd_inpaint_fp16(self, seed, timestep, expected_slice): expected_output_slice = torch.tensor(expected_slice) assert torch_all_close(output_slice, expected_output_slice, atol=5e-3) + + @parameterized.expand( + [ + # fmt: off + [83, 4, [ 0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]], + [17, 0.55, [ 0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]], + [8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]], + [3, 1000, [ 0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]], + # fmt: on + ] + ) + @require_torch_gpu + def test_stabilityai_sd_v2_fp16(self, seed, timestep, expected_slice): + model = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True) + latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True) + encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True) + + with torch.no_grad(): + sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + + assert sample.shape == latents.shape + + output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() + expected_output_slice = torch.tensor(expected_slice) + + assert torch_all_close(output_slice, expected_output_slice, atol=5e-3) From 40f755049396d8daabc82a3d078b07dc57222abc Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 28 Nov 2022 17:29:40 +0100 Subject: [PATCH 06/15] Add simple UNet integration tests for Flax. Note that the expected values are taken from the PyTorch results. This ensures the Flax and PyTorch versions are not too far off. --- tests/models/test_models_unet_2d_flax.py | 102 +++++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 tests/models/test_models_unet_2d_flax.py diff --git a/tests/models/test_models_unet_2d_flax.py b/tests/models/test_models_unet_2d_flax.py new file mode 100644 index 000000000000..82c1777428e2 --- /dev/null +++ b/tests/models/test_models_unet_2d_flax.py @@ -0,0 +1,102 @@ +import gc +import jax +import jax.numpy as jnp +import unittest + +from diffusers import FlaxUNet2DConditionModel +from diffusers.utils.testing_utils import ( + load_hf_numpy, + require_flax, + slow, +) +from parameterized import parameterized + +@slow +@require_flax +class FlaxUNet2DConditionModelIntegrationTests(unittest.TestCase): + def get_file_format(self, seed, shape): + return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy" + + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + + def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False): + dtype = jnp.bfloat16 if fp16 else jnp.float32 + image = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype) + return image + + def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"): + dtype = jnp.bfloat16 if fp16 else jnp.float32 + revision = "bf16" if fp16 else None + + model, params = FlaxUNet2DConditionModel.from_pretrained( + model_id, subfolder="unet", dtype=dtype, revision=revision + ) + return model, params + + def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False): + dtype = jnp.bfloat16 if fp16 else jnp.float32 + hidden_states = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype) + return hidden_states + + @parameterized.expand( + [ + # fmt: off + [83, 4, [-0.2323, -0.1304, 0.0813, -0.3093, -0.0919, -0.1571, -0.1125, -0.5806]], + [17, 0.55, [-0.0831, -0.2443, 0.0901, -0.0919, 0.3396, 0.0103, -0.3743, 0.0701]], + [8, 0.89, [-0.4863, 0.0859, 0.0875, -0.1658, 0.9199, -0.0114, 0.4839, 0.4639]], + [3, 1000, [-0.5649, 0.2402, -0.5518, 0.1248, 1.1328, -0.2443, -0.0325, -1.0078]], + # fmt: on + ] + ) + def test_compvis_sd_v1_4_flax_vs_torch_fp16(self, seed, timestep, expected_slice): + model, params = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True) + latents = self.get_latents(seed, fp16=True) + encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) + + sample = model.apply( + {"params": params}, + latents, + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=encoder_hidden_states, + ).sample + + assert sample.shape == latents.shape + + output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32) + expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32) + + # Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, in the same hardware + assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2) + + @parameterized.expand( + [ + # fmt: off + [83, 4, [ 0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]], + [17, 0.55, [ 0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]], + [8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]], + [3, 1000, [ 0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]], + # fmt: on + ] + ) + def test_stabilityai_sd_v2_flax_vs_torch_fp16(self, seed, timestep, expected_slice): + model, params = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True) + latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True) + encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True) + + sample = model.apply( + {"params": params}, + latents, + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=encoder_hidden_states, + ).sample + + assert sample.shape == latents.shape + + output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32) + expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32) + + # Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, on the same hardware + assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2) From 0cf9dbdfed6408fba3c02c1cb7d68c203d03734a Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 28 Nov 2022 17:33:21 +0100 Subject: [PATCH 07/15] Apply suggestions from code review Co-authored-by: Patrick von Platen --- src/diffusers/models/attention_flax.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 40a29f91c43c..99467d1340ba 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -103,7 +103,6 @@ class FlaxBasicTransformerBlock(nn.Module): Hidden states dimension inside each head dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate - cross_attention_dim (`int`, *optional*): The size of the context vector for cross attention. only_cross_attention (`bool`, defaults to `False`): Whether to only apply cross attention. From 084d83a344756e311196a8799ed11c4196fdf72c Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 28 Nov 2022 18:25:46 +0100 Subject: [PATCH 08/15] Typos and style --- src/diffusers/models/attention_flax.py | 2 +- src/diffusers/models/unet_2d_condition_flax.py | 2 +- tests/models/test_models_unet_2d_flax.py | 11 ++++------- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 99467d1340ba..71106e05452c 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -15,6 +15,7 @@ import flax.linen as nn import jax.numpy as jnp + class FlaxAttentionBlock(nn.Module): r""" A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762 @@ -103,7 +104,6 @@ class FlaxBasicTransformerBlock(nn.Module): Hidden states dimension inside each head dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate - The size of the context vector for cross attention. only_cross_attention (`bool`, defaults to `False`): Whether to only apply cross attention. dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 4339837aff3f..8a33853700d6 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -100,7 +100,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): only_cross_attention: Union[bool, Tuple[bool]] = False block_out_channels: Tuple[int] = (320, 640, 1280, 1280) layers_per_block: int = 2 - attention_head_dim: Union[int, Tuple[int]] = 8, + attention_head_dim: Union[int, Tuple[int]] = 8 cross_attention_dim: int = 1280 dropout: float = 0.0 use_linear_projection: bool = False diff --git a/tests/models/test_models_unet_2d_flax.py b/tests/models/test_models_unet_2d_flax.py index 82c1777428e2..582e0cc38948 100644 --- a/tests/models/test_models_unet_2d_flax.py +++ b/tests/models/test_models_unet_2d_flax.py @@ -1,16 +1,13 @@ import gc -import jax -import jax.numpy as jnp import unittest +import jax +import jax.numpy as jnp from diffusers import FlaxUNet2DConditionModel -from diffusers.utils.testing_utils import ( - load_hf_numpy, - require_flax, - slow, -) +from diffusers.utils.testing_utils import load_hf_numpy, require_flax, slow from parameterized import parameterized + @slow @require_flax class FlaxUNet2DConditionModelIntegrationTests(unittest.TestCase): From ff84eccb3a615e9b05ecd82c4113d9789560256f Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 28 Nov 2022 18:29:43 +0100 Subject: [PATCH 09/15] Tests: verify jax is available. --- tests/models/test_models_unet_2d_flax.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/models/test_models_unet_2d_flax.py b/tests/models/test_models_unet_2d_flax.py index 582e0cc38948..f8a7687a0e5a 100644 --- a/tests/models/test_models_unet_2d_flax.py +++ b/tests/models/test_models_unet_2d_flax.py @@ -1,12 +1,15 @@ import gc import unittest -import jax -import jax.numpy as jnp from diffusers import FlaxUNet2DConditionModel +from diffusers.utils import is_flax_available from diffusers.utils.testing_utils import load_hf_numpy, require_flax, slow from parameterized import parameterized +if is_flax_available(): + import jax + import jax.numpy as jnp + @slow @require_flax From 93d833d431c88e6556685b43169d881b19856a5f Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 28 Nov 2022 18:47:41 +0100 Subject: [PATCH 10/15] Style --- src/diffusers/utils/testing_utils.py | 1 + tests/models/test_models_unet_2d_flax.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index bf398e5b6fe5..46707b351185 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -108,6 +108,7 @@ def slow(test_case): Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them. """ + return test_case return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) diff --git a/tests/models/test_models_unet_2d_flax.py b/tests/models/test_models_unet_2d_flax.py index f8a7687a0e5a..ec34c4a901f1 100644 --- a/tests/models/test_models_unet_2d_flax.py +++ b/tests/models/test_models_unet_2d_flax.py @@ -6,6 +6,7 @@ from diffusers.utils.testing_utils import load_hf_numpy, require_flax, slow from parameterized import parameterized + if is_flax_available(): import jax import jax.numpy as jnp From eee35b83ecbc87736961bd2cd521f9bd601fac22 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 28 Nov 2022 19:47:02 +0100 Subject: [PATCH 11/15] Make flake happy --- tests/models/test_models_unet_2d.py | 8 ++++---- tests/models/test_models_unet_2d_flax.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/models/test_models_unet_2d.py b/tests/models/test_models_unet_2d.py index 8c1db0831a30..59b9e02ff8b9 100644 --- a/tests/models/test_models_unet_2d.py +++ b/tests/models/test_models_unet_2d.py @@ -643,10 +643,10 @@ def test_compvis_sd_inpaint_fp16(self, seed, timestep, expected_slice): @parameterized.expand( [ # fmt: off - [83, 4, [ 0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]], - [17, 0.55, [ 0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]], - [8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]], - [3, 1000, [ 0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]], + [83, 4, [0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]], + [17, 0.55, [0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]], + [8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]], + [3, 1000, [0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]], # fmt: on ] ) diff --git a/tests/models/test_models_unet_2d_flax.py b/tests/models/test_models_unet_2d_flax.py index ec34c4a901f1..4b279d2f3386 100644 --- a/tests/models/test_models_unet_2d_flax.py +++ b/tests/models/test_models_unet_2d_flax.py @@ -75,10 +75,10 @@ def test_compvis_sd_v1_4_flax_vs_torch_fp16(self, seed, timestep, expected_slice @parameterized.expand( [ # fmt: off - [83, 4, [ 0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]], - [17, 0.55, [ 0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]], - [8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]], - [3, 1000, [ 0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]], + [83, 4, [0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]], + [17, 0.55, [0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]], + [8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]], + [3, 1000, [0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]], # fmt: on ] ) From 783d8cd2ea63993f35909a3e8e8567e21084c64f Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 28 Nov 2022 19:47:58 +0100 Subject: [PATCH 12/15] Remove typo. --- src/diffusers/utils/testing_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 46707b351185..bf398e5b6fe5 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -108,7 +108,6 @@ def slow(test_case): Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them. """ - return test_case return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) From a32af1f93e98dd43871d3a12ec5a06e19b45c81b Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 28 Nov 2022 20:31:39 +0000 Subject: [PATCH 13/15] Simple Flax SD 2 pipeline tests. --- .../test_stable_diffusion_flax.py | 100 ++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py new file mode 100644 index 000000000000..991fda4f6aa4 --- /dev/null +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py @@ -0,0 +1,100 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest +import numpy as np + +from diffusers import FlaxStableDiffusionPipeline, FlaxDPMSolverMultistepScheduler +from diffusers.utils import is_flax_available, slow +from diffusers.utils.testing_utils import require_flax + + +if is_flax_available(): + import jax + import jax.numpy as jnp + from flax.jax_utils import replicate + from flax.training.common_utils import shard + + +@slow +@require_flax +class FlaxStableDiffusion2PipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + + def test_stable_diffusion_flax(self): + sd_pipe, params = FlaxStableDiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-2", + revision="bf16", + dtype=jnp.bfloat16, + ) + + prompt = "A painting of a squirrel eating a burger" + num_samples = jax.device_count() + prompt = num_samples * [prompt] + prompt_ids = sd_pipe.prepare_inputs(prompt) + + params = replicate(params) + prompt_ids = shard(prompt_ids) + + prng_seed = jax.random.PRNGKey(0) + prng_seed = jax.random.split(prng_seed, jax.device_count()) + + images = sd_pipe(prompt_ids, params, prng_seed, num_inference_steps=25, jit=True)[0] + assert images.shape == (jax.device_count(), 1, 768, 768, 3) + + images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) + image_slice = images[0, 253:256, 253:256, -1] + + output_slice = jnp.asarray(jax.device_get(image_slice.flatten())) + expected_slice = jnp.array([0.4238, 0.4414, 0.4395, 0.4453, 0.4629, 0.4590, 0.4531, 0.45508, 0.4512]) + print(f"output_slice: {output_slice}") + assert jnp.abs(output_slice - expected_slice).max() < 1e-2 + + def test_stable_diffusion_dpm_flax(self): + model_id = "stabilityai/stable-diffusion-2" + scheduler, scheduler_params = FlaxDPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler") + sd_pipe, params = FlaxStableDiffusionPipeline.from_pretrained( + model_id, + scheduler=scheduler, + revision="bf16", + dtype=jnp.bfloat16, + ) + params["scheduler"] = scheduler_params + + prompt = "A painting of a squirrel eating a burger" + num_samples = jax.device_count() + prompt = num_samples * [prompt] + prompt_ids = sd_pipe.prepare_inputs(prompt) + + params = replicate(params) + prompt_ids = shard(prompt_ids) + + prng_seed = jax.random.PRNGKey(0) + prng_seed = jax.random.split(prng_seed, jax.device_count()) + + images = sd_pipe(prompt_ids, params, prng_seed, num_inference_steps=25, jit=True)[0] + assert images.shape == (jax.device_count(), 1, 768, 768, 3) + + images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) + image_slice = images[0, 253:256, 253:256, -1] + + output_slice = jnp.asarray(jax.device_get(image_slice.flatten())) + expected_slice = jnp.array([0.4336, 0.42969, 0.4453, 0.4199, 0.4297, 0.4531, 0.4434, 0.4434, 0.4297]) + print(f"output_slice: {output_slice}") + assert jnp.abs(output_slice - expected_slice).max() < 1e-2 From 383dacd3999fcd6f001251e7f46f131ec7fb88c4 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 28 Nov 2022 22:41:59 +0100 Subject: [PATCH 14/15] Import order --- .../pipelines/stable_diffusion_2/test_stable_diffusion_flax.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py index 991fda4f6aa4..5c751791739f 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py @@ -15,9 +15,10 @@ import gc import unittest + import numpy as np -from diffusers import FlaxStableDiffusionPipeline, FlaxDPMSolverMultistepScheduler +from diffusers import FlaxDPMSolverMultistepScheduler, FlaxStableDiffusionPipeline from diffusers.utils import is_flax_available, slow from diffusers.utils.testing_utils import require_flax From d99411bea2a5904b6a609b78c38a1f8ff13c90ca Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 28 Nov 2022 22:48:37 +0100 Subject: [PATCH 15/15] Remove unused import. --- .../pipelines/stable_diffusion_2/test_stable_diffusion_flax.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py index 5c751791739f..f10f0e179827 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py @@ -16,8 +16,6 @@ import gc import unittest -import numpy as np - from diffusers import FlaxDPMSolverMultistepScheduler, FlaxStableDiffusionPipeline from diffusers.utils import is_flax_available, slow from diffusers.utils.testing_utils import require_flax