Skip to content

Commit 0e7cc3a

Browse files
committed
Merge branch 'main' into revamp-prototype-features-transforms
2 parents be67431 + 11d903e commit 0e7cc3a

File tree

4 files changed

+88
-40
lines changed

4 files changed

+88
-40
lines changed

.github/process_commit.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
"module: video",
4444
"Perf",
4545
"Revert(ed)",
46+
"topic: build",
4647
}
4748

4849

test/test_models_detection_negative_samples.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,17 @@ def test_forward_negative_sample_retinanet(self):
143143

144144
assert_equal(loss_dict["bbox_regression"], torch.tensor(0.0))
145145

146+
def test_forward_negative_sample_fcos(self):
147+
model = torchvision.models.detection.fcos_resnet50_fpn(
148+
num_classes=2, min_size=100, max_size=100, pretrained_backbone=False
149+
)
150+
151+
images, targets = self._make_empty_sample()
152+
loss_dict = model(images, targets)
153+
154+
assert_equal(loss_dict["bbox_regression"], torch.tensor(0.0))
155+
assert_equal(loss_dict["bbox_ctrness"], torch.tensor(0.0))
156+
146157
def test_forward_negative_sample_ssd(self):
147158
model = torchvision.models.detection.ssd300_vgg16(num_classes=2, pretrained_backbone=False)
148159

test/test_videoapi.py

Lines changed: 62 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -56,37 +56,68 @@ def test_frame_reading(self):
5656
for test_video, config in test_videos.items():
5757
full_path = os.path.join(VIDEO_DIR, test_video)
5858

59-
av_reader = av.open(full_path)
60-
61-
if av_reader.streams.video:
62-
video_reader = VideoReader(full_path, "video")
63-
for av_frame in av_reader.decode(av_reader.streams.video[0]):
64-
vr_frame = next(video_reader)
65-
66-
assert float(av_frame.pts * av_frame.time_base) == approx(vr_frame["pts"], abs=0.1)
67-
68-
av_array = torch.tensor(av_frame.to_rgb().to_ndarray()).permute(2, 0, 1)
69-
vr_array = vr_frame["data"]
70-
mean_delta = torch.mean(torch.abs(av_array.float() - vr_array.float()))
71-
# on average the difference is very small and caused
72-
# by decoding (around 1%)
73-
# TODO: asses empirically how to set this? atm it's 1%
74-
# averaged over all frames
75-
assert mean_delta.item() < 2.5
76-
77-
av_reader = av.open(full_path)
78-
if av_reader.streams.audio:
79-
video_reader = VideoReader(full_path, "audio")
80-
for av_frame in av_reader.decode(av_reader.streams.audio[0]):
81-
vr_frame = next(video_reader)
82-
assert float(av_frame.pts * av_frame.time_base) == approx(vr_frame["pts"], abs=0.1)
83-
84-
av_array = torch.tensor(av_frame.to_ndarray()).permute(1, 0)
85-
vr_array = vr_frame["data"]
86-
87-
max_delta = torch.max(torch.abs(av_array.float() - vr_array.float()))
88-
# we assure that there is never more than 1% difference in signal
89-
assert max_delta.item() < 0.001
59+
with av.open(full_path) as av_reader:
60+
is_video = True if av_reader.streams.video else False
61+
62+
if is_video:
63+
av_frames, vr_frames = [], []
64+
av_pts, vr_pts = [], []
65+
# get av frames
66+
for av_frame in av_reader.decode(av_reader.streams.video[0]):
67+
av_frames.append(torch.tensor(av_frame.to_rgb().to_ndarray()).permute(2, 0, 1))
68+
av_pts.append(av_frame.pts * av_frame.time_base)
69+
70+
# get vr frames
71+
video_reader = VideoReader(full_path, "video")
72+
for vr_frame in video_reader:
73+
vr_frames.append(vr_frame["data"])
74+
vr_pts.append(vr_frame["pts"])
75+
76+
# same number of frames
77+
assert len(vr_frames) == len(av_frames)
78+
assert len(vr_pts) == len(av_pts)
79+
80+
# compare the frames and ptss
81+
for i in range(len(vr_frames)):
82+
assert float(av_pts[i]) == approx(vr_pts[i], abs=0.1)
83+
mean_delta = torch.mean(torch.abs(av_frames[i].float() - vr_frames[i].float()))
84+
# on average the difference is very small and caused
85+
# by decoding (around 1%)
86+
# TODO: asses empirically how to set this? atm it's 1%
87+
# averaged over all frames
88+
assert mean_delta.item() < 2.55
89+
90+
del vr_frames, av_frames, vr_pts, av_pts
91+
92+
# test audio reading compared to PYAV
93+
with av.open(full_path) as av_reader:
94+
is_audio = True if av_reader.streams.audio else False
95+
96+
if is_audio:
97+
av_frames, vr_frames = [], []
98+
av_pts, vr_pts = [], []
99+
# get av frames
100+
for av_frame in av_reader.decode(av_reader.streams.audio[0]):
101+
av_frames.append(torch.tensor(av_frame.to_ndarray()).permute(1, 0))
102+
av_pts.append(av_frame.pts * av_frame.time_base)
103+
av_reader.close()
104+
105+
# get vr frames
106+
video_reader = VideoReader(full_path, "audio")
107+
for vr_frame in video_reader:
108+
vr_frames.append(vr_frame["data"])
109+
vr_pts.append(vr_frame["pts"])
110+
111+
# same number of frames
112+
assert len(vr_frames) == len(av_frames)
113+
assert len(vr_pts) == len(av_pts)
114+
115+
# compare the frames and ptss
116+
for i in range(len(vr_frames)):
117+
assert float(av_pts[i]) == approx(vr_pts[i], abs=0.1)
118+
max_delta = torch.max(torch.abs(av_frames[i].float() - vr_frames[i].float()))
119+
# we assure that there is never more than 1% difference in signal
120+
assert max_delta.item() < 0.001
90121

