Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
67e245c
First UNet Flax modeling blocks.
pcuenca Sep 12, 2022
c3fdbf9
Remove FlaxUNet2DConfig class.
pcuenca Sep 12, 2022
1067e34
ignore_for_config non-config args.
pcuenca Sep 12, 2022
95073e1
Implement `FlaxModelMixin`
mishig25 Sep 13, 2022
b9f6eb4
Merge remote-tracking branch 'origin/flax_model_mixin' into flax-unet…
pcuenca Sep 13, 2022
9891e5c
Use new mixins for Flax UNet.
pcuenca Sep 13, 2022
2d90544
Merge remote-tracking branch 'origin/main' into flax-unet-flaxmodelmixin
pcuenca Sep 13, 2022
25c615a
Import `FlaxUNet2DConditionModel` if flax is available.
pcuenca Sep 13, 2022
91559f3
Rm unused method `framework`
mishig25 Sep 14, 2022
f7a0ab2
Update src/diffusers/modeling_flax_utils.py
Sep 14, 2022
d41f2bf
Indicate types in flax.struct.dataclass as pointed out by @mishig25
pcuenca Sep 14, 2022
e0ec7bf
Fix typo in transformer block.
pcuenca Sep 14, 2022
5e7aeea
make style
pcuenca Sep 14, 2022
70ce383
Merge remote-tracking branch 'origin/main' into flax-unet-flaxmodelmixin
pcuenca Sep 14, 2022
5d81bf8
some more changes
patrickvonplaten Sep 14, 2022
1430ab8
make style
patrickvonplaten Sep 14, 2022
6a2a4c1
Add comment
mishig25 Sep 14, 2022
8d20417
Merge remote-tracking branch 'origin/flax_model_mixin' into flax-unet…
pcuenca Sep 14, 2022
2bf0267
Update src/diffusers/modeling_flax_utils.py
Sep 14, 2022
25ab3ca
Rm unneeded comment
mishig25 Sep 14, 2022
1e8466e
Update docstrings
mishig25 Sep 14, 2022
6842d29
correct ignore kwargs
patrickvonplaten Sep 14, 2022
4f6b01b
Merge branch 'flax_model_mixin' of https://github.com/huggingface/dif…
patrickvonplaten Sep 14, 2022
0f26c05
make style
patrickvonplaten Sep 14, 2022
d98e8c7
Update docstring examples
mishig25 Sep 14, 2022
5a7b784
Merge branch 'flax_model_mixin' of https://github.com/huggingface/dif…
mishig25 Sep 14, 2022
5d08577
Make style
mishig25 Sep 14, 2022
31caae9
Merge remote-tracking branch 'origin/flax_model_mixin' into flax-unet…
pcuenca Sep 14, 2022
0611b17
Merge remote-tracking branch 'origin/main' into flax-unet-flaxmodelmixin
pcuenca Sep 14, 2022
39bbd13
Style: remove empty line.
pcuenca Sep 14, 2022
ea99f35
Apply style (after upgrading black from pinned version)
pcuenca Sep 14, 2022
2d896f6
Remove some commented code and unused imports.
pcuenca Sep 15, 2022
da6ddfd
Add init_weights (not yet in use until #513).
pcuenca Sep 15, 2022
e7347c0
Trickle down deterministic to blocks.
pcuenca Sep 15, 2022
cfca52f
Rename q, k, v according to the latest PyTorch version.
pcuenca Sep 15, 2022
a48500a
Flax UNet docstrings, default props as in PyTorch.
pcuenca Sep 15, 2022
b33ef5e
Fix minor typos in PyTorch docstrings.
pcuenca Sep 15, 2022
b8798ba
Use FlaxUNet2DConditionOutput as output from UNet.
pcuenca Sep 15, 2022
da97b21
make style
pcuenca Sep 15, 2022
802e710
Merge remote-tracking branch 'origin/main' into flax-unet-flaxmodelmixin
pcuenca Sep 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@

if is_flax_available():
from .modeling_flax_utils import FlaxModelMixin
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
from .schedulers import (
FlaxDDIMScheduler,
FlaxDDPMScheduler,
Expand Down
180 changes: 180 additions & 0 deletions src/diffusers/models/attention_flax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import flax.linen as nn
import jax.numpy as jnp


class FlaxAttentionBlock(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use the same names as the PyTorch modules

Suggested change
class FlaxAttentionBlock(nn.Module):
class FlaxCrossAttention(nn.Module):

query_dim: int
heads: int = 8
dim_head: int = 64
dropout: float = 0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dropout is not used, we should add the dropout layer here.

dtype: jnp.dtype = jnp.float32

def setup(self):
inner_dim = self.dim_head * self.heads
self.scale = self.dim_head**-0.5

# Weights were exported with old names {to_q, to_k, to_v, to_out}
self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q")
self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")
self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")

self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out")
Comment on lines +30 to +35
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit),

since we are using setup here could just use self.to_q = nn.Dense(....) instead of passing name. This will also make it easy to compare flax and pt code when reading.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the original name was self.to_q, I changed it here to make it like the renamed PyTorch version but kept the same weight names.


def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor

def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor

def __call__(self, hidden_states, context=None, deterministic=True):
context = hidden_states if context is None else context

query_proj = self.query(hidden_states)
key_proj = self.key(context)
value_proj = self.value(context)

query_states = self.reshape_heads_to_batch_dim(query_proj)
key_states = self.reshape_heads_to_batch_dim(key_proj)
value_states = self.reshape_heads_to_batch_dim(value_proj)

# compute attentions
attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
attention_scores = attention_scores * self.scale
attention_probs = nn.softmax(attention_scores, axis=2)

# attend to values
hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
hidden_states = self.proj_attn(hidden_states)
return hidden_states


class FlaxBasicTransformerBlock(nn.Module):
dim: int
n_heads: int
d_head: int
dropout: float = 0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add the dropout layer

dtype: jnp.dtype = jnp.float32

def setup(self):
# self attention
self.self_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
# cross attention
self.cross_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
Comment on lines +85 to +87
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The names should match with pt version for autoconversion.

Suggested change
self.self_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
# cross attention
self.cross_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
self.attn1 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
# cross attention
self.attn2 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)

self.ff = FlaxGluFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)

