Skip to content

Commit 5d1372c

Browse files
authored
Add VideoClips and Kinetics dataset (#1077)
* Add VideoClips and Kinetics dataset * Lint + add back missing line * Adds ClipSampler following Bruno comment * Change name following Bruno's suggestion * Enable specifying a target framerate * Fix test_io for new interface * Add comment mentioning drop_last behavior * Make compute_clips more robust * Flake8 * Fix for Python2
1 parent 2b81ad8 commit 5d1372c

File tree

6 files changed

+385
-4
lines changed

6 files changed

+385
-4
lines changed

test/test_datasets_video_utils.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
import contextlib
2+
import os
3+
import torch
4+
import unittest
5+
6+
from torchvision import io
7+
from torchvision.datasets.video_utils import VideoClips, unfold, RandomClipSampler
8+
9+
from common_utils import get_tmp_dir
10+
11+
12+
@contextlib.contextmanager
13+
def get_list_of_videos(num_videos=5, sizes=None, fps=None):
14+
with get_tmp_dir() as tmp_dir:
15+
names = []
16+
for i in range(num_videos):
17+
if sizes is None:
18+
size = 5 * (i + 1)
19+
else:
20+
size = sizes[i]
21+
if fps is None:
22+
f = 5
23+
else:
24+
f = fps[i]
25+
data = torch.randint(0, 255, (size, 300, 400, 3), dtype=torch.uint8)
26+
name = os.path.join(tmp_dir, "{}.mp4".format(i))
27+
names.append(name)
28+
io.write_video(name, data, fps=f)
29+
30+
yield names
31+
32+
33+
class Tester(unittest.TestCase):
34+
35+
def test_unfold(self):
36+
a = torch.arange(7)
37+
38+
r = unfold(a, 3, 3, 1)
39+
expected = torch.tensor([
40+
[0, 1, 2],
41+
[3, 4, 5],
42+
])
43+
self.assertTrue(r.equal(expected))
44+
45+
r = unfold(a, 3, 2, 1)
46+
expected = torch.tensor([
47+
[0, 1, 2],
48+
[2, 3, 4],
49+
[4, 5, 6]
50+
])
51+
self.assertTrue(r.equal(expected))
52+
53+
r = unfold(a, 3, 2, 2)
54+
expected = torch.tensor([
55+
[0, 2, 4],
56+
[2, 4, 6],
57+
])
58+
self.assertTrue(r.equal(expected))
59+
60+
def test_video_clips(self):
61+
with get_list_of_videos(num_videos=3) as video_list:
62+
video_clips = VideoClips(video_list, 5, 5)
63+
self.assertEqual(video_clips.num_clips(), 1 + 2 + 3)
64+
for i, (v_idx, c_idx) in enumerate([(0, 0), (1, 0), (1, 1), (2, 0), (2, 1), (2, 2)]):
65+
video_idx, clip_idx = video_clips.get_clip_location(i)
66+
self.assertEqual(video_idx, v_idx)
67+
self.assertEqual(clip_idx, c_idx)
68+
69+
video_clips = VideoClips(video_list, 6, 6)
70+
self.assertEqual(video_clips.num_clips(), 0 + 1 + 2)
71+
for i, (v_idx, c_idx) in enumerate([(1, 0), (2, 0), (2, 1)]):
72+
video_idx, clip_idx = video_clips.get_clip_location(i)
73+
self.assertEqual(video_idx, v_idx)
74+
self.assertEqual(clip_idx, c_idx)
75+
76+
video_clips = VideoClips(video_list, 6, 1)
77+
self.assertEqual(video_clips.num_clips(), 0 + (10 - 6 + 1) + (15 - 6 + 1))
78+
for i, v_idx, c_idx in [(0, 1, 0), (4, 1, 4), (5, 2, 0), (6, 2, 1)]:
79+
video_idx, clip_idx = video_clips.get_clip_location(i)
80+
self.assertEqual(video_idx, v_idx)
81+
self.assertEqual(clip_idx, c_idx)
82+
83+
def test_video_sampler(self):
84+
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
85+
video_clips = VideoClips(video_list, 5, 5)
86+
sampler = RandomClipSampler(video_clips, 3)
87+
self.assertEqual(len(sampler), 3 * 3)
88+
indices = torch.tensor(list(iter(sampler)))
89+
videos = indices // 5
90+
v_idxs, count = torch.unique(videos, return_counts=True)
91+
self.assertTrue(v_idxs.equal(torch.tensor([0, 1, 2])))
92+
self.assertTrue(count.equal(torch.tensor([3, 3, 3])))
93+
94+
def test_video_sampler_unequal(self):
95+
with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list:
96+
video_clips = VideoClips(video_list, 5, 5)
97+
sampler = RandomClipSampler(video_clips, 3)
98+
self.assertEqual(len(sampler), 2 + 3 + 3)
99+
indices = list(iter(sampler))
100+
self.assertIn(0, indices)
101+
self.assertIn(1, indices)
102+
# remove elements of the first video, to simplify testing
103+
indices.remove(0)
104+
indices.remove(1)
105+
indices = torch.tensor(indices) - 2
106+
videos = indices // 5
107+
v_idxs, count = torch.unique(videos, return_counts=True)
108+
self.assertTrue(v_idxs.equal(torch.tensor([0, 1])))
109+
self.assertTrue(count.equal(torch.tensor([3, 3])))
110+
111+
def test_video_clips_custom_fps(self):
112+
with get_list_of_videos(num_videos=3, sizes=[12, 12, 12], fps=[3, 4, 6]) as video_list:
113+
num_frames = 4
114+
for fps in [1, 3, 4, 10]:
115+
video_clips = VideoClips(video_list, num_frames, num_frames, fps)
116+
for i in range(video_clips.num_clips()):
117+
video, audio, info, video_idx = video_clips.get_clip(i)
118+
self.assertEqual(video.shape[0], num_frames)
119+
self.assertEqual(info["video_fps"], fps)
120+
# TODO add tests checking that the content is right
121+
122+
def test_compute_clips_for_video(self):
123+
video_pts = torch.arange(30)
124+
# case 1: single clip
125+
num_frames = 13
126+
orig_fps = 30
127+
duration = float(len(video_pts)) / orig_fps
128+
new_fps = 13
129+
clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames,
130+
orig_fps, new_fps)
131+
resampled_idxs = VideoClips._resample_video_idx(int(duration * new_fps), orig_fps, new_fps)
132+
self.assertEqual(len(clips), 1)
133+
self.assertTrue(clips.equal(idxs))
134+
self.assertTrue(idxs[0].equal(resampled_idxs))
135+
136+
# case 2: all frames appear only once
137+
num_frames = 4
138+
orig_fps = 30
139+
duration = float(len(video_pts)) / orig_fps
140+
new_fps = 12
141+
clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames,
142+
orig_fps, new_fps)
143+
resampled_idxs = VideoClips._resample_video_idx(int(duration * new_fps), orig_fps, new_fps)
144+
self.assertEqual(len(clips), 3)
145+
self.assertTrue(clips.equal(idxs))
146+
self.assertTrue(idxs.flatten().equal(resampled_idxs))
147+
148+
149+
if __name__ == '__main__':
150+
unittest.main()

