Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import shutil

import torch
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME


def read(*names, **kwargs):
Expand Down Expand Up @@ -121,6 +121,10 @@ def get_extensions():
include_dirs = [extensions_dir]
tests_include_dirs = [test_dir, models_dir]

# TorchVision video reader
video_reader_src_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'cpu', 'video_reader')
video_reader_src = glob.glob(os.path.join(video_reader_src_dir, "*.cpp"))

ext_modules = [
extension(
'torchvision._C',
Expand All @@ -135,7 +139,25 @@ def get_extensions():
include_dirs=tests_include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
),
CppExtension(
'torchvision.video_reader',
video_reader_src,
include_dirs=[
video_reader_src_dir,
'/home/zyan3/local/anaconda3/envs/pytorch_py3/include',
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fmassa , I will remove this line.

For header files from ffmpeg, we need to ensure they are installed at default header file search path.

],
libraries=[
'glog',
'avcodec',
'avformat',
'avutil',
'swresample',
'swscale',
],
extra_compile_args=["-std=c++14"],
extra_link_args=["-std=c++14"],
),
]

return ext_modules
Expand Down Expand Up @@ -176,5 +198,8 @@ def run(self):
"scipy": ["scipy"],
},
ext_modules=get_extensions(),
cmdclass={'build_ext': torch.utils.cpp_extension.BuildExtension, 'clean': clean}
cmdclass={
'build_ext': BuildExtension.with_options(no_python_abi_suffix=True),
'clean': clean,
}
)
Binary file added test/assets/videos/R6llTwEh07w.mp4
Binary file not shown.
Binary file not shown.
59 changes: 59 additions & 0 deletions test/assets/videos/README
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
Video meta-information Notation

Video File Name
video: codec, fps
audio: codec, bits per sample, sample rate

Test videos are listed below.
--------------------------------

- RATRACE_wave_f_nm_np1_fr_goo_37.avi
- source: hmdb51
- video: DivX MPEG-4
- fps: 30
- audio: N/A

- SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi
- source: hmdb51
- video: DivX MPEG-4
- fps: 30
- audio: N/A

- TrumanShow_wave_f_nm_np1_fr_med_26.avi
- source: hmdb51
- video: DivX MPEG-4
- fps: 30
- audio: N/A

- v_SoccerJuggling_g23_c01.avi
- source: ucf101
- video: Xvid MPEG-4
- fps: 29.97
- audio: N/A

- v_SoccerJuggling_g24_c01.avi
- source: ucf101
- video: Xvid MPEG-4
- fps: 29.97
- audio: N/A

- R6llTwEh07w.mp4
- source: kinetics-400
- video: H-264 - MPEG-4 AVC (part 10) (avc1)
- fps: 30
- audio: MPEG AAC audio (mp4a)
- sample rate: 44.1K Hz

- SOX5yA1l24A.mp4
- source: kinetics-400
- video: H-264 - MPEG-4 AVC (part 10) (avc1)
- fps: 29.97
- audio: MPEG AAC audio (mp4a)
- sample rate: 48K Hz

- WUzgd7C1pWA.mp4
- source: kinetics-400
- video: H-264 - MPEG-4 AVC (part 10) (avc1)
- fps: 29.97
- audio: MPEG AAC audio (mp4a)
- sample rate: 48K Hz
Binary file added test/assets/videos/SOX5yA1l24A.mp4
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added test/assets/videos/WUzgd7C1pWA.mp4
Binary file not shown.
Binary file added test/assets/videos/v_SoccerJuggling_g23_c01.avi
Binary file not shown.
Binary file added test/assets/videos/v_SoccerJuggling_g24_c01.avi
Binary file not shown.
46 changes: 28 additions & 18 deletions test/test_io.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import contextlib
import logging
import tempfile
import torch
import torchvision.datasets.utils as utils
Expand All @@ -23,6 +24,9 @@
av = None


log = logging.getLogger(__name__)


