Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions test/test_video_gpu_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@

import pytest
import torch
import torchvision
from torchvision import _HAS_GPU_VIDEO_DECODER
from torchvision.io import VideoReader
from torchvision.io import _HAS_GPU_VIDEO_DECODER, VideoReader

try:
import av
Expand All @@ -31,9 +29,8 @@ class TestVideoGPUDecoder:
],
)
def test_frame_reading(self, video_file):
torchvision.set_video_backend("cuda")
full_path = os.path.join(VIDEO_DIR, video_file)
decoder = VideoReader(full_path)
decoder = VideoReader(full_path, device="cuda")
with av.open(full_path) as container:
for av_frame in container.decode(container.streams.video[0]):
av_frames = torch.tensor(av_frame.to_rgb(src_colorspace="ITU709").to_ndarray())
Expand All @@ -57,8 +54,7 @@ def test_frame_reading(self, video_file):
],
)
def test_seek_reading(self, keyframes, full_path, duration):
torchvision.set_video_backend("cuda")
decoder = VideoReader(full_path)
decoder = VideoReader(full_path, device="cuda")
time = duration / 2
decoder.seek(time, keyframes_only=keyframes)
with av.open(full_path) as container:
Expand All @@ -83,9 +79,8 @@ def test_seek_reading(self, keyframes, full_path, duration):
],
)
def test_metadata(self, video_file):
torchvision.set_video_backend("cuda")
full_path = os.path.join(VIDEO_DIR, video_file)
decoder = VideoReader(full_path)
decoder = VideoReader(full_path, device="cuda")
video_metadata = decoder.get_metadata()["video"]
with av.open(full_path) as container:
video = container.streams.video[0]
Expand Down
98 changes: 40 additions & 58 deletions test/test_videoapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ def fate(name, path="."):
class TestVideoApi:
@pytest.mark.skipif(av is None, reason="PyAV unavailable")
@pytest.mark.parametrize("test_video", test_videos.keys())
@pytest.mark.parametrize("backend", ["video_reader", "pyav"])
def test_frame_reading(self, test_video, backend):
torchvision.set_video_backend(backend)
def test_frame_reading(self, test_video):
full_path = os.path.join(VIDEO_DIR, test_video)
with av.open(full_path) as av_reader:
if av_reader.streams.video:
Expand Down Expand Up @@ -119,70 +117,58 @@ def test_frame_reading(self, test_video, backend):

@pytest.mark.parametrize("stream", ["video", "audio"])
@pytest.mark.parametrize("test_video", test_videos.keys())
@pytest.mark.parametrize("backend", ["video_reader", "pyav"])
def test_frame_reading_mem_vs_file(self, test_video, stream, backend):
torchvision.set_video_backend(backend)
def test_frame_reading_mem_vs_file(self, test_video, stream):
full_path = os.path.join(VIDEO_DIR, test_video)

reader = VideoReader(full_path)
reader_md = reader.get_metadata()

if stream in reader_md:
# Test video reading from file vs from memory
vr_frames, vr_frames_mem = [], []
vr_pts, vr_pts_mem = [], []
# get vr frames
video_reader = VideoReader(full_path, stream)
for vr_frame in video_reader:
vr_frames.append(vr_frame["data"])
vr_pts.append(vr_frame["pts"])

# get vr frames = read from memory
f = open(full_path, "rb")
fbytes = f.read()
f.close()
video_reader_from_mem = VideoReader(fbytes, stream)

for vr_frame_from_mem in video_reader_from_mem:
vr_frames_mem.append(vr_frame_from_mem["data"])
vr_pts_mem.append(vr_frame_from_mem["pts"])

# same number of frames
assert len(vr_frames) == len(vr_frames_mem)
assert len(vr_pts) == len(vr_pts_mem)

# compare the frames and ptss
for i in range(len(vr_frames)):
assert vr_pts[i] == vr_pts_mem[i]
mean_delta = torch.mean(torch.abs(vr_frames[i].float() - vr_frames_mem[i].float()))
# on average the difference is very small and caused
# by decoding (around 1%)
# TODO: asses empirically how to set this? atm it's 1%
# averaged over all frames
assert mean_delta.item() < 2.55

del vr_frames, vr_pts, vr_frames_mem, vr_pts_mem
else:
del reader, reader_md
# Test video reading from file vs from memory
vr_frames, vr_frames_mem = [], []
vr_pts, vr_pts_mem = [], []
# get vr frames
video_reader = VideoReader(full_path, stream)
for vr_frame in video_reader:
vr_frames.append(vr_frame["data"])
vr_pts.append(vr_frame["pts"])

# get vr frames = read from memory
f = open(full_path, "rb")
fbytes = f.read()
f.close()
video_reader_from_mem = VideoReader(fbytes, stream)

for vr_frame_from_mem in video_reader_from_mem:
vr_frames_mem.append(vr_frame_from_mem["data"])
vr_pts_mem.append(vr_frame_from_mem["pts"])

# same number of frames
assert len(vr_frames) == len(vr_frames_mem)
assert len(vr_pts) == len(vr_pts_mem)

