diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 1b8609474750..71106e05452c 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -104,6 +104,8 @@ class FlaxBasicTransformerBlock(nn.Module): Hidden states dimension inside each head dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate + only_cross_attention (`bool`, defaults to `False`): + Whether to only apply cross attention. dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ @@ -111,10 +113,11 @@ class FlaxBasicTransformerBlock(nn.Module): n_heads: int d_head: int dropout: float = 0.0 + only_cross_attention: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): - # self attention + # self attention (or cross_attention if only_cross_attention is True) self.attn1 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) # cross attention self.attn2 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) @@ -126,7 +129,10 @@ def setup(self): def __call__(self, hidden_states, context, deterministic=True): # self attention residual = hidden_states - hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic) + if self.only_cross_attention: + hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic) + else: + hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic) hidden_states = hidden_states + residual # cross attention @@ -159,6 +165,8 @@ class FlaxTransformer2DModel(nn.Module): Number of transformers block dropout (:obj:`float`, *optional*, defaults to 0.0): Dropout rate + use_linear_projection (`bool`, defaults to `False`): tbd + only_cross_attention (`bool`, defaults to `False`): tbd dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): Parameters `dtype` """ @@ -167,49 +175,70 @@ class FlaxTransformer2DModel(nn.Module): d_head: int depth: int = 1 dropout: float = 0.0 + use_linear_projection: bool = False + only_cross_attention: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5) inner_dim = self.n_heads * self.d_head - self.proj_in = nn.Conv( - inner_dim, - kernel_size=(1, 1), - strides=(1, 1), - padding="VALID", - dtype=self.dtype, - ) + if self.use_linear_projection: + self.proj_in = nn.Dense(inner_dim, dtype=self.dtype) + else: + self.proj_in = nn.Conv( + inner_dim, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) self.transformer_blocks = [ - FlaxBasicTransformerBlock(inner_dim, self.n_heads, self.d_head, dropout=self.dropout, dtype=self.dtype) + FlaxBasicTransformerBlock( + inner_dim, + self.n_heads, + self.d_head, + dropout=self.dropout, + only_cross_attention=self.only_cross_attention, + dtype=self.dtype, + ) for _ in range(self.depth) ] - self.proj_out = nn.Conv( - inner_dim, - kernel_size=(1, 1), - strides=(1, 1), - padding="VALID", - dtype=self.dtype, - ) + if self.use_linear_projection: + self.proj_out = nn.Dense(inner_dim, dtype=self.dtype) + else: + self.proj_out = nn.Conv( + inner_dim, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) def __call__(self, hidden_states, context, deterministic=True): batch, height, width, channels = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) - hidden_states = self.proj_in(hidden_states) - - hidden_states = hidden_states.reshape(batch, height * width, channels) + if self.use_linear_projection: + hidden_states = hidden_states.reshape(batch, height * width, channels) + hidden_states = self.proj_in(hidden_states) + else: + hidden_states = self.proj_in(hidden_states) + hidden_states = hidden_states.reshape(batch, height * width, channels) for transformer_block in self.transformer_blocks: hidden_states = transformer_block(hidden_states, context, deterministic=deterministic) - hidden_states = hidden_states.reshape(batch, height, width, channels) + if self.use_linear_projection: + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, width, channels) + else: + hidden_states = hidden_states.reshape(batch, height, width, channels) + hidden_states = self.proj_out(hidden_states) - hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states + residual - return hidden_states diff --git a/src/diffusers/models/unet_2d_blocks_flax.py b/src/diffusers/models/unet_2d_blocks_flax.py index 5798385b9d28..96e76cb06a59 100644 --- a/src/diffusers/models/unet_2d_blocks_flax.py +++ b/src/diffusers/models/unet_2d_blocks_flax.py @@ -46,6 +46,8 @@ class FlaxCrossAttnDownBlock2D(nn.Module): num_layers: int = 1 attn_num_head_channels: int = 1 add_downsample: bool = True + use_linear_projection: bool = False + only_cross_attention: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): @@ -68,6 +70,8 @@ def setup(self): n_heads=self.attn_num_head_channels, d_head=self.out_channels // self.attn_num_head_channels, depth=1, + use_linear_projection=self.use_linear_projection, + only_cross_attention=self.only_cross_attention, dtype=self.dtype, ) attentions.append(attn_block) @@ -178,6 +182,8 @@ class FlaxCrossAttnUpBlock2D(nn.Module): num_layers: int = 1 attn_num_head_channels: int = 1 add_upsample: bool = True + use_linear_projection: bool = False + only_cross_attention: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): @@ -201,6 +207,8 @@ def setup(self): n_heads=self.attn_num_head_channels, d_head=self.out_channels // self.attn_num_head_channels, depth=1, + use_linear_projection=self.use_linear_projection, + only_cross_attention=self.only_cross_attention, dtype=self.dtype, ) attentions.append(attn_block) @@ -310,6 +318,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): dropout: float = 0.0 num_layers: int = 1 attn_num_head_channels: int = 1 + use_linear_projection: bool = False dtype: jnp.dtype = jnp.float32 def setup(self): @@ -331,6 +340,7 @@ def setup(self): n_heads=self.attn_num_head_channels, d_head=self.in_channels // self.attn_num_head_channels, depth=1, + use_linear_projection=self.use_linear_projection, dtype=self.dtype, ) attentions.append(attn_block) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 7ca9c191b448..8a33853700d6 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -79,7 +79,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. - attention_head_dim (`int`, *optional*, defaults to 8): + attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8): The dimension of the attention heads. cross_attention_dim (`int`, *optional*, defaults to 768): The dimension of the cross attention features. @@ -97,11 +97,13 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): "DownBlock2D", ) up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D") + only_cross_attention: Union[bool, Tuple[bool]] = False block_out_channels: Tuple[int] = (320, 640, 1280, 1280) layers_per_block: int = 2 - attention_head_dim: int = 8 + attention_head_dim: Union[int, Tuple[int]] = 8 cross_attention_dim: int = 1280 dropout: float = 0.0 + use_linear_projection: bool = False dtype: jnp.dtype = jnp.float32 freq_shift: int = 0 @@ -134,6 +136,14 @@ def setup(self): self.time_proj = FlaxTimesteps(block_out_channels[0], freq_shift=self.config.freq_shift) self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) + only_cross_attention = self.only_cross_attention + if isinstance(only_cross_attention, bool): + only_cross_attention = (only_cross_attention,) * len(self.down_block_types) + + attention_head_dim = self.attention_head_dim + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(self.down_block_types) + # down down_blocks = [] output_channel = block_out_channels[0] @@ -148,8 +158,10 @@ def setup(self): out_channels=output_channel, dropout=self.dropout, num_layers=self.layers_per_block, - attn_num_head_channels=self.attention_head_dim, + attn_num_head_channels=attention_head_dim[i], add_downsample=not is_final_block, + use_linear_projection=self.use_linear_projection, + only_cross_attention=only_cross_attention[i], dtype=self.dtype, ) else: @@ -169,13 +181,16 @@ def setup(self): self.mid_block = FlaxUNetMidBlock2DCrossAttn( in_channels=block_out_channels[-1], dropout=self.dropout, - attn_num_head_channels=self.attention_head_dim, + attn_num_head_channels=attention_head_dim[-1], + use_linear_projection=self.use_linear_projection, dtype=self.dtype, ) # up up_blocks = [] reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(self.up_block_types): prev_output_channel = output_channel @@ -190,9 +205,11 @@ def setup(self): out_channels=output_channel, prev_output_channel=prev_output_channel, num_layers=self.layers_per_block + 1, - attn_num_head_channels=self.attention_head_dim, + attn_num_head_channels=reversed_attention_head_dim[i], add_upsample=not is_final_block, dropout=self.dropout, + use_linear_projection=self.use_linear_projection, + only_cross_attention=only_cross_attention[i], dtype=self.dtype, ) else: diff --git a/tests/models/test_models_unet_2d.py b/tests/models/test_models_unet_2d.py index 02c6d314bfff..59b9e02ff8b9 100644 --- a/tests/models/test_models_unet_2d.py +++ b/tests/models/test_models_unet_2d.py @@ -639,3 +639,29 @@ def test_compvis_sd_inpaint_fp16(self, seed, timestep, expected_slice): expected_output_slice = torch.tensor(expected_slice) assert torch_all_close(output_slice, expected_output_slice, atol=5e-3) + + @parameterized.expand( + [ + # fmt: off + [83, 4, [0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]], + [17, 0.55, [0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]], + [8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]], + [3, 1000, [0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]], + # fmt: on + ] + ) + @require_torch_gpu + def test_stabilityai_sd_v2_fp16(self, seed, timestep, expected_slice): + model = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True) + latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True) + encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True) + + with torch.no_grad(): + sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + + assert sample.shape == latents.shape + + output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() + expected_output_slice = torch.tensor(expected_slice) + + assert torch_all_close(output_slice, expected_output_slice, atol=5e-3) diff --git a/tests/models/test_models_unet_2d_flax.py b/tests/models/test_models_unet_2d_flax.py new file mode 100644 index 000000000000..4b279d2f3386 --- /dev/null +++ b/tests/models/test_models_unet_2d_flax.py @@ -0,0 +1,103 @@ +import gc +import unittest + +from diffusers import FlaxUNet2DConditionModel +from diffusers.utils import is_flax_available +from diffusers.utils.testing_utils import load_hf_numpy, require_flax, slow +from parameterized import parameterized + + +if is_flax_available(): + import jax + import jax.numpy as jnp + + +@slow +@require_flax +class FlaxUNet2DConditionModelIntegrationTests(unittest.TestCase): + def get_file_format(self, seed, shape): + return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy" + + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + + def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False): + dtype = jnp.bfloat16 if fp16 else jnp.float32 + image = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype) + return image + + def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"): + dtype = jnp.bfloat16 if fp16 else jnp.float32 + revision = "bf16" if fp16 else None + + model, params = FlaxUNet2DConditionModel.from_pretrained( + model_id, subfolder="unet", dtype=dtype, revision=revision + ) + return model, params + + def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False): + dtype = jnp.bfloat16 if fp16 else jnp.float32 + hidden_states = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype) + return hidden_states + + @parameterized.expand( + [ + # fmt: off + [83, 4, [-0.2323, -0.1304, 0.0813, -0.3093, -0.0919, -0.1571, -0.1125, -0.5806]], + [17, 0.55, [-0.0831, -0.2443, 0.0901, -0.0919, 0.3396, 0.0103, -0.3743, 0.0701]], + [8, 0.89, [-0.4863, 0.0859, 0.0875, -0.1658, 0.9199, -0.0114, 0.4839, 0.4639]], + [3, 1000, [-0.5649, 0.2402, -0.5518, 0.1248, 1.1328, -0.2443, -0.0325, -1.0078]], + # fmt: on + ] + ) + def test_compvis_sd_v1_4_flax_vs_torch_fp16(self, seed, timestep, expected_slice): + model, params = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True) + latents = self.get_latents(seed, fp16=True) + encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) + + sample = model.apply( + {"params": params}, + latents, + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=encoder_hidden_states, + ).sample + + assert sample.shape == latents.shape + + output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32) + expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32) + + # Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, in the same hardware + assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2) + + @parameterized.expand( + [ + # fmt: off + [83, 4, [0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]], + [17, 0.55, [0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]], + [8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]], + [3, 1000, [0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]], + # fmt: on + ] + ) + def test_stabilityai_sd_v2_flax_vs_torch_fp16(self, seed, timestep, expected_slice): + model, params = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True) + latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True) + encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True) + + sample = model.apply( + {"params": params}, + latents, + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=encoder_hidden_states, + ).sample + + assert sample.shape == latents.shape + + output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32) + expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32) + + # Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, on the same hardware + assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py new file mode 100644 index 000000000000..f10f0e179827 --- /dev/null +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py @@ -0,0 +1,99 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +from diffusers import FlaxDPMSolverMultistepScheduler, FlaxStableDiffusionPipeline +from diffusers.utils import is_flax_available, slow +from diffusers.utils.testing_utils import require_flax + + +if is_flax_available(): + import jax + import jax.numpy as jnp + from flax.jax_utils import replicate + from flax.training.common_utils import shard + + +@slow +@require_flax +class FlaxStableDiffusion2PipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + + def test_stable_diffusion_flax(self): + sd_pipe, params = FlaxStableDiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-2", + revision="bf16", + dtype=jnp.bfloat16, + ) + + prompt = "A painting of a squirrel eating a burger" + num_samples = jax.device_count() + prompt = num_samples * [prompt] + prompt_ids = sd_pipe.prepare_inputs(prompt) + + params = replicate(params) + prompt_ids = shard(prompt_ids) + + prng_seed = jax.random.PRNGKey(0) + prng_seed = jax.random.split(prng_seed, jax.device_count()) + + images = sd_pipe(prompt_ids, params, prng_seed, num_inference_steps=25, jit=True)[0] + assert images.shape == (jax.device_count(), 1, 768, 768, 3) + + images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) + image_slice = images[0, 253:256, 253:256, -1] + + output_slice = jnp.asarray(jax.device_get(image_slice.flatten())) + expected_slice = jnp.array([0.4238, 0.4414, 0.4395, 0.4453, 0.4629, 0.4590, 0.4531, 0.45508, 0.4512]) + print(f"output_slice: {output_slice}") + assert jnp.abs(output_slice - expected_slice).max() < 1e-2 + + def test_stable_diffusion_dpm_flax(self): + model_id = "stabilityai/stable-diffusion-2" + scheduler, scheduler_params = FlaxDPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler") + sd_pipe, params = FlaxStableDiffusionPipeline.from_pretrained( + model_id, + scheduler=scheduler, + revision="bf16", + dtype=jnp.bfloat16, + ) + params["scheduler"] = scheduler_params + + prompt = "A painting of a squirrel eating a burger" + num_samples = jax.device_count() + prompt = num_samples * [prompt] + prompt_ids = sd_pipe.prepare_inputs(prompt) + + params = replicate(params) + prompt_ids = shard(prompt_ids) + + prng_seed = jax.random.PRNGKey(0) + prng_seed = jax.random.split(prng_seed, jax.device_count()) + + images = sd_pipe(prompt_ids, params, prng_seed, num_inference_steps=25, jit=True)[0] + assert images.shape == (jax.device_count(), 1, 768, 768, 3) + + images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) + image_slice = images[0, 253:256, 253:256, -1] + + output_slice = jnp.asarray(jax.device_get(image_slice.flatten())) + expected_slice = jnp.array([0.4336, 0.42969, 0.4453, 0.4199, 0.4297, 0.4531, 0.4434, 0.4434, 0.4297]) + print(f"output_slice: {output_slice}") + assert jnp.abs(output_slice - expected_slice).max() < 1e-2