|
35 | 35 |
|
36 | 36 | import torch |
37 | 37 | import torch.nn.functional as F |
38 | | -from monai.networks.blocks import Convolution |
| 38 | +from monai.networks.blocks import Convolution, MLPBlock |
39 | 39 | from monai.networks.layers.factories import Pool |
40 | 40 | from torch import nn |
41 | 41 |
|
@@ -66,46 +66,6 @@ def zero_module(module: nn.Module) -> nn.Module: |
66 | 66 | return module |
67 | 67 |
|
68 | 68 |
|
69 | | -class GEGLU(nn.Module): |
70 | | - """ |
71 | | - A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. |
72 | | -
|
73 | | - Args: |
74 | | - dim_in: number of channels in the input. |
75 | | - dim_out: number of channels in the output. |
76 | | - """ |
77 | | - |
78 | | - def __init__(self, dim_in: int, dim_out: int) -> None: |
79 | | - super().__init__() |
80 | | - self.proj = nn.Linear(dim_in, dim_out * 2) |
81 | | - |
82 | | - def forward(self, x: torch.Tensor) -> torch.Tensor: |
83 | | - x, gate = self.proj(x).chunk(2, dim=-1) |
84 | | - return x * F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) |
85 | | - |
86 | | - |
87 | | -class FeedForward(nn.Module): |
88 | | - """ |
89 | | - A feed-forward layer. |
90 | | -
|
91 | | - Args: |
92 | | - num_channels: number of channels in the input. |
93 | | - dim_out: number of channels in the output. If not given, defaults to `dim`. |
94 | | - mult: multiplier to use for the hidden dimension. |
95 | | - dropout: dropout probability to use. |
96 | | - """ |
97 | | - |
98 | | - def __init__(self, num_channels: int, dim_out: Optional[int] = None, mult: int = 4, dropout: float = 0.0) -> None: |
99 | | - super().__init__() |
100 | | - inner_dim = int(num_channels * mult) |
101 | | - dim_out = dim_out if dim_out is not None else num_channels |
102 | | - |
103 | | - self.net = nn.Sequential(GEGLU(num_channels, inner_dim), nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) |
104 | | - |
105 | | - def forward(self, x: torch.Tensor) -> torch.Tensor: |
106 | | - return self.net(x) |
107 | | - |
108 | | - |
109 | 69 | class CrossAttention(nn.Module): |
110 | 70 | """ |
111 | 71 | A cross attention layer. |
@@ -239,7 +199,7 @@ def __init__( |
239 | 199 | dropout=dropout, |
240 | 200 | upcast_attention=upcast_attention, |
241 | 201 | ) # is a self-attention |
242 | | - self.ff = FeedForward(num_channels, dropout=dropout) |
| 202 | + self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout) |
243 | 203 | self.attn2 = CrossAttention( |
244 | 204 | query_dim=num_channels, |
245 | 205 | cross_attention_dim=cross_attention_dim, |
@@ -1677,10 +1637,8 @@ def __init__( |
1677 | 1637 | super().__init__() |
1678 | 1638 | if with_conditioning is True and cross_attention_dim is None: |
1679 | 1639 | raise ValueError( |
1680 | | - ( |
1681 | | - "DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " |
1682 | | - "when using with_conditioning." |
1683 | | - ) |
| 1640 | + "DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " |
| 1641 | + "when using with_conditioning." |
1684 | 1642 | ) |
1685 | 1643 | if cross_attention_dim is not None and with_conditioning is False: |
1686 | 1644 | raise ValueError( |
|
0 commit comments