Skip to content

Commit fe1b384

Browse files
committed
convert gt_boxes to right dtype
1 parent edc875f commit fe1b384

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

torchvision/models/detection/roi_heads.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,8 @@ def check_targets(self, targets):
444444

445445
def select_training_samples(self, proposals, targets):
446446
self.check_targets(targets)
447-
gt_boxes = [t["boxes"] for t in targets]
447+
dtype = proposals[0].dtype
448+
gt_boxes = [t["boxes"].to(dtype) for t in targets]
448449
gt_labels = [t["labels"] for t in targets]
449450

450451
# append ground-truth bboxes to propos

0 commit comments

Comments
 (0)