From 67e245c2b5bd85b175219fba17b5ba59d9d8a801 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 12 Sep 2022 18:23:21 +0200 Subject: [PATCH 01/31] First UNet Flax modeling blocks. Mimic the structure of the PyTorch files. The model classes themselves need work, depending on what we do about configuration and initialization. --- src/diffusers/models/attention_flax.py | 181 ++++++++++++ src/diffusers/models/embeddings_flax.py | 56 ++++ src/diffusers/models/resnet_flax.py | 111 ++++++++ .../models/unet_2d_condition_flax.py | 257 +++++++++++++++++ src/diffusers/models/unet_blocks_flax.py | 263 ++++++++++++++++++ 5 files changed, 868 insertions(+) create mode 100644 src/diffusers/models/attention_flax.py create mode 100644 src/diffusers/models/embeddings_flax.py create mode 100644 src/diffusers/models/resnet_flax.py create mode 100644 src/diffusers/models/unet_2d_condition_flax.py create mode 100644 src/diffusers/models/unet_blocks_flax.py diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py new file mode 100644 index 000000000000..77e5ad9c75fb --- /dev/null +++ b/src/diffusers/models/attention_flax.py @@ -0,0 +1,181 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import flax.linen as nn +import jax.numpy as jnp + + +class FlaxAttentionBlock(nn.Module): + query_dim: int + heads: int = 8 + dim_head: int = 64 + dropout: float = 0.0 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + inner_dim = self.dim_head * self.heads + self.scale = self.dim_head**-0.5 + + self.to_q = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype) + self.to_k = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype) + self.to_v = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype) + + self.to_out = nn.Dense(self.query_dim, dtype=self.dtype) + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = jnp.transpose(tensor, (0, 2, 1, 3)) + tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = jnp.transpose(tensor, (0, 2, 1, 3)) + tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def __call__(self, hidden_states, context=None, deterministic=True): + context = hidden_states if context is None else context + + q = self.to_q(hidden_states) + k = self.to_k(context) + v = self.to_v(context) + + q = self.reshape_heads_to_batch_dim(q) + k = self.reshape_heads_to_batch_dim(k) + v = self.reshape_heads_to_batch_dim(v) + + # compute attentions + attn_weights = jnp.einsum("b i d, b j d->b i j", q, k) + attn_weights = attn_weights * self.scale + attn_weights = nn.softmax(attn_weights, axis=2) + + ## attend to values + hidden_states = jnp.einsum("b i j, b j d -> b i d", attn_weights, v) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + hidden_states = self.to_out(hidden_states) + return hidden_states + + +class FlaxBasicTransformerBlock(nn.Module): + dim: int + n_heads: int + d_head: int + dropout: float = 0.0 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + # self attention + self.self_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) + # cross attention + self.cross_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) + self.ff = FlaxGluFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) + self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) + self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) + self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) + + def __call__(self, hidden_states, context, deterministic=True): + # self attention + residual = hidden_states + hidden_states = self.self_attn(self.norm1(hidden_states)) + hidden_states = hidden_states + residual + + # cross attention + residual = hidden_states + hidden_states = self.cross_attn(self.norm2(hidden_states), context) + hidden_states = hidden_states + residual + + # feed forward + residual = hidden_states + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states = hidden_states + residual + + return hidden_states + + +class FlaxSpatialTransformer(nn.Module): + in_channels: int + n_heads: int + d_head: int + depth: int = 1 + dropout: float = 0.0 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5) + + inner_dim = self.n_heads * self.d_head + self.proj_in = nn.Conv( + inner_dim, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) + + self.transformer_blocks = [ + TransformerBlock(inner_dim, self.n_heads, self.d_head, dropout=self.dropout, dtype=self.dtype) + for _ in range(self.depth) + ] + + self.proj_out = nn.Conv( + inner_dim, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) + + def __call__(self, hidden_states, context, deterministic=True): + batch, height, width, channels = hidden_states.shape + # import ipdb; ipdb.set_trace() + residual = hidden_states + hidden_states = self.norm(hidden_states) + hidden_states = self.proj_in(hidden_states) + + # hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 1)) + hidden_states = hidden_states.reshape(batch, height * width, channels) + + for transformer_block in self.transformer_blocks: + hidden_states = transformer_block(hidden_states, context) + + hidden_states = hidden_states.reshape(batch, height, width, channels) + # hidden_states = jnp.transpose(hidden_states, (0, 3, 1, 2)) + + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states + residual + + return hidden_states + + +class FlaxGluFeedForward(nn.Module): + dim: int + dropout: float = 0.0 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + inner_dim = self.dim * 4 + self.dense1 = nn.Dense(inner_dim * 2, dtype=self.dtype) + self.dense2 = nn.Dense(self.dim, dtype=self.dtype) + + def __call__(self, hidden_states, deterministic=True): + hidden_states = self.dense1(hidden_states) + hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2) + hidden_states = hidden_linear * nn.gelu(hidden_gelu) + hidden_states = self.dense2(hidden_states) + return hidden_states diff --git a/src/diffusers/models/embeddings_flax.py b/src/diffusers/models/embeddings_flax.py new file mode 100644 index 000000000000..63442ab997b4 --- /dev/null +++ b/src/diffusers/models/embeddings_flax.py @@ -0,0 +1,56 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import flax.linen as nn +import jax.numpy as jnp + + +# This is like models.embeddings.get_timestep_embedding (PyTorch) but +# less general (only handles the case we currently need). +def get_sinusoidal_embeddings(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] tensor of positional embeddings. + """ + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = jnp.exp(jnp.arange(half_dim) * -emb) + emb = timesteps[:, None] * emb[None, :] + emb = jnp.concatenate([jnp.cos(emb), jnp.sin(emb)], -1) + return emb + + +class FlaxTimestepEmbedding(nn.Module): + time_embed_dim: int = 32 + dtype: jnp.dtype = jnp.float32 + + @nn.compact + def __call__(self, temb): + temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb) + temb = nn.silu(temb) + temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb) + return temb + + +class FlaxTimesteps(nn.Module): + dim: int = 32 + + @nn.compact + def __call__(self, timesteps): + return get_sinusoidal_embeddings(timesteps, self.dim) diff --git a/src/diffusers/models/resnet_flax.py b/src/diffusers/models/resnet_flax.py new file mode 100644 index 000000000000..46ccee35adcc --- /dev/null +++ b/src/diffusers/models/resnet_flax.py @@ -0,0 +1,111 @@ +import flax.linen as nn +import jax +import jax.numpy as jnp + + +class FlaxUpsample2D(nn.Module): + out_channels: int + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv = nn.Conv( + self.out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + def __call__(self, hidden_states): + batch, height, width, channels = hidden_states.shape + hidden_states = jax.image.resize( + hidden_states, + shape=(batch, height * 2, width * 2, channels), + method="nearest", + ) + hidden_states = self.conv(hidden_states) + return hidden_states + + +class FlaxDownsample2D(nn.Module): + out_channels: int + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv = nn.Conv( + self.out_channels, + kernel_size=(3, 3), + strides=(2, 2), + padding=((1, 1), (1, 1)), # padding="VALID", + dtype=self.dtype, + ) + + def __call__(self, hidden_states): + # pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim + # hidden_states = jnp.pad(hidden_states, pad_width=pad) + hidden_states = self.conv(hidden_states) + return hidden_states + + +class FlaxResnetBlock2D(nn.Module): + in_channels: int + out_channels: int = None + dropout_prob: float = 0.0 + use_nin_shortcut: bool = None + dtype: jnp.dtype = jnp.float32 + + def setup(self): + out_channels = self.in_channels if self.out_channels is None else self.out_channels + + self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5) + self.conv1 = nn.Conv( + out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype) + + self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5) + self.dropout = nn.Dropout(self.dropout_prob) + self.conv2 = nn.Conv( + out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut + + self.conv_shortcut = None + if use_nin_shortcut: + self.conv_shortcut = nn.Conv( + out_channels, + kernel_size=(1, 1), + strides=(1, 1), + padding="VALID", + dtype=self.dtype, + ) + + def __call__(self, hidden_states, temb, deterministic=True): + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states = nn.swish(hidden_states) + hidden_states = self.conv1(hidden_states) + + temb = self.time_emb_proj(nn.swish(temb)) + temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1) + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + hidden_states = nn.swish(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + residual = self.conv_shortcut(residual) + + return hidden_states + residual diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py new file mode 100644 index 000000000000..85c1e965a951 --- /dev/null +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -0,0 +1,257 @@ +from typing import Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict + +from ..configuration_utils import ConfigMixin, register_to_config +from ..modeling_utils import FlaxModelMixin +from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps +from .unet_blocks_flax import ( + FlaxDownBlock2D, + FlaxCrossAttnDownBlock2D, + FlaxUNetMidBlock2DCrossAttn, + FlaxUpBlock2D, + FlaxCrossAttnUpBlock2D, +) + + +# Configuration - we may not need this any more +class FlaxUNet2DConfig(ConfigMixin): + def __init__( + self, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"), + up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + block_out_channels=(224, 448, 672, 896), + layers_per_block=2, + attention_head_dim=8, + cross_attention_dim=768, + dropout=0.1, + **kwargs, + ): + super().__init__(**kwargs) + self.sample_size = sample_size + self.in_channels = in_channels + self.out_channels = out_channels + self.down_block_types = down_block_types + self.up_block_types = up_block_types + self.block_out_channels = block_out_channels + self.layers_per_block = layers_per_block + self.attention_head_dim = attention_head_dim + self.cross_attention_dim = cross_attention_dim + self.dropout = dropout + + +# This is TBD. We may not need the module + the class +class FlaxUNet2DModule(nn.Module): + config: FlaxUNet2DConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + config = self.config + + self.sample_size = config.sample_size + block_out_channels = config.block_out_channels + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = nn.Conv( + block_out_channels[0], + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + # time + self.time_proj = FlaxTimesteps(block_out_channels[0]) + self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) + + # down + down_blocks = [] + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(config.down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + if down_block_type == "CrossAttnDownBlock2D": + down_block = FlaxCrossAttnDownBlock2D( + in_channels=input_channel, + out_channels=output_channel, + dropout=config.dropout, + num_layers=config.layers_per_block, + attn_num_head_channels=config.attention_head_dim, + add_downsample=not is_final_block, + dtype=self.dtype, + ) + else: + down_block = FlaxDownBlock2D( + in_channels=input_channel, + out_channels=output_channel, + dropout=config.dropout, + num_layers=config.layers_per_block, + add_downsample=not is_final_block, + dtype=self.dtype, + ) + + down_blocks.append(down_block) + self.down_blocks = down_blocks + + # mid + self.mid_block = FlaxUNetMidBlock2DCrossAttn( + in_channels=block_out_channels[-1], + dropout=config.dropout, + attn_num_head_channels=config.attention_head_dim, + dtype=self.dtype, + ) + + # up + up_blocks = [] + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(config.up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + is_final_block = i == len(block_out_channels) - 1 + + if up_block_type == "CrossAttnUpBlock2D": + up_block = FlaxCrossAttnUpBlock2D( + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + num_layers=config.layers_per_block + 1, + attn_num_head_channels=config.attention_head_dim, + add_upsample=not is_final_block, + dropout=config.dropout, + dtype=self.dtype, + ) + else: + up_block = FlaxUpBlock2D( + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + num_layers=config.layers_per_block + 1, + add_upsample=not is_final_block, + dropout=config.dropout, + dtype=self.dtype, + ) + + up_blocks.append(up_block) + prev_output_channel = output_channel + self.up_blocks = up_blocks + + # out + self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-5) + self.conv_out = nn.Conv( + config.out_channels, + kernel_size=(3, 3), + strides=(1, 1), + padding=((1, 1), (1, 1)), + dtype=self.dtype, + ) + + def __call__(self, sample, timesteps, encoder_hidden_states, deterministic=True): + # 1. time + # broadcast to batch dimension + # timesteps = jnp.broadcast_to(timesteps, (sample.shape[0],) + timesteps.shape) + t_emb = self.time_proj(timesteps) + t_emb = self.time_embedding(t_emb) + + # 2. pre-process + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for down_block in self.down_blocks: + if isinstance(down_block, FlaxCrossAttnDownBlock2D): + sample, res_samples = down_block(sample, t_emb, encoder_hidden_states) + else: + sample, res_samples = down_block(sample, t_emb) + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block(sample, t_emb, encoder_hidden_states) + + # 5. up + for up_block in self.up_blocks: + res_samples = down_block_res_samples[-(self.config.layers_per_block + 1) :] + down_block_res_samples = down_block_res_samples[: -(self.config.layers_per_block + 1)] + if isinstance(up_block, FlaxCrossAttnUpBlock2D): + sample = up_block( + sample, + temb=t_emb, + encoder_hidden_states=encoder_hidden_states, + res_hidden_states_tuple=res_samples, + ) + else: + sample = up_block(sample, temb=t_emb, res_hidden_states_tuple=res_samples) + + # 6. post-process + sample = self.conv_norm_out(sample) + sample = nn.silu(sample) + sample = self.conv_out(sample) + + return sample + + +class FlaxUNet2DConditionModel(nn.Module, ConfigMixin, FlaxModelMixin): + module_class = FlaxUNet2DModule + config_class = FlaxUNet2DConfig + base_model_prefix = "model" + module_class: nn.Module = None + + def __init__( + self, + config: FlaxUNet2DConfig, + input_shape: Tuple = (1, 32, 32, 4), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + # init input tensors + sample_shape = (1, self.config.sample_size, self.config.sample_size, self.config.in_channels) + sample = jnp.zeros(sample_shape, dtype=jnp.float32) + timesteps = jnp.ones((1,), dtype=jnp.int32) + encoder_hidden_states = jnp.zeros((1, 1, self.config.cross_attention_dim), dtype=jnp.float32) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + return self.module.init(rngs, sample, timesteps, encoder_hidden_states)["params"] + + def __call__( + self, + sample, + timesteps, + encoder_hidden_states, + params: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + ): + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + return self.module.apply( + {"params": params or self.params}, + jnp.array(sample), + jnp.array(timesteps, dtype=jnp.int32), + encoder_hidden_states, + not train, + rngs=rngs, + ) + + +# class UNet2D(UNet2DPretrainedModel): +# module_class = UNet2DModule diff --git a/src/diffusers/models/unet_blocks_flax.py b/src/diffusers/models/unet_blocks_flax.py new file mode 100644 index 000000000000..5de83bb2a559 --- /dev/null +++ b/src/diffusers/models/unet_blocks_flax.py @@ -0,0 +1,263 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import flax.linen as nn +import jax.numpy as jnp + +from .attention_flax import FlaxAttentionBlock, FlaxSpatialTransformer +from .resnet_flax import FlaxDownsample2D, FlaxUpsample2D, FlaxResnetBlock2D + +class FlaxCrossAttnDownBlock2D(nn.Module): + in_channels: int + out_channels: int + dropout: float = 0.0 + num_layers: int = 1 + attn_num_head_channels: int = 1 + add_downsample: bool = True + dtype: jnp.dtype = jnp.float32 + + def setup(self): + resnets = [] + attentions = [] + + for i in range(self.num_layers): + in_channels = self.in_channels if i == 0 else self.out_channels + + res_block = FlaxResnetBlock2D( + in_channels=in_channels, + out_channels=self.out_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + + attn_block = FlaxSpatialTransformer( + in_channels=self.out_channels, + n_heads=self.attn_num_head_channels, + d_head=self.out_channels // self.attn_num_head_channels, + depth=1, + dtype=self.dtype, + ) + attentions.append(attn_block) + + self.resnets = resnets + self.attentions = attentions + + if self.add_downsample: + self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype) + + def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states) + output_states += (hidden_states,) + + if self.add_downsample: + hidden_states = self.downsample(hidden_states) + output_states += (hidden_states,) + + return hidden_states, output_states + + +class FlaxDownBlock2D(nn.Module): + in_channels: int + out_channels: int + dropout: float = 0.0 + num_layers: int = 1 + add_downsample: bool = True + dtype: jnp.dtype = jnp.float32 + + def setup(self): + resnets = [] + + for i in range(self.num_layers): + in_channels = self.in_channels if i == 0 else self.out_channels + + res_block = FlaxResnetBlock2D( + in_channels=in_channels, + out_channels=self.out_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + self.resnets = resnets + + if self.add_downsample: + self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype) + + def __call__(self, hidden_states, temb, deterministic=True): + output_states = () + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb) + output_states += (hidden_states,) + + if self.add_downsample: + hidden_states = self.downsample(hidden_states) + output_states += (hidden_states,) + + return hidden_states, output_states + + +class FlaxCrossAttnUpBlock2D(nn.Module): + in_channels: int + out_channels: int + prev_output_channel: int + dropout: float = 0.0 + num_layers: int = 1 + attn_num_head_channels: int = 1 + add_upsample: bool = True + dtype: jnp.dtype = jnp.float32 + + def setup(self): + resnets = [] + attentions = [] + + for i in range(self.num_layers): + res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels + resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels + + res_block = FlaxResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=self.out_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + + attn_block = FlaxSpatialTransformer( + in_channels=self.out_channels, + n_heads=self.attn_num_head_channels, + d_head=self.out_channels // self.attn_num_head_channels, + depth=1, + dtype=self.dtype, + ) + attentions.append(attn_block) + + self.resnets = resnets + self.attentions = attentions + + if self.add_upsample: + self.upsample = FlaxUpsample2D(self.out_channels, dtype=self.dtype) + + def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True): + + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states) + + if self.add_upsample: + hidden_states = self.upsample(hidden_states) + + return hidden_states + + +class FlaxUpBlock2D(nn.Module): + in_channels: int + out_channels: int + prev_output_channel: int + dropout: float = 0.0 + num_layers: int = 1 + add_upsample: bool = True + dtype: jnp.dtype = jnp.float32 + + def setup(self): + resnets = [] + + for i in range(self.num_layers): + res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels + resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels + + res_block = FlaxResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=self.out_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + + self.resnets = resnets + + if self.add_upsample: + self.upsample = FlaxUpsample2D(self.out_channels, dtype=self.dtype) + + def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) + + hidden_states = resnet(hidden_states, temb) + + if self.add_upsample: + hidden_states = self.upsample(hidden_states) + + return hidden_states + + +class FlaxUNetMidBlock2DCrossAttn(nn.Module): + in_channels: int + dropout: float = 0.0 + num_layers: int = 1 + attn_num_head_channels: int = 1 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + # there is always at least one resnet + resnets = [ + FlaxResnetBlock2D( + in_channels=self.in_channels, + out_channels=self.in_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + ] + + attentions = [] + + for _ in range(self.num_layers): + attn_block = FlaxSpatialTransformer( + in_channels=self.in_channels, + n_heads=self.attn_num_head_channels, + d_head=self.in_channels // self.attn_num_head_channels, + depth=1, + dtype=self.dtype, + ) + attentions.append(attn_block) + + res_block = FlaxResnetBlock2D( + in_channels=self.in_channels, + out_channels=self.in_channels, + dropout_prob=self.dropout, + dtype=self.dtype, + ) + resnets.append(res_block) + + self.resnets = resnets + self.attentions = attentions + + def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states, encoder_hidden_states) + hidden_states = resnet(hidden_states, temb) + + return hidden_states From c3fdbf95320f893701f05cd5ec2ec2f906f36bca Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 12 Sep 2022 18:39:30 +0200 Subject: [PATCH 02/31] Remove FlaxUNet2DConfig class. --- .../models/unet_2d_condition_flax.py | 127 ++++++++++-------- 1 file changed, 68 insertions(+), 59 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 85c1e965a951..d3b54cf9d743 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -17,45 +17,25 @@ ) -# Configuration - we may not need this any more -class FlaxUNet2DConfig(ConfigMixin): - def __init__( - self, - sample_size=32, - in_channels=4, - out_channels=4, - down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"), - up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), - block_out_channels=(224, 448, 672, 896), - layers_per_block=2, - attention_head_dim=8, - cross_attention_dim=768, - dropout=0.1, - **kwargs, - ): - super().__init__(**kwargs) - self.sample_size = sample_size - self.in_channels = in_channels - self.out_channels = out_channels - self.down_block_types = down_block_types - self.up_block_types = up_block_types - self.block_out_channels = block_out_channels - self.layers_per_block = layers_per_block - self.attention_head_dim = attention_head_dim - self.cross_attention_dim = cross_attention_dim - self.dropout = dropout - - # This is TBD. We may not need the module + the class class FlaxUNet2DModule(nn.Module): - config: FlaxUNet2DConfig + # config args + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"), + up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + block_out_channels=(224, 448, 672, 896), + layers_per_block=2, + attention_head_dim=8, + cross_attention_dim=768, + dropout=0.1, + + # model args dtype: jnp.dtype = jnp.float32 def setup(self): - config = self.config - - self.sample_size = config.sample_size - block_out_channels = config.block_out_channels + block_out_channels = self.block_out_channels time_embed_dim = block_out_channels[0] * 4 # input @@ -74,7 +54,7 @@ def setup(self): # down down_blocks = [] output_channel = block_out_channels[0] - for i, down_block_type in enumerate(config.down_block_types): + for i, down_block_type in enumerate(self.down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 @@ -83,9 +63,9 @@ def setup(self): down_block = FlaxCrossAttnDownBlock2D( in_channels=input_channel, out_channels=output_channel, - dropout=config.dropout, - num_layers=config.layers_per_block, - attn_num_head_channels=config.attention_head_dim, + dropout=self.dropout, + num_layers=self.layers_per_block, + attn_num_head_channels=self.attention_head_dim, add_downsample=not is_final_block, dtype=self.dtype, ) @@ -93,8 +73,8 @@ def setup(self): down_block = FlaxDownBlock2D( in_channels=input_channel, out_channels=output_channel, - dropout=config.dropout, - num_layers=config.layers_per_block, + dropout=self.dropout, + num_layers=self.layers_per_block, add_downsample=not is_final_block, dtype=self.dtype, ) @@ -105,8 +85,8 @@ def setup(self): # mid self.mid_block = FlaxUNetMidBlock2DCrossAttn( in_channels=block_out_channels[-1], - dropout=config.dropout, - attn_num_head_channels=config.attention_head_dim, + dropout=self.dropout, + attn_num_head_channels=self.attention_head_dim, dtype=self.dtype, ) @@ -114,7 +94,7 @@ def setup(self): up_blocks = [] reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(config.up_block_types): + for i, up_block_type in enumerate(self.up_block_types): prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] @@ -126,10 +106,10 @@ def setup(self): in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, - num_layers=config.layers_per_block + 1, - attn_num_head_channels=config.attention_head_dim, + num_layers=self.layers_per_block + 1, + attn_num_head_channels=self.attention_head_dim, add_upsample=not is_final_block, - dropout=config.dropout, + dropout=self.dropout, dtype=self.dtype, ) else: @@ -137,9 +117,9 @@ def setup(self): in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, - num_layers=config.layers_per_block + 1, + num_layers=self.layers_per_block + 1, add_upsample=not is_final_block, - dropout=config.dropout, + dropout=self.dropout, dtype=self.dtype, ) @@ -150,7 +130,7 @@ def setup(self): # out self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-5) self.conv_out = nn.Conv( - config.out_channels, + self.out_channels, kernel_size=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)), @@ -181,8 +161,8 @@ def __call__(self, sample, timesteps, encoder_hidden_states, deterministic=True) # 5. up for up_block in self.up_blocks: - res_samples = down_block_res_samples[-(self.config.layers_per_block + 1) :] - down_block_res_samples = down_block_res_samples[: -(self.config.layers_per_block + 1)] + res_samples = down_block_res_samples[-(self.layers_per_block + 1) :] + down_block_res_samples = down_block_res_samples[: -(self.layers_per_block + 1)] if isinstance(up_block, FlaxCrossAttnUpBlock2D): sample = up_block( sample, @@ -202,29 +182,58 @@ def __call__(self, sample, timesteps, encoder_hidden_states, deterministic=True) class FlaxUNet2DConditionModel(nn.Module, ConfigMixin, FlaxModelMixin): - module_class = FlaxUNet2DModule - config_class = FlaxUNet2DConfig base_model_prefix = "model" - module_class: nn.Module = None + module_class = FlaxUNet2DModule + @register_to_config def __init__( self, - config: FlaxUNet2DConfig, + # config args + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"), + up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + block_out_channels=(224, 448, 672, 896), + layers_per_block=2, + attention_head_dim=8, + cross_attention_dim=768, + dropout=0.1, + + # model args - to be ignored for config input_shape: Tuple = (1, 32, 32, 4), seed: int = 0, dtype: jnp.dtype = jnp.float32, _do_init: bool = True, **kwargs, ): - module = self.module_class(config=config, dtype=dtype, **kwargs) - super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + module = self.module_class( + sample_size=sample_size, + in_channels=in_channels, + out_channels=out_channels, + down_block_types=down_block_types, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + attention_head_dim=attention_head_dim, + cross_attention_dim=cross_attention_dim, + dropout=dropout, + dtype=dtype, **kwargs) + super().__init__( + module, + input_shape=input_shape, + seed=seed, + dtype=dtype, + _do_init=_do_init + ) + # Note: input_shape is ignored def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: # init input tensors - sample_shape = (1, self.config.sample_size, self.config.sample_size, self.config.in_channels) + sample_shape = (1, self.module.sample_size, self.module.sample_size, self.module.in_channels) sample = jnp.zeros(sample_shape, dtype=jnp.float32) timesteps = jnp.ones((1,), dtype=jnp.int32) - encoder_hidden_states = jnp.zeros((1, 1, self.config.cross_attention_dim), dtype=jnp.float32) + encoder_hidden_states = jnp.zeros((1, 1, self.module.cross_attention_dim), dtype=jnp.float32) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} From 1067e3415527ea442b871708cf085170cc15dd0e Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 12 Sep 2022 18:45:33 +0200 Subject: [PATCH 03/31] ignore_for_config non-config args. --- src/diffusers/models/unet_2d_condition_flax.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index d3b54cf9d743..249390a16294 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -184,6 +184,7 @@ def __call__(self, sample, timesteps, encoder_hidden_states, deterministic=True) class FlaxUNet2DConditionModel(nn.Module, ConfigMixin, FlaxModelMixin): base_model_prefix = "model" module_class = FlaxUNet2DModule + ignore_for_config = ["input_shape", "seed", "dtype", "_do_init"] @register_to_config def __init__( @@ -200,7 +201,7 @@ def __init__( cross_attention_dim=768, dropout=0.1, - # model args - to be ignored for config + # model args input_shape: Tuple = (1, 32, 32, 4), seed: int = 0, dtype: jnp.dtype = jnp.float32, From 95073e13b2172bcadddd4ed528360c4f216d753b Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Tue, 13 Sep 2022 09:20:34 +0000 Subject: [PATCH 04/31] Implement `FlaxModelMixin` --- src/diffusers/configuration_utils.py | 40 +++ src/diffusers/modeling_flax_utils.py | 500 +++++++++++++++++++++++++++ 2 files changed, 540 insertions(+) create mode 100644 src/diffusers/modeling_flax_utils.py diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index fbe75f3f1441..bd08f25bffdf 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -401,3 +401,43 @@ def inner_init(self, *args, **kwargs): getattr(self, "register_to_config")(**new_kwargs) return inner_init + + +def flax_register_to_config(cls): + original_init = cls.__init__ + + @functools.wraps(original_init) + def init(self, *args, **kwargs): + # Ignore private kwargs in the init. + init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} + # original_init(self, *args, **init_kwargs) + if not isinstance(self, ConfigMixin): + raise RuntimeError( + f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " + "not inherit from `ConfigMixin`." + ) + + ignore = getattr(self, "ignore_for_config", []) + # Get positional arguments aligned with kwargs + new_kwargs = {} + signature = inspect.signature(init) + parameters = { + name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore + } + for arg, name in zip(args, parameters.keys()): + new_kwargs[name] = arg + + # Then add all kwargs + new_kwargs.update( + { + k: init_kwargs.get(k, default) + for k, default in parameters.items() + if k not in ignore and k not in new_kwargs + } + ) + getattr(self, "register_to_config")(**new_kwargs) + + original_init(self, *args, **init_kwargs) + + cls.__init__ = init + return cls diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py new file mode 100644 index 000000000000..dbfe258c9c13 --- /dev/null +++ b/src/diffusers/modeling_flax_utils.py @@ -0,0 +1,500 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from pickle import UnpicklingError +from typing import Any, Dict, Union + +import jax +import jax.numpy as jnp +import msgpack.exceptions +from flax.core.frozen_dict import FrozenDict +from flax.serialization import from_bytes, to_bytes +from flax.traverse_util import flatten_dict, unflatten_dict +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError +from requests import HTTPError + +from .modeling_utils import WEIGHTS_NAME +from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging + + +FLAX_WEIGHTS_NAME = "flax_model.msgpack" # TODO should be "diffusion_flax_model.msgpack" + +logger = logging.get_logger(__name__) + + +class FlaxModelMixin: + r""" + Base class for all flax models. + + [`FlaxModelMixin`] takes care of storing the configuration of the models and handles methods for loading, + downloading and saving models. + """ + _missing_keys = set() + config_name = CONFIG_NAME + ignore_for_config = ["parent", "name"] + _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] + + @classmethod + def _from_config(cls, config, **kwargs): + """ + All context managers that the model should be initialized under go here. + """ + return cls(config, **kwargs) + + @property + def framework(self) -> str: + """ + :str: Identifies that this is a Flax model. + """ + return "flax" + + def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any: + """ + Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`. + """ + + # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27 + def conditional_cast(param): + if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating): + param = param.astype(dtype) + return param + + if mask is None: + return jax.tree_map(conditional_cast, params) + + flat_params = flatten_dict(params) + flat_mask, _ = jax.tree_flatten(mask) + + for masked, key in zip(flat_mask, flat_params.keys()): + if masked: + param = flat_params[key] + flat_params[key] = conditional_cast(param) + + return unflatten_dict(flat_params) + + def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None): + r""" + Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast + the `params` in place. + + This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full + half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed. + + Arguments: + params (`Union[Dict, FrozenDict]`): + A `PyTree` of model parameters. + mask (`Union[Dict, FrozenDict]`): + A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params + you want to cast, and should be `False` for those you want to skip. + + Examples: + + ```python + >>> from transformers import FlaxBertModel + + >>> # load model + >>> model = FlaxBertModel.from_pretrained("bert-base-cased") + >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision + >>> model.params = model.to_bf16(model.params) + >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale) + >>> # then pass the mask as follows + >>> from flax import traverse_util + + >>> model = FlaxBertModel.from_pretrained("bert-base-cased") + >>> flat_params = traverse_util.flatten_dict(model.params) + >>> mask = { + ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) + ... for path in flat_params + ... } + >>> mask = traverse_util.unflatten_dict(mask) + >>> model.params = model.to_bf16(model.params, mask) + ```""" + return self._cast_floating_to(params, jnp.bfloat16, mask) + + def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None): + r""" + Cast the floating-point `parmas` to `jax.numpy.float32`. This method can be used to explicitly convert the + model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place. + + Arguments: + params (`Union[Dict, FrozenDict]`): + A `PyTree` of model parameters. + mask (`Union[Dict, FrozenDict]`): + A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params + you want to cast, and should be `False` for those you want to skip + + Examples: + + ```python + >>> from transformers import FlaxBertModel + + >>> # Download model and configuration from huggingface.co + >>> model = FlaxBertModel.from_pretrained("bert-base-cased") + >>> # By default, the model params will be in fp32, to illustrate the use of this method, + >>> # we'll first cast to fp16 and back to fp32 + >>> model.params = model.to_f16(model.params) + >>> # now cast back to fp32 + >>> model.params = model.to_fp32(model.params) + ```""" + return self._cast_floating_to(params, jnp.float32, mask) + + def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None): + r""" + Cast the floating-point `parmas` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the + `params` in place. + + This method can be used on GPU to explicitly convert the model parameters to float16 precision to do full + half-precision training or to save weights in float16 for inference in order to save memory and improve speed. + + Arguments: + params (`Union[Dict, FrozenDict]`): + A `PyTree` of model parameters. + mask (`Union[Dict, FrozenDict]`): + A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params + you want to cast, and should be `False` for those you want to skip + + Examples: + + ```python + >>> from transformers import FlaxBertModel + + >>> # load model + >>> model = FlaxBertModel.from_pretrained("bert-base-cased") + >>> # By default, the model params will be in fp32, to cast these to float16 + >>> model.params = model.to_fp16(model.params) + >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale) + >>> # then pass the mask as follows + >>> from flax import traverse_util + + >>> model = FlaxBertModel.from_pretrained("bert-base-cased") + >>> flat_params = traverse_util.flatten_dict(model.params) + >>> mask = { + ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) + ... for path in flat_params + ... } + >>> mask = traverse_util.unflatten_dict(mask) + >>> model.params = model.to_fp16(model.params, mask) + ```""" + return self._cast_floating_to(params, jnp.float16, mask) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + dtype: jnp.dtype = jnp.float32, + *model_args, + **kwargs, + ): + r""" + Instantiate a pretrained flax model from a pre-trained model configuration. + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a + user or organization name, like `dbmdz/bert-base-german-cased`. + - A path to a *directory* containing model weights saved using [`~ModelMixin.save_pretrained`], + e.g., `./my_model_directory/`. + - A path or url to a *pt index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In this case, + `from_pt` should be set to `True`. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~ModelMixin.to_fp16`] and + [`~ModelMixin.to_bf16`]. + model_args (sequence of positional arguments, *optional*): + All remaining positional arguments will be passed to the underlying model's `__init__` method. + config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*): + Can be either: + + - an instance of a class derived from [`PretrainedConfig`], + - a string or path valid as input to [`~PretrainedConfig.from_pretrained`]. + + Configuration for the model to use instead of an automatically loaded configuration. Configuration can + be automatically loaded when: + + - The model is a model provided by the library (loaded with the *model id* string of a pretrained + model). + - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the + save directory. + - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a + configuration JSON file named *config.json* is found in the directory. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + from_pt (`bool`, *optional*, defaults to `False`): + Load the model weights from a PyTorch checkpoint save file (see docstring of + `pretrained_model_name_or_path` argument). + ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): + Whether or not to raise an error if some of the weights from the checkpoint do not have the same size + as the weights of the model (if for instance, you are instantiating a model with 10 labels from a + checkpoint with 3 labels). + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + `output_attentions=True`). Behaves differently depending on whether a `config` is provided or + automatically loaded: + + - If a configuration is provided with `config`, `**kwargs` will be directly passed to the + underlying model's `__init__` method (we assume all relevant updates to the configuration have + already been done) + - If a configuration is not provided, `kwargs` will be first passed to the configuration class + initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that + corresponds to a configuration attribute will be used to override said attribute with the + supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute + will be passed to the underlying model's `__init__` function. + + Examples: + + ```python + >>> from transformers import BertConfig, FlaxBertModel + + >>> # Download model and configuration from huggingface.co and cache. + >>> model = FlaxBertModel.from_pretrained("bert-base-cased") + >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable). + >>> model = FlaxBertModel.from_pretrained("./test/saved_model/") + >>> # Loading from a PyTorch checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable). + >>> config = BertConfig.from_json_file("./pt_model/config.json") + >>> model = FlaxBertModel.from_pretrained("./pt_model/pytorch_model.bin", from_pt=True, config=config) + ```""" + config = kwargs.pop("config", None) + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + from_auto_class = kwargs.pop("_from_auto", False) + subfolder = kwargs.pop("subfolder", None) + + user_agent = {"file_type": "model", "framework": "flax", "from_auto_class": from_auto_class} + + # Load config if we don't provide a configuration + config_path = config if config is not None else pretrained_model_name_or_path + model, model_kwargs = cls.from_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + # model args + dtype=dtype, + **kwargs, + ) + + # Load model + if os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)): + # Load from a Flax checkpoint + model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME) + # At this stage we don't have a weight file so we will raise an error. + elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME): + raise EnvironmentError( + f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " + "but there is a file for PyTorch weights." + ) + else: + raise EnvironmentError( + f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory " + f"{pretrained_model_name_or_path}." + ) + else: + try: + model_file = hf_hub_download( + pretrained_model_name_or_path, + filename=FLAX_WEIGHTS_NAME, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, + subfolder=subfolder, + revision=revision, + ) + + except RepositoryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " + "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " + "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli " + "login` and pass `use_auth_token=True`." + ) + except RevisionNotFoundError: + raise EnvironmentError( + f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " + "this model name. Check the model page at " + f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." + ) + except EntryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME}." + ) + except HTTPError as err: + raise EnvironmentError( + f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n" + f"{err}" + ) + except ValueError: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" + f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" + f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your" + " internet connection or see how to run the library in offline mode at" + " 'https://huggingface.co/docs/transformers/installation#offline-mode'." + ) + except EnvironmentError: + raise EnvironmentError( + f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " + "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " + f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " + f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}." + ) + + try: + with open(model_file, "rb") as state_f: + state = from_bytes(cls, state_f.read()) + except (UnpicklingError, msgpack.exceptions.ExtraData) as e: + try: + with open(model_file) as f: + if f.read().startswith("version"): + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please" + " install git-lfs and run `git lfs install` followed by `git lfs pull` in the" + " folder you cloned." + ) + else: + raise ValueError from e + except (UnicodeDecodeError, ValueError): + raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ") + # make sure all arrays are stored as jnp.arrays + # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4: + # https://github.com/google/flax/issues/1261 + state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state) + + # flatten dicts + state = flatten_dict(state) + + # dictionary of key: dtypes for the model params + param_dtypes = jax.tree_map(lambda x: x.dtype, state) + # extract keys of parameters not in jnp.float32 + fp16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.float16] + bf16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.bfloat16] + + # raise a warning if any of the parameters are not in jnp.float32 + if len(fp16_params) > 0: + logger.warning( + f"Some of the weights of {model.__class__.__name__} were initialized in float16 precision from " + f"the model checkpoint at {pretrained_model_name_or_path}:\n{fp16_params}\n" + "You should probably UPCAST the model weights to float32 if this was not intended. " + "See [`~ModelMixin.to_fp32`] for further information on how to do this." + ) + + if len(bf16_params) > 0: + logger.warning( + f"Some of the weights of {model.__class__.__name__} were initialized in bfloat16 precision from " + f"the model checkpoint at {pretrained_model_name_or_path}:\n{bf16_params}\n" + "You should probably UPCAST the model weights to float32 if this was not intended. " + "See [`~ModelMixin.to_fp32`] for further information on how to do this." + ) + + return model, unflatten_dict(state) + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + params: Union[Dict, FrozenDict], + is_main_process: bool = True, + **kwargs, + ): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + `[`~FlaxPreTrainedModel.from_pretrained`]` class method + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face model hub after saving it. + + + + Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`, + which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing + folder. Pass along `temp_dir=True` to use a temporary directory instead. + + + + kwargs: + Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + os.makedirs(save_directory, exist_ok=True) + + model_to_save = self + + # Attach architecture to the config + # Save the config + if is_main_process: + model_to_save.save_config(save_directory) + + # save model + output_model_file = os.path.join(save_directory, FLAX_WEIGHTS_NAME) + with open(output_model_file, "wb") as f: + model_bytes = to_bytes(params) + f.write(model_bytes) + + logger.info(f"Model weights saved in {output_model_file}") From 9891e5c49b95bd45b2f0f567000d31eddc481236 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 13 Sep 2022 19:15:48 +0200 Subject: [PATCH 05/31] Use new mixins for Flax UNet. For some reason the configuration is not correctly applied; the signature of the `__init__` method does not contain all the parameters by the time it's inspected in `extract_init_dict`. --- .../models/unet_2d_condition_flax.py | 148 ++++++------------ 1 file changed, 44 insertions(+), 104 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 249390a16294..d75ee8cf5817 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -3,10 +3,10 @@ import flax.linen as nn import jax import jax.numpy as jnp -from flax.core.frozen_dict import FrozenDict +# from flax.core.frozen_dict import FrozenDict -from ..configuration_utils import ConfigMixin, register_to_config -from ..modeling_utils import FlaxModelMixin +from ..configuration_utils import ConfigMixin, flax_register_to_config +from ..modeling_flax_utils import FlaxModelMixin from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps from .unet_blocks_flax import ( FlaxDownBlock2D, @@ -17,22 +17,37 @@ ) -# This is TBD. We may not need the module + the class -class FlaxUNet2DModule(nn.Module): - # config args - sample_size=32, - in_channels=4, - out_channels=4, - down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"), - up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), - block_out_channels=(224, 448, 672, 896), - layers_per_block=2, - attention_head_dim=8, - cross_attention_dim=768, - dropout=0.1, +@flax_register_to_config +class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): + sample_size=32 + in_channels=4 + out_channels=4 + down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D") + up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D") + block_out_channels=(224, 448, 672, 896) + layers_per_block=2 + attention_head_dim=8 + cross_attention_dim=768 + dropout=0.1 + dtype: jnp.dtype = jnp.float32 # model args - dtype: jnp.dtype = jnp.float32 + # input_shape: Tuple = (1, 32, 32, 4) + # seed: int = 0 + + # # Note: input_shape is ignored + # def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: + # # init input tensors + # sample_shape = (1, self.module.sample_size, self.module.sample_size, self.module.in_channels) + # sample = jnp.zeros(sample_shape, dtype=jnp.float32) + # timesteps = jnp.ones((1,), dtype=jnp.int32) + # encoder_hidden_states = jnp.zeros((1, 1, self.module.cross_attention_dim), dtype=jnp.float32) + + # params_rng, dropout_rng = jax.random.split(rng) + # rngs = {"params": params_rng, "dropout": dropout_rng} + + # return self.module.init(rngs, sample, timesteps, encoder_hidden_states)["params"] + def setup(self): block_out_channels = self.block_out_channels @@ -137,7 +152,18 @@ def setup(self): dtype=self.dtype, ) - def __call__(self, sample, timesteps, encoder_hidden_states, deterministic=True): + def __call__( + self, + sample, + timesteps, + encoder_hidden_states, + # params: dict = None, + # dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + ): + # Handle any PRNG if needed + # rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + # 1. time # broadcast to batch dimension # timesteps = jnp.broadcast_to(timesteps, (sample.shape[0],) + timesteps.shape) @@ -179,89 +205,3 @@ def __call__(self, sample, timesteps, encoder_hidden_states, deterministic=True) sample = self.conv_out(sample) return sample - - -class FlaxUNet2DConditionModel(nn.Module, ConfigMixin, FlaxModelMixin): - base_model_prefix = "model" - module_class = FlaxUNet2DModule - ignore_for_config = ["input_shape", "seed", "dtype", "_do_init"] - - @register_to_config - def __init__( - self, - # config args - sample_size=32, - in_channels=4, - out_channels=4, - down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"), - up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), - block_out_channels=(224, 448, 672, 896), - layers_per_block=2, - attention_head_dim=8, - cross_attention_dim=768, - dropout=0.1, - - # model args - input_shape: Tuple = (1, 32, 32, 4), - seed: int = 0, - dtype: jnp.dtype = jnp.float32, - _do_init: bool = True, - **kwargs, - ): - module = self.module_class( - sample_size=sample_size, - in_channels=in_channels, - out_channels=out_channels, - down_block_types=down_block_types, - up_block_types=up_block_types, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - attention_head_dim=attention_head_dim, - cross_attention_dim=cross_attention_dim, - dropout=dropout, - dtype=dtype, **kwargs) - super().__init__( - module, - input_shape=input_shape, - seed=seed, - dtype=dtype, - _do_init=_do_init - ) - - # Note: input_shape is ignored - def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: - # init input tensors - sample_shape = (1, self.module.sample_size, self.module.sample_size, self.module.in_channels) - sample = jnp.zeros(sample_shape, dtype=jnp.float32) - timesteps = jnp.ones((1,), dtype=jnp.int32) - encoder_hidden_states = jnp.zeros((1, 1, self.module.cross_attention_dim), dtype=jnp.float32) - - params_rng, dropout_rng = jax.random.split(rng) - rngs = {"params": params_rng, "dropout": dropout_rng} - - return self.module.init(rngs, sample, timesteps, encoder_hidden_states)["params"] - - def __call__( - self, - sample, - timesteps, - encoder_hidden_states, - params: dict = None, - dropout_rng: jax.random.PRNGKey = None, - train: bool = False, - ): - # Handle any PRNG if needed - rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} - - return self.module.apply( - {"params": params or self.params}, - jnp.array(sample), - jnp.array(timesteps, dtype=jnp.int32), - encoder_hidden_states, - not train, - rngs=rngs, - ) - - -# class UNet2D(UNet2DPretrainedModel): -# module_class = UNet2DModule From 25c615a2d90ad3e69ec211e7212dd923fc4de26c Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 13 Sep 2022 19:20:03 +0200 Subject: [PATCH 06/31] Import `FlaxUNet2DConditionModel` if flax is available. --- src/diffusers/__init__.py | 1 + src/diffusers/utils/dummy_flax_objects.py | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 34cc16591d40..fd6acd6c416c 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -64,5 +64,6 @@ if is_flax_available(): from .schedulers import FlaxPNDMScheduler + from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel else: from .utils.dummy_flax_objects import * # noqa F403 diff --git a/src/diffusers/utils/dummy_flax_objects.py b/src/diffusers/utils/dummy_flax_objects.py index b5f4362bcb6e..610d7723cd9e 100644 --- a/src/diffusers/utils/dummy_flax_objects.py +++ b/src/diffusers/utils/dummy_flax_objects.py @@ -9,3 +9,10 @@ class FlaxPNDMScheduler(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) + + +class FlaxUNet2DConditionModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) From 91559f3107c7d31763094dae4076837d81cbcea1 Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Wed, 14 Sep 2022 08:23:27 +0000 Subject: [PATCH 07/31] Rm unused method `framework` --- src/diffusers/modeling_flax_utils.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index dbfe258c9c13..b8deb4487c66 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -55,13 +55,6 @@ def _from_config(cls, config, **kwargs): """ return cls(config, **kwargs) - @property - def framework(self) -> str: - """ - :str: Identifies that this is a Flax model. - """ - return "flax" - def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any: """ Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`. From f7a0ab2d45b6c6d6068cf5c506f3b704fcf9a388 Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Wed, 14 Sep 2022 10:24:12 +0200 Subject: [PATCH 08/31] Update src/diffusers/modeling_flax_utils.py Co-authored-by: Suraj Patil --- src/diffusers/modeling_flax_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index b8deb4487c66..4eab80085d92 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -452,7 +452,7 @@ def save_pretrained( ): """ Save a model and its configuration file to a directory, so that it can be re-loaded using the - `[`~FlaxPreTrainedModel.from_pretrained`]` class method + `[`~FlaxModelMixin.from_pretrained`]` class method Arguments: save_directory (`str` or `os.PathLike`): From d41f2bf0fdea84debcca7778b2305ddf52bd9330 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 14 Sep 2022 10:27:42 +0200 Subject: [PATCH 09/31] Indicate types in flax.struct.dataclass as pointed out by @mishig25 Co-authored-by: Mishig Davaadorj --- .../models/unet_2d_condition_flax.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index d75ee8cf5817..82f044b65874 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -19,16 +19,16 @@ @flax_register_to_config class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): - sample_size=32 - in_channels=4 - out_channels=4 - down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D") - up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D") - block_out_channels=(224, 448, 672, 896) - layers_per_block=2 - attention_head_dim=8 - cross_attention_dim=768 - dropout=0.1 + sample_size:int=32 + in_channels:int=4 + out_channels:int=4 + down_block_types:Tuple=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D") + up_block_types:Tuple=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D") + block_out_channels:Tuple=(224, 448, 672, 896) + layers_per_block:int=2 + attention_head_dim:int=8 + cross_attention_dim:int=768 + dropout:float=0.1 dtype: jnp.dtype = jnp.float32 # model args From e0ec7bffacd892a2122f3dd24d6c480780a079c0 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 14 Sep 2022 10:41:17 +0200 Subject: [PATCH 10/31] Fix typo in transformer block. --- src/diffusers/models/attention_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 77e5ad9c75fb..e6e33cb16a1e 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -129,7 +129,7 @@ def setup(self): ) self.transformer_blocks = [ - TransformerBlock(inner_dim, self.n_heads, self.d_head, dropout=self.dropout, dtype=self.dtype) + FlaxBasicTransformerBlock(inner_dim, self.n_heads, self.d_head, dropout=self.dropout, dtype=self.dtype) for _ in range(self.depth) ] From 5e7aeea3c74a317bab2d4c4dc603776411d7b5b0 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 14 Sep 2022 10:44:15 +0200 Subject: [PATCH 11/31] make style --- src/diffusers/__init__.py | 2 +- .../models/unet_2d_condition_flax.py | 30 ++++++++++--------- src/diffusers/models/unet_blocks_flax.py | 3 +- 3 files changed, 19 insertions(+), 16 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index fd6acd6c416c..5ff99d64b756 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -63,7 +63,7 @@ from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403 if is_flax_available(): - from .schedulers import FlaxPNDMScheduler from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel + from .schedulers import FlaxPNDMScheduler else: from .utils.dummy_flax_objects import * # noqa F403 diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 82f044b65874..90e62509baea 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -3,32 +3,35 @@ import flax.linen as nn import jax import jax.numpy as jnp -# from flax.core.frozen_dict import FrozenDict from ..configuration_utils import ConfigMixin, flax_register_to_config from ..modeling_flax_utils import FlaxModelMixin from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps from .unet_blocks_flax import ( - FlaxDownBlock2D, FlaxCrossAttnDownBlock2D, + FlaxCrossAttnUpBlock2D, + FlaxDownBlock2D, FlaxUNetMidBlock2DCrossAttn, FlaxUpBlock2D, - FlaxCrossAttnUpBlock2D, ) +# from flax.core.frozen_dict import FrozenDict + + + @flax_register_to_config class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): - sample_size:int=32 - in_channels:int=4 - out_channels:int=4 - down_block_types:Tuple=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D") - up_block_types:Tuple=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D") - block_out_channels:Tuple=(224, 448, 672, 896) - layers_per_block:int=2 - attention_head_dim:int=8 - cross_attention_dim:int=768 - dropout:float=0.1 + sample_size: int = 32 + in_channels: int = 4 + out_channels: int = 4 + down_block_types: Tuple = ("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D") + up_block_types: Tuple = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D") + block_out_channels: Tuple = (224, 448, 672, 896) + layers_per_block: int = 2 + attention_head_dim: int = 8 + cross_attention_dim: int = 768 + dropout: float = 0.1 dtype: jnp.dtype = jnp.float32 # model args @@ -48,7 +51,6 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): # return self.module.init(rngs, sample, timesteps, encoder_hidden_states)["params"] - def setup(self): block_out_channels = self.block_out_channels time_embed_dim = block_out_channels[0] * 4 diff --git a/src/diffusers/models/unet_blocks_flax.py b/src/diffusers/models/unet_blocks_flax.py index 5de83bb2a559..b9802a246305 100644 --- a/src/diffusers/models/unet_blocks_flax.py +++ b/src/diffusers/models/unet_blocks_flax.py @@ -15,7 +15,8 @@ import jax.numpy as jnp from .attention_flax import FlaxAttentionBlock, FlaxSpatialTransformer -from .resnet_flax import FlaxDownsample2D, FlaxUpsample2D, FlaxResnetBlock2D +from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D + class FlaxCrossAttnDownBlock2D(nn.Module): in_channels: int From 5d81bf8eef9677d058bffc68b0330dbb55a88d9e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 14 Sep 2022 14:03:54 +0200 Subject: [PATCH 12/31] some more changes --- src/diffusers/configuration_utils.py | 41 ++++++++++++++-------------- src/diffusers/modeling_flax_utils.py | 2 -- 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index bd08f25bffdf..07f2f73d2cf3 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -21,6 +21,7 @@ import re from collections import OrderedDict from typing import Any, Dict, Tuple, Union +import dataclasses from huggingface_hub import hf_hub_download from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError @@ -408,36 +409,36 @@ def flax_register_to_config(cls): @functools.wraps(original_init) def init(self, *args, **kwargs): - # Ignore private kwargs in the init. - init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} - # original_init(self, *args, **init_kwargs) if not isinstance(self, ConfigMixin): raise RuntimeError( f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does " "not inherit from `ConfigMixin`." ) - ignore = getattr(self, "ignore_for_config", []) + # Ignore private kwargs in the init. Retrieve all passed attributes + init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} + + # Retrieve default values + fields = dataclasses.fields(self) + default_kwargs = {} + for field in fields: + if field.name in ("parent", "name"): + continue + if type(field.default) == dataclasses._MISSING_TYPE: + default_kwargs[field.name] = None + else: + default_kwargs[field.name] = getattr(self, field.name) + + # Make sure init_kwargs override default kwargs + new_kwargs = {**default_kwargs, **init_kwargs} + # Get positional arguments aligned with kwargs - new_kwargs = {} - signature = inspect.signature(init) - parameters = { - name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore - } - for arg, name in zip(args, parameters.keys()): + for i, arg in enumerate(args): + name = fields[i].name new_kwargs[name] = arg - # Then add all kwargs - new_kwargs.update( - { - k: init_kwargs.get(k, default) - for k, default in parameters.items() - if k not in ignore and k not in new_kwargs - } - ) getattr(self, "register_to_config")(**new_kwargs) - - original_init(self, *args, **init_kwargs) + original_init(self, *args, **kwargs) cls.__init__ = init return cls diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 4eab80085d92..ba431686d7ec 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -43,9 +43,7 @@ class FlaxModelMixin: [`FlaxModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading and saving models. """ - _missing_keys = set() config_name = CONFIG_NAME - ignore_for_config = ["parent", "name"] _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] @classmethod From 1430ab80807feccf34876053f8e11e7e4eb7d08e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 14 Sep 2022 14:04:52 +0200 Subject: [PATCH 13/31] make style --- src/diffusers/configuration_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 07f2f73d2cf3..f648c932083a 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ ConfigMixinuration base class and utilities.""" +import dataclasses import functools import inspect import json @@ -21,7 +22,6 @@ import re from collections import OrderedDict from typing import Any, Dict, Tuple, Union -import dataclasses from huggingface_hub import hf_hub_download from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError From 6a2a4c1f13beeb704b967509c3e4db7e2f4c96e0 Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Wed, 14 Sep 2022 12:45:48 +0000 Subject: [PATCH 14/31] Add comment --- src/diffusers/configuration_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index f648c932083a..e567655f4f77 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -422,6 +422,7 @@ def init(self, *args, **kwargs): fields = dataclasses.fields(self) default_kwargs = {} for field in fields: + # ignore flax specific attributes if field.name in ("parent", "name"): continue if type(field.default) == dataclasses._MISSING_TYPE: From 2bf02677cd94cd5f3ae0b6e0a90180aa32fe0282 Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Wed, 14 Sep 2022 14:53:18 +0200 Subject: [PATCH 15/31] Update src/diffusers/modeling_flax_utils.py Co-authored-by: Patrick von Platen --- src/diffusers/modeling_flax_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index ba431686d7ec..e06552e032f1 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -31,7 +31,7 @@ from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging -FLAX_WEIGHTS_NAME = "flax_model.msgpack" # TODO should be "diffusion_flax_model.msgpack" +FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" # TODO should be "diffusion_flax_model.msgpack" logger = logging.get_logger(__name__) From 25ab3cad11e624fd16d16283fa7224fe082881b8 Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Wed, 14 Sep 2022 12:53:51 +0000 Subject: [PATCH 16/31] Rm unneeded comment --- src/diffusers/modeling_flax_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index e06552e032f1..61b66e5bb546 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -31,7 +31,7 @@ from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging -FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" # TODO should be "diffusion_flax_model.msgpack" +FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" logger = logging.get_logger(__name__) From 1e8466e49c5678ac0c719e478d32a9ad3966d82c Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Wed, 14 Sep 2022 13:05:34 +0000 Subject: [PATCH 17/31] Update docstrings --- src/diffusers/modeling_flax_utils.py | 44 ++++++---------------------- 1 file changed, 9 insertions(+), 35 deletions(-) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 61b66e5bb546..7a00cce850fc 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -118,7 +118,7 @@ def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None): def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None): r""" - Cast the floating-point `parmas` to `jax.numpy.float32`. This method can be used to explicitly convert the + Cast the floating-point `params` to `jax.numpy.float32`. This method can be used to explicitly convert the model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place. Arguments: @@ -145,7 +145,7 @@ def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None): def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None): r""" - Cast the floating-point `parmas` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the + Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the `params` in place. This method can be used on GPU to explicitly convert the model parameters to float16 precision to do full @@ -225,27 +225,9 @@ def from_pretrained( [`~ModelMixin.to_bf16`]. model_args (sequence of positional arguments, *optional*): All remaining positional arguments will be passed to the underlying model's `__init__` method. - config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*): - Can be either: - - - an instance of a class derived from [`PretrainedConfig`], - - a string or path valid as input to [`~PretrainedConfig.from_pretrained`]. - - Configuration for the model to use instead of an automatically loaded configuration. Configuration can - be automatically loaded when: - - - The model is a model provided by the library (loaded with the *model id* string of a pretrained - model). - - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the - save directory. - - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a - configuration JSON file named *config.json* is found in the directory. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - from_pt (`bool`, *optional*, defaults to `False`): - Load the model weights from a PyTorch checkpoint save file (see docstring of - `pretrained_model_name_or_path` argument). ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): Whether or not to raise an error if some of the weights from the checkpoint do not have the same size as the weights of the model (if for instance, you are instantiating a model with 10 labels from a @@ -274,7 +256,7 @@ def from_pretrained( underlying model's `__init__` method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided, `kwargs` will be first passed to the configuration class - initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that + initialization function ([`~ConfigMixin.from_config`]). Each key of `kwargs` that corresponds to a configuration attribute will be used to override said attribute with the supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's `__init__` function. @@ -446,7 +428,6 @@ def save_pretrained( save_directory: Union[str, os.PathLike], params: Union[Dict, FrozenDict], is_main_process: bool = True, - **kwargs, ): """ Save a model and its configuration file to a directory, so that it can be re-loaded using the @@ -455,19 +436,12 @@ def save_pretrained( Arguments: save_directory (`str` or `os.PathLike`): Directory to which to save. Will be created if it doesn't exist. - push_to_hub (`bool`, *optional*, defaults to `False`): - Whether or not to push your model to the Hugging Face model hub after saving it. - - - - Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`, - which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing - folder. Pass along `temp_dir=True` to use a temporary directory instead. - - - - kwargs: - Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + params (`Union[Dict, FrozenDict]`): + A `PyTree` of model parameters. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. """ if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") From 6842d29e30a497f52331c90d277241bde9fad3e9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 14 Sep 2022 15:07:55 +0200 Subject: [PATCH 18/31] correct ignore kwargs --- src/diffusers/configuration_utils.py | 7 ++++++- src/diffusers/modeling_flax_utils.py | 1 + 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index e567655f4f77..bb66205412c3 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -272,6 +272,11 @@ def extract_init_dict(cls, config_dict, **kwargs): # remove general kwargs if present in dict if "kwargs" in expected_keys: expected_keys.remove("kwargs") + # remove flax interal keys + if hasattr(cls, "_flax_internal_args"): + for arg in cls._flax_internal_args: + expected_keys.remove(arg) + # remove keys to be ignored if len(cls.ignore_for_config) > 0: expected_keys = expected_keys - set(cls.ignore_for_config) @@ -423,7 +428,7 @@ def init(self, *args, **kwargs): default_kwargs = {} for field in fields: # ignore flax specific attributes - if field.name in ("parent", "name"): + if field.name in self._flax_internal_args: continue if type(field.default) == dataclasses._MISSING_TYPE: default_kwargs[field.name] = None diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 61b66e5bb546..06658cf60ad8 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -45,6 +45,7 @@ class FlaxModelMixin: """ config_name = CONFIG_NAME _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] + _flax_internal_args = ["name", "parent"] @classmethod def _from_config(cls, config, **kwargs): From 0f26c05ab096f5914cd3eb5f81888af59e61a9eb Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 14 Sep 2022 15:08:06 +0200 Subject: [PATCH 19/31] make style --- src/diffusers/modeling_flax_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 77121c9327fc..f4aa7f683b60 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -257,10 +257,10 @@ def from_pretrained( underlying model's `__init__` method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided, `kwargs` will be first passed to the configuration class - initialization function ([`~ConfigMixin.from_config`]). Each key of `kwargs` that - corresponds to a configuration attribute will be used to override said attribute with the - supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute - will be passed to the underlying model's `__init__` function. + initialization function ([`~ConfigMixin.from_config`]). Each key of `kwargs` that corresponds to + a configuration attribute will be used to override said attribute with the supplied `kwargs` + value. Remaining keys that do not correspond to any configuration attribute will be passed to the + underlying model's `__init__` function. Examples: From d98e8c70565919f3983fe333aa42fdedd856e781 Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Wed, 14 Sep 2022 13:14:47 +0000 Subject: [PATCH 20/31] Update docstring examples --- src/diffusers/modeling_flax_utils.py | 44 +++++++++++++--------------- 1 file changed, 20 insertions(+), 24 deletions(-) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 7a00cce850fc..eb5b29c3ac44 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -95,24 +95,24 @@ def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None): Examples: ```python - >>> from transformers import FlaxBertModel + >>> from diffusers import FlaxUNet2DConditionModel >>> # load model - >>> model = FlaxBertModel.from_pretrained("bert-base-cased") + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4") >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision - >>> model.params = model.to_bf16(model.params) + >>> params = model.to_bf16(params) >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale) >>> # then pass the mask as follows >>> from flax import traverse_util - >>> model = FlaxBertModel.from_pretrained("bert-base-cased") - >>> flat_params = traverse_util.flatten_dict(model.params) + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4") + >>> flat_params = traverse_util.flatten_dict(params) >>> mask = { ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) ... for path in flat_params ... } >>> mask = traverse_util.unflatten_dict(mask) - >>> model.params = model.to_bf16(model.params, mask) + >>> params = model.to_bf16(params, mask) ```""" return self._cast_floating_to(params, jnp.bfloat16, mask) @@ -131,15 +131,15 @@ def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None): Examples: ```python - >>> from transformers import FlaxBertModel + >>> from diffusers import FlaxUNet2DConditionModel >>> # Download model and configuration from huggingface.co - >>> model = FlaxBertModel.from_pretrained("bert-base-cased") + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4") >>> # By default, the model params will be in fp32, to illustrate the use of this method, >>> # we'll first cast to fp16 and back to fp32 - >>> model.params = model.to_f16(model.params) + >>> params = model.to_f16(params) >>> # now cast back to fp32 - >>> model.params = model.to_fp32(model.params) + >>> params = model.to_fp32(params) ```""" return self._cast_floating_to(params, jnp.float32, mask) @@ -161,24 +161,24 @@ def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None): Examples: ```python - >>> from transformers import FlaxBertModel + >>> from diffusers import FlaxUNet2DConditionModel >>> # load model - >>> model = FlaxBertModel.from_pretrained("bert-base-cased") + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4") >>> # By default, the model params will be in fp32, to cast these to float16 - >>> model.params = model.to_fp16(model.params) + >>> params = model.to_fp16(params) >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale) >>> # then pass the mask as follows >>> from flax import traverse_util - >>> model = FlaxBertModel.from_pretrained("bert-base-cased") - >>> flat_params = traverse_util.flatten_dict(model.params) + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4") + >>> flat_params = traverse_util.flatten_dict(params) >>> mask = { ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale")) ... for path in flat_params ... } >>> mask = traverse_util.unflatten_dict(mask) - >>> model.params = model.to_fp16(model.params, mask) + >>> params = model.to_fp16(params, mask) ```""" return self._cast_floating_to(params, jnp.float16, mask) @@ -205,8 +205,7 @@ def from_pretrained( Can be either: - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a - user or organization name, like `dbmdz/bert-base-german-cased`. + Valid model ids are namespaced under a user or organization name, like `CompVis/stable-diffusion-v1-4`. - A path to a *directory* containing model weights saved using [`~ModelMixin.save_pretrained`], e.g., `./my_model_directory/`. - A path or url to a *pt index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In this case, @@ -264,15 +263,12 @@ def from_pretrained( Examples: ```python - >>> from transformers import BertConfig, FlaxBertModel + >>> from diffusers import FlaxUNet2DConditionModel >>> # Download model and configuration from huggingface.co and cache. - >>> model = FlaxBertModel.from_pretrained("bert-base-cased") + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4") >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable). - >>> model = FlaxBertModel.from_pretrained("./test/saved_model/") - >>> # Loading from a PyTorch checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable). - >>> config = BertConfig.from_json_file("./pt_model/config.json") - >>> model = FlaxBertModel.from_pretrained("./pt_model/pytorch_model.bin", from_pt=True, config=config) + >>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/") ```""" config = kwargs.pop("config", None) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) From 5d085770979b3e61ae70da88210d0dd7a61c9ddc Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Wed, 14 Sep 2022 13:21:24 +0000 Subject: [PATCH 21/31] Make style --- src/diffusers/modeling_flax_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/modeling_flax_utils.py b/src/diffusers/modeling_flax_utils.py index 6b5b56164981..1abf900fa359 100644 --- a/src/diffusers/modeling_flax_utils.py +++ b/src/diffusers/modeling_flax_utils.py @@ -206,7 +206,8 @@ def from_pretrained( Can be either: - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - Valid model ids are namespaced under a user or organization name, like `CompVis/stable-diffusion-v1-4`. + Valid model ids are namespaced under a user or organization name, like + `CompVis/stable-diffusion-v1-4`. - A path to a *directory* containing model weights saved using [`~ModelMixin.save_pretrained`], e.g., `./my_model_directory/`. - A path or url to a *pt index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In this case, From 39bbd13c36d4de5888390e31102e59289dd4e545 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 14 Sep 2022 18:37:56 +0200 Subject: [PATCH 22/31] Style: remove empty line. --- src/diffusers/models/unet_2d_condition_flax.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 90e62509baea..0783895c0204 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -19,7 +19,6 @@ # from flax.core.frozen_dict import FrozenDict - @flax_register_to_config class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): sample_size: int = 32 From ea99f35f3f3b867e959d19986774ed9d9c514248 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 14 Sep 2022 19:08:04 +0200 Subject: [PATCH 23/31] Apply style (after upgrading black from pinned version) --- src/diffusers/models/unet_blocks_flax.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/unet_blocks_flax.py b/src/diffusers/models/unet_blocks_flax.py index b9802a246305..be1c6b021e60 100644 --- a/src/diffusers/models/unet_blocks_flax.py +++ b/src/diffusers/models/unet_blocks_flax.py @@ -154,7 +154,6 @@ def setup(self): self.upsample = FlaxUpsample2D(self.out_channels, dtype=self.dtype) def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True): - for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] From 2d896f6d3812f8a8aa3e64abe499b6a19f93785e Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 15 Sep 2022 09:50:03 +0200 Subject: [PATCH 24/31] Remove some commented code and unused imports. --- src/diffusers/models/attention_flax.py | 2 +- src/diffusers/models/unet_2d_condition_flax.py | 12 ------------ src/diffusers/models/unet_blocks_flax.py | 2 +- 3 files changed, 2 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index e6e33cb16a1e..7090de328e63 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -65,7 +65,7 @@ def __call__(self, hidden_states, context=None, deterministic=True): attn_weights = attn_weights * self.scale attn_weights = nn.softmax(attn_weights, axis=2) - ## attend to values + # attend to values hidden_states = jnp.einsum("b i j, b j d -> b i d", attn_weights, v) hidden_states = self.reshape_batch_dim_to_heads(hidden_states) hidden_states = self.to_out(hidden_states) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 0783895c0204..402e2e563e01 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -1,7 +1,6 @@ from typing import Tuple import flax.linen as nn -import jax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, flax_register_to_config @@ -16,9 +15,6 @@ ) -# from flax.core.frozen_dict import FrozenDict - - @flax_register_to_config class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): sample_size: int = 32 @@ -33,10 +29,6 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): dropout: float = 0.1 dtype: jnp.dtype = jnp.float32 - # model args - # input_shape: Tuple = (1, 32, 32, 4) - # seed: int = 0 - # # Note: input_shape is ignored # def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: # # init input tensors @@ -158,16 +150,12 @@ def __call__( sample, timesteps, encoder_hidden_states, - # params: dict = None, - # dropout_rng: jax.random.PRNGKey = None, train: bool = False, ): # Handle any PRNG if needed # rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} # 1. time - # broadcast to batch dimension - # timesteps = jnp.broadcast_to(timesteps, (sample.shape[0],) + timesteps.shape) t_emb = self.time_proj(timesteps) t_emb = self.time_embedding(t_emb) diff --git a/src/diffusers/models/unet_blocks_flax.py b/src/diffusers/models/unet_blocks_flax.py index be1c6b021e60..49c89fc211ed 100644 --- a/src/diffusers/models/unet_blocks_flax.py +++ b/src/diffusers/models/unet_blocks_flax.py @@ -14,7 +14,7 @@ import flax.linen as nn import jax.numpy as jnp -from .attention_flax import FlaxAttentionBlock, FlaxSpatialTransformer +from .attention_flax import FlaxSpatialTransformer from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D From da6ddfd382fd3790529b2cc6b879752989e6a795 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 15 Sep 2022 09:54:59 +0200 Subject: [PATCH 25/31] Add init_weights (not yet in use until #513). --- .../models/unet_2d_condition_flax.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 402e2e563e01..aaa3b5c2c143 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -1,6 +1,9 @@ from typing import Tuple import flax.linen as nn +from flax.core.frozen_dict import FrozenDict + +import jax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, flax_register_to_config @@ -29,18 +32,18 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): dropout: float = 0.1 dtype: jnp.dtype = jnp.float32 - # # Note: input_shape is ignored - # def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: - # # init input tensors - # sample_shape = (1, self.module.sample_size, self.module.sample_size, self.module.in_channels) - # sample = jnp.zeros(sample_shape, dtype=jnp.float32) - # timesteps = jnp.ones((1,), dtype=jnp.int32) - # encoder_hidden_states = jnp.zeros((1, 1, self.module.cross_attention_dim), dtype=jnp.float32) - # params_rng, dropout_rng = jax.random.split(rng) - # rngs = {"params": params_rng, "dropout": dropout_rng} + def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict: + # init input tensors + sample_shape = (1, self.sample_size, self.sample_size, self.in_channels) + sample = jnp.zeros(sample_shape, dtype=jnp.float32) + timesteps = jnp.ones((1,), dtype=jnp.int32) + encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} - # return self.module.init(rngs, sample, timesteps, encoder_hidden_states)["params"] + return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"] def setup(self): block_out_channels = self.block_out_channels From e7347c0b51dbf2f3dd44b43abeacab06acad009f Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 15 Sep 2022 10:30:31 +0200 Subject: [PATCH 26/31] Trickle down deterministic to blocks. --- src/diffusers/models/attention_flax.py | 10 ++++------ src/diffusers/models/unet_2d_condition_flax.py | 12 +++++------- src/diffusers/models/unet_blocks_flax.py | 16 ++++++++-------- 3 files changed, 17 insertions(+), 21 deletions(-) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 7090de328e63..1a901f701163 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -92,17 +92,17 @@ def setup(self): def __call__(self, hidden_states, context, deterministic=True): # self attention residual = hidden_states - hidden_states = self.self_attn(self.norm1(hidden_states)) + hidden_states = self.self_attn(self.norm1(hidden_states), deterministic=deterministic) hidden_states = hidden_states + residual # cross attention residual = hidden_states - hidden_states = self.cross_attn(self.norm2(hidden_states), context) + hidden_states = self.cross_attn(self.norm2(hidden_states), context, deterministic=deterministic) hidden_states = hidden_states + residual # feed forward residual = hidden_states - hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic) hidden_states = hidden_states + residual return hidden_states @@ -148,14 +148,12 @@ def __call__(self, hidden_states, context, deterministic=True): hidden_states = self.norm(hidden_states) hidden_states = self.proj_in(hidden_states) - # hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 1)) hidden_states = hidden_states.reshape(batch, height * width, channels) for transformer_block in self.transformer_blocks: - hidden_states = transformer_block(hidden_states, context) + hidden_states = transformer_block(hidden_states, context, deterministic=deterministic) hidden_states = hidden_states.reshape(batch, height, width, channels) - # hidden_states = jnp.transpose(hidden_states, (0, 3, 1, 2)) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states + residual diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index aaa3b5c2c143..0c65d5e1746f 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -155,9 +155,6 @@ def __call__( encoder_hidden_states, train: bool = False, ): - # Handle any PRNG if needed - # rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} - # 1. time t_emb = self.time_proj(timesteps) t_emb = self.time_embedding(t_emb) @@ -169,13 +166,13 @@ def __call__( down_block_res_samples = (sample,) for down_block in self.down_blocks: if isinstance(down_block, FlaxCrossAttnDownBlock2D): - sample, res_samples = down_block(sample, t_emb, encoder_hidden_states) + sample, res_samples = down_block(sample, t_emb, encoder_hidden_states, deterministic=not train) else: - sample, res_samples = down_block(sample, t_emb) + sample, res_samples = down_block(sample, t_emb, deterministic=not train) down_block_res_samples += res_samples # 4. mid - sample = self.mid_block(sample, t_emb, encoder_hidden_states) + sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train) # 5. up for up_block in self.up_blocks: @@ -187,9 +184,10 @@ def __call__( temb=t_emb, encoder_hidden_states=encoder_hidden_states, res_hidden_states_tuple=res_samples, + deterministic=not train, ) else: - sample = up_block(sample, temb=t_emb, res_hidden_states_tuple=res_samples) + sample = up_block(sample, temb=t_emb, res_hidden_states_tuple=res_samples, deterministic=not train) # 6. post-process sample = self.conv_norm_out(sample) diff --git a/src/diffusers/models/unet_blocks_flax.py b/src/diffusers/models/unet_blocks_flax.py index 49c89fc211ed..ce67eb12b19f 100644 --- a/src/diffusers/models/unet_blocks_flax.py +++ b/src/diffusers/models/unet_blocks_flax.py @@ -61,8 +61,8 @@ def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=Tru output_states = () for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, encoder_hidden_states) + hidden_states = resnet(hidden_states, temb, deterministic=deterministic) + hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) output_states += (hidden_states,) if self.add_downsample: @@ -102,7 +102,7 @@ def __call__(self, hidden_states, temb, deterministic=True): output_states = () for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, deterministic=deterministic) output_states += (hidden_states,) if self.add_downsample: @@ -160,8 +160,8 @@ def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_ res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, encoder_hidden_states) + hidden_states = resnet(hidden_states, temb, deterministic=deterministic) + hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) if self.add_upsample: hidden_states = self.upsample(hidden_states) @@ -205,7 +205,7 @@ def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=T res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1) - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb, deterministic=deterministic) if self.add_upsample: hidden_states = self.upsample(hidden_states) @@ -257,7 +257,7 @@ def setup(self): def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): - hidden_states = attn(hidden_states, encoder_hidden_states) - hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic) + hidden_states = resnet(hidden_states, temb, deterministic=deterministic) return hidden_states From cfca52fcc33812872142be2b9bee70f7d4b5a720 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 15 Sep 2022 10:58:19 +0200 Subject: [PATCH 27/31] Rename q, k, v according to the latest PyTorch version. Note that weights were exported with the old names, so we need to be careful. --- src/diffusers/models/attention_flax.py | 33 +++++++++++++------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 1a901f701163..fa8536617570 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -25,13 +25,14 @@ class FlaxAttentionBlock(nn.Module): def setup(self): inner_dim = self.dim_head * self.heads - self.scale = self.dim_head**-0.5 + self.scale = self.dim_head ** -0.5 - self.to_q = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype) - self.to_k = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype) - self.to_v = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype) + # Weights were exported with old names {to_q, to_k, to_v, to_out} + self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q") + self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k") + self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v") - self.to_out = nn.Dense(self.query_dim, dtype=self.dtype) + self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out") def reshape_heads_to_batch_dim(self, tensor): batch_size, seq_len, dim = tensor.shape @@ -52,23 +53,23 @@ def reshape_batch_dim_to_heads(self, tensor): def __call__(self, hidden_states, context=None, deterministic=True): context = hidden_states if context is None else context - q = self.to_q(hidden_states) - k = self.to_k(context) - v = self.to_v(context) + query_proj = self.query(hidden_states) + key_proj = self.key(context) + value_proj = self.value(context) - q = self.reshape_heads_to_batch_dim(q) - k = self.reshape_heads_to_batch_dim(k) - v = self.reshape_heads_to_batch_dim(v) + query_states = self.reshape_heads_to_batch_dim(query_proj) + key_states = self.reshape_heads_to_batch_dim(key_proj) + value_states = self.reshape_heads_to_batch_dim(value_proj) # compute attentions - attn_weights = jnp.einsum("b i d, b j d->b i j", q, k) - attn_weights = attn_weights * self.scale - attn_weights = nn.softmax(attn_weights, axis=2) + attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states) + attention_scores = attention_scores * self.scale + attention_probs = nn.softmax(attention_scores, axis=2) # attend to values - hidden_states = jnp.einsum("b i j, b j d -> b i d", attn_weights, v) + hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states) hidden_states = self.reshape_batch_dim_to_heads(hidden_states) - hidden_states = self.to_out(hidden_states) + hidden_states = self.proj_attn(hidden_states) return hidden_states From a48500a7749c14fd8aa0d6d2d6b7b9ff8c90f83e Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 15 Sep 2022 11:11:37 +0200 Subject: [PATCH 28/31] Flax UNet docstrings, default props as in PyTorch. --- .../models/unet_2d_condition_flax.py | 46 +++++++++++++++++-- 1 file changed, 41 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 0c65d5e1746f..5f34ed56a45a 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -20,16 +20,41 @@ @flax_register_to_config class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): + r""" + FlaxUNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep + and returns sample shaped output. + + This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the models (such as downloading or saving, etc.) + + Parameters: + sample_size (`int`, *optional*): The size of the input sample. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. The corresponding class names will be: + "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D" + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): + The tuple of upsample blocks to use. The corresponding class names will be: + "FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D" + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + cross_attention_dim (`int`, *optional*, defaults to 768): The dimension of the cross attention features. + dropout (`float`, *optional*, defaults to 0): Dropout probability for down, up and bottleneck blocks. + """ + sample_size: int = 32 in_channels: int = 4 out_channels: int = 4 - down_block_types: Tuple = ("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D") - up_block_types: Tuple = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D") - block_out_channels: Tuple = (224, 448, 672, 896) + down_block_types: Tuple[str] = ("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D") + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D") + block_out_channels: Tuple[int] = (320, 640, 1280, 1280) layers_per_block: int = 2 attention_head_dim: int = 8 - cross_attention_dim: int = 768 - dropout: float = 0.1 + cross_attention_dim: int = 1280 + dropout: float = 0.0 dtype: jnp.dtype = jnp.float32 @@ -155,6 +180,17 @@ def __call__( encoder_hidden_states, train: bool = False, ): + """r + Args: + sample (`jnp.ndarray`): (channel, height, width) noisy inputs tensor + timestep (`jnp.ndarray` or `float` or `int`): timesteps + encoder_hidden_states (`jnp.ndarray`): (channel, height, width) encoder hidden states + train (`bool`, *optional*, defaults to `False`): + Use deterministic functions and disable dropout when not training. + + Returns: + `jnp.ndarray` sample. + """ # 1. time t_emb = self.time_proj(timesteps) t_emb = self.time_embedding(t_emb) From b33ef5eaf0bbd3a69feb74db746f6dc9fc551f12 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 15 Sep 2022 11:20:51 +0200 Subject: [PATCH 29/31] Fix minor typos in PyTorch docstrings. --- src/diffusers/models/unet_2d_condition.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 92caaca92e24..d18c5435f1ba 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -28,7 +28,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): and returns sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the model (such as downloading or saving, etc.) + implements for all the models (such as downloading or saving, etc.) Parameters: sample_size (`int`, *optional*): The size of the input sample. @@ -196,7 +196,7 @@ def forward( """r Args: sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor - timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. From b8798ba4db3527e3b64df996b840d7c28a8df7d2 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 15 Sep 2022 11:21:25 +0200 Subject: [PATCH 30/31] Use FlaxUNet2DConditionOutput as output from UNet. --- .../models/unet_2d_condition_flax.py | 29 ++++++++++++++++--- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 5f34ed56a45a..3c6fe83d90ae 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -1,4 +1,5 @@ -from typing import Tuple +from dataclasses import dataclass +from typing import Tuple, Union import flax.linen as nn from flax.core.frozen_dict import FrozenDict @@ -8,6 +9,8 @@ from ..configuration_utils import ConfigMixin, flax_register_to_config from ..modeling_flax_utils import FlaxModelMixin +from ..utils import BaseOutput + from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps from .unet_blocks_flax import ( FlaxCrossAttnDownBlock2D, @@ -17,6 +20,16 @@ FlaxUpBlock2D, ) +@dataclass +class FlaxUNet2DConditionOutput(BaseOutput): + """ + Args: + sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): + Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: jnp.ndarray + @flax_register_to_config class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): @@ -178,18 +191,23 @@ def __call__( sample, timesteps, encoder_hidden_states, + return_dict: bool = True, train: bool = False, - ): + ) -> Union[FlaxUNet2DConditionOutput, Tuple]: """r Args: sample (`jnp.ndarray`): (channel, height, width) noisy inputs tensor timestep (`jnp.ndarray` or `float` or `int`): timesteps encoder_hidden_states (`jnp.ndarray`): (channel, height, width) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a plain tuple. train (`bool`, *optional*, defaults to `False`): Use deterministic functions and disable dropout when not training. Returns: - `jnp.ndarray` sample. + [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. """ # 1. time t_emb = self.time_proj(timesteps) @@ -230,4 +248,7 @@ def __call__( sample = nn.silu(sample) sample = self.conv_out(sample) - return sample + if not return_dict: + return (sample,) + + return FlaxUNet2DConditionOutput(sample=sample) From da97b21cd8aab7e9621b60ba373872f3642c56c0 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 15 Sep 2022 11:25:06 +0200 Subject: [PATCH 31/31] make style --- src/diffusers/models/attention_flax.py | 2 +- .../models/unet_2d_condition_flax.py | 32 +++++++++++-------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index fa8536617570..918c7469a74c 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -25,7 +25,7 @@ class FlaxAttentionBlock(nn.Module): def setup(self): inner_dim = self.dim_head * self.heads - self.scale = self.dim_head ** -0.5 + self.scale = self.dim_head**-0.5 # Weights were exported with old names {to_q, to_k, to_v, to_out} self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q") diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 3c6fe83d90ae..1ac68e10c159 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -2,15 +2,13 @@ from typing import Tuple, Union import flax.linen as nn -from flax.core.frozen_dict import FrozenDict - import jax import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict from ..configuration_utils import ConfigMixin, flax_register_to_config from ..modeling_flax_utils import FlaxModelMixin from ..utils import BaseOutput - from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps from .unet_blocks_flax import ( FlaxCrossAttnDownBlock2D, @@ -20,6 +18,7 @@ FlaxUpBlock2D, ) + @dataclass class FlaxUNet2DConditionOutput(BaseOutput): """ @@ -34,8 +33,8 @@ class FlaxUNet2DConditionOutput(BaseOutput): @flax_register_to_config class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): r""" - FlaxUNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep - and returns sample shaped output. + FlaxUNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a + timestep and returns sample shaped output. This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for the generic methods the library implements for all the models (such as downloading or saving, etc.) @@ -45,11 +44,11 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): - The tuple of downsample blocks to use. The corresponding class names will be: - "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D" + The tuple of downsample blocks to use. The corresponding class names will be: "FlaxCrossAttnDownBlock2D", + "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D" up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): - The tuple of upsample blocks to use. The corresponding class names will be: - "FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D" + The tuple of upsample blocks to use. The corresponding class names will be: "FlaxUpBlock2D", + "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D" block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. @@ -61,7 +60,12 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): sample_size: int = 32 in_channels: int = 4 out_channels: int = 4 - down_block_types: Tuple[str] = ("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D") + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ) up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D") block_out_channels: Tuple[int] = (320, 640, 1280, 1280) layers_per_block: int = 2 @@ -70,7 +74,6 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): dropout: float = 0.0 dtype: jnp.dtype = jnp.float32 - def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict: # init input tensors sample_shape = (1, self.sample_size, self.sample_size, self.in_channels) @@ -200,14 +203,15 @@ def __call__( timestep (`jnp.ndarray` or `float` or `int`): timesteps encoder_hidden_states (`jnp.ndarray`): (channel, height, width) encoder hidden states return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a plain tuple. + Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a + plain tuple. train (`bool`, *optional*, defaults to `False`): Use deterministic functions and disable dropout when not training. Returns: [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`: - [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. + [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. """ # 1. time t_emb = self.time_proj(timesteps)