diff --git a/torchvision/models/detection/generalized_rcnn.py b/torchvision/models/detection/generalized_rcnn.py index 1ee0542c9c6..4a8fadfe345 100644 --- a/torchvision/models/detection/generalized_rcnn.py +++ b/torchvision/models/detection/generalized_rcnn.py @@ -4,6 +4,7 @@ """ from collections import OrderedDict +from typing import Union import torch from torch import nn import warnings @@ -35,7 +36,7 @@ def __init__(self, backbone, rpn, roi_heads, transform): @torch.jit.unused def eager_outputs(self, losses, detections): - # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] + # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]] if self.training: return losses @@ -85,11 +86,11 @@ def forward(self, images, targets=None): boxes = target["boxes"] degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] if degenerate_boxes.any(): - # print the first degenrate box + # print the first degenerate box bb_idx = degenerate_boxes.any(dim=1).nonzero().view(-1)[0] degen_bb: List[float] = boxes[bb_idx].tolist() raise ValueError("All bounding boxes should have positive height and width." - " Found invaid box {} for target at index {}." + " Found invalid box {} for target at index {}." .format(degen_bb, target_idx)) features = self.backbone(images.tensors) diff --git a/torchvision/ops/poolers.py b/torchvision/ops/poolers.py index 32734cff86a..bf9412af056 100644 --- a/torchvision/ops/poolers.py +++ b/torchvision/ops/poolers.py @@ -1,4 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +from typing import Union + import torch import torch.nn.functional as F from torch import nn, Tensor @@ -119,7 +121,7 @@ class MultiScaleRoIAlign(nn.Module): def __init__( self, featmap_names: List[str], - output_size: List[int], + output_size: Union[int, Tuple[int], List[int]], sampling_ratio: int, ): super(MultiScaleRoIAlign, self).__init__()