Skip to content

Commit 976f299

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Port 'examples/python/video_api.ipynb' to gallery (#4241)
Summary: * Port 'examples/python/video_api.ipynb' to gallery * Address rst formattion suggestions Reviewed By: vmoens Differential Revision: D30096845 fbshipit-source-id: 1c3d543e476df4d70b29e4ad58a699721456ae90
1 parent 82930ad commit 976f299

File tree

1 file changed

+341
-0
lines changed

1 file changed

+341
-0
lines changed

gallery/plot_video_api.py

Lines changed: 341 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,341 @@
1+
"""
2+
=======================
3+
Video API
4+
=======================
5+
6+
This example illustrates some of the APIs that torchvision offers for
7+
videos, together with the examples on how to build datasets and more.
8+
"""
9+
10+
####################################
11+
# 1. Introduction: building a new video object and examining the properties
12+
# -------------------------------------------------------------------------
13+
# First we select a video to test the object out. For the sake of argument
14+
# we're using one from kinetics400 dataset.
15+
# To create it, we need to define the path and the stream we want to use.
16+
17+
######################################
18+
# Chosen video statistics:
19+
#
20+
# - WUzgd7C1pWA.mp4
21+
# - source:
22+
# - kinetics-400
23+
# - video:
24+
# - H-264
25+
# - MPEG-4 AVC (part 10) (avc1)
26+
# - fps: 29.97
27+
# - audio:
28+
# - MPEG AAC audio (mp4a)
29+
# - sample rate: 48K Hz
30+
#
31+
32+
import torch
33+
import torchvision
34+
from torchvision.datasets.utils import download_url
35+
36+
# Download the sample video
37+
download_url(
38+
"https://github.com/pytorch/vision/blob/master/test/assets/videos/WUzgd7C1pWA.mp4?raw=true",
39+
".",
40+
"WUzgd7C1pWA.mp4"
41+
)
42+
video_path = "./WUzgd7C1pWA.mp4"
43+
44+
######################################
45+
# Streams are defined in a similar fashion as torch devices. We encode them as strings in a form
46+
# of ``stream_type:stream_id`` where ``stream_type`` is a string and ``stream_id`` a long int.
47+
# The constructor accepts passing a ``stream_type`` only, in which case the stream is auto-discovered.
48+
# Firstly, let's get the metadata for our particular video:
49+
50+
stream = "video"
51+
video = torchvision.io.VideoReader(video_path, stream)
52+
video.get_metadata()
53+
54+
######################################
55+
# Here we can see that video has two streams - a video and an audio stream.
56+
# Currently available stream types include ['video', 'audio'].
57+
# Each descriptor consists of two parts: stream type (e.g. 'video') and a unique stream id
58+
# (which are determined by video encoding).
59+
# In this way, if the video container contains multiple streams of the same type,
60+
# users can access the one they want.
61+
# If only stream type is passed, the decoder auto-detects first stream of that type and returns it.
62+
63+
######################################
64+
# Let's read all the frames from the video stream. By default, the return value of
65+
# ``next(video_reader)`` is a dict containing the following fields.
66+
#
67+
# The return fields are:
68+
#
69+
# - ``data``: containing a torch.tensor
70+
# - ``pts``: containing a float timestamp of this particular frame
71+
72+
metadata = video.get_metadata()
73+
video.set_current_stream("audio")
74+
75+
frames = [] # we are going to save the frames here.
76+
ptss = [] # pts is a presentation timestamp in seconds (float) of each frame
77+
for frame in video:
78+
frames.append(frame['data'])
79+
ptss.append(frame['pts'])
80+
81+
print("PTS for first five frames ", ptss[:5])
82+
print("Total number of frames: ", len(frames))
83+
approx_nf = metadata['audio']['duration'][0] * metadata['audio']['framerate'][0]
84+
print("Approx total number of datapoints we can expect: ", approx_nf)
85+
print("Read data size: ", frames[0].size(0) * len(frames))
86+
87+
######################################
88+
# But what if we only want to read certain time segment of the video?
89+
# That can be done easily using the combination of our ``seek`` function, and the fact that each call
90+
# to next returns the presentation timestamp of the returned frame in seconds.
91+
#
92+
# Given that our implementation relies on python iterators,
93+
# we can leverage itertools to simplify the process and make it more pythonic.
94+
#
95+
# For example, if we wanted to read ten frames from second second:
96+
97+
98+
import itertools
99+
video.set_current_stream("video")
100+
101+
frames = [] # we are going to save the frames here.
102+
103+
# We seek into a second second of the video and use islice to get 10 frames since
104+
for frame, pts in itertools.islice(video.seek(2), 10):
105+
frames.append(frame)
106+
107+
print("Total number of frames: ", len(frames))
108+
109+
######################################
110+
# Or if we wanted to read from 2nd to 5th second,
111+
# We seek into a second second of the video,
112+
# then we utilize the itertools takewhile to get the
113+
# correct number of frames:
114+
115+
video.set_current_stream("video")
116+
frames = [] # we are going to save the frames here.
117+
video = video.seek(2)
118+
119+
for frame in itertools.takewhile(lambda x: x['pts'] <= 5, video):
120+
frames.append(frame['data'])
121+
122+
print("Total number of frames: ", len(frames))
123+
approx_nf = (5 - 2) * video.get_metadata()['video']['fps'][0]
124+
print("We can expect approx: ", approx_nf)
125+
print("Tensor size: ", frames[0].size())
126+
127+
####################################
128+
# 2. Building a sample read_video function
129+
# ----------------------------------------------------------------------------------------
130+
# We can utilize the methods above to build the read video function that follows
131+
# the same API to the existing ``read_video`` function.
132+
133+
134+
def example_read_video(video_object, start=0, end=None, read_video=True, read_audio=True):
135+
if end is None:
136+
end = float("inf")
137+
if end < start:
138+
raise ValueError(
139+
"end time should be larger than start time, got "
140+
"start time={} and end time={}".format(start, end)
141+
)
142+
143+
video_frames = torch.empty(0)
144+
video_pts = []
145+
if read_video:
146+
video_object.set_current_stream("video")
147+
frames = []
148+
for frame in itertools.takewhile(lambda x: x['pts'] <= end, video_object.seek(start)):
149+
frames.append(frame['data'])
150+
video_pts.append(frame['pts'])
151+
if len(frames) > 0:
152+
video_frames = torch.stack(frames, 0)
153+
154+
audio_frames = torch.empty(0)
155+
audio_pts = []
156+
if read_audio:
157+
video_object.set_current_stream("audio")
158+
frames = []
159+
for frame in itertools.takewhile(lambda x: x['pts'] <= end, video_object.seek(start)):
160+
frames.append(frame['data'])
161+
video_pts.append(frame['pts'])
162+
if len(frames) > 0:
163+
audio_frames = torch.cat(frames, 0)
164+
165+
return video_frames, audio_frames, (video_pts, audio_pts), video_object.get_metadata()
166+
167+
168+
# Total number of frames should be 327 for video and 523264 datapoints for audio
169+
vf, af, info, meta = example_read_video(video)
170+
print(vf.size(), af.size())
171+
172+
####################################
173+
# 3. Building an example randomly sampled dataset (can be applied to training dataest of kinetics400)
174+
# -------------------------------------------------------------------------------------------------------
175+
# Cool, so now we can use the same principle to make the sample dataset.
176+
# We suggest trying out iterable dataset for this purpose.
177+
# Here, we are going to build an example dataset that reads randomly selected 10 frames of video.
178+
179+
####################################
180+
# Make sample dataset
181+
import os
182+
os.makedirs("./dataset", exist_ok=True)
183+
os.makedirs("./dataset/1", exist_ok=True)
184+
os.makedirs("./dataset/2", exist_ok=True)
185+
186+
####################################
187+
# Download the videos
188+
from torchvision.datasets.utils import download_url
189+
download_url(
190+
"https://github.com/pytorch/vision/blob/master/test/assets/videos/WUzgd7C1pWA.mp4?raw=true",
191+
"./dataset/1", "WUzgd7C1pWA.mp4"
192+
)
193+
download_url(
194+
"https://github.com/pytorch/vision/blob/master/test/assets/videos/RATRACE_wave_f_nm_np1_fr_goo_37.avi?raw=true",
195+
"./dataset/1",
196+
"RATRACE_wave_f_nm_np1_fr_goo_37.avi"
197+
)
198+
download_url(
199+
"https://github.com/pytorch/vision/blob/master/test/assets/videos/SOX5yA1l24A.mp4?raw=true",
200+
"./dataset/2",
201+
"SOX5yA1l24A.mp4"
202+
)
203+
download_url(
204+
"https://github.com/pytorch/vision/blob/master/test/assets/videos/v_SoccerJuggling_g23_c01.avi?raw=true",
205+
"./dataset/2",
206+
"v_SoccerJuggling_g23_c01.avi"
207+
)
208+
download_url(
209+
"https://github.com/pytorch/vision/blob/master/test/assets/videos/v_SoccerJuggling_g24_c01.avi?raw=true",
210+
"./dataset/2",
211+
"v_SoccerJuggling_g24_c01.avi"
212+
)
213+
214+
####################################
215+
# Housekeeping and utilities
216+
import os
217+
import random
218+
219+
from torchvision.datasets.folder import make_dataset
220+
from torchvision import transforms as t
221+
222+
223+
def _find_classes(dir):
224+
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
225+
classes.sort()
226+
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
227+
return classes, class_to_idx
228+
229+
230+
def get_samples(root, extensions=(".mp4", ".avi")):
231+
_, class_to_idx = _find_classes(root)
232+
return make_dataset(root, class_to_idx, extensions=extensions)
233+
234+
####################################
235+
# We are going to define the dataset and some basic arguments.
236+
# We assume the structure of the FolderDataset, and add the following parameters:
237+
#
238+
# - ``clip_len``: length of a clip in frames
239+
# - ``frame_transform``: transform for every frame individually
240+
# - ``video_transform``: transform on a video sequence
241+
#
242+
# .. note::
243+
# We actually add epoch size as using :func:`~torch.utils.data.IterableDataset`
244+
# class allows us to naturally oversample clips or images from each video if needed.
245+
246+
247+
class RandomDataset(torch.utils.data.IterableDataset):
248+
def __init__(self, root, epoch_size=None, frame_transform=None, video_transform=None, clip_len=16):
249+
super(RandomDataset).__init__()
250+
251+
self.samples = get_samples(root)
252+
253+
# Allow for temporal jittering
254+
if epoch_size is None:
255+
epoch_size = len(self.samples)
256+
self.epoch_size = epoch_size
257+
258+
self.clip_len = clip_len
259+
self.frame_transform = frame_transform
260+
self.video_transform = video_transform
261+
262+
def __iter__(self):
263+
for i in range(self.epoch_size):
264+
# Get random sample
265+
path, target = random.choice(self.samples)
266+
# Get video object
267+
vid = torchvision.io.VideoReader(path, "video")
268+
metadata = vid.get_metadata()
269+
video_frames = [] # video frame buffer
270+
271+
# Seek and return frames
272+
max_seek = metadata["video"]['duration'][0] - (self.clip_len / metadata["video"]['fps'][0])
273+
start = random.uniform(0., max_seek)
274+
for frame in itertools.islice(vid.seek(start), self.clip_len):
275+
video_frames.append(self.frame_transform(frame['data']))
276+
current_pts = frame['pts']
277+
# Stack it into a tensor
278+
video = torch.stack(video_frames, 0)
279+
if self.video_transform:
280+
video = self.video_transform(video)
281+
output = {
282+
'path': path,
283+
'video': video,
284+
'target': target,
285+
'start': start,
286+
'end': current_pts}
287+
yield output
288+
289+
####################################
290+
# Given a path of videos in a folder structure, i.e:
291+
#
292+
# - dataset
293+
# - class 1
294+
# - file 0
295+
# - file 1
296+
# - ...
297+
# - class 2
298+
# - file 0
299+
# - file 1
300+
# - ...
301+
# - ...
302+
#
303+
# We can generate a dataloader and test the dataset.
304+
305+
306+
transforms = [t.Resize((112, 112))]
307+
frame_transform = t.Compose(transforms)
308+
309+
dataset = RandomDataset("./dataset", epoch_size=None, frame_transform=frame_transform)
310+
311+
####################################
312+
from torch.utils.data import DataLoader
313+
loader = DataLoader(dataset, batch_size=12)
314+
data = {"video": [], 'start': [], 'end': [], 'tensorsize': []}
315+
for batch in loader:
316+
for i in range(len(batch['path'])):
317+
data['video'].append(batch['path'][i])
318+
data['start'].append(batch['start'][i].item())
319+
data['end'].append(batch['end'][i].item())
320+
data['tensorsize'].append(batch['video'][i].size())
321+
print(data)
322+
323+
####################################
324+
# 4. Data Visualization
325+
# ----------------------------------
326+
# Example of visualized video
327+
328+
import matplotlib.pylab as plt
329+
330+
plt.figure(figsize=(12, 12))
331+
for i in range(16):
332+
plt.subplot(4, 4, i + 1)
333+
plt.imshow(batch["video"][0, i, ...].permute(1, 2, 0))
334+
plt.axis("off")
335+
336+
####################################
337+
# Cleanup the video and dataset:
338+
import os
339+
import shutil
340+
os.remove("./WUzgd7C1pWA.mp4")
341+
shutil.rmtree("./dataset")

0 commit comments

Comments
 (0)