Skip to content

Commit 7ef7788

Browse files
committed
Fix CUDA crash w/ channels-last + CSP models. Remove use of chunk()
1 parent 317ea3e commit 7ef7788

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

timm/models/cspnet.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,9 +264,11 @@ def forward(self, x):
264264
if self.conv_down is not None:
265265
x = self.conv_down(x)
266266
x = self.conv_exp(x)
267-
xs, xb = x.chunk(2, dim=1)
267+
split = x.shape[1] // 2
268+
xs, xb = x[:, :split], x[:, split:]
268269
xb = self.blocks(xb)
269-
out = self.conv_transition(torch.cat([xs, self.conv_transition_b(xb)], dim=1))
270+
xb = self.conv_transition_b(xb).contiguous()
271+
out = self.conv_transition(torch.cat([xs, xb], dim=1))
270272
return out
271273

272274

0 commit comments

Comments
 (0)