Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 52 additions & 23 deletions src/diffusers/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,20 @@ class FlaxBasicTransformerBlock(nn.Module):
Hidden states dimension inside each head
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
only_cross_attention (`bool`, defaults to `False`):
Whether to only apply cross attention.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
dim: int
n_heads: int
d_head: int
dropout: float = 0.0
only_cross_attention: bool = False
dtype: jnp.dtype = jnp.float32

def setup(self):
# self attention
# self attention (or cross_attention if only_cross_attention is True)
self.attn1 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
# cross attention
self.attn2 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
Expand All @@ -126,7 +129,10 @@ def setup(self):
def __call__(self, hidden_states, context, deterministic=True):
# self attention
residual = hidden_states
hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
if self.only_cross_attention:
hidden_states = self.attn1(self.norm1(hidden_states), context, deterministic=deterministic)
else:
hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
hidden_states = hidden_states + residual

# cross attention
Expand Down Expand Up @@ -159,6 +165,8 @@ class FlaxTransformer2DModel(nn.Module):
Number of transformers block
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
use_linear_projection (`bool`, defaults to `False`): tbd
only_cross_attention (`bool`, defaults to `False`): tbd
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
Expand All @@ -167,49 +175,70 @@ class FlaxTransformer2DModel(nn.Module):
d_head: int
depth: int = 1
dropout: float = 0.0
use_linear_projection: bool = False
only_cross_attention: bool = False
dtype: jnp.dtype = jnp.float32

def setup(self):
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)

inner_dim = self.n_heads * self.d_head
self.proj_in = nn.Conv(
inner_dim,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)
if self.use_linear_projection:
self.proj_in = nn.Dense(inner_dim, dtype=self.dtype)
else:
self.proj_in = nn.Conv(
inner_dim,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)

self.transformer_blocks = [
FlaxBasicTransformerBlock(inner_dim, self.n_heads, self.d_head, dropout=self.dropout, dtype=self.dtype)
FlaxBasicTransformerBlock(
inner_dim,
self.n_heads,
self.d_head,
dropout=self.dropout,
only_cross_attention=self.only_cross_attention,
dtype=self.dtype,
)
for _ in range(self.depth)
]

self.proj_out = nn.Conv(
inner_dim,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)
if self.use_linear_projection:
self.proj_out = nn.Dense(inner_dim, dtype=self.dtype)
else:
self.proj_out = nn.Conv(
inner_dim,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)

def __call__(self, hidden_states, context, deterministic=True):
batch, height, width, channels = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_in(hidden_states)

hidden_states = hidden_states.reshape(batch, height * width, channels)
if self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height * width, channels)
hidden_states = self.proj_in(hidden_states)
else:
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.reshape(batch, height * width, channels)

for transformer_block in self.transformer_blocks:
hidden_states = transformer_block(hidden_states, context, deterministic=deterministic)

hidden_states = hidden_states.reshape(batch, height, width, channels)
if self.use_linear_projection:
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, width, channels)
else:
hidden_states = hidden_states.reshape(batch, height, width, channels)
hidden_states = self.proj_out(hidden_states)

hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states + residual

return hidden_states


Expand Down
10 changes: 10 additions & 0 deletions src/diffusers/models/unet_2d_blocks_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
num_layers: int = 1
attn_num_head_channels: int = 1
add_downsample: bool = True
use_linear_projection: bool = False
only_cross_attention: bool = False
dtype: jnp.dtype = jnp.float32

def setup(self):
Expand All @@ -68,6 +70,8 @@ def setup(self):
n_heads=self.attn_num_head_channels,
d_head=self.out_channels // self.attn_num_head_channels,
depth=1,
use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention,
dtype=self.dtype,
)
attentions.append(attn_block)
Expand Down Expand Up @@ -178,6 +182,8 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
num_layers: int = 1
attn_num_head_channels: int = 1
add_upsample: bool = True
use_linear_projection: bool = False
only_cross_attention: bool = False
dtype: jnp.dtype = jnp.float32

