Skip to content

Commit cd91360

Browse files
committed
Handle corrupted video headers in io
1 parent ed5b2dc commit cd91360

File tree

2 files changed

+64
-27
lines changed

2 files changed

+64
-27
lines changed

test/test_io.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,23 @@ def test_read_partial_video_pts_unit_sec(self):
236236
self.assertEqual(len(lv), 4)
237237
self.assertTrue(data[4:8].equal(lv))
238238

239+
def test_read_video_corrupted_file(self):
240+
with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
241+
f.write(b'This is not an mpg4 file')
242+
video, audio, info = io.read_video(f.name)
243+
self.assertIsInstance(video, torch.Tensor)
244+
self.assertIsInstance(audio, torch.Tensor)
245+
self.assertEqual(video.numel(), 0)
246+
self.assertEqual(audio.numel(), 0)
247+
self.assertEqual(info, {})
248+
249+
def test_read_video_timestamps_corrupted_file(self):
250+
with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
251+
f.write(b'This is not an mpg4 file')
252+
video_pts, video_fps = io.read_video_timestamps(f.name)
253+
self.assertEqual(video_pts, [])
254+
self.assertIs(video_fps, None)
255+
239256
# TODO add tests for audio
240257

241258

torchvision/io/video.py

Lines changed: 47 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -193,25 +193,36 @@ def read_video(filename, start_pts=0, end_pts=None, pts_unit='pts'):
193193
raise ValueError("end_pts should be larger than start_pts, got "
194194
"start_pts={} and end_pts={}".format(start_pts, end_pts))
195195

196-
container = av.open(filename, metadata_errors='ignore')
197196
info = {}
198-
199197
video_frames = []
200-
if container.streams.video:
201-
video_frames = _read_from_stream(container, start_pts, end_pts, pts_unit,
202-
container.streams.video[0], {'video': 0})
203-
info["video_fps"] = float(container.streams.video[0].average_rate)
204198
audio_frames = []
205-
if container.streams.audio:
206-
audio_frames = _read_from_stream(container, start_pts, end_pts, pts_unit,
207-
container.streams.audio[0], {'audio': 0})
208-
info["audio_fps"] = container.streams.audio[0].rate
209199

210-
container.close()
200+
try:
201+
container = av.open(filename, metadata_errors='ignore')
202+
except av.AVError:
203+
# TODO raise a warning?
204+
pass
205+
else:
206+
if container.streams.video:
207+
video_frames = _read_from_stream(container, start_pts, end_pts, pts_unit,
208+
container.streams.video[0], {'video': 0})
209+
info["video_fps"] = float(container.streams.video[0].average_rate)
210+
211+
if container.streams.audio:
212+
audio_frames = _read_from_stream(container, start_pts, end_pts, pts_unit,
213+
container.streams.audio[0], {'audio': 0})
214+
info["audio_fps"] = container.streams.audio[0].rate
215+
216+
container.close()
211217

212218
vframes = [frame.to_rgb().to_ndarray() for frame in video_frames]
213219
aframes = [frame.to_ndarray() for frame in audio_frames]
214-
vframes = torch.as_tensor(np.stack(vframes))
220+
221+
if vframes:
222+
vframes = torch.as_tensor(np.stack(vframes))
223+
else:
224+
vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)
225+
215226
if aframes:
216227
aframes = np.concatenate(aframes, 1)
217228
aframes = torch.as_tensor(aframes)
@@ -255,21 +266,30 @@ def read_video_timestamps(filename, pts_unit='pts'):
255266
"""
256267
_check_av_available()
257268

258-
container = av.open(filename, metadata_errors='ignore')
259-
260269
video_frames = []
261270
video_fps = None
262-
if container.streams.video:
263-
video_stream = container.streams.video[0]
264-
video_time_base = video_stream.time_base
265-
if _can_read_timestamps_from_packets(container):
266-
# fast path
267-
video_frames = [x for x in container.demux(video=0) if x.pts is not None]
268-
else:
269-
video_frames = _read_from_stream(container, 0, float("inf"), pts_unit,
270-
video_stream, {'video': 0})
271-
video_fps = float(video_stream.average_rate)
272-
container.close()
271+
272+
try:
273+
container = av.open(filename, metadata_errors='ignore')
274+
except av.AVError:
275+
# TODO add a warning
276+
pass
277+
else:
278+
if container.streams.video:
279+
video_stream = container.streams.video[0]
280+
video_time_base = video_stream.time_base
281+
if _can_read_timestamps_from_packets(container):
282+
# fast path
283+
video_frames = [x for x in container.demux(video=0) if x.pts is not None]
284+
else:
285+
video_frames = _read_from_stream(container, 0, float("inf"), pts_unit,
286+
video_stream, {'video': 0})
287+
video_fps = float(video_stream.average_rate)
288+
container.close()
289+
290+
pts = [x.pts for x in video_frames]
291+
273292
if pts_unit == 'sec':
274-
return [x.pts * video_time_base for x in video_frames], video_fps
275-
return [x.pts for x in video_frames], video_fps
293+
pts = [x * video_time_base for x in pts]
294+
295+
return pts, video_fps

0 commit comments

Comments
 (0)