Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,20 @@ def test_mobilenetv2_residual_setting(self):
out = model(x)
self.assertEqual(out.shape[-1], 1000)

def test_fasterrcnn_double(self):
model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False)
model.double()
model.eval()
input_shape = (3, 300, 300)
x = torch.rand(input_shape, dtype=torch.float64)
model_input = [x]
out = model(model_input)
self.assertIs(model_input[0], x)
self.assertEqual(len(out), 1)
self.assertTrue("boxes" in out[0])
self.assertTrue("scores" in out[0])
self.assertTrue("labels" in out[0])


for model_name in get_available_classification_models():
# for-loop bodies don't define scopes, so we have to save the variables
Expand Down
3 changes: 2 additions & 1 deletion torchvision/models/detection/roi_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,8 @@ def check_targets(self, targets):

def select_training_samples(self, proposals, targets):
self.check_targets(targets)
gt_boxes = [t["boxes"] for t in targets]
dtype = proposals[0].dtype
gt_boxes = [t["boxes"].to(dtype) for t in targets]
gt_labels = [t["labels"] for t in targets]

# append ground-truth bboxes to propos
Expand Down
12 changes: 7 additions & 5 deletions torchvision/models/detection/rpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def __init__(
self._cache = {}

@staticmethod
def generate_anchors(scales, aspect_ratios, device="cpu"):
scales = torch.as_tensor(scales, dtype=torch.float32, device=device)
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=torch.float32, device=device)
def generate_anchors(scales, aspect_ratios, dtype=torch.float32, device="cpu"):
scales = torch.as_tensor(scales, dtype=dtype, device=device)
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
h_ratios = torch.sqrt(aspect_ratios)
w_ratios = 1 / h_ratios

Expand All @@ -61,13 +61,14 @@ def generate_anchors(scales, aspect_ratios, device="cpu"):
base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2
return base_anchors.round()

def set_cell_anchors(self, device):
def set_cell_anchors(self, dtype, device):
if self.cell_anchors is not None:
return self.cell_anchors
cell_anchors = [
self.generate_anchors(
sizes,
aspect_ratios,
dtype,
device
)
for sizes, aspect_ratios in zip(self.sizes, self.aspect_ratios)
Expand Down Expand Up @@ -114,7 +115,8 @@ def forward(self, image_list, feature_maps):
grid_sizes = tuple([feature_map.shape[-2:] for feature_map in feature_maps])
image_size = image_list.tensors.shape[-2:]
strides = tuple((image_size[0] / g[0], image_size[1] / g[1]) for g in grid_sizes)
self.set_cell_anchors(feature_maps[0].device)
dtype, device = feature_maps[0].dtype, feature_maps[0].device
self.set_cell_anchors(dtype, device)
anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides)
anchors = []
for i, (image_height, image_width) in enumerate(image_list.image_sizes):
Expand Down