Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 46 additions & 14 deletions src/diffusers/models/vae_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ class FlaxResnetBlock2D(nn.Module):
Output channels
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
groups (:obj:`int`, *optional*, defaults to `32`):
The number of groups to use for group norm.
use_nin_shortcut (:obj:`bool`, *optional*, defaults to `None`):
Whether to use `nin_shortcut`. This activates a new layer inside ResNet block
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Expand All @@ -128,13 +130,14 @@ class FlaxResnetBlock2D(nn.Module):
in_channels: int
out_channels: int = None
dropout: float = 0.0
groups: int = 32
use_nin_shortcut: bool = None
dtype: jnp.dtype = jnp.float32

def setup(self):
out_channels = self.in_channels if self.out_channels is None else self.out_channels

self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
self.norm1 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
self.conv1 = nn.Conv(
out_channels,
kernel_size=(3, 3),
Expand All @@ -143,7 +146,7 @@ def setup(self):
dtype=self.dtype,
)

self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
self.norm2 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
self.dropout_layer = nn.Dropout(self.dropout)
self.conv2 = nn.Conv(
out_channels,
Expand Down Expand Up @@ -191,20 +194,23 @@ class FlaxAttentionBlock(nn.Module):
Input channels
num_head_channels (:obj:`int`, *optional*, defaults to `None`):
Number of attention heads
num_groups (:obj:`int`, *optional*, defaults to `32`):
The number of groups to use for group norm
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`

"""
channels: int
num_head_channels: int = None
num_groups: int = 32
dtype: jnp.dtype = jnp.float32

def setup(self):
self.num_heads = self.channels // self.num_head_channels if self.num_head_channels is not None else 1

dense = partial(nn.Dense, self.channels, dtype=self.dtype)

self.group_norm = nn.GroupNorm(num_groups=32, epsilon=1e-6)
self.group_norm = nn.GroupNorm(num_groups=self.num_groups, epsilon=1e-6)
self.query, self.key, self.value = dense(), dense(), dense()
self.proj_attn = dense()

Expand Down Expand Up @@ -264,6 +270,8 @@ class FlaxDownEncoderBlock2D(nn.Module):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of Resnet layer block
resnet_groups (:obj:`int`, *optional*, defaults to `32`):
The number of groups to use for the Resnet block group norm
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add downsample layer
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Expand All @@ -273,6 +281,7 @@ class FlaxDownEncoderBlock2D(nn.Module):
out_channels: int
dropout: float = 0.0
num_layers: int = 1
resnet_groups: int = 32
add_downsample: bool = True
dtype: jnp.dtype = jnp.float32

Expand All @@ -285,6 +294,7 @@ def setup(self):
in_channels=in_channels,
out_channels=self.out_channels,
dropout=self.dropout,
groups=self.resnet_groups,
dtype=self.dtype,
)
resnets.append(res_block)
Expand All @@ -303,9 +313,9 @@ def __call__(self, hidden_states, deterministic=True):
return hidden_states


class FlaxUpEncoderBlock2D(nn.Module):
class FlaxUpDecoderBlock2D(nn.Module):
r"""
Flax Resnet blocks-based Encoder block for diffusion-based VAE.
Flax Resnet blocks-based Decoder block for diffusion-based VAE.

Parameters:
in_channels (:obj:`int`):
Expand All @@ -316,15 +326,18 @@ class FlaxUpEncoderBlock2D(nn.Module):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of Resnet layer block
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add downsample layer
resnet_groups (:obj:`int`, *optional*, defaults to `32`):
The number of groups to use for the Resnet block group norm
add_upsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add upsample layer
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int
out_channels: int
dropout: float = 0.0
num_layers: int = 1
resnet_groups: int = 32
add_upsample: bool = True
dtype: jnp.dtype = jnp.float32

Expand All @@ -336,6 +349,7 @@ def setup(self):
in_channels=in_channels,
out_channels=self.out_channels,
dropout=self.dropout,
groups=self.resnet_groups,
dtype=self.dtype,
)
resnets.append(res_block)
Expand Down Expand Up @@ -366,6 +380,8 @@ class FlaxUNetMidBlock2D(nn.Module):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of Resnet layer block
resnet_groups (:obj:`int`, *optional*, defaults to `32`):
The number of groups to use for the Resnet and Attention block group norm
attn_num_head_channels (:obj:`int`, *optional*, defaults to `1`):
Number of attention heads for each attention block
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Expand All @@ -374,16 +390,20 @@ class FlaxUNetMidBlock2D(nn.Module):
in_channels: int
dropout: float = 0.0
num_layers: int = 1
resnet_groups: int = 32
attn_num_head_channels: int = 1
dtype: jnp.dtype = jnp.float32

def setup(self):
resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(self.in_channels // 4, 32)

# there is always at least one resnet
resnets = [
FlaxResnetBlock2D(
in_channels=self.in_channels,
out_channels=self.in_channels,
dropout=self.dropout,
groups=resnet_groups,
dtype=self.dtype,
)
]
Expand All @@ -392,14 +412,18 @@ def setup(self):

for _ in range(self.num_layers):
attn_block = FlaxAttentionBlock(
channels=self.in_channels, num_head_channels=self.attn_num_head_channels, dtype=self.dtype
channels=self.in_channels,
num_head_channels=self.attn_num_head_channels,
num_groups=resnet_groups,
dtype=self.dtype,
)
attentions.append(attn_block)

res_block = FlaxResnetBlock2D(
in_channels=self.in_channels,
out_channels=self.in_channels,
dropout=self.dropout,
groups=resnet_groups,
dtype=self.dtype,
)
resnets.append(res_block)
Expand Down Expand Up @@ -441,7 +465,7 @@ class FlaxEncoder(nn.Module):
Tuple containing the number of output channels for each block
layers_per_block (:obj:`int`, *optional*, defaults to `2`):
Number of Resnet layer for each block
norm_num_groups (:obj:`int`, *optional*, defaults to `2`):
norm_num_groups (:obj:`int`, *optional*, defaults to `32`):
norm num group
act_fn (:obj:`str`, *optional*, defaults to `silu`):
Activation function
Expand Down Expand Up @@ -483,6 +507,7 @@ def setup(self):
in_channels=input_channel,
out_channels=output_channel,
num_layers=self.layers_per_block,
resnet_groups=self.norm_num_groups,
add_downsample=not is_final_block,
dtype=self.dtype,
)
Expand All @@ -491,12 +516,15 @@ def setup(self):

# middle
self.mid_block = FlaxUNetMidBlock2D(
in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype
in_channels=block_out_channels[-1],
resnet_groups=self.norm_num_groups,
attn_num_head_channels=None,
dtype=self.dtype,
)

# end
conv_out_channels = 2 * self.out_channels if self.double_z else self.out_channels
self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6)
self.conv_out = nn.Conv(
conv_out_channels,
kernel_size=(3, 3),
Expand Down Expand Up @@ -581,7 +609,10 @@ def setup(self):

# middle
self.mid_block = FlaxUNetMidBlock2D(
in_channels=block_out_channels[-1], attn_num_head_channels=None, dtype=self.dtype
in_channels=block_out_channels[-1],
resnet_groups=self.norm_num_groups,
attn_num_head_channels=None,
dtype=self.dtype,
)

# upsampling
Expand All @@ -594,10 +625,11 @@ def setup(self):

is_final_block = i == len(block_out_channels) - 1

up_block = FlaxUpEncoderBlock2D(
up_block = FlaxUpDecoderBlock2D(
in_channels=prev_output_channel,
out_channels=output_channel,
num_layers=self.layers_per_block + 1,
resnet_groups=self.norm_num_groups,
add_upsample=not is_final_block,
dtype=self.dtype,
)
Expand All @@ -607,7 +639,7 @@ def setup(self):
self.up_blocks = up_blocks

# end
self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
self.conv_norm_out = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-6)
self.conv_out = nn.Conv(
self.out_channels,
kernel_size=(3, 3),
Expand Down
9 changes: 9 additions & 0 deletions src/diffusers/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import requests
from packaging import version

from .import_utils import is_flax_available


global_rng = random.Random()
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
Expand Down Expand Up @@ -89,6 +91,13 @@ def slow(test_case):
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)


def require_flax(test_case):
"""
Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed
"""
return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)


def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
"""
Args:
Expand Down
44 changes: 44 additions & 0 deletions tests/test_modeling_common_flax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from diffusers.utils import is_flax_available
from diffusers.utils.testing_utils import require_flax


if is_flax_available():
import jax


@require_flax
class FlaxModelTesterMixin:
def test_output(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

model = self.model_class(**init_dict)
variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
jax.lax.stop_gradient(variables)

output = model.apply(variables, inputs_dict["sample"])

if isinstance(output, dict):
output = output.sample

self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")

def test_forward_with_norm_groups(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

init_dict["norm_num_groups"] = 16
init_dict["block_out_channels"] = (16, 32)

model = self.model_class(**init_dict)
variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
jax.lax.stop_gradient(variables)

output = model.apply(variables, inputs_dict["sample"])

if isinstance(output, dict):
output = output.sample

self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
39 changes: 39 additions & 0 deletions tests/test_models_vae_flax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import unittest

from diffusers import FlaxAutoencoderKL
from diffusers.utils import is_flax_available
from diffusers.utils.testing_utils import require_flax

from .test_modeling_common_flax import FlaxModelTesterMixin


if is_flax_available():
import jax


@require_flax
class FlaxAutoencoderKLTests(FlaxModelTesterMixin, unittest.TestCase):
model_class = FlaxAutoencoderKL

@property
def dummy_input(self):
batch_size = 4
num_channels = 3
sizes = (32, 32)

prng_key = jax.random.PRNGKey(0)
image = jax.random.uniform(prng_key, ((batch_size, num_channels) + sizes))

return {"sample": image, "prng_key": prng_key}

def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": [32, 64],
"in_channels": 3,
"out_channels": 3,
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
"latent_channels": 4,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict