diff --git a/torchvision/prototype/models/depth/stereo/raft_stereo.py b/torchvision/prototype/models/depth/stereo/raft_stereo.py index 20ef077a615..1519356316b 100644 --- a/torchvision/prototype/models/depth/stereo/raft_stereo.py +++ b/torchvision/prototype/models/depth/stereo/raft_stereo.py @@ -450,7 +450,12 @@ def forward( hidden_state, context = torch.split(context_outs[i], [hidden_dims[i], context_out_channels[i]], dim=1) hidden_states.append(torch.tanh(hidden_state)) contexts.append( - torch.split(context_conv(F.relu(context)), [hidden_dims[i], hidden_dims[i], hidden_dims[i]], dim=1) + # mypy is technically correct here. The return type of `torch.split` was incorrectly annotated with + # `List[int]` although it should have been `Tuple[Tensor, ...]`. However, the latter is not supported by + # JIT and thus we have to keep the wrong annotation here and silence mypy. + torch.split( # type: ignore[arg-type] + context_conv(F.relu(context)), [hidden_dims[i], hidden_dims[i], hidden_dims[i]], dim=1 + ) ) _, Cf, Hf, Wf = fmap1.shape