Skip to content

Commit 8b0be93

Browse files
younesbelkadamishig25pcuenca
authored
Flax documentation (#589)
* documenting `attention_flax.py` file * documenting `embeddings_flax.py` * documenting `unet_blocks_flax.py` * Add new objs to doc page * document `vae_flax.py` * Apply suggestions from code review * modify `unet_2d_condition_flax.py` * make style * Apply suggestions from code review * make style * Apply suggestions from code review * fix indent * fix typo * fix indent unet * Update src/diffusers/models/vae_flax.py * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: Mishig Davaadorj <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
1 parent df80ccf commit 8b0be93

File tree

6 files changed

+429
-8
lines changed

6 files changed

+429
-8
lines changed

docs/source/api/models.mdx

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,21 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
4545

4646
## AutoencoderKL
4747
[[autodoc]] AutoencoderKL
48+
49+
## FlaxModelMixin
50+
[[autodoc]] FlaxModelMixin
51+
52+
## FlaxUNet2DConditionOutput
53+
[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionOutput
54+
55+
## FlaxUNet2DConditionModel
56+
[[autodoc]] FlaxUNet2DConditionModel
57+
58+
## FlaxDecoderOutput
59+
[[autodoc]] models.vae_flax.FlaxDecoderOutput
60+
61+
## FlaxAutoencoderKLOutput
62+
[[autodoc]] models.vae_flax.FlaxAutoencoderKLOutput
63+
64+
## FlaxAutoencoderKL
65+
[[autodoc]] FlaxAutoencoderKL

src/diffusers/models/attention_flax.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,22 @@
1717

1818

1919
class FlaxAttentionBlock(nn.Module):
20+
r"""
21+
A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
22+
23+
Parameters:
24+
query_dim (:obj:`int`):
25+
Input hidden states dimension
26+
heads (:obj:`int`, *optional*, defaults to 8):
27+
Number of heads
28+
dim_head (:obj:`int`, *optional*, defaults to 64):
29+
Hidden states dimension inside each head
30+
dropout (:obj:`float`, *optional*, defaults to 0.0):
31+
Dropout rate
32+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
33+
Parameters `dtype`
34+
35+
"""
2036
query_dim: int
2137
heads: int = 8
2238
dim_head: int = 64
@@ -74,6 +90,23 @@ def __call__(self, hidden_states, context=None, deterministic=True):
7490

7591

7692
class FlaxBasicTransformerBlock(nn.Module):
93+
r"""
94+
A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in:
95+
https://arxiv.org/abs/1706.03762
96+
97+
98+
Parameters:
99+
dim (:obj:`int`):
100+
Inner hidden states dimension
101+
n_heads (:obj:`int`):
102+
Number of heads
103+
d_head (:obj:`int`):
104+
Hidden states dimension inside each head
105+
dropout (:obj:`float`, *optional*, defaults to 0.0):
106+
Dropout rate
107+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
108+
Parameters `dtype`
109+
"""
77110
dim: int
78111
n_heads: int
79112
d_head: int
@@ -110,6 +143,25 @@ def __call__(self, hidden_states, context, deterministic=True):
110143

111144

112145
class FlaxSpatialTransformer(nn.Module):
146+
r"""
147+
A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:
148+
https://arxiv.org/pdf/1506.02025.pdf
149+
150+
151+
Parameters:
152+
in_channels (:obj:`int`):
153+
Input number of channels
154+
n_heads (:obj:`int`):
155+
Number of heads
156+
d_head (:obj:`int`):
157+
Hidden states dimension inside each head
158+
depth (:obj:`int`, *optional*, defaults to 1):
159+
Number of transformers block
160+
dropout (:obj:`float`, *optional*, defaults to 0.0):
161+
Dropout rate
162+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
163+
Parameters `dtype`
164+
"""
113165
in_channels: int
114166
n_heads: int
115167
d_head: int
@@ -162,6 +214,18 @@ def __call__(self, hidden_states, context, deterministic=True):
162214

163215

164216
class FlaxGluFeedForward(nn.Module):
217+
r"""
218+
Flax module that encapsulates two Linear layers separated by a gated linear unit activation from:
219+
https://arxiv.org/abs/2002.05202
220+
221+
Parameters:
222+
dim (:obj:`int`):
223+
Inner hidden states dimension
224+
dropout (:obj:`float`, *optional*, defaults to 0.0):
225+
Dropout rate
226+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
227+
Parameters `dtype`
228+
"""
165229
dim: int
166230
dropout: float = 0.0
167231
dtype: jnp.dtype = jnp.float32
@@ -179,6 +243,18 @@ def __call__(self, hidden_states, deterministic=True):
179243

180244

181245
class FlaxGEGLU(nn.Module):
246+
r"""
247+
Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
248+
https://arxiv.org/abs/2002.05202.
249+
250+
Parameters:
251+
dim (:obj:`int`):
252+
Input hidden states dimension
253+
dropout (:obj:`float`, *optional*, defaults to 0.0):
254+
Dropout rate
255+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
256+
Parameters `dtype`
257+
"""
182258
dim: int
183259
dropout: float = 0.0
184260
dtype: jnp.dtype = jnp.float32

src/diffusers/models/embeddings_flax.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,15 @@ def get_sinusoidal_embeddings(timesteps, embedding_dim, freq_shift: float = 1):
3737

3838

3939
class FlaxTimestepEmbedding(nn.Module):
40+
r"""
41+
Time step Embedding Module. Learns embeddings for input time steps.
42+
43+
Args:
44+
time_embed_dim (`int`, *optional*, defaults to `32`):
45+
Time step embedding dimension
46+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
47+
Parameters `dtype`
48+
"""
4049
time_embed_dim: int = 32
4150
dtype: jnp.dtype = jnp.float32
4251

@@ -49,6 +58,13 @@ def __call__(self, temb):
4958

5059

5160
class FlaxTimesteps(nn.Module):
61+
r"""
62+
Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239
63+
64+
Args:
65+
dim (`int`, *optional*, defaults to `32`):
66+
Time step embedding dimension
67+
"""
5268
dim: int = 32
5369
freq_shift: float = 1
5470

src/diffusers/models/unet_2d_condition_flax.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,23 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
3939
This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for the generic methods the library
4040
implements for all the models (such as downloading or saving, etc.)
4141
42+
Also, this model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
43+
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
44+
general usage and behavior.
45+
46+
Finally, this model supports inherent JAX features such as:
47+
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
48+
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
49+
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
50+
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
51+
4252
Parameters:
43-
sample_size (`int`, *optional*): The size of the input sample.
44-
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
45-
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
53+
sample_size (`int`, *optional*):
54+
The size of the input sample.
55+
in_channels (`int`, *optional*, defaults to 4):
56+
The number of channels in the input sample.
57+
out_channels (`int`, *optional*, defaults to 4):
58+
The number of channels in the output.
4659
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
4760
The tuple of downsample blocks to use. The corresponding class names will be: "FlaxCrossAttnDownBlock2D",
4861
"FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D"
@@ -51,10 +64,14 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
5164
"FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D"
5265
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
5366
The tuple of output channels for each block.
54-
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
55-
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
56-
cross_attention_dim (`int`, *optional*, defaults to 768): The dimension of the cross attention features.
57-
dropout (`float`, *optional*, defaults to 0): Dropout probability for down, up and bottleneck blocks.
67+
layers_per_block (`int`, *optional*, defaults to 2):
68+
The number of layers per block.
69+
attention_head_dim (`int`, *optional*, defaults to 8):
70+
The dimension of the attention heads.
71+
cross_attention_dim (`int`, *optional*, defaults to 768):
72+
The dimension of the cross attention features.
73+
dropout (`float`, *optional*, defaults to 0):
74+
Dropout probability for down, up and bottleneck blocks.
5875
"""
5976

6077
sample_size: int = 32

src/diffusers/models/unet_blocks_flax.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,26 @@
1919

2020

2121
class FlaxCrossAttnDownBlock2D(nn.Module):
22+
r"""
23+
Cross Attention 2D Downsizing block - original architecture from Unet transformers:
24+
https://arxiv.org/abs/2103.06104
25+
26+
Parameters:
27+
in_channels (:obj:`int`):
28+
Input channels
29+
out_channels (:obj:`int`):
30+
Output channels
31+
dropout (:obj:`float`, *optional*, defaults to 0.0):
32+
Dropout rate
33+
num_layers (:obj:`int`, *optional*, defaults to 1):
34+
Number of attention blocks layers
35+
attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):
36+
Number of attention heads of each spatial transformer block
37+
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
38+
Whether to add downsampling layer before each final output
39+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
40+
Parameters `dtype`
41+
"""
2242
in_channels: int
2343
out_channels: int
2444
dropout: float = 0.0
@@ -73,6 +93,23 @@ def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=Tru
7393

7494

7595
class FlaxDownBlock2D(nn.Module):
96+
r"""
97+
Flax 2D downsizing block
98+
99+
Parameters:
100+
in_channels (:obj:`int`):
101+
Input channels
102+
out_channels (:obj:`int`):
103+
Output channels
104+
dropout (:obj:`float`, *optional*, defaults to 0.0):
105+
Dropout rate
106+
num_layers (:obj:`int`, *optional*, defaults to 1):
107+
Number of attention blocks layers
108+
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
109+
Whether to add downsampling layer before each final output
110+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
111+
Parameters `dtype`
112+
"""
76113
in_channels: int
77114
out_channels: int
78115
dropout: float = 0.0
@@ -113,6 +150,26 @@ def __call__(self, hidden_states, temb, deterministic=True):
113150

114151

115152
class FlaxCrossAttnUpBlock2D(nn.Module):
153+
r"""
154+
Cross Attention 2D Upsampling block - original architecture from Unet transformers:
155+
https://arxiv.org/abs/2103.06104
156+
157+
Parameters:
158+
in_channels (:obj:`int`):
159+
Input channels
160+
out_channels (:obj:`int`):
161+
Output channels
162+
dropout (:obj:`float`, *optional*, defaults to 0.0):
163+
Dropout rate
164+
num_layers (:obj:`int`, *optional*, defaults to 1):
165+
Number of attention blocks layers
166+
attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):
167+
Number of attention heads of each spatial transformer block
168+
add_upsample (:obj:`bool`, *optional*, defaults to `True`):
169+
Whether to add upsampling layer before each final output
170+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
171+
Parameters `dtype`
172+
"""
116173
in_channels: int
117174
out_channels: int
118175
prev_output_channel: int
@@ -170,6 +227,25 @@ def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_
170227

171228

172229
class FlaxUpBlock2D(nn.Module):
230+
r"""
231+
Flax 2D upsampling block
232+
233+
Parameters:
234+
in_channels (:obj:`int`):
235+
Input channels
236+
out_channels (:obj:`int`):
237+
Output channels
238+
prev_output_channel (:obj:`int`):
239+
Output channels from the previous block
240+
dropout (:obj:`float`, *optional*, defaults to 0.0):
241+
Dropout rate
242+
num_layers (:obj:`int`, *optional*, defaults to 1):
243+
Number of attention blocks layers
244+
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
245+
Whether to add downsampling layer before each final output
246+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
247+
Parameters `dtype`
248+
"""
173249
in_channels: int
174250
out_channels: int
175251
prev_output_channel: int
@@ -214,6 +290,21 @@ def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=T
214290

215291

216292
class FlaxUNetMidBlock2DCrossAttn(nn.Module):
293+
r"""
294+
Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104
295+
296+
Parameters:
297+
in_channels (:obj:`int`):
298+
Input channels
299+
dropout (:obj:`float`, *optional*, defaults to 0.0):
300+
Dropout rate
301+
num_layers (:obj:`int`, *optional*, defaults to 1):
302+
Number of attention blocks layers
303+
attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):
304+
Number of attention heads of each spatial transformer block
305+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
306+
Parameters `dtype`
307+
"""
217308
in_channels: int
218309
dropout: float = 0.0
219310
num_layers: int = 1

0 commit comments

Comments
 (0)