91122
def test_metadata(self):
92123
"""

torchvision/models/detection/fcos.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,13 @@ def compute_loss(
5959
all_gt_classes_targets = []
6060
all_gt_boxes_targets = []
6161
for targets_per_image, matched_idxs_per_image in zip(targets, matched_idxs):
62-
gt_classes_targets = targets_per_image["labels"][matched_idxs_per_image.clip(min=0)]
62+
if len(targets_per_image["labels"]) == 0:
63+
gt_classes_targets = targets_per_image["labels"].new_zeros((len(matched_idxs_per_image),))
64+
gt_boxes_targets = targets_per_image["boxes"].new_zeros((len(matched_idxs_per_image), 4))
65+
else:
66+
gt_classes_targets = targets_per_image["labels"][matched_idxs_per_image.clip(min=0)]
67+
gt_boxes_targets = targets_per_image["boxes"][matched_idxs_per_image.clip(min=0)]
6368
gt_classes_targets[matched_idxs_per_image < 0] = -1 # backgroud
64-
gt_boxes_targets = targets_per_image["boxes"][matched_idxs_per_image.clip(min=0)]
6569
all_gt_classes_targets.append(gt_classes_targets)
6670
all_gt_boxes_targets.append(gt_boxes_targets)
6771

@@ -95,13 +99,14 @@ def compute_loss(
9599
]
96100
bbox_reg_targets = torch.stack(bbox_reg_targets, dim=0)
97101
if len(bbox_reg_targets) == 0:
98-
bbox_reg_targets.new_zeros(len(bbox_reg_targets))
99-
left_right = bbox_reg_targets[:, :, [0, 2]]
100-
top_bottom = bbox_reg_targets[:, :, [1, 3]]
101-
gt_ctrness_targets = torch.sqrt(
102-
(left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0])
103-
* (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
104-
)
102+
gt_ctrness_targets = bbox_reg_targets.new_zeros(bbox_reg_targets.size()[:-1])
103+
else:
104+
left_right = bbox_reg_targets[:, :, [0, 2]]
105+
top_bottom = bbox_reg_targets[:, :, [1, 3]]
106+
gt_ctrness_targets = torch.sqrt(
107+
(left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0])
108+
* (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
109+
)
105110
pred_centerness = bbox_ctrness.squeeze(dim=2)
106111
loss_bbox_ctrness = nn.functional.binary_cross_entropy_with_logits(
107112
pred_centerness[foregroud_mask], gt_ctrness_targets[foregroud_mask], reduction="sum"

0 commit comments

Comments
 (0)