Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions torchvision/models/detection/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,40 @@ def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:

return targets

def encode_all(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
"""
vectorized version of `encode_single`
Args:
reference_boxes (Tensor): reference boxes
proposals (Tensor): boxes to be encoded

Returns:
Tensor: the encoded relative box offsets that can be used to
decode the boxes.

"""

# get the center of reference_boxes
reference_boxes_ctr_x = 0.5 * (reference_boxes[..., 0] + reference_boxes[..., 2])
reference_boxes_ctr_y = 0.5 * (reference_boxes[..., 1] + reference_boxes[..., 3])

# get box regression transformation deltas
target_l = reference_boxes_ctr_x - proposals[..., 0]
target_t = reference_boxes_ctr_y - proposals[..., 1]
target_r = proposals[..., 2] - reference_boxes_ctr_x
target_b = proposals[..., 3] - reference_boxes_ctr_y

targets = torch.stack((target_l, target_t, target_r, target_b), dim=-1)

if self.normalize_by_size:
reference_boxes_w = reference_boxes[..., 2] - reference_boxes[..., 0]
reference_boxes_h = reference_boxes[..., 3] - reference_boxes[..., 1]
reference_boxes_size = torch.stack(
(reference_boxes_w, reference_boxes_h, reference_boxes_w, reference_boxes_h), dim=-1
)
targets = targets / reference_boxes_size
return targets

def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
"""
From a set of original boxes and encoded relative box offsets,
Expand Down
13 changes: 7 additions & 6 deletions torchvision/models/detection/fcos.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,19 +88,20 @@ def compute_loss(

pred_boxes = self.box_coder.decode_all(bbox_regression, anchors)

# List[Tensor] to Tensor conversion of `all_gt_boxes_target` and `anchors`
all_gt_boxes_targets, anchors = torch.stack(all_gt_boxes_targets), torch.stack(anchors)

# amp issue: pred_boxes need to convert float
loss_bbox_reg = generalized_box_iou_loss(
pred_boxes[foregroud_mask],
torch.stack(all_gt_boxes_targets)[foregroud_mask],
all_gt_boxes_targets[foregroud_mask],
reduction="sum",
)

# ctrness loss
bbox_reg_targets = [
self.box_coder.encode_single(anchors_per_image, boxes_targets_per_image)
for anchors_per_image, boxes_targets_per_image in zip(anchors, all_gt_boxes_targets)
]
bbox_reg_targets = torch.stack(bbox_reg_targets, dim=0)

bbox_reg_targets = self.box_coder.encode_all(anchors, all_gt_boxes_targets)

if len(bbox_reg_targets) == 0:
gt_ctrness_targets = bbox_reg_targets.new_zeros(bbox_reg_targets.size()[:-1])
else:
Expand Down