Skip to content

Commit d2f7486

Browse files
authored
convert torch.split return to list in RAFT (#7597)
1 parent 689ff29 commit d2f7486

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

torchvision/prototype/models/depth/stereo/raft_stereo.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,12 @@ def forward(
450450
hidden_state, context = torch.split(context_outs[i], [hidden_dims[i], context_out_channels[i]], dim=1)
451451
hidden_states.append(torch.tanh(hidden_state))
452452
contexts.append(
453-
torch.split(context_conv(F.relu(context)), [hidden_dims[i], hidden_dims[i], hidden_dims[i]], dim=1)
453+
# mypy is technically correct here. The return type of `torch.split` was incorrectly annotated with
454+
# `List[int]` although it should have been `Tuple[Tensor, ...]`. However, the latter is not supported by
455+
# JIT and thus we have to keep the wrong annotation here and silence mypy.
456+
torch.split( # type: ignore[arg-type]
457+
context_conv(F.relu(context)), [hidden_dims[i], hidden_dims[i], hidden_dims[i]], dim=1
458+
)
454459
)
455460

456461
_, Cf, Hf, Wf = fmap1.shape

0 commit comments

Comments
 (0)