test/test_io.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def test_read_timestamps(self):
4444
data = self._create_video_frames(10, 300, 300)
4545
io.write_video(f.name, data, fps=5)
4646

47-
pts = io.read_video_timestamps(f.name)
47+
pts, _ = io.read_video_timestamps(f.name)
4848

4949
# note: not all formats/codecs provide accurate information for computing the
5050
# timestamps. For the format that we use here, this information is available,
@@ -63,7 +63,7 @@ def test_read_partial_video(self):
6363
data = self._create_video_frames(10, 300, 300)
6464
io.write_video(f.name, data, fps=5)
6565

66-
pts = io.read_video_timestamps(f.name)
66+
pts, _ = io.read_video_timestamps(f.name)
6767

6868
for start in range(5):
6969
for l in range(1, 4):

torchvision/datasets/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .sbd import SBDataset
2020
from .vision import VisionDataset
2121
from .usps import USPS
22+
from .kinetics import KineticsVideo
2223

2324
__all__ = ('LSUN', 'LSUNClass',
2425
'ImageFolder', 'DatasetFolder', 'FakeData',
@@ -28,4 +29,4 @@
2829
'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k',
2930
'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet',
3031
'Caltech101', 'Caltech256', 'CelebA', 'SBDataset', 'VisionDataset',
31-
'USPS')
32+
'USPS', 'KineticsVideo')

torchvision/datasets/kinetics.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from .video_utils import VideoClips
2+
from .utils import list_dir
3+
from .folder import make_dataset
4+
from .vision import VisionDataset
5+
6+
7+
class KineticsVideo(VisionDataset):
8+
def __init__(self, root, frames_per_clip, step_between_clips=1):
9+
super(KineticsVideo, self).__init__(root)
10+
extensions = ('avi',)
11+
12+
classes = list(sorted(list_dir(root)))
13+
class_to_idx = {classes[i]: i for i in range(len(classes))}
14+
self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None)
15+
self.classes = classes
16+
video_list = [x[0] for x in self.samples]
17+
self.video_clips = VideoClips(video_list, frames_per_clip, step_between_clips)
18+
19+
def __len__(self):
20+
return self.video_clips.num_clips()
21+
22+
def __getitem__(self, idx):
23+
video, audio, info, video_idx = self.video_clips.get_clip(idx)
24+
label = self.samples[video_idx][1]
25+
26+
return video, audio, label

0 commit comments

Comments
 (0)