Skip to content

Commit 4d1e4e2

Browse files
Flax support for Stable Diffusion 2 (#1423)
* Flax: start adapting to Stable Diffusion 2 * More changes. * attention_head_dim can be a tuple. * Fix typos * Add simple SD 2 integration test. Slice values taken from my Ampere GPU. * Add simple UNet integration tests for Flax. Note that the expected values are taken from the PyTorch results. This ensures the Flax and PyTorch versions are not too far off. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * Typos and style * Tests: verify jax is available. * Style * Make flake happy * Remove typo. * Simple Flax SD 2 pipeline tests. * Import order * Remove unused import. Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: @camenduru
1 parent a808a85 commit 4d1e4e2

File tree

6 files changed

+312
-28
lines changed

6 files changed

+312
-28
lines changed

src/diffusers/models/attention_flax.py

Lines changed: 52 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -104,17 +104,20 @@ class FlaxBasicTransformerBlock(nn.Module):
104104
Hidden states dimension inside each head
105105
dropout (:obj:`float`, *optional*, defaults to 0.0):
106106
Dropout rate
107+
only_cross_attention (`bool`, defaults to `False`):
108+
Whether to only apply cross attention.
107109
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
108110
Parameters `dtype`
109111
"""
110112
dim: int
111113
n_heads: int
112114
d_head: int
113115
dropout: float = 0.0
116+
only_cross_attention: bool = False
114117
dtype: jnp.dtype = jnp.float32
115118

116119
def setup(self):
117-
# self attention
120+
# self attention (or cross_attention if only_cross_attention is True)
118121
self.attn1 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
119122
# cross attention
120123
self.attn2 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
@@ -126,7 +129,10 @@ def setup(self):
126129
def __call__(self, hidden_states, context, deterministic=True):
127130
# self attention
128131
residual = hidden_states
129-
hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
132+
if self.only_cross_attention:
133+
hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic)
134+
else:
135+
hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
130136
hidden_states = hidden_states + residual
131137

132138
# cross attention
@@ -159,6 +165,8 @@ class FlaxTransformer2DModel(nn.Module):
159165
Number of transformers block
160166
dropout (:obj:`float`, *optional*, defaults to 0.0):
161167
Dropout rate
168+
use_linear_projection (`bool`, defaults to `False`): tbd
169+
only_cross_attention (`bool`, defaults to `False`): tbd
162170
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
163171
Parameters `dtype`
164172
"""
@@ -167,49 +175,70 @@ class FlaxTransformer2DModel(nn.Module):
167175
d_head: int
168176
depth: int = 1
169177
dropout: float = 0.0
178+
use_linear_projection: bool = False
179+
only_cross_attention: bool = False
170180
dtype: jnp.dtype = jnp.float32
171181

172182
def setup(self):
173183
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
174184

175185
inner_dim = self.n_heads * self.d_head
176-
self.proj_in = nn.Conv(
177-
inner_dim,
178-
kernel_size=(1, 1),
179-
strides=(1, 1),
180-
padding="VALID",
181-
dtype=self.dtype,
182-
)
186+
if self.use_linear_projection:
187+
self.proj_in = nn.Dense(inner_dim, dtype=self.dtype)
188+
else:
189+
self.proj_in = nn.Conv(
190+
inner_dim,
191+
kernel_size=(1, 1),
192+
strides=(1, 1),
193+
padding="VALID",
194+
dtype=self.dtype,
195+
)
183196

184197
self.transformer_blocks = [
185-
FlaxBasicTransformerBlock(inner_dim, self.n_heads, self.d_head, dropout=self.dropout, dtype=self.dtype)
198+
FlaxBasicTransformerBlock(
199+
inner_dim,
200+
self.n_heads,
201+
self.d_head,
202+
dropout=self.dropout,
203+
only_cross_attention=self.only_cross_attention,
204+
dtype=self.dtype,
205+
)
186206
for _ in range(self.depth)
187207
]
188208

189-
self.proj_out = nn.Conv(
190-
inner_dim,
191-
kernel_size=(1, 1),
192-
strides=(1, 1),
193-
padding="VALID",
194-
dtype=self.dtype,
195-
)
209+
if self.use_linear_projection:
210+
self.proj_out = nn.Dense(inner_dim, dtype=self.dtype)
211+
else:
212+
self.proj_out = nn.Conv(
213+
inner_dim,
214+
kernel_size=(1, 1),
215+
strides=(1, 1),
216+
padding="VALID",
217+
dtype=self.dtype,
218+
)
196219

197220
def __call__(self, hidden_states, context, deterministic=True):
198221
batch, height, width, channels = hidden_states.shape
199222
residual = hidden_states
200223
hidden_states = self.norm(hidden_states)
201-
hidden_states = self.proj_in(hidden_states)
202-
203-
hidden_states = hidden_states.reshape(batch, height * width, channels)
224+
if self.use_linear_projection:
225+
hidden_states = hidden_states.reshape(batch, height * width, channels)
226+
hidden_states = self.proj_in(hidden_states)
227+
else:
228+
hidden_states = self.proj_in(hidden_states)
229+
hidden_states = hidden_states.reshape(batch, height * width, channels)
204230

205231
for transformer_block in self.transformer_blocks:
206232
hidden_states = transformer_block(hidden_states, context, deterministic=deterministic)
207233

208-
hidden_states = hidden_states.reshape(batch, height, width, channels)
234+
if self.use_linear_projection:
235+
hidden_states = self.proj_out(hidden_states)
236+
hidden_states = hidden_states.reshape(batch, height, width, channels)
237+
else:
238+
hidden_states = hidden_states.reshape(batch, height, width, channels)
239+
hidden_states = self.proj_out(hidden_states)
209240

210-
hidden_states = self.proj_out(hidden_states)
211241
hidden_states = hidden_states + residual
212-
213242
return hidden_states
214243

215244

src/diffusers/models/unet_2d_blocks_flax.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
4646
num_layers: int = 1
4747
attn_num_head_channels: int = 1
4848
add_downsample: bool = True
49+
use_linear_projection: bool = False
50+
only_cross_attention: bool = False
4951
dtype: jnp.dtype = jnp.float32
5052

5153
def setup(self):
@@ -68,6 +70,8 @@ def setup(self):
6870
n_heads=self.attn_num_head_channels,
6971
d_head=self.out_channels // self.attn_num_head_channels,
7072
depth=1,
73+
use_linear_projection=self.use_linear_projection,
74+
only_cross_attention=self.only_cross_attention,
7175
dtype=self.dtype,
7276
)
7377
attentions.append(attn_block)
@@ -178,6 +182,8 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
178182
num_layers: int = 1
179183
attn_num_head_channels: int = 1
180184
add_upsample: bool = True
185+
use_linear_projection: bool = False
186+
only_cross_attention: bool = False
181187
dtype: jnp.dtype = jnp.float32
182188