def _create_video_frames(num_frames, height, width):
y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width))
data = []
Expand All @@ -44,7 +48,9 @@ def temp_video(num_frames, height, width, fps, lossless=False, video_codec=None,
options = {'crf': '0'}

if video_codec is None:
video_codec = 'libx264'
# when video_codec is not set, we assume it is libx264rgb which accepts
# RGB pixel formats as input instead of YUV
video_codec = 'libx264rgb'
if options is None:
options = {}

Expand All @@ -62,14 +68,15 @@ class Tester(unittest.TestCase):

def test_write_read_video(self):
with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
lv, _, info = io.read_video(f_name)
lv, _, info = io.read_video_from_file(f_name)

self.assertTrue(data.equal(lv))
self.assertEqual(info["video_fps"], 5)

def test_read_timestamps(self):
with temp_video(10, 300, 300, 5) as (f_name, data):
pts, _ = io.read_video_timestamps(f_name)
# pts, _ = io.read_video_timestamps(f_name)
video_pts, _, _ = io.read_video_timestamps_from_file(f_name)

# note: not all formats/codecs provide accurate information for computing the
# timestamps. For the format that we use here, this information is available,
Expand All @@ -80,46 +87,48 @@ def test_read_timestamps(self):
num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
expected_pts = [i * pts_step for i in range(num_frames)]

self.assertEqual(pts, expected_pts)
self.assertEqual(video_pts, expected_pts)

def test_read_partial_video(self):
with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
pts, _ = io.read_video_timestamps(f_name)
# pts, _ = io.read_video_timestamps(f_name)
video_pts, _, _ = io.read_video_timestamps_from_file(f_name)
for start in range(5):
for l in range(1, 4):
lv, _, _ = io.read_video(f_name, pts[start], pts[start + l - 1])
# lv, _, _ = io.read_video(f_name, pts[start], pts[start + l - 1])
lv, _, _ = io.read_video_from_file(
f_name,
video_pts_range=(video_pts[start], video_pts[start + l - 1]),
)
s_data = data[start:(start + l)]
self.assertEqual(len(lv), l)
self.assertTrue(s_data.equal(lv))

lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7])
self.assertEqual(len(lv), 4)
self.assertTrue(data[4:8].equal(lv))

def test_read_partial_video_bframes(self):
# do not use lossless encoding, to test the presence of B-frames
options = {'bframes': '16', 'keyint': '10', 'min-keyint': '4'}
with temp_video(100, 300, 300, 5, options=options) as (f_name, data):
pts, _ = io.read_video_timestamps(f_name)
video_pts, _, _ = io.read_video_timestamps_from_file(f_name)
for start in range(0, 80, 20):
for l in range(1, 4):
lv, _, _ = io.read_video(f_name, pts[start], pts[start + l - 1])
lv, _, _ = io.read_video_from_file(
f_name,
video_pts_range=(video_pts[start], video_pts[start + l - 1]),
)
s_data = data[start:(start + l)]
self.assertEqual(len(lv), l)
self.assertTrue((s_data.float() - lv.float()).abs().max() < self.TOLERANCE)

lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7])
self.assertEqual(len(lv), 4)
self.assertTrue((data[4:8].float() - lv.float()).abs().max() < self.TOLERANCE)

def test_read_packed_b_frames_divx_file(self):
with get_tmp_dir() as temp_dir:
name = "hmdb51_Turnk_r_Pippi_Michel_cartwheel_f_cm_np2_le_med_6.avi"
f_name = os.path.join(temp_dir, name)
url = "https://download.pytorch.org/vision_tests/io/" + name
try:
utils.download_url(url, temp_dir)
pts, fps = io.read_video_timestamps(f_name)
# pts, fps = io.read_video_timestamps(f_name)
pts, _, info = io.read_video_timestamps_from_file(f_name)
fps = info["video_fps"]
self.assertEqual(pts, sorted(pts))
self.assertEqual(fps, 30)
except URLError:
Expand All @@ -129,7 +138,8 @@ def test_read_packed_b_frames_divx_file(self):

def test_read_timestamps_from_packet(self):
with temp_video(10, 300, 300, 5, video_codec='mpeg4') as (f_name, data):
pts, _ = io.read_video_timestamps(f_name)
# pts, _ = io.read_video_timestamps(f_name)
pts, _, _ = io.read_video_timestamps_from_file(f_name)

# note: not all formats/codecs provide accurate information for computing the
# timestamps. For the format that we use here, this information is available,
Expand Down
4 changes: 2 additions & 2 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def get_available_video_models():
"fcn_resnet101": False,
"googlenet": False,
"densenet121": False,
"resnet18": False,
"resnet18": True,
"alexnet": True,
"shufflenet_v2_x1_0": False,
"shufflenet_v2_x1_0": True,
"squeezenet1_0": True,
"vgg11": True,
"inception_v3": False,
Expand Down
Loading