def __call__(self, hidden_states, context, deterministic=True):
# self attention
residual = hidden_states
hidden_states = self.self_attn(self.norm1(hidden_states), deterministic=deterministic)
hidden_states = hidden_states + residual

# cross attention
residual = hidden_states
hidden_states = self.cross_attn(self.norm2(hidden_states), context, deterministic=deterministic)
hidden_states = hidden_states + residual

# feed forward
residual = hidden_states
hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)
hidden_states = hidden_states + residual

return hidden_states


class FlaxSpatialTransformer(nn.Module):
in_channels: int
n_heads: int
d_head: int
depth: int = 1
dropout: float = 0.0
dtype: jnp.dtype = jnp.float32

def setup(self):
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)

inner_dim = self.n_heads * self.d_head
self.proj_in = nn.Conv(
inner_dim,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)

self.transformer_blocks = [
FlaxBasicTransformerBlock(inner_dim, self.n_heads, self.d_head, dropout=self.dropout, dtype=self.dtype)
for _ in range(self.depth)
]

self.proj_out = nn.Conv(
inner_dim,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)

def __call__(self, hidden_states, context, deterministic=True):
batch, height, width, channels = hidden_states.shape
# import ipdb; ipdb.set_trace()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# import ipdb; ipdb.set_trace()

residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_in(hidden_states)

hidden_states = hidden_states.reshape(batch, height * width, channels)

for transformer_block in self.transformer_blocks:
hidden_states = transformer_block(hidden_states, context, deterministic=deterministic)

hidden_states = hidden_states.reshape(batch, height, width, channels)

hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states + residual

return hidden_states