# compare the frames and ptss
for i in range(len(vr_frames)):
assert vr_pts[i] == vr_pts_mem[i]
mean_delta = torch.mean(torch.abs(vr_frames[i].float() - vr_frames_mem[i].float()))
# on average the difference is very small and caused
# by decoding (around 1%)
# TODO: asses empirically how to set this? atm it's 1%
# averaged over all frames
assert mean_delta.item() < 2.55

del vr_frames, vr_pts, vr_frames_mem, vr_pts_mem

@pytest.mark.parametrize("test_video,config", test_videos.items())
@pytest.mark.parametrize("backend", ["video_reader", "pyav"])
def test_metadata(self, test_video, config, backend):
def test_metadata(self, test_video, config):
"""
Test that the metadata returned via pyav corresponds to the one returned
by the new video decoder API
"""
torchvision.set_video_backend(backend)
full_path = os.path.join(VIDEO_DIR, test_video)
reader = VideoReader(full_path, "video")
reader_md = reader.get_metadata()
assert config.video_fps == approx(reader_md["video"]["fps"][0], abs=0.0001)
assert config.duration == approx(reader_md["video"]["duration"][0], abs=0.5)

@pytest.mark.parametrize("test_video", test_videos.keys())
@pytest.mark.parametrize("backend", ["video_reader", "pyav"])
def test_seek_start(self, test_video, backend):
torchvision.set_video_backend(backend)
def test_seek_start(self, test_video):
full_path = os.path.join(VIDEO_DIR, test_video)
video_reader = VideoReader(full_path, "video")
num_frames = 0
Expand All @@ -208,9 +194,7 @@ def test_seek_start(self, test_video, backend):
assert start_num_frames == num_frames

@pytest.mark.parametrize("test_video", test_videos.keys())
@pytest.mark.parametrize("backend", ["video_reader"])
def test_accurateseek_middle(self, test_video, backend):
torchvision.set_video_backend(backend)
def test_accurateseek_middle(self, test_video):
full_path = os.path.join(VIDEO_DIR, test_video)
stream = "video"
video_reader = VideoReader(full_path, stream)
Expand Down Expand Up @@ -249,9 +233,7 @@ def test_fate_suite(self):

@pytest.mark.skipif(av is None, reason="PyAV unavailable")
@pytest.mark.parametrize("test_video,config", test_videos.items())
@pytest.mark.parametrize("backend", ["pyav", "video_reader"])
def test_keyframe_reading(self, test_video, config, backend):
torchvision.set_video_backend(backend)
def test_keyframe_reading(self, test_video, config):
full_path = os.path.join(VIDEO_DIR, test_video)

av_reader = av.open(full_path)
Expand Down
21 changes: 4 additions & 17 deletions torchvision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,16 @@
import os
import warnings
from modulefinder import Module

import torch
from torchvision import datasets, io, models, ops, transforms, utils

from .extension import _HAS_OPS, _load_library
from .extension import _HAS_OPS

try:
from .version import __version__ # noqa: F401
except ImportError:
pass

try:
_load_library("Decoder")
_HAS_GPU_VIDEO_DECODER = True
except (ImportError, OSError, ModuleNotFoundError):
_HAS_GPU_VIDEO_DECODER = False


# Check if torchvision is being imported within the root folder
if not _HAS_OPS and os.path.dirname(os.path.realpath(__file__)) == os.path.join(
os.path.realpath(os.getcwd()), "torchvision"
Expand Down Expand Up @@ -74,16 +66,11 @@ def set_video_backend(backend):
backend, please compile torchvision from source.
"""
global _video_backend
if backend not in ["pyav", "video_reader", "cuda"]:
raise ValueError("Invalid video backend '%s'. Options are 'pyav', 'video_reader' and 'cuda'" % backend)
if backend not in ["pyav", "video_reader"]:
raise ValueError("Invalid video backend '%s'. Options are 'pyav' and 'video_reader'" % backend)
if backend == "video_reader" and not io._HAS_VIDEO_OPT:
# TODO: better messages
message = "video_reader video backend is not available. Please compile torchvision from source and try again"
raise RuntimeError(message)
elif backend == "cuda" and not _HAS_GPU_VIDEO_DECODER:
# TODO: better messages
message = "cuda video backend is not available."
raise RuntimeError(message)
warnings.warn(message)
else:
_video_backend = backend

Expand Down
5 changes: 5 additions & 0 deletions torchvision/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

from ..utils import _log_api_usage_once

try:
from ._load_gpu_decoder import _HAS_GPU_VIDEO_DECODER
except ModuleNotFoundError:
_HAS_GPU_VIDEO_DECODER = False
from ._video_opt import (
_HAS_VIDEO_OPT,
_probe_video_from_file,
Expand Down Expand Up @@ -43,6 +47,7 @@
"_read_video_timestamps_from_memory",
"_probe_video_from_memory",
"_HAS_VIDEO_OPT",
"_HAS_GPU_VIDEO_DECODER",
"_read_video_clip_from_memory",
"_read_video_meta_data",
"VideoMetaData",
Expand Down
8 changes: 8 additions & 0 deletions torchvision/io/_load_gpu_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from ..extension import _load_library


try:
_load_library("Decoder")
_HAS_GPU_VIDEO_DECODER = True
except (ImportError, OSError):
_HAS_GPU_VIDEO_DECODER = False
Loading