Skip to content

Commit c967b88

Browse files
authored
Update unet.py (#1955)
1 parent d0ec11b commit c967b88

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

pl_examples/models/unet.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,10 @@ def __init__(self, in_ch: int, out_ch: int, bilinear: bool = False):
9999
super().__init__()
100100
self.upsample = None
101101
if bilinear:
102-
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
102+
self.upsample = nn.Sequential(
103+
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
104+
nn.Conv2d(in_ch, in_ch // 2, kernel_size=1),
105+
)
103106
else:
104107
self.upsample = nn.ConvTranspose2d(in_ch, in_ch // 2, kernel_size=2, stride=2)
105108

0 commit comments

Comments
 (0)