Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit fd362dc

Browse files
authored
Replace FeedForward with MLPBlock (#201)
* Update monai-weekly prerelease and replace FeedForward with MLPBlock * Add dropout Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>
1 parent 9dca9e4 commit fd362dc

File tree

3 files changed

+6
-48
lines changed

3 files changed

+6
-48
lines changed

generative/networks/nets/diffusion_model_unet.py

Lines changed: 4 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636
import torch
3737
import torch.nn.functional as F
38-
from monai.networks.blocks import Convolution
38+
from monai.networks.blocks import Convolution, MLPBlock
3939
from monai.networks.layers.factories import Pool
4040
from torch import nn
4141

@@ -66,46 +66,6 @@ def zero_module(module: nn.Module) -> nn.Module:
6666
return module
6767

6868

69-
class GEGLU(nn.Module):
70-
"""
71-
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
72-
73-
Args:
74-
dim_in: number of channels in the input.
75-
dim_out: number of channels in the output.
76-
"""
77-
78-
def __init__(self, dim_in: int, dim_out: int) -> None:
79-
super().__init__()
80-
self.proj = nn.Linear(dim_in, dim_out * 2)
81-
82-
def forward(self, x: torch.Tensor) -> torch.Tensor:
83-
x, gate = self.proj(x).chunk(2, dim=-1)
84-
return x * F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
85-
86-
87-
class FeedForward(nn.Module):
88-
"""
89-
A feed-forward layer.
90-
91-
Args:
92-
num_channels: number of channels in the input.
93-
dim_out: number of channels in the output. If not given, defaults to `dim`.
94-
mult: multiplier to use for the hidden dimension.
95-
dropout: dropout probability to use.
96-
"""
97-
98-
def __init__(self, num_channels: int, dim_out: Optional[int] = None, mult: int = 4, dropout: float = 0.0) -> None:
99-
super().__init__()
100-
inner_dim = int(num_channels * mult)
101-
dim_out = dim_out if dim_out is not None else num_channels
102-
103-
self.net = nn.Sequential(GEGLU(num_channels, inner_dim), nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
104-
105-
def forward(self, x: torch.Tensor) -> torch.Tensor:
106-
return self.net(x)
107-
108-
10969
class CrossAttention(nn.Module):
11070
"""
11171
A cross attention layer.
@@ -239,7 +199,7 @@ def __init__(
239199
dropout=dropout,
240200
upcast_attention=upcast_attention,
241201
) # is a self-attention
242-
self.ff = FeedForward(num_channels, dropout=dropout)
202+
self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout)
243203
self.attn2 = CrossAttention(
244204
query_dim=num_channels,
245205
cross_attention_dim=cross_attention_dim,
@@ -1677,10 +1637,8 @@ def __init__(
16771637
super().__init__()
16781638
if with_conditioning is True and cross_attention_dim is None:
16791639
raise ValueError(
1680-
(
1681-
"DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) "
1682-
"when using with_conditioning."
1683-
)
1640+
"DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) "
1641+
"when using with_conditioning."
16841642
)
16851643
if cross_attention_dim is not None and with_conditioning is False:
16861644
raise ValueError(

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
numpy>=1.17
22
torch>=1.8
3-
monai-weekly==1.1.dev2248
3+
monai-weekly==1.2.dev2304

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,6 @@
1717
version="0.1.0",
1818
description="Installer to help to use the prototypes from MONAI generative models in other projects.",
1919
install_requires=[
20-
"monai-weekly==1.1.dev2248",
20+
"monai-weekly==1.2.dev2304",
2121
],
2222
)

0 commit comments

Comments
 (0)