class FlaxGluFeedForward(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will have to split this in two modules FeedForward and GEGLU like in PyTorch.

dim: int
dropout: float = 0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add the dropout layer.

dtype: jnp.dtype = jnp.float32

def setup(self):
inner_dim = self.dim * 4
self.dense1 = nn.Dense(inner_dim * 2, dtype=self.dtype)
self.dense2 = nn.Dense(self.dim, dtype=self.dtype)

def __call__(self, hidden_states, deterministic=True):
hidden_states = self.dense1(hidden_states)
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
hidden_states = hidden_linear * nn.gelu(hidden_gelu)
hidden_states = self.dense2(hidden_states)
return hidden_states
56 changes: 56 additions & 0 deletions src/diffusers/models/embeddings_flax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math

import flax.linen as nn
import jax.numpy as jnp


# This is like models.embeddings.get_timestep_embedding (PyTorch) but
# less general (only handles the case we currently need).
Comment on lines +20 to +21
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we could update this once we start converting other models.

def get_sinusoidal_embeddings(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.

:param timesteps: a 1-D tensor of N indices, one per batch element.
These may be fractional.
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
embeddings. :return: an [N x dim] tensor of positional embeddings.
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = jnp.exp(jnp.arange(half_dim) * -emb)
emb = timesteps[:, None] * emb[None, :]
emb = jnp.concatenate([jnp.cos(emb), jnp.sin(emb)], -1)
return emb


class FlaxTimestepEmbedding(nn.Module):
time_embed_dim: int = 32
dtype: jnp.dtype = jnp.float32

@nn.compact
def __call__(self, temb):
temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb)
temb = nn.silu(temb)
temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb)
return temb


class FlaxTimesteps(nn.Module):
dim: int = 32

@nn.compact
def __call__(self, timesteps):
return get_sinusoidal_embeddings(timesteps, self.dim)
111 changes: 111 additions & 0 deletions src/diffusers/models/resnet_flax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import flax.linen as nn
import jax
import jax.numpy as jnp


class FlaxUpsample2D(nn.Module):
out_channels: int
dtype: jnp.dtype = jnp.float32

def setup(self):
self.conv = nn.Conv(
self.out_channels,
kernel_size=(3, 3),
strides=(1, 1),
padding=((1, 1), (1, 1)),
dtype=self.dtype,
)

def __call__(self, hidden_states):
batch, height, width, channels = hidden_states.shape
hidden_states = jax.image.resize(
hidden_states,
shape=(batch, height * 2, width * 2, channels),
method="nearest",
)
hidden_states = self.conv(hidden_states)
return hidden_states


class FlaxDownsample2D(nn.Module):
out_channels: int
dtype: jnp.dtype = jnp.float32

def setup(self):
self.conv = nn.Conv(
self.out_channels,
kernel_size=(3, 3),
strides=(2, 2),
padding=((1, 1), (1, 1)), # padding="VALID",
dtype=self.dtype,
)

def __call__(self, hidden_states):
# pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim
# hidden_states = jnp.pad(hidden_states, pad_width=pad)
Comment on lines +44 to +45
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim
# hidden_states = jnp.pad(hidden_states, pad_width=pad)

hidden_states = self.conv(hidden_states)
return hidden_states


class FlaxResnetBlock2D(nn.Module):
in_channels: int
out_channels: int = None
dropout_prob: float = 0.0
use_nin_shortcut: bool = None
dtype: jnp.dtype = jnp.float32

def setup(self):
out_channels = self.in_channels if self.out_channels is None else self.out_channels

self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
self.conv1 = nn.Conv(
out_channels,
kernel_size=(3, 3),
strides=(1, 1),
padding=((1, 1), (1, 1)),
dtype=self.dtype,
)

self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype)

self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
self.dropout = nn.Dropout(self.dropout_prob)
self.conv2 = nn.Conv(
out_channels,
kernel_size=(3, 3),
strides=(1, 1),
padding=((1, 1), (1, 1)),
dtype=self.dtype,
)

use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut

self.conv_shortcut = None
if use_nin_shortcut:
self.conv_shortcut = nn.Conv(
out_channels,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)

def __call__(self, hidden_states, temb, deterministic=True):
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = nn.swish(hidden_states)
hidden_states = self.conv1(hidden_states)

temb = self.time_emb_proj(nn.swish(temb))
temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1)
Comment on lines +99 to +100
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) let's add a comment here about the shapes, bit hard to understand with the code

hidden_states = hidden_states + temb

hidden_states = self.norm2(hidden_states)
hidden_states = nn.swish(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic)
hidden_states = self.conv2(hidden_states)

if self.conv_shortcut is not None:
residual = self.conv_shortcut(residual)

return hidden_states + residual
4 changes: 2 additions & 2 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
and returns sample shaped output.

This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the model (such as downloading or saving, etc.)
implements for all the models (such as downloading or saving, etc.)

Parameters:
sample_size (`int`, *optional*): The size of the input sample.
Expand Down Expand Up @@ -198,7 +198,7 @@ def forward(
"""r
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
Expand Down
Loading