183189
def setup(self):
@@ -201,6 +207,8 @@ def setup(self):
201207
n_heads=self.attn_num_head_channels,
202208
d_head=self.out_channels // self.attn_num_head_channels,
203209
depth=1,
210+
use_linear_projection=self.use_linear_projection,
211+
only_cross_attention=self.only_cross_attention,
204212
dtype=self.dtype,
205213
)
206214
attentions.append(attn_block)
@@ -310,6 +318,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
310318
dropout: float = 0.0
311319
num_layers: int = 1
312320
attn_num_head_channels: int = 1
321+
use_linear_projection: bool = False
313322
dtype: jnp.dtype = jnp.float32
314323

315324
def setup(self):
@@ -331,6 +340,7 @@ def setup(self):
331340
n_heads=self.attn_num_head_channels,
332341
d_head=self.in_channels // self.attn_num_head_channels,
333342
depth=1,
343+
use_linear_projection=self.use_linear_projection,
334344
dtype=self.dtype,
335345
)
336346
attentions.append(attn_block)

src/diffusers/models/unet_2d_condition_flax.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
7979
The tuple of output channels for each block.
8080
layers_per_block (`int`, *optional*, defaults to 2):
8181
The number of layers per block.
82-
attention_head_dim (`int`, *optional*, defaults to 8):
82+
attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
8383
The dimension of the attention heads.
8484
cross_attention_dim (`int`, *optional*, defaults to 768):
8585
The dimension of the cross attention features.
@@ -97,11 +97,13 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
9797
"DownBlock2D",
9898
)
9999
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
100+
only_cross_attention: Union[bool, Tuple[bool]] = False
100101
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
101102
layers_per_block: int = 2
102-
attention_head_dim: int = 8
103+
attention_head_dim: Union[int, Tuple[int]] = 8
103104
cross_attention_dim: int = 1280
104105
dropout: float = 0.0
106+
use_linear_projection: bool = False
105107
dtype: jnp.dtype = jnp.float32
106108
freq_shift: int = 0
107109

