diff --git a/test/test_video_gpu_decoder.py b/test/test_video_gpu_decoder.py index d987db6ddeb..66c9aa04a46 100644 --- a/test/test_video_gpu_decoder.py +++ b/test/test_video_gpu_decoder.py @@ -3,7 +3,9 @@ import pytest import torch -from torchvision.io import _HAS_GPU_VIDEO_DECODER, VideoReader +import torchvision +from torchvision import _HAS_GPU_VIDEO_DECODER +from torchvision.io import VideoReader try: import av @@ -29,8 +31,9 @@ class TestVideoGPUDecoder: ], ) def test_frame_reading(self, video_file): + torchvision.set_video_backend("cuda") full_path = os.path.join(VIDEO_DIR, video_file) - decoder = VideoReader(full_path, device="cuda") + decoder = VideoReader(full_path) with av.open(full_path) as container: for av_frame in container.decode(container.streams.video[0]): av_frames = torch.tensor(av_frame.to_rgb(src_colorspace="ITU709").to_ndarray()) @@ -54,7 +57,8 @@ def test_frame_reading(self, video_file): ], ) def test_seek_reading(self, keyframes, full_path, duration): - decoder = VideoReader(full_path, device="cuda") + torchvision.set_video_backend("cuda") + decoder = VideoReader(full_path) time = duration / 2 decoder.seek(time, keyframes_only=keyframes) with av.open(full_path) as container: @@ -79,8 +83,9 @@ def test_seek_reading(self, keyframes, full_path, duration): ], ) def test_metadata(self, video_file): + torchvision.set_video_backend("cuda") full_path = os.path.join(VIDEO_DIR, video_file) - decoder = VideoReader(full_path, device="cuda") + decoder = VideoReader(full_path) video_metadata = decoder.get_metadata()["video"] with av.open(full_path) as container: video = container.streams.video[0] diff --git a/test/test_videoapi.py b/test/test_videoapi.py index 4688e5a640b..c1bfb9012c4 100644 --- a/test/test_videoapi.py +++ b/test/test_videoapi.py @@ -53,7 +53,9 @@ def fate(name, path="."): class TestVideoApi: @pytest.mark.skipif(av is None, reason="PyAV unavailable") @pytest.mark.parametrize("test_video", test_videos.keys()) - def test_frame_reading(self, test_video): + @pytest.mark.parametrize("backend", ["video_reader", "pyav"]) + def test_frame_reading(self, test_video, backend): + torchvision.set_video_backend(backend) full_path = os.path.join(VIDEO_DIR, test_video) with av.open(full_path) as av_reader: if av_reader.streams.video: @@ -117,50 +119,60 @@ def test_frame_reading(self, test_video): @pytest.mark.parametrize("stream", ["video", "audio"]) @pytest.mark.parametrize("test_video", test_videos.keys()) - def test_frame_reading_mem_vs_file(self, test_video, stream): + @pytest.mark.parametrize("backend", ["video_reader", "pyav"]) + def test_frame_reading_mem_vs_file(self, test_video, stream, backend): + torchvision.set_video_backend(backend) full_path = os.path.join(VIDEO_DIR, test_video) - # Test video reading from file vs from memory - vr_frames, vr_frames_mem = [], [] - vr_pts, vr_pts_mem = [], [] - # get vr frames - video_reader = VideoReader(full_path, stream) - for vr_frame in video_reader: - vr_frames.append(vr_frame["data"]) - vr_pts.append(vr_frame["pts"]) - - # get vr frames = read from memory - f = open(full_path, "rb") - fbytes = f.read() - f.close() - video_reader_from_mem = VideoReader(fbytes, stream) - - for vr_frame_from_mem in video_reader_from_mem: - vr_frames_mem.append(vr_frame_from_mem["data"]) - vr_pts_mem.append(vr_frame_from_mem["pts"]) - - # same number of frames - assert len(vr_frames) == len(vr_frames_mem) - assert len(vr_pts) == len(vr_pts_mem) - - # compare the frames and ptss - for i in range(len(vr_frames)): - assert vr_pts[i] == vr_pts_mem[i] - mean_delta = torch.mean(torch.abs(vr_frames[i].float() - vr_frames_mem[i].float())) - # on average the difference is very small and caused - # by decoding (around 1%) - # TODO: asses empirically how to set this? atm it's 1% - # averaged over all frames - assert mean_delta.item() < 2.55 - - del vr_frames, vr_pts, vr_frames_mem, vr_pts_mem + reader = VideoReader(full_path) + reader_md = reader.get_metadata() + + if stream in reader_md: + # Test video reading from file vs from memory + vr_frames, vr_frames_mem = [], [] + vr_pts, vr_pts_mem = [], [] + # get vr frames + video_reader = VideoReader(full_path, stream) + for vr_frame in video_reader: + vr_frames.append(vr_frame["data"]) + vr_pts.append(vr_frame["pts"]) + + # get vr frames = read from memory + f = open(full_path, "rb") + fbytes = f.read() + f.close() + video_reader_from_mem = VideoReader(fbytes, stream) + + for vr_frame_from_mem in video_reader_from_mem: + vr_frames_mem.append(vr_frame_from_mem["data"]) + vr_pts_mem.append(vr_frame_from_mem["pts"]) + + # same number of frames + assert len(vr_frames) == len(vr_frames_mem) + assert len(vr_pts) == len(vr_pts_mem) + + # compare the frames and ptss + for i in range(len(vr_frames)): + assert vr_pts[i] == vr_pts_mem[i] + mean_delta = torch.mean(torch.abs(vr_frames[i].float() - vr_frames_mem[i].float())) + # on average the difference is very small and caused + # by decoding (around 1%) + # TODO: asses empirically how to set this? atm it's 1% + # averaged over all frames + assert mean_delta.item() < 2.55 + + del vr_frames, vr_pts, vr_frames_mem, vr_pts_mem + else: + del reader, reader_md @pytest.mark.parametrize("test_video,config", test_videos.items()) - def test_metadata(self, test_video, config): + @pytest.mark.parametrize("backend", ["video_reader", "pyav"]) + def test_metadata(self, test_video, config, backend): """ Test that the metadata returned via pyav corresponds to the one returned by the new video decoder API """ + torchvision.set_video_backend(backend) full_path = os.path.join(VIDEO_DIR, test_video) reader = VideoReader(full_path, "video") reader_md = reader.get_metadata() @@ -168,7 +180,9 @@ def test_metadata(self, test_video, config): assert config.duration == approx(reader_md["video"]["duration"][0], abs=0.5) @pytest.mark.parametrize("test_video", test_videos.keys()) - def test_seek_start(self, test_video): + @pytest.mark.parametrize("backend", ["video_reader", "pyav"]) + def test_seek_start(self, test_video, backend): + torchvision.set_video_backend(backend) full_path = os.path.join(VIDEO_DIR, test_video) video_reader = VideoReader(full_path, "video") num_frames = 0 @@ -194,7 +208,9 @@ def test_seek_start(self, test_video): assert start_num_frames == num_frames @pytest.mark.parametrize("test_video", test_videos.keys()) - def test_accurateseek_middle(self, test_video): + @pytest.mark.parametrize("backend", ["video_reader"]) + def test_accurateseek_middle(self, test_video, backend): + torchvision.set_video_backend(backend) full_path = os.path.join(VIDEO_DIR, test_video) stream = "video" video_reader = VideoReader(full_path, stream) @@ -233,7 +249,9 @@ def test_fate_suite(self): @pytest.mark.skipif(av is None, reason="PyAV unavailable") @pytest.mark.parametrize("test_video,config", test_videos.items()) - def test_keyframe_reading(self, test_video, config): + @pytest.mark.parametrize("backend", ["pyav", "video_reader"]) + def test_keyframe_reading(self, test_video, config, backend): + torchvision.set_video_backend(backend) full_path = os.path.join(VIDEO_DIR, test_video) av_reader = av.open(full_path) diff --git a/torchvision/__init__.py b/torchvision/__init__.py index 739f79407b3..def7e82b840 100644 --- a/torchvision/__init__.py +++ b/torchvision/__init__.py @@ -1,16 +1,24 @@ import os import warnings +from modulefinder import Module import torch from torchvision import datasets, io, models, ops, transforms, utils -from .extension import _HAS_OPS +from .extension import _HAS_OPS, _load_library try: from .version import __version__ # noqa: F401 except ImportError: pass +try: + _load_library("Decoder") + _HAS_GPU_VIDEO_DECODER = True +except (ImportError, OSError, ModuleNotFoundError): + _HAS_GPU_VIDEO_DECODER = False + + # Check if torchvision is being imported within the root folder if not _HAS_OPS and os.path.dirname(os.path.realpath(__file__)) == os.path.join( os.path.realpath(os.getcwd()), "torchvision" @@ -66,11 +74,16 @@ def set_video_backend(backend): backend, please compile torchvision from source. """ global _video_backend - if backend not in ["pyav", "video_reader"]: - raise ValueError("Invalid video backend '%s'. Options are 'pyav' and 'video_reader'" % backend) + if backend not in ["pyav", "video_reader", "cuda"]: + raise ValueError("Invalid video backend '%s'. Options are 'pyav', 'video_reader' and 'cuda'" % backend) if backend == "video_reader" and not io._HAS_VIDEO_OPT: + # TODO: better messages message = "video_reader video backend is not available. Please compile torchvision from source and try again" - warnings.warn(message) + raise RuntimeError(message) + elif backend == "cuda" and not _HAS_GPU_VIDEO_DECODER: + # TODO: better messages + message = "cuda video backend is not available." + raise RuntimeError(message) else: _video_backend = backend diff --git a/torchvision/io/__init__.py b/torchvision/io/__init__.py index ba7d4f69f26..0787b8230e0 100644 --- a/torchvision/io/__init__.py +++ b/torchvision/io/__init__.py @@ -4,10 +4,6 @@ from ..utils import _log_api_usage_once -try: - from ._load_gpu_decoder import _HAS_GPU_VIDEO_DECODER -except ModuleNotFoundError: - _HAS_GPU_VIDEO_DECODER = False from ._video_opt import ( _HAS_VIDEO_OPT, _probe_video_from_file, @@ -47,7 +43,6 @@ "_read_video_timestamps_from_memory", "_probe_video_from_memory", "_HAS_VIDEO_OPT", - "_HAS_GPU_VIDEO_DECODER", "_read_video_clip_from_memory", "_read_video_meta_data", "VideoMetaData", diff --git a/torchvision/io/_load_gpu_decoder.py b/torchvision/io/_load_gpu_decoder.py deleted file mode 100644 index f7869f0a9d1..00000000000 --- a/torchvision/io/_load_gpu_decoder.py +++ /dev/null @@ -1,8 +0,0 @@ -from ..extension import _load_library - - -try: - _load_library("Decoder") - _HAS_GPU_VIDEO_DECODER = True -except (ImportError, OSError): - _HAS_GPU_VIDEO_DECODER = False diff --git a/torchvision/io/video_reader.py b/torchvision/io/video_reader.py index 0449d6d1ea4..764b82dfe42 100644 --- a/torchvision/io/video_reader.py +++ b/torchvision/io/video_reader.py @@ -1,14 +1,12 @@ +import io import warnings + from typing import Any, Dict, Iterator, Optional import torch from ..utils import _log_api_usage_once -try: - from ._load_gpu_decoder import _HAS_GPU_VIDEO_DECODER -except ModuleNotFoundError: - _HAS_GPU_VIDEO_DECODER = False from ._video_opt import _HAS_VIDEO_OPT if _HAS_VIDEO_OPT: @@ -22,11 +20,37 @@ def _has_video_opt() -> bool: return False +try: + import av + + av.logging.set_level(av.logging.ERROR) + if not hasattr(av.video.frame.VideoFrame, "pict_type"): + av = ImportError( + """\ +Your version of PyAV is too old for the necessary video operations in torchvision. +If you are on Python 3.5, you will have to build from source (the conda-forge +packages are not up-to-date). See +https://github.com/mikeboers/PyAV#installation for instructions on how to +install PyAV on your system. +""" + ) +except ImportError: + av = ImportError( + """\ +PyAV is not installed, and is necessary for the video operations in torchvision. +See https://github.com/mikeboers/PyAV#installation for instructions on how to +install PyAV on your system. +""" + ) + + class VideoReader: """ Fine-grained video-reading API. Supports frame-by-frame reading of various streams from a single video - container. + container. Much like previous video_reader API it supports the following + backends: video_reader, pyav, and cuda. + Backends can be set via `torchvision.set_video_backend` function. .. betastatus:: VideoReader class @@ -88,16 +112,11 @@ class VideoReader: Default value (0) enables multithreading with codec-dependent heuristic. The performance will depend on the version of FFMPEG codecs supported. - device (str, optional): Device to be used for decoding. Defaults to ``"cpu"``. - To use GPU decoding, pass ``device="cuda"``. path (str, optional): .. warning: This parameter was deprecated in ``0.15`` and will be removed in ``0.17``. Please use ``src`` instead. - - - """ def __init__( @@ -105,45 +124,59 @@ def __init__( src: str = "", stream: str = "video", num_threads: int = 0, - device: str = "cpu", path: Optional[str] = None, ) -> None: _log_api_usage_once(self) - self.is_cuda = False - device = torch.device(device) - if device.type == "cuda": - if not _HAS_GPU_VIDEO_DECODER: - raise RuntimeError("Not compiled with GPU decoder support.") - self.is_cuda = True - self._c = torch.classes.torchvision.GPUDecoder(src, device) - return - if not _has_video_opt(): - raise RuntimeError( - "Not compiled with video_reader support, " - + "to enable video_reader support, please install " - + "ffmpeg (version 4.2 is currently supported) and " - + "build torchvision from source." - ) - - if src == "": - if path is None: - raise TypeError("src cannot be empty") - src = path - warnings.warn("path is deprecated and will be removed in 0.17. Please use src instead") - - elif isinstance(src, bytes): - src = torch.frombuffer(src, dtype=torch.uint8) + from .. import get_video_backend + self.backend = get_video_backend() if isinstance(src, str): - self._c = torch.classes.torchvision.Video(src, stream, num_threads) + if src == "": + if path is None: + raise TypeError("src cannot be empty") + src = path + warnings.warn("path is deprecated and will be removed in 0.17. Please use src instead") + elif isinstance(src, bytes): + if self.backend in ["cuda"]: + raise RuntimeError( + "VideoReader cannot be initialized from bytes object when using cuda or pyav backend." + ) + elif self.backend == "pyav": + src = io.BytesIO(src) + else: + src = torch.frombuffer(src, dtype=torch.uint8) elif isinstance(src, torch.Tensor): - if self.is_cuda: - raise RuntimeError("GPU VideoReader cannot be initialized from Tensor or bytes object.") - self._c = torch.classes.torchvision.Video("", "", 0) - self._c.init_from_memory(src, stream, num_threads) + if self.backend in ["cuda", "pyav"]: + raise RuntimeError( + "VideoReader cannot be initialized from Tensor object when using cuda or pyav backend." + ) else: raise TypeError("`src` must be either string, Tensor or bytes object.") + if self.backend == "cuda": + device = torch.device("cuda") + self._c = torch.classes.torchvision.GPUDecoder(src, device) + + elif self.backend == "video_reader": + if isinstance(src, str): + self._c = torch.classes.torchvision.Video(src, stream, num_threads) + elif isinstance(src, torch.Tensor): + self._c = torch.classes.torchvision.Video("", "", 0) + self._c.init_from_memory(src, stream, num_threads) + + elif self.backend == "pyav": + self.container = av.open(src, metadata_errors="ignore") + # TODO: load metadata + stream_type = stream.split(":")[0] + stream_id = 0 if len(stream.split(":")) == 1 else int(stream.split(":")[1]) + self.pyav_stream = {stream_type: stream_id} + self._c = self.container.decode(**self.pyav_stream) + + # TODO: add extradata exception + + else: + raise RuntimeError("Unknown video backend: {}".format(self.backend)) + def __next__(self) -> Dict[str, Any]: """Decodes and returns the next frame of the current stream. Frames are encoded as a dict with mandatory @@ -156,14 +189,29 @@ def __next__(self) -> Dict[str, Any]: and corresponding timestamp (``pts``) in seconds """ - if self.is_cuda: + if self.backend == "cuda": frame = self._c.next() if frame.numel() == 0: raise StopIteration - return {"data": frame} - frame, pts = self._c.next() + return {"data": frame, "pts": None} + elif self.backend == "video_reader": + frame, pts = self._c.next() + else: + try: + frame = next(self._c) + pts = float(frame.pts * frame.time_base) + if "video" in self.pyav_stream: + frame = torch.tensor(frame.to_rgb().to_ndarray()).permute(2, 0, 1) + elif "audio" in self.pyav_stream: + frame = torch.tensor(frame.to_ndarray()).permute(1, 0) + else: + frame = None + except av.error.EOFError: + raise StopIteration + if frame.numel() == 0: raise StopIteration + return {"data": frame, "pts": pts} def __iter__(self) -> Iterator[Dict[str, Any]]: @@ -182,7 +230,18 @@ def seek(self, time_s: float, keyframes_only: bool = False) -> "VideoReader": frame with the exact timestamp if it exists or the first frame with timestamp larger than ``time_s``. """ - self._c.seek(time_s, keyframes_only) + if self.backend in ["cuda", "video_reader"]: + self._c.seek(time_s, keyframes_only) + else: + # handle special case as pyav doesn't catch it + if time_s < 0: + time_s = 0 + temp_str = self.container.streams.get(**self.pyav_stream)[0] + offset = int(round(time_s / temp_str.time_base)) + if not keyframes_only: + warnings.warn("Accurate seek is not implemented for pyav backend") + self.container.seek(offset, backward=True, any_frame=False, stream=temp_str) + self._c = self.container.decode(**self.pyav_stream) return self def get_metadata(self) -> Dict[str, Any]: @@ -191,6 +250,21 @@ def get_metadata(self) -> Dict[str, Any]: Returns: (dict): dictionary containing duration and frame rate for every stream """ + if self.backend == "pyav": + metadata = {} # type: Dict[str, Any] + for stream in self.container.streams: + if stream.type not in metadata: + if stream.type == "video": + rate_n = "fps" + else: + rate_n = "framerate" + metadata[stream.type] = {rate_n: [], "duration": []} + + rate = stream.average_rate if stream.average_rate is not None else stream.sample_rate + + metadata[stream.type]["duration"].append(float(stream.duration * stream.time_base)) + metadata[stream.type][rate_n].append(float(rate)) + return metadata return self._c.get_metadata() def set_current_stream(self, stream: str) -> bool: @@ -210,6 +284,12 @@ def set_current_stream(self, stream: str) -> bool: Returns: (bool): True on succes, False otherwise """ - if self.is_cuda: - print("GPU decoding only works with video stream.") + if self.backend == "cuda": + warnings.warn("GPU decoding only works with video stream.") + if self.backend == "pyav": + stream_type = stream.split(":")[0] + stream_id = 0 if len(stream.split(":")) == 1 else int(stream.split(":")[1]) + self.pyav_stream = {stream_type: stream_id} + self._c = self.container.decode(**self.pyav_stream) + return True return self._c.set_current_stream(stream)