Skip to content

Commit d8b0e4f

Browse files
pcuencamishig25Mishig Davaadorjpatil-surajpatrickvonplaten
authored
UNet Flax with FlaxModelMixin (#502)
* First UNet Flax modeling blocks. Mimic the structure of the PyTorch files. The model classes themselves need work, depending on what we do about configuration and initialization. * Remove FlaxUNet2DConfig class. * ignore_for_config non-config args. * Implement `FlaxModelMixin` * Use new mixins for Flax UNet. For some reason the configuration is not correctly applied; the signature of the `__init__` method does not contain all the parameters by the time it's inspected in `extract_init_dict`. * Import `FlaxUNet2DConditionModel` if flax is available. * Rm unused method `framework` * Update src/diffusers/modeling_flax_utils.py Co-authored-by: Suraj Patil <[email protected]> * Indicate types in flax.struct.dataclass as pointed out by @mishig25 Co-authored-by: Mishig Davaadorj <[email protected]> * Fix typo in transformer block. * make style * some more changes * make style * Add comment * Update src/diffusers/modeling_flax_utils.py Co-authored-by: Patrick von Platen <[email protected]> * Rm unneeded comment * Update docstrings * correct ignore kwargs * make style * Update docstring examples * Make style * Style: remove empty line. * Apply style (after upgrading black from pinned version) * Remove some commented code and unused imports. * Add init_weights (not yet in use until #513). * Trickle down deterministic to blocks. * Rename q, k, v according to the latest PyTorch version. Note that weights were exported with the old names, so we need to be careful. * Flax UNet docstrings, default props as in PyTorch. * Fix minor typos in PyTorch docstrings. * Use FlaxUNet2DConditionOutput as output from UNet. * make style Co-authored-by: Mishig Davaadorj <[email protected]> Co-authored-by: Mishig Davaadorj <[email protected]> Co-authored-by: Suraj Patil <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
1 parent fb5468a commit d8b0e4f

File tree

8 files changed

+878
-2
lines changed

8 files changed

+878
-2
lines changed

src/diffusers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464

6565
if is_flax_available():
6666
from .modeling_flax_utils import FlaxModelMixin
67+
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
6768
from .schedulers import (
6869
FlaxDDIMScheduler,
6970
FlaxDDPMScheduler,
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# Copyright 2022 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import flax.linen as nn
16+
import jax.numpy as jnp
17+
18+
19+
class FlaxAttentionBlock(nn.Module):
20+
query_dim: int
21+
heads: int = 8
22+
dim_head: int = 64
23+
dropout: float = 0.0
24+
dtype: jnp.dtype = jnp.float32
25+
26+
def setup(self):
27+
inner_dim = self.dim_head * self.heads
28+
self.scale = self.dim_head**-0.5
29+
30+
# Weights were exported with old names {to_q, to_k, to_v, to_out}
31+
self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q")
32+
self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")
33+
self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
34+
35+
self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out")
36+
37+
def reshape_heads_to_batch_dim(self, tensor):
38+
batch_size, seq_len, dim = tensor.shape
39+
head_size = self.heads
40+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
41+
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
42+
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
43+
return tensor
44+
45+
def reshape_batch_dim_to_heads(self, tensor):
46+
batch_size, seq_len, dim = tensor.shape
47+
head_size = self.heads
48+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
49+
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
50+
tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
51+
return tensor
52+
53+
def __call__(self, hidden_states, context=None, deterministic=True):
54+
context = hidden_states if context is None else context
55+
56+
query_proj = self.query(hidden_states)
57+
key_proj = self.key(context)
58+
value_proj = self.value(context)
59+
60+
query_states = self.reshape_heads_to_batch_dim(query_proj)
61+
key_states = self.reshape_heads_to_batch_dim(key_proj)
62+
value_states = self.reshape_heads_to_batch_dim(value_proj)
63+
64+
# compute attentions
65+
attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
66+
attention_scores = attention_scores * self.scale
67+
attention_probs = nn.softmax(attention_scores, axis=2)
68+
69+
# attend to values
70+
hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
71+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
72+
hidden_states = self.proj_attn(hidden_states)
73+
return hidden_states
74+
75+
76+
class FlaxBasicTransformerBlock(nn.Module):
77+
dim: int
78+
n_heads: int
79+
d_head: int
80+
dropout: float = 0.0
81+
dtype: jnp.dtype = jnp.float32
82+
83+
def setup(self):
84+
# self attention
85+
self.self_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
86+
# cross attention
87+
self.cross_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
88+
self.ff = FlaxGluFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
89+
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
90+
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
91+
self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
92+
93+
def __call__(self, hidden_states, context, deterministic=True):
94+
# self attention
95+
residual = hidden_states
96+
hidden_states = self.self_attn(self.norm1(hidden_states), deterministic=deterministic)
97+
hidden_states = hidden_states + residual
98+
99+
# cross attention
100+
residual = hidden_states
101+
hidden_states = self.cross_attn(self.norm2(hidden_states), context, deterministic=deterministic)
102+
hidden_states = hidden_states + residual
103+
104+
# feed forward
105+
residual = hidden_states
106+
hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic)
107+
hidden_states = hidden_states + residual
108+
109+
return hidden_states
110+
111+
112+
class FlaxSpatialTransformer(nn.Module):
113+
in_channels: int
114+
n_heads: int
115+
d_head: int
116+
depth: int = 1
117+
dropout: float = 0.0
118+
dtype: jnp.dtype = jnp.float32
119+
120+
def setup(self):
121+
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
122+
123+
inner_dim = self.n_heads * self.d_head
124+
self.proj_in = nn.Conv(
125+
inner_dim,
126+
kernel_size=(1, 1),
127+
strides=(1, 1),
128+
padding="VALID",
129+
dtype=self.dtype,
130+
)
131+
132+
self.transformer_blocks = [
133+
FlaxBasicTransformerBlock(inner_dim, self.n_heads, self.d_head, dropout=self.dropout, dtype=self.dtype)
134+
for _ in range(self.depth)
135+
]
136+
137+
self.proj_out = nn.Conv(
138+
inner_dim,
139+
kernel_size=(1, 1),
140+
strides=(1, 1),
141+
padding="VALID",
142+
dtype=self.dtype,
143+
)
144+
145+
def __call__(self, hidden_states, context, deterministic=True):
146+
batch, height, width, channels = hidden_states.shape
147+
# import ipdb; ipdb.set_trace()
148+
residual = hidden_states
149+
hidden_states = self.norm(hidden_states)
150+
hidden_states = self.proj_in(hidden_states)
151+
152+
hidden_states = hidden_states.reshape(batch, height * width, channels)
153+
154+
for transformer_block in self.transformer_blocks:
155+
hidden_states = transformer_block(hidden_states, context, deterministic=deterministic)
156+
157+
hidden_states = hidden_states.reshape(batch, height, width, channels)
158+
159+
hidden_states = self.proj_out(hidden_states)
160+
hidden_states = hidden_states + residual
161+
162+
return hidden_states
163+
164+
165+
class FlaxGluFeedForward(nn.Module):
166+
dim: int
167+
dropout: float = 0.0
168+
dtype: jnp.dtype = jnp.float32
169+
170+
def setup(self):
171+
inner_dim = self.dim * 4
172+
self.dense1 = nn.Dense(inner_dim * 2, dtype=self.dtype)
173+
self.dense2 = nn.Dense(self.dim, dtype=self.dtype)
174+
175+
def __call__(self, hidden_states, deterministic=True):
176+
hidden_states = self.dense1(hidden_states)
177+
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
178+
hidden_states = hidden_linear * nn.gelu(hidden_gelu)
179+
hidden_states = self.dense2(hidden_states)
180+
return hidden_states
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright 2022 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import math
15+
16+
import flax.linen as nn
17+
import jax.numpy as jnp
18+
19+
20+
# This is like models.embeddings.get_timestep_embedding (PyTorch) but
21+
# less general (only handles the case we currently need).
22+
def get_sinusoidal_embeddings(timesteps, embedding_dim):
23+
"""
24+
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
25+
26+
:param timesteps: a 1-D tensor of N indices, one per batch element.
27+
These may be fractional.
28+
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
29+
embeddings. :return: an [N x dim] tensor of positional embeddings.
30+
"""
31+
half_dim = embedding_dim // 2
32+
emb = math.log(10000) / (half_dim - 1)
33+
emb = jnp.exp(jnp.arange(half_dim) * -emb)
34+
emb = timesteps[:, None] * emb[None, :]
35+
emb = jnp.concatenate([jnp.cos(emb), jnp.sin(emb)], -1)
36+
return emb
37+
38+
39+
class FlaxTimestepEmbedding(nn.Module):
40+
time_embed_dim: int = 32
41+
dtype: jnp.dtype = jnp.float32
42+
43+
@nn.compact
44+
def __call__(self, temb):
45+
temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb)
46+
temb = nn.silu(temb)
47+
temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb)
48+
return temb
49+
50+
51+
class FlaxTimesteps(nn.Module):
52+
dim: int = 32
53+
54+
@nn.compact
55+
def __call__(self, timesteps):
56+
return get_sinusoidal_embeddings(timesteps, self.dim)
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import flax.linen as nn
2+
import jax
3+
import jax.numpy as jnp
4+
5+
6+
class FlaxUpsample2D(nn.Module):
7+
out_channels: int
8+
dtype: jnp.dtype = jnp.float32
9+
10+
def setup(self):
11+
self.conv = nn.Conv(
12+
self.out_channels,
13+
kernel_size=(3, 3),
14+
strides=(1, 1),
15+
padding=((1, 1), (1, 1)),
16+
dtype=self.dtype,
17+
)
18+
19+
def __call__(self, hidden_states):
20+
batch, height, width, channels = hidden_states.shape
21+
hidden_states = jax.image.resize(
22+
hidden_states,
23+
shape=(batch, height * 2, width * 2, channels),
24+
method="nearest",
25+
)
26+
hidden_states = self.conv(hidden_states)
27+
return hidden_states
28+
29+
30+
class FlaxDownsample2D(nn.Module):
31+
out_channels: int
32+
dtype: jnp.dtype = jnp.float32
33+
34+
def setup(self):
35+
self.conv = nn.Conv(
36+
self.out_channels,
37+
kernel_size=(3, 3),
38+
strides=(2, 2),
39+
padding=((1, 1), (1, 1)), # padding="VALID",
40+
dtype=self.dtype,
41+
)
42+
43+
def __call__(self, hidden_states):
44+
# pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim
45+
# hidden_states = jnp.pad(hidden_states, pad_width=pad)
46+
hidden_states = self.conv(hidden_states)
47+
return hidden_states
48+
49+
50+
class FlaxResnetBlock2D(nn.Module):
51+
in_channels: int
52+
out_channels: int = None
53+
dropout_prob: float = 0.0
54+
use_nin_shortcut: bool = None
55+
dtype: jnp.dtype = jnp.float32
56+
57+
def setup(self):
58+
out_channels = self.in_channels if self.out_channels is None else self.out_channels
59+
60+
self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
61+
self.conv1 = nn.Conv(
62+
out_channels,
63+
kernel_size=(3, 3),
64+
strides=(1, 1),
65+
padding=((1, 1), (1, 1)),
66+
dtype=self.dtype,
67+
)
68+
69+
self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype)
70+
71+
self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
72+
self.dropout = nn.Dropout(self.dropout_prob)
73+
self.conv2 = nn.Conv(
74+
out_channels,
75+
kernel_size=(3, 3),
76+
strides=(1, 1),
77+
padding=((1, 1), (1, 1)),
78+
dtype=self.dtype,
79+
)
80+
81+
use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut
82+
83+
self.conv_shortcut = None
84+
if use_nin_shortcut:
85+
self.conv_shortcut = nn.Conv(
86+
out_channels,
87+
kernel_size=(1, 1),
88+
strides=(1, 1),
89+
padding="VALID",
90+
dtype=self.dtype,
91+
)
92+
93+
def __call__(self, hidden_states, temb, deterministic=True):
94+
residual = hidden_states
95+
hidden_states = self.norm1(hidden_states)
96+
hidden_states = nn.swish(hidden_states)
97+
hidden_states = self.conv1(hidden_states)
98+
99+
temb = self.time_emb_proj(nn.swish(temb))
100+
temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1)
101+
hidden_states = hidden_states + temb
102+
103+
hidden_states = self.norm2(hidden_states)
104+
hidden_states = nn.swish(hidden_states)
105+
hidden_states = self.dropout(hidden_states, deterministic)
106+
hidden_states = self.conv2(hidden_states)
107+
108+
if self.conv_shortcut is not None:
109+
residual = self.conv_shortcut(residual)
110+
111+
return hidden_states + residual

src/diffusers/models/unet_2d_condition.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
2828
and returns sample shaped output.
2929
3030
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
31-
implements for all the model (such as downloading or saving, etc.)
31+
implements for all the models (such as downloading or saving, etc.)
3232
3333
Parameters:
3434
sample_size (`int`, *optional*): The size of the input sample.
@@ -198,7 +198,7 @@ def forward(
198198
"""r
199199
Args:
200200
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
201-
timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
201+
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
202202
encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states
203203
return_dict (`bool`, *optional*, defaults to `True`):
204204
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.

0 commit comments

Comments
 (0)