@@ -134,6 +136,14 @@ def setup(self):
134136
self.time_proj = FlaxTimesteps(block_out_channels[0], freq_shift=self.config.freq_shift)
135137
self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
136138

139+
only_cross_attention = self.only_cross_attention
140+
if isinstance(only_cross_attention, bool):
141+
only_cross_attention = (only_cross_attention,) * len(self.down_block_types)
142+
143+
attention_head_dim = self.attention_head_dim
144+
if isinstance(attention_head_dim, int):
145+
attention_head_dim = (attention_head_dim,) * len(self.down_block_types)
146+
137147
# down
138148
down_blocks = []
139149
output_channel = block_out_channels[0]
@@ -148,8 +158,10 @@ def setup(self):
148158
out_channels=output_channel,
149159
dropout=self.dropout,
150160
num_layers=self.layers_per_block,
151-
attn_num_head_channels=self.attention_head_dim,
161+
attn_num_head_channels=attention_head_dim[i],
152162
add_downsample=not is_final_block,
163+
use_linear_projection=self.use_linear_projection,
164+
only_cross_attention=only_cross_attention[i],
153165
dtype=self.dtype,
154166
)
155167
else:
@@ -169,13 +181,16 @@ def setup(self):
169181
self.mid_block = FlaxUNetMidBlock2DCrossAttn(
170182
in_channels=block_out_channels[-1],
171183
dropout=self.dropout,
172-
attn_num_head_channels=self.attention_head_dim,
184+
attn_num_head_channels=attention_head_dim[-1],
185+
use_linear_projection=self.use_linear_projection,
173186
dtype=self.dtype,
174187
)
175188

176189
# up
177190
up_blocks = []
178191
reversed_block_out_channels = list(reversed(block_out_channels))
192+
reversed_attention_head_dim = list(reversed(attention_head_dim))
193+
only_cross_attention = list(reversed(only_cross_attention))
179194
output_channel = reversed_block_out_channels[0]
180195
for i, up_block_type in enumerate(self.up_block_types):
181196
prev_output_channel = output_channel
@@ -190,9 +205,11 @@ def setup(self):
190205
out_channels=output_channel,
191206
prev_output_channel=prev_output_channel,
192207
num_layers=self.layers_per_block + 1,
193-
attn_num_head_channels=self.attention_head_dim,
208+
attn_num_head_channels=reversed_attention_head_dim[i],
194209
add_upsample=not is_final_block,
195210
dropout=self.dropout,
211+
use_linear_projection=self.use_linear_projection,
212+
only_cross_attention=only_cross_attention[i],
196213
dtype=self.dtype,
197214
)
198215
else:

tests/models/test_models_unet_2d.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,3 +639,29 @@ def test_compvis_sd_inpaint_fp16(self, seed, timestep, expected_slice):
639639
expected_output_slice = torch.tensor(expected_slice)
640640

641641
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
642+
643+
@parameterized.expand(
644+
[
645+
# fmt: off
646+
[83, 4, [0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]],
647+
[17, 0.55, [0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]],
648+
[8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]],
649+
[3, 1000, [0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]],
650+
# fmt: on
651+
]
652+
)
653+
@require_torch_gpu
654+
def test_stabilityai_sd_v2_fp16(self, seed, timestep, expected_slice):
655+
model = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True)
656+
latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True)
657+
encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True)
658+
659+
with torch.no_grad():
660+
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
661+
662+
assert sample.shape == latents.shape
663+
664+
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
665+
expected_output_slice = torch.tensor(expected_slice)
666+
667+
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)

0 commit comments

Comments
 (0)