diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index a45cb2af5ae1..ecf4fe5fef95 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -64,6 +64,7 @@ if is_flax_available(): from .modeling_flax_utils import FlaxModelMixin + from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel from .schedulers import ( FlaxDDIMScheduler, FlaxDDPMScheduler, diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py new file mode 100644 index 000000000000..918c7469a74c --- /dev/null +++ b/src/diffusers/models/attention_flax.py @@ -0,0 +1,180 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import flax.linen as nn +import jax.numpy as jnp + + +class FlaxAttentionBlock(nn.Module): + query_dim: int + heads: int = 8 + dim_head: int = 64 + dropout: float = 0.0 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + inner_dim = self.dim_head * self.heads + self.scale = self.dim_head**-0.5 + + # Weights were exported with old names {to_q, to_k, to_v, to_out} + self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q") + self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k") + self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v") + + self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out") + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = jnp.transpose(tensor, (0, 2, 1, 3)) + tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = jnp.transpose(tensor, (0, 2, 1, 3)) + tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def __call__(self, hidden_states, context=None, deterministic=True): + context = hidden_states if context is None else context + + query_proj = self.query(hidden_states) + key_proj = self.key(context) + value_proj = self.value(context) + + query_states = self.reshape_heads_to_batch_dim(query_proj) + key_states = self.reshape_heads_to_batch_dim(key_proj) + value_states = self.reshape_heads_to_batch_dim(value_proj) + + # compute attentions + attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states) + attention_scores = attention_scores * self.scale + attention_probs = nn.softmax(attention_scores, axis=2) + + # attend to values + hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + hidden_states = self.proj_attn(hidden_states) + return hidden_states + + +class FlaxBasicTransformerBlock(nn.Module): + dim: int + n_heads: int + d_head: int + dropout: float = 0.0 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + # self attention + self.self_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) + # cross attention + self.cross_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) + self.ff = FlaxGluFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) + self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) + self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) + self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) + + def __call__(self, hidden_states, context, deterministic=True): + # self attention + residual = hidden_states + hidden_states = self.self_attn(self.norm1(hidden_states), deterministic=deterministic) + hidden_states = hidden_states + residual + + # cross attention + residual = hidden_states + hidden_states = self.cross_attn(self.norm2(hidden_states), context, deterministic=deterministic) + hidden_states = hidden_states + residual + + # feed forward + residual = hidden_states + hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic) + hidden_states = hidden_states + residual + + return hidden_states + + +class FlaxSpatialTransformer(nn.Module): + in_channels: int + n_heads: int + d_head: int + depth: int = 1 + dropout: float = 0.0 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5) + + inner_dim = self.n_heads * self.d_head + self.proj_in = nn.Conv( + inner_dim, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) + + self.transformer_blocks = [ + FlaxBasicTransformerBlock(inner_dim, self.n_heads, self.d_head, dropout=self.dropout, dtype=self.dtype) + for _ in range(self.depth) + ] + + self.proj_out = nn.Conv( + inner_dim, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) + + def __call__(self, hidden_states, context, deterministic=True): + batch, height, width, channels = hidden_states.shape + # import ipdb; ipdb.set_trace() + residual = hidden_states + hidden_states = self.norm(hidden_states) + hidden_states = self.proj_in(hidden_states) + + hidden_states = hidden_states.reshape(batch, height * width, channels) + + for transformer_block in self.transformer_blocks: + hidden_states = transformer_block(hidden_states, context, deterministic=deterministic) + + hidden_states = hidden_states.reshape(batch, height, width, channels) + + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states + residual + + return hidden_states + + +class FlaxGluFeedForward(nn.Module): + dim: int + dropout: float = 0.0 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + inner_dim = self.dim * 4 + self.dense1 = nn.Dense(inner_dim * 2, dtype=self.dtype) + self.dense2 = nn.Dense(self.dim, dtype=self.dtype) + + def __call__(self, hidden_states, deterministic=True): + hidden_states = self.dense1(hidden_states) + hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2) + hidden_states = hidden_linear * nn.gelu(hidden_gelu) + hidden_states = self.dense2(hidden_states) + return hidden_states diff --git a/src/diffusers/models/embeddings_flax.py b/src/diffusers/models/embeddings_flax.py new file mode 100644 index 000000000000..63442ab997b4 --- /dev/null +++ b/src/diffusers/models/embeddings_flax.py @@ -0,0 +1,56 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import flax.linen as nn +import jax.numpy as jnp + + +# This is like models.embeddings.get_timestep_embedding (PyTorch) but +# less general (only handles the case we currently need). +def get_sinusoidal_embeddings(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] tensor of positional embeddings. + """ + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = jnp.exp(jnp.arange(half_dim) * -emb) + emb = timesteps[:, None] * emb[None, :] + emb = jnp.concatenate([jnp.cos(emb), jnp.sin(emb)], -1) + return emb + + +class FlaxTimestepEmbedding(nn.Module): + time_embed_dim: int = 32 + dtype: jnp.dtype = jnp.float32 + + @nn.compact + def __call__(self, temb): + temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb) + temb = nn.silu(temb) + temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb) + return temb + + +class FlaxTimesteps(nn.Module): + dim: int = 32 + + @nn.compact + def __call__(self, timesteps): + return get_sinusoidal_embeddings(timesteps, self.dim) diff --git a/src/diffusers/models/resnet_flax.py b/src/diffusers/models/resnet_flax.py new file mode 100644 index 000000000000..46ccee35adcc --- /dev/null +++ b/src/diffusers/models/resnet_flax.py @@ -0,0 +1,111 @@ +import flax.linen as nn +import jax +import jax.numpy as jnp + + +class FlaxUpsample2D(nn.Module): + out_channels: int + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv = nn.Conv( + self.out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + def __call__(self, hidden_states): + batch, height, width, channels = hidden_states.shape + hidden_states = jax.image.resize( + hidden_states, + shape=(batch, height * 2, width * 2, channels), + method="nearest", + ) + hidden_states = self.conv(hidden_states) + return hidden_states + + +class FlaxDownsample2D(nn.Module): + out_channels: int + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv = nn.Conv( + self.out_channels, + kernel_size=(3, 3), + strides=(2, 2), + padding=((1, 1), (1, 1)), # padding="VALID", + dtype=self.dtype, + ) + + def __call__(self, hidden_states): + # pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim + # hidden_states = jnp.pad(hidden_states, pad_width=pad) + hidden_states = self.conv(hidden_states) + return hidden_states + + +class FlaxResnetBlock2D(nn.Module): + in_channels: int + out_channels: int = None + dropout_prob: float = 0.0 + use_nin_shortcut: bool = None + dtype: jnp.dtype = jnp.float32 + + def setup(self): + out_channels = self.in_channels if self.out_channels is None else self.out_channels + + self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5) + self.conv1 = nn.Conv( + out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype) + + self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5) + self.dropout = nn.Dropout(self.dropout_prob) + self.conv2 = nn.Conv( + out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut + + self.conv_shortcut = None + if use_nin_shortcut: + self.conv_shortcut = nn.Conv( + out_channels, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) + + def __call__(self, hidden_states, temb, deterministic=True): + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states = nn.swish(hidden_states) + hidden_states = self.conv1(hidden_states) + + temb = self.time_emb_proj(nn.swish(temb)) + temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1) + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + hidden_states = nn.swish(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + residual = self.conv_shortcut(residual) + + return hidden_states + residual diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index a9989a5ed27e..cdb0462194a6 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -28,7 +28,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): and returns sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the model (such as downloading or saving, etc.) + implements for all the models (such as downloading or saving, etc.) Parameters: sample_size (`int`, *optional*): The size of the input sample. @@ -198,7 +198,7 @@ def forward( """r Args: sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor - timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py new file mode 100644 index 000000000000..1ac68e10c159 --- /dev/null +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -0,0 +1,258 @@ +from dataclasses import dataclass +from typing import Tuple, Union + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict + +from ..configuration_utils import ConfigMixin, flax_register_to_config +from ..modeling_flax_utils import FlaxModelMixin +from ..utils import BaseOutput +from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps +from .unet_blocks_flax import ( + FlaxCrossAttnDownBlock2D, + FlaxCrossAttnUpBlock2D, + FlaxDownBlock2D, + FlaxUNetMidBlock2DCrossAttn, + FlaxUpBlock2D, +) + + +@dataclass +class FlaxUNet2DConditionOutput(BaseOutput): + """ + Args: + sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): + Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: jnp.ndarray + + +@flax_register_to_config +class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): + r""" + FlaxUNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a + timestep and returns sample shaped output. + + This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the models (such as downloading or saving, etc.) + + Parameters: + sample_size (`int`, *optional*): The size of the input sample. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. The corresponding class names will be: "FlaxCrossAttnDownBlock2D", + "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D" + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): + The tuple of upsample blocks to use. The corresponding class names will be: "FlaxUpBlock2D", + "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D" + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + cross_attention_dim (`int`, *optional*, defaults to 768): The dimension of the cross attention features. + dropout (`float`, *optional*, defaults to 0): Dropout probability for down, up and bottleneck blocks. + """ + + sample_size: int = 32 + in_channels: int = 4 + out_channels: int = 4 + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ) + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D") + block_out_channels: Tuple[int] = (320, 640, 1280, 1280) + layers_per_block: int = 2 + attention_head_dim: int = 8 + cross_attention_dim: int = 1280 + dropout: float = 0.0 + dtype: jnp.dtype = jnp.float32 + + def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict: + # init input tensors + sample_shape = (1, self.sample_size, self.sample_size, self.in_channels) + sample = jnp.zeros(sample_shape, dtype=jnp.float32) + timesteps = jnp.ones((1,), dtype=jnp.int32) + encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"] + + def setup(self): + block_out_channels = self.block_out_channels + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = nn.Conv( + block_out_channels[0], + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + # time + self.time_proj = FlaxTimesteps(block_out_channels[0]) + self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) + + # down + down_blocks = [] + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(self.down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + if down_block_type == "CrossAttnDownBlock2D": + down_block = FlaxCrossAttnDownBlock2D( + in_channels=input_channel, + out_channels=output_channel, + dropout=self.dropout, + num_layers=self.layers_per_block, + attn_num_head_channels=self.attention_head_dim, + add_downsample=not is_final_block, + dtype=self.dtype, + ) + else: + down_block = FlaxDownBlock2D( + in_channels=input_channel, + out_channels=output_channel, + dropout=self.dropout, + num_layers=self.layers_per_block, + add_downsample=not is_final_block, + dtype=self.dtype, + ) + + down_blocks.append(down_block) + self.down_blocks = down_blocks + + # mid + self.mid_block = FlaxUNetMidBlock2DCrossAttn( + in_channels=block_out_channels[-1], + dropout=self.dropout, + attn_num_head_channels=self.attention_head_dim, + dtype=self.dtype, + ) + + # up + up_blocks = [] + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(self.up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + is_final_block = i == len(block_out_channels) - 1 + + if up_block_type == "CrossAttnUpBlock2D": + up_block = FlaxCrossAttnUpBlock2D( + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + num_layers=self.layers_per_block + 1, + attn_num_head_channels=self.attention_head_dim, + add_upsample=not is_final_block, + dropout=self.dropout, + dtype=self.dtype, + ) + else: + up_block = FlaxUpBlock2D( + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + num_layers=self.layers_per_block + 1, + add_upsample=not is_final_block, + dropout=self.dropout, + dtype=self.dtype, + ) + + up_blocks.append(up_block) + prev_output_channel = output_channel + self.up_blocks = up_blocks + + # out + self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-5) + self.conv_out = nn.Conv( + self.out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + def __call__( + self, + sample, + timesteps, + encoder_hidden_states, + return_dict: bool = True, + train: bool = False, + ) -> Union[FlaxUNet2DConditionOutput, Tuple]: + """r + Args: + sample (`jnp.ndarray`): (channel, height, width) noisy inputs tensor + timestep (`jnp.ndarray` or `float` or `int`): timesteps + encoder_hidden_states (`jnp.ndarray`): (channel, height, width) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a + plain tuple. + train (`bool`, *optional*, defaults to `False`): + Use deterministic functions and disable dropout when not training. + + Returns: + [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. + """ + # 1. time + t_emb = self.time_proj(timesteps) + t_emb = self.time_embedding(t_emb) + + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for down_block in self.down_blocks: + if isinstance(down_block, FlaxCrossAttnDownBlock2D): + sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train) + else: + sample, res_samples = down_block(sample, t_emb, deterministic=not train) + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train) + + # 5. up + for up_block in self.up_blocks: + res_samples = down_block_res_samples[-(self.layers_per_block + 1) :] + down_block_res_samples = down_block_res_samples[: -(self.layers_per_block + 1)] + if isinstance(up_block, FlaxCrossAttnUpBlock2D): + sample = up_block( + sample, + temb=t_emb, + encoder_hidden_states=encoder_hidden_states, + res_hidden_states_tuple=res_samples, + deterministic=not train, + ) + else: + sample = up_block(sample, temb=t_emb, res_hidden_states_tuple=res_samples, deterministic=not train) + + # 6. post-process + sample = self.conv_norm_out(sample) + sample = nn.silu(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return FlaxUNet2DConditionOutput(sample=sample) diff --git a/src/diffusers/models/unet_blocks_flax.py b/src/diffusers/models/unet_blocks_flax.py new file mode 100644 index 000000000000..ce67eb12b19f --- /dev/null +++ b/src/diffusers/models/unet_blocks_flax.py @@ -0,0 +1,263 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import flax.linen as nn +import jax.numpy as jnp + +from .attention_flax import FlaxSpatialTransformer +from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D + + +class FlaxCrossAttnDownBlock2D(nn.Module): + in_channels: int + out_channels: int + dropout: float = 0.0 + num_layers: int = 1 + attn_num_head_channels: int = 1 + add_downsample: bool = True + dtype: jnp.dtype = jnp.float32 + + def setup(self): + resnets = [] + attentions = [] + + for i in range(self.num_layers): + in_channels = self.in_channels if i == 0 else self.out_channels + + res_block = FlaxResnetBlock2D( + in_channels=in_channels, + out_channels=self.out_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + + attn_block = FlaxSpatialTransformer( + in_channels=self.out_channels, + n_heads=self.attn_num_head_channels, + d_head=self.out_channels // self.attn_num_head_channels, + depth=1, + dtype=self.dtype, + ) + attentions.append(attn_block) + + self.resnets = resnets + self.attentions = attentions + + if self.add_downsample: + self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype) + + def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb, deterministic=deterministic) + hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) + output_states += (hidden_states,) + + if self.add_downsample: + hidden_states = self.downsample(hidden_states) + output_states += (hidden_states,) + + return hidden_states, output_states + + +class FlaxDownBlock2D(nn.Module): + in_channels: int + out_channels: int + dropout: float = 0.0 + num_layers: int = 1 + add_downsample: bool = True + dtype: jnp.dtype = jnp.float32 + + def setup(self): + resnets = [] + + for i in range(self.num_layers): + in_channels = self.in_channels if i == 0 else self.out_channels + + res_block = FlaxResnetBlock2D( + in_channels=in_channels, + out_channels=self.out_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + self.resnets = resnets + + if self.add_downsample: + self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype) + + def __call__(self, hidden_states, temb, deterministic=True): + output_states = () + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb, deterministic=deterministic) + output_states += (hidden_states,) + + if self.add_downsample: + hidden_states = self.downsample(hidden_states) + output_states += (hidden_states,) + + return hidden_states, output_states + + +class FlaxCrossAttnUpBlock2D(nn.Module): + in_channels: int + out_channels: int + prev_output_channel: int + dropout: float = 0.0 + num_layers: int = 1 + attn_num_head_channels: int = 1 + add_upsample: bool = True + dtype: jnp.dtype = jnp.float32 + + def setup(self): + resnets = [] + attentions = [] + + for i in range(self.num_layers): + res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels + resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels + + res_block = FlaxResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=self.out_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + + attn_block = FlaxSpatialTransformer( + in_channels=self.out_channels, + n_heads=self.attn_num_head_channels, + d_head=self.out_channels // self.attn_num_head_channels, + depth=1, + dtype=self.dtype, + ) + attentions.append(attn_block) + + self.resnets = resnets + self.attentions = attentions + + if self.add_upsample: + self.upsample = FlaxUpsample2D(self.out_channels, dtype=self.dtype) + + def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) + + hidden_states = resnet(hidden_states, temb, deterministic=deterministic) + hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) + + if self.add_upsample: + hidden_states = self.upsample(hidden_states) + + return hidden_states + + +class FlaxUpBlock2D(nn.Module): + in_channels: int + out_channels: int + prev_output_channel: int + dropout: float = 0.0 + num_layers: int = 1 + add_upsample: bool = True + dtype: jnp.dtype = jnp.float32 + + def setup(self): + resnets = [] + + for i in range(self.num_layers): + res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels + resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels + + res_block = FlaxResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=self.out_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + + self.resnets = resnets + + if self.add_upsample: + self.upsample = FlaxUpsample2D(self.out_channels, dtype=self.dtype) + + def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) + + hidden_states = resnet(hidden_states, temb, deterministic=deterministic) + + if self.add_upsample: + hidden_states = self.upsample(hidden_states) + + return hidden_states + + +class FlaxUNetMidBlock2DCrossAttn(nn.Module): + in_channels: int + dropout: float = 0.0 + num_layers: int = 1 + attn_num_head_channels: int = 1 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + # there is always at least one resnet + resnets = [ + FlaxResnetBlock2D( + in_channels=self.in_channels, + out_channels=self.in_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + ] + + attentions = [] + + for _ in range(self.num_layers): + attn_block = FlaxSpatialTransformer( + in_channels=self.in_channels, + n_heads=self.attn_num_head_channels, + d_head=self.in_channels // self.attn_num_head_channels, + depth=1, + dtype=self.dtype, + ) + attentions.append(attn_block) + + res_block = FlaxResnetBlock2D( + in_channels=self.in_channels, + out_channels=self.in_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + + self.resnets = resnets + self.attentions = attentions + + def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) + hidden_states = resnet(hidden_states, temb, deterministic=deterministic) + + return hidden_states diff --git a/src/diffusers/utils/dummy_flax_objects.py b/src/diffusers/utils/dummy_flax_objects.py index d246f41a72d2..9615afb6f920 100644 --- a/src/diffusers/utils/dummy_flax_objects.py +++ b/src/diffusers/utils/dummy_flax_objects.py @@ -46,6 +46,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) +class FlaxUNet2DConditionModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxScoreSdeVeScheduler(metaclass=DummyObject): _backends = ["flax"]