From e6646dbedb85a136241992d4e51e14ea908262d7 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 17 May 2023 12:46:55 +0200 Subject: [PATCH 1/4] fix annotation --- torchvision/prototype/models/depth/stereo/raft_stereo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/models/depth/stereo/raft_stereo.py b/torchvision/prototype/models/depth/stereo/raft_stereo.py index 20ef077a615..b80e16912f7 100644 --- a/torchvision/prototype/models/depth/stereo/raft_stereo.py +++ b/torchvision/prototype/models/depth/stereo/raft_stereo.py @@ -442,7 +442,7 @@ def forward( hidden_dims = self.update_block.hidden_dims context_out_channels = [context_outs[i].shape[1] - hidden_dims[i] for i in range(len(context_outs))] hidden_states: List[Tensor] = [] - contexts: List[List[Tensor]] = [] + contexts: List[Tuple[Tensor, ...]] = [] for i, context_conv in enumerate(self.context_convs): # As in the original paper, the actual output of the context encoder is split in 2 parts: # - one part is used to initialize the hidden state of the recurent units of the update block From 62b76341af81396f9816a1c7b1bafbc0ddd1b490 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 17 May 2023 13:01:42 +0200 Subject: [PATCH 2/4] Revert "fix annotation" This reverts commit e6646dbedb85a136241992d4e51e14ea908262d7. --- torchvision/prototype/models/depth/stereo/raft_stereo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/models/depth/stereo/raft_stereo.py b/torchvision/prototype/models/depth/stereo/raft_stereo.py index b80e16912f7..20ef077a615 100644 --- a/torchvision/prototype/models/depth/stereo/raft_stereo.py +++ b/torchvision/prototype/models/depth/stereo/raft_stereo.py @@ -442,7 +442,7 @@ def forward( hidden_dims = self.update_block.hidden_dims context_out_channels = [context_outs[i].shape[1] - hidden_dims[i] for i in range(len(context_outs))] hidden_states: List[Tensor] = [] - contexts: List[Tuple[Tensor, ...]] = [] + contexts: List[List[Tensor]] = [] for i, context_conv in enumerate(self.context_convs): # As in the original paper, the actual output of the context encoder is split in 2 parts: # - one part is used to initialize the hidden state of the recurent units of the update block From 79e0c6faef797a4db44ae5dc0471e8c4ad1623d4 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 17 May 2023 13:02:29 +0200 Subject: [PATCH 3/4] convert output of torch.split to list for JIT --- torchvision/prototype/models/depth/stereo/raft_stereo.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchvision/prototype/models/depth/stereo/raft_stereo.py b/torchvision/prototype/models/depth/stereo/raft_stereo.py index 20ef077a615..fd4d6e4a04f 100644 --- a/torchvision/prototype/models/depth/stereo/raft_stereo.py +++ b/torchvision/prototype/models/depth/stereo/raft_stereo.py @@ -450,7 +450,9 @@ def forward( hidden_state, context = torch.split(context_outs[i], [hidden_dims[i], context_out_channels[i]], dim=1) hidden_states.append(torch.tanh(hidden_state)) contexts.append( - torch.split(context_conv(F.relu(context)), [hidden_dims[i], hidden_dims[i], hidden_dims[i]], dim=1) + list( + torch.split(context_conv(F.relu(context)), [hidden_dims[i], hidden_dims[i], hidden_dims[i]], dim=1) + ) ) _, Cf, Hf, Wf = fmap1.shape From 25545b9ecd66e8cde06f2e587b0f21bbf080be1a Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 19 May 2023 13:24:43 +0200 Subject: [PATCH 4/4] silence mypy instead of runtime changes --- torchvision/prototype/models/depth/stereo/raft_stereo.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/models/depth/stereo/raft_stereo.py b/torchvision/prototype/models/depth/stereo/raft_stereo.py index fd4d6e4a04f..1519356316b 100644 --- a/torchvision/prototype/models/depth/stereo/raft_stereo.py +++ b/torchvision/prototype/models/depth/stereo/raft_stereo.py @@ -450,8 +450,11 @@ def forward( hidden_state, context = torch.split(context_outs[i], [hidden_dims[i], context_out_channels[i]], dim=1) hidden_states.append(torch.tanh(hidden_state)) contexts.append( - list( - torch.split(context_conv(F.relu(context)), [hidden_dims[i], hidden_dims[i], hidden_dims[i]], dim=1) + # mypy is technically correct here. The return type of `torch.split` was incorrectly annotated with + # `List[int]` although it should have been `Tuple[Tensor, ...]`. However, the latter is not supported by + # JIT and thus we have to keep the wrong annotation here and silence mypy. + torch.split( # type: ignore[arg-type] + context_conv(F.relu(context)), [hidden_dims[i], hidden_dims[i], hidden_dims[i]], dim=1 ) )