|
19 | 19 |
|
20 | 20 |
|
21 | 21 | class FlaxCrossAttnDownBlock2D(nn.Module): |
| 22 | + r""" |
| 23 | + Cross Attention 2D Downsizing block - original architecture from Unet transformers: |
| 24 | + https://arxiv.org/abs/2103.06104 |
| 25 | +
|
| 26 | + Parameters: |
| 27 | + in_channels (:obj:`int`): |
| 28 | + Input channels |
| 29 | + out_channels (:obj:`int`): |
| 30 | + Output channels |
| 31 | + dropout (:obj:`float`, *optional*, defaults to 0.0): |
| 32 | + Dropout rate |
| 33 | + num_layers (:obj:`int`, *optional*, defaults to 1): |
| 34 | + Number of attention blocks layers |
| 35 | + attn_num_head_channels (:obj:`int`, *optional*, defaults to 1): |
| 36 | + Number of attention heads of each spatial transformer block |
| 37 | + add_downsample (:obj:`bool`, *optional*, defaults to `True`): |
| 38 | + Whether to add downsampling layer before each final output |
| 39 | + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): |
| 40 | + Parameters `dtype` |
| 41 | + """ |
22 | 42 | in_channels: int |
23 | 43 | out_channels: int |
24 | 44 | dropout: float = 0.0 |
@@ -73,6 +93,23 @@ def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=Tru |
73 | 93 |
|
74 | 94 |
|
75 | 95 | class FlaxDownBlock2D(nn.Module): |
| 96 | + r""" |
| 97 | + Flax 2D downsizing block |
| 98 | +
|
| 99 | + Parameters: |
| 100 | + in_channels (:obj:`int`): |
| 101 | + Input channels |
| 102 | + out_channels (:obj:`int`): |
| 103 | + Output channels |
| 104 | + dropout (:obj:`float`, *optional*, defaults to 0.0): |
| 105 | + Dropout rate |
| 106 | + num_layers (:obj:`int`, *optional*, defaults to 1): |
| 107 | + Number of attention blocks layers |
| 108 | + add_downsample (:obj:`bool`, *optional*, defaults to `True`): |
| 109 | + Whether to add downsampling layer before each final output |
| 110 | + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): |
| 111 | + Parameters `dtype` |
| 112 | + """ |
76 | 113 | in_channels: int |
77 | 114 | out_channels: int |
78 | 115 | dropout: float = 0.0 |
@@ -113,6 +150,26 @@ def __call__(self, hidden_states, temb, deterministic=True): |
113 | 150 |
|
114 | 151 |
|
115 | 152 | class FlaxCrossAttnUpBlock2D(nn.Module): |
| 153 | + r""" |
| 154 | + Cross Attention 2D Upsampling block - original architecture from Unet transformers: |
| 155 | + https://arxiv.org/abs/2103.06104 |
| 156 | +
|
| 157 | + Parameters: |
| 158 | + in_channels (:obj:`int`): |
| 159 | + Input channels |
| 160 | + out_channels (:obj:`int`): |
| 161 | + Output channels |
| 162 | + dropout (:obj:`float`, *optional*, defaults to 0.0): |
| 163 | + Dropout rate |
| 164 | + num_layers (:obj:`int`, *optional*, defaults to 1): |
| 165 | + Number of attention blocks layers |
| 166 | + attn_num_head_channels (:obj:`int`, *optional*, defaults to 1): |
| 167 | + Number of attention heads of each spatial transformer block |
| 168 | + add_upsample (:obj:`bool`, *optional*, defaults to `True`): |
| 169 | + Whether to add upsampling layer before each final output |
| 170 | + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): |
| 171 | + Parameters `dtype` |
| 172 | + """ |
116 | 173 | in_channels: int |
117 | 174 | out_channels: int |
118 | 175 | prev_output_channel: int |
@@ -170,6 +227,25 @@ def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_ |
170 | 227 |
|
171 | 228 |
|
172 | 229 | class FlaxUpBlock2D(nn.Module): |
| 230 | + r""" |
| 231 | + Flax 2D upsampling block |
| 232 | +
|
| 233 | + Parameters: |
| 234 | + in_channels (:obj:`int`): |
| 235 | + Input channels |
| 236 | + out_channels (:obj:`int`): |
| 237 | + Output channels |
| 238 | + prev_output_channel (:obj:`int`): |
| 239 | + Output channels from the previous block |
| 240 | + dropout (:obj:`float`, *optional*, defaults to 0.0): |
| 241 | + Dropout rate |
| 242 | + num_layers (:obj:`int`, *optional*, defaults to 1): |
| 243 | + Number of attention blocks layers |
| 244 | + add_downsample (:obj:`bool`, *optional*, defaults to `True`): |
| 245 | + Whether to add downsampling layer before each final output |
| 246 | + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): |
| 247 | + Parameters `dtype` |
| 248 | + """ |
173 | 249 | in_channels: int |
174 | 250 | out_channels: int |
175 | 251 | prev_output_channel: int |
@@ -214,6 +290,21 @@ def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=T |
214 | 290 |
|
215 | 291 |
|
216 | 292 | class FlaxUNetMidBlock2DCrossAttn(nn.Module): |
| 293 | + r""" |
| 294 | + Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104 |
| 295 | +
|
| 296 | + Parameters: |
| 297 | + in_channels (:obj:`int`): |
| 298 | + Input channels |
| 299 | + dropout (:obj:`float`, *optional*, defaults to 0.0): |
| 300 | + Dropout rate |
| 301 | + num_layers (:obj:`int`, *optional*, defaults to 1): |
| 302 | + Number of attention blocks layers |
| 303 | + attn_num_head_channels (:obj:`int`, *optional*, defaults to 1): |
| 304 | + Number of attention heads of each spatial transformer block |
| 305 | + dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): |
| 306 | + Parameters `dtype` |
| 307 | + """ |
217 | 308 | in_channels: int |
218 | 309 | dropout: float = 0.0 |
219 | 310 | num_layers: int = 1 |
|
0 commit comments