def setup(self):
Expand All @@ -201,6 +207,8 @@ def setup(self):
n_heads=self.attn_num_head_channels,
d_head=self.out_channels // self.attn_num_head_channels,
depth=1,
use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention,
dtype=self.dtype,
)
attentions.append(attn_block)
Expand Down Expand Up @@ -310,6 +318,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
dropout: float = 0.0
num_layers: int = 1
attn_num_head_channels: int = 1
use_linear_projection: bool = False
dtype: jnp.dtype = jnp.float32

def setup(self):
Expand All @@ -331,6 +340,7 @@ def setup(self):
n_heads=self.attn_num_head_channels,
d_head=self.in_channels // self.attn_num_head_channels,
depth=1,
use_linear_projection=self.use_linear_projection,
dtype=self.dtype,
)
attentions.append(attn_block)
Expand Down
27 changes: 22 additions & 5 deletions src/diffusers/models/unet_2d_condition_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
attention_head_dim (`int`, *optional*, defaults to 8):
attention_head_dim (`int` or `Tuple[int]`, *optional*, defaults to 8):
The dimension of the attention heads.
cross_attention_dim (`int`, *optional*, defaults to 768):
The dimension of the cross attention features.
Expand All @@ -97,11 +97,13 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
"DownBlock2D",
)
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
only_cross_attention: Union[bool, Tuple[bool]] = False
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
layers_per_block: int = 2
attention_head_dim: int = 8
attention_head_dim: Union[int, Tuple[int]] = 8
cross_attention_dim: int = 1280
dropout: float = 0.0
use_linear_projection: bool = False
dtype: jnp.dtype = jnp.float32
freq_shift: int = 0

Expand Down Expand Up @@ -134,6 +136,14 @@ def setup(self):
self.time_proj = FlaxTimesteps(block_out_channels[0], freq_shift=self.config.freq_shift)
self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)

only_cross_attention = self.only_cross_attention
if isinstance(only_cross_attention, bool):
only_cross_attention = (only_cross_attention,) * len(self.down_block_types)

attention_head_dim = self.attention_head_dim
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(self.down_block_types)

# down
down_blocks = []
output_channel = block_out_channels[0]
Expand All @@ -148,8 +158,10 @@ def setup(self):
out_channels=output_channel,
dropout=self.dropout,
num_layers=self.layers_per_block,
attn_num_head_channels=self.attention_head_dim,
attn_num_head_channels=attention_head_dim[i],
add_downsample=not is_final_block,
use_linear_projection=self.use_linear_projection,
only_cross_attention=only_cross_attention[i],
dtype=self.dtype,
)
else:
Expand All @@ -169,13 +181,16 @@ def setup(self):
self.mid_block = FlaxUNetMidBlock2DCrossAttn(
in_channels=block_out_channels[-1],
dropout=self.dropout,
attn_num_head_channels=self.attention_head_dim,
attn_num_head_channels=attention_head_dim[-1],
use_linear_projection=self.use_linear_projection,
dtype=self.dtype,
)

# up
up_blocks = []
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(self.up_block_types):
prev_output_channel = output_channel
Expand All @@ -190,9 +205,11 @@ def setup(self):
out_channels=output_channel,
prev_output_channel=prev_output_channel,
num_layers=self.layers_per_block + 1,
attn_num_head_channels=self.attention_head_dim,
attn_num_head_channels=reversed_attention_head_dim[i],
add_upsample=not is_final_block,
dropout=self.dropout,
use_linear_projection=self.use_linear_projection,
only_cross_attention=only_cross_attention[i],
dtype=self.dtype,
)
else:
Expand Down
26 changes: 26 additions & 0 deletions tests/models/test_models_unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,3 +639,29 @@ def test_compvis_sd_inpaint_fp16(self, seed, timestep, expected_slice):
expected_output_slice = torch.tensor(expected_slice)

assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)

@parameterized.expand(
[
# fmt: off
[83, 4, [0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]],
[17, 0.55, [0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]],
[8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]],
[3, 1000, [0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]],
# fmt: on
]
)
@require_torch_gpu
def test_stabilityai_sd_v2_fp16(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adding this!

latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True)
encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True)

with torch.no_grad():
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample

assert sample.shape == latents.shape

output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor(expected_slice)

assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
Loading