diff --git a/setup.py b/setup.py index a62895d76c8..9678a2c17b7 100644 --- a/setup.py +++ b/setup.py @@ -427,6 +427,59 @@ def get_extensions(): ) ) + # Locating video codec + # CUDA_HOME should be set to the cuda root directory. + # TORCHVISION_INCLUDE and TORCHVISION_LIBRARY should include the location to + # video codec header files and libraries respectively. + video_codec_found = ( + extension is CUDAExtension + and CUDA_HOME is not None + and any([os.path.exists(os.path.join(folder, "cuviddec.h")) for folder in vision_include]) + and any([os.path.exists(os.path.join(folder, "nvcuvid.h")) for folder in vision_include]) + and any([os.path.exists(os.path.join(folder, "libnvcuvid.so")) for folder in library_dirs]) + ) + + print(f"video codec found: {video_codec_found}") + + if ( + video_codec_found + and has_ffmpeg + and any([os.path.exists(os.path.join(folder, "libavcodec", "bsf.h")) for folder in ffmpeg_include_dir]) + ): + gpu_decoder_path = os.path.join(extensions_dir, "io", "decoder", "gpu") + gpu_decoder_src = glob.glob(os.path.join(gpu_decoder_path, "*.cpp")) + cuda_libs = os.path.join(CUDA_HOME, "lib64") + cuda_inc = os.path.join(CUDA_HOME, "include") + + ext_modules.append( + extension( + "torchvision.Decoder", + gpu_decoder_src, + include_dirs=include_dirs + [gpu_decoder_path] + [cuda_inc] + ffmpeg_include_dir, + library_dirs=ffmpeg_library_dir + library_dirs + [cuda_libs], + libraries=[ + "avcodec", + "avformat", + "avutil", + "swresample", + "swscale", + "nvcuvid", + "cuda", + "cudart", + "z", + "pthread", + "dl", + ], + extra_compile_args=extra_compile_args, + ) + ) + else: + print( + "The installed version of ffmpeg is missing the header file 'bsf.h' which is " + "required for GPU video decoding. Please install the latest ffmpeg from conda-forge channel:" + " `conda install -c conda-forge ffmpeg`." + ) + return ext_modules diff --git a/test/test_video_gpu_decoder.py b/test/test_video_gpu_decoder.py new file mode 100644 index 00000000000..84309e3e217 --- /dev/null +++ b/test/test_video_gpu_decoder.py @@ -0,0 +1,41 @@ +import os + +import pytest +import torch +from torchvision.io import _HAS_VIDEO_DECODER, VideoReader + +try: + import av +except ImportError: + av = None + +VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos") + +test_videos = [ + "RATRACE_wave_f_nm_np1_fr_goo_37.avi", + "TrumanShow_wave_f_nm_np1_fr_med_26.avi", + "v_SoccerJuggling_g23_c01.avi", + "v_SoccerJuggling_g24_c01.avi", + "R6llTwEh07w.mp4", + "SOX5yA1l24A.mp4", + "WUzgd7C1pWA.mp4", +] + + +@pytest.mark.skipif(_HAS_VIDEO_DECODER is False, reason="Didn't compile with support for gpu decoder") +class TestVideoGPUDecoder: + @pytest.mark.skipif(av is None, reason="PyAV unavailable") + def test_frame_reading(self): + for test_video in test_videos: + full_path = os.path.join(VIDEO_DIR, test_video) + decoder = VideoReader(full_path, device="cuda:0") + with av.open(full_path) as container: + for av_frame in container.decode(container.streams.video[0]): + av_frames = torch.tensor(av_frame.to_ndarray().flatten()) + vision_frames = next(decoder)["data"] + mean_delta = torch.mean(torch.abs(av_frames.float() - decoder._reformat(vision_frames).float())) + assert mean_delta < 0.1 + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/torchvision/csrc/io/decoder/gpu/README.rst b/torchvision/csrc/io/decoder/gpu/README.rst new file mode 100644 index 00000000000..cebd31cb557 --- /dev/null +++ b/torchvision/csrc/io/decoder/gpu/README.rst @@ -0,0 +1,21 @@ +GPU Decoder +=========== + +GPU decoder depends on ffmpeg for demuxing, uses NVDECODE APIs from the nvidia-video-codec sdk and uses cuda for processing on gpu. In order to use this, please follow the following steps: + +* Download the latest `nvidia-video-codec-sdk `_ +* Extract the zipped file. +* Set TORCHVISION_INCLUDE environment variable to the location of the video codec headers(`nvcuvid.h` and `cuviddec.h`), which would be under `Interface` directory. +* Set TORCHVISION_LIBRARY environment variable to the location of the video codec library(`libnvcuvid.so`), which would be under `Lib/linux/stubs/x86_64` directory. +* Install the latest ffmpeg from `conda-forge` channel. + +.. code:: bash + + conda install -c conda-forge ffmpeg + +* Set CUDA_HOME environment variable to the cuda root directory. +* Build torchvision from source: + +.. code:: bash + + python setup.py install diff --git a/torchvision/csrc/io/decoder/gpu/decoder.cpp b/torchvision/csrc/io/decoder/gpu/decoder.cpp new file mode 100644 index 00000000000..4471fd6b783 --- /dev/null +++ b/torchvision/csrc/io/decoder/gpu/decoder.cpp @@ -0,0 +1,417 @@ +#include "decoder.h" +#include +#include +#include +#include + +static float chroma_height_factor(cudaVideoSurfaceFormat surface_format) { + return (surface_format == cudaVideoSurfaceFormat_YUV444 || + surface_format == cudaVideoSurfaceFormat_YUV444_16Bit) + ? 1.0 + : 0.5; +} + +static int chroma_plane_count(cudaVideoSurfaceFormat surface_format) { + return (surface_format == cudaVideoSurfaceFormat_YUV444 || + surface_format == cudaVideoSurfaceFormat_YUV444_16Bit) + ? 2 + : 1; +} + +/* Initialise cu_context and video_codec, create context lock and create parser + * object. + */ +void Decoder::init(CUcontext context, cudaVideoCodec codec) { + cu_context = context; + video_codec = codec; + check_for_cuda_errors( + cuvidCtxLockCreate(&ctx_lock, cu_context), __LINE__, __FILE__); + + CUVIDPARSERPARAMS parser_params = {}; + parser_params.CodecType = codec; + parser_params.ulMaxNumDecodeSurfaces = 1; + parser_params.ulClockRate = 1000; + parser_params.ulMaxDisplayDelay = 0u; + parser_params.pUserData = this; + parser_params.pfnSequenceCallback = video_sequence_handler; + parser_params.pfnDecodePicture = picture_decode_handler; + parser_params.pfnDisplayPicture = picture_display_handler; + parser_params.pfnGetOperatingPoint = operating_point_handler; + + check_for_cuda_errors( + cuvidCreateVideoParser(&parser, &parser_params), __LINE__, __FILE__); +} + +/* Destroy parser object and context lock. + */ +Decoder::~Decoder() { + if (parser) { + cuvidDestroyVideoParser(parser); + } + cuvidCtxLockDestroy(ctx_lock); +} + +/* Destroy CUvideodecoder object and free up all the unreturned decoded frames. + */ +void Decoder::release() { + cuCtxPushCurrent(cu_context); + if (decoder) { + cuvidDestroyDecoder(decoder); + } + cuCtxPopCurrent(NULL); +} + +/* Trigger video decoding. + */ +void Decoder::decode(const uint8_t* data, unsigned long size) { + CUVIDSOURCEDATAPACKET pkt = {}; + pkt.flags = CUVID_PKT_TIMESTAMP; + pkt.payload_size = size; + pkt.payload = data; + pkt.timestamp = 0; + if (!data || size == 0) { + pkt.flags |= CUVID_PKT_ENDOFSTREAM; + } + check_for_cuda_errors(cuvidParseVideoData(parser, &pkt), __LINE__, __FILE__); + cuvidStream = 0; +} + +/* Fetch a decoded frame and remove it from the queue. + */ +torch::Tensor Decoder::fetch_frame() { + if (decoded_frames.empty()) { + auto options = + torch::TensorOptions().dtype(torch::kU8).device(torch::kCUDA); + return torch::zeros({0}, options); + } + torch::Tensor frame = decoded_frames.front(); + decoded_frames.pop(); + return frame; +} + +/* Called when a picture is ready to be decoded. + */ +int Decoder::handle_picture_decode(CUVIDPICPARAMS* pic_params) { + if (!decoder) { + TORCH_CHECK(false, "Uninitialised decoder"); + } + pic_num_in_decode_order[pic_params->CurrPicIdx] = decode_pic_count++; + check_for_cuda_errors(cuCtxPushCurrent(cu_context), __LINE__, __FILE__); + check_for_cuda_errors( + cuvidDecodePicture(decoder, pic_params), __LINE__, __FILE__); + check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__, __FILE__); + return 1; +} + +/* Process the decoded data and copy it to a cuda memory location. + */ +int Decoder::handle_picture_display(CUVIDPARSERDISPINFO* disp_info) { + CUVIDPROCPARAMS proc_params = {}; + proc_params.progressive_frame = disp_info->progressive_frame; + proc_params.second_field = disp_info->repeat_first_field + 1; + proc_params.top_field_first = disp_info->top_field_first; + proc_params.unpaired_field = disp_info->repeat_first_field < 0; + proc_params.output_stream = cuvidStream; + + CUdeviceptr source_frame = 0; + unsigned int source_pitch = 0; + check_for_cuda_errors(cuCtxPushCurrent(cu_context), __LINE__, __FILE__); + check_for_cuda_errors( + cuvidMapVideoFrame( + decoder, + disp_info->picture_index, + &source_frame, + &source_pitch, + &proc_params), + __LINE__, + __FILE__); + + CUVIDGETDECODESTATUS decode_status; + memset(&decode_status, 0, sizeof(decode_status)); + CUresult result = + cuvidGetDecodeStatus(decoder, disp_info->picture_index, &decode_status); + if (result == CUDA_SUCCESS && + (decode_status.decodeStatus == cuvidDecodeStatus_Error || + decode_status.decodeStatus == cuvidDecodeStatus_Error_Concealed)) { + VLOG(1) << "Decode Error occurred for picture " + << pic_num_in_decode_order[disp_info->picture_index]; + } + + auto options = torch::TensorOptions().dtype(torch::kU8).device(torch::kCUDA); + torch::Tensor decoded_frame = torch::empty({get_frame_size()}, options); + uint8_t* frame_ptr = decoded_frame.data_ptr(); + + // Copy luma plane + CUDA_MEMCPY2D m = {0}; + m.srcMemoryType = CU_MEMORYTYPE_DEVICE; + m.srcDevice = source_frame; + m.srcPitch = source_pitch; + m.dstMemoryType = CU_MEMORYTYPE_DEVICE; + m.dstDevice = (CUdeviceptr)(m.dstHost = frame_ptr); + m.dstPitch = get_width() * bytes_per_pixel; + m.WidthInBytes = get_width() * bytes_per_pixel; + m.Height = luma_height; + check_for_cuda_errors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__, __FILE__); + + // Copy chroma plane + // NVDEC output has luma height aligned by 2. Adjust chroma offset by aligning + // height + m.srcDevice = + (CUdeviceptr)((uint8_t*)source_frame + m.srcPitch * ((surface_height + 1) & ~1)); + m.dstDevice = (CUdeviceptr)(m.dstHost = frame_ptr + m.dstPitch * luma_height); + m.Height = chroma_height; + check_for_cuda_errors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__, __FILE__); + + if (num_chroma_planes == 2) { + m.srcDevice = + (CUdeviceptr)((uint8_t*)source_frame + m.srcPitch * ((surface_height + 1) & ~1) * 2); + m.dstDevice = + (CUdeviceptr)(m.dstHost = frame_ptr + m.dstPitch * luma_height * 2); + m.Height = chroma_height; + check_for_cuda_errors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__, __FILE__); + } + check_for_cuda_errors(cuStreamSynchronize(cuvidStream), __LINE__, __FILE__); + decoded_frames.push(decoded_frame); + check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__, __FILE__); + + check_for_cuda_errors( + cuvidUnmapVideoFrame(decoder, source_frame), __LINE__, __FILE__); + return 1; +} + +/* Query the capabilities of the underlying hardware video decoder and + * verify if the hardware supports decoding the passed video. + */ +void Decoder::query_hardware(CUVIDEOFORMAT* video_format) { + CUVIDDECODECAPS decode_caps = {}; + decode_caps.eCodecType = video_format->codec; + decode_caps.eChromaFormat = video_format->chroma_format; + decode_caps.nBitDepthMinus8 = video_format->bit_depth_luma_minus8; + + check_for_cuda_errors(cuCtxPushCurrent(cu_context), __LINE__, __FILE__); + check_for_cuda_errors(cuvidGetDecoderCaps(&decode_caps), __LINE__, __FILE__); + check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__, __FILE__); + + if (!decode_caps.bIsSupported) { + TORCH_CHECK(false, "Codec not supported on this GPU"); + } + if ((video_format->coded_width > decode_caps.nMaxWidth) || + (video_format->coded_height > decode_caps.nMaxHeight)) { + TORCH_CHECK( + false, + "Resolution : ", + video_format->coded_width, + "x", + video_format->coded_height, + "\nMax Supported (wxh) : ", + decode_caps.nMaxWidth, + "x", + decode_caps.nMaxHeight, + "\nResolution not supported on this GPU"); + } + if ((video_format->coded_width >> 4) * (video_format->coded_height >> 4) > + decode_caps.nMaxMBCount) { + TORCH_CHECK( + false, + "MBCount : ", + (video_format->coded_width >> 4) * (video_format->coded_height >> 4), + "\nMax Supported mbcnt : ", + decode_caps.nMaxMBCount, + "\nMBCount not supported on this GPU"); + } + // Check if output format supported. If not, check fallback options + if (!(decode_caps.nOutputFormatMask & (1 << video_output_format))) { + if (decode_caps.nOutputFormatMask & (1 << cudaVideoSurfaceFormat_NV12)) { + video_output_format = cudaVideoSurfaceFormat_NV12; + } else if ( + decode_caps.nOutputFormatMask & (1 << cudaVideoSurfaceFormat_P016)) { + video_output_format = cudaVideoSurfaceFormat_P016; + } else if ( + decode_caps.nOutputFormatMask & (1 << cudaVideoSurfaceFormat_YUV444)) { + video_output_format = cudaVideoSurfaceFormat_YUV444; + } else if ( + decode_caps.nOutputFormatMask & + (1 << cudaVideoSurfaceFormat_YUV444_16Bit)) { + video_output_format = cudaVideoSurfaceFormat_YUV444_16Bit; + } else { + TORCH_CHECK(false, "No supported output format found"); + } + } +} + +/* Called before decoding frames and/or whenever there is a configuration + * change. + */ +int Decoder::handle_video_sequence(CUVIDEOFORMAT* video_format) { + // video_codec has been set in init(). Here it's set + // again for potential correction. + video_codec = video_format->codec; + video_chroma_format = video_format->chroma_format; + bit_depth_minus8 = video_format->bit_depth_luma_minus8; + bytes_per_pixel = bit_depth_minus8 > 0 ? 2 : 1; + // Set the output surface format same as chroma format + switch (video_chroma_format) { + case cudaVideoChromaFormat_Monochrome: + case cudaVideoChromaFormat_420: + video_output_format = video_format->bit_depth_luma_minus8 + ? cudaVideoSurfaceFormat_P016 + : cudaVideoSurfaceFormat_NV12; + break; + case cudaVideoChromaFormat_444: + video_output_format = video_format->bit_depth_luma_minus8 + ? cudaVideoSurfaceFormat_YUV444_16Bit + : cudaVideoSurfaceFormat_YUV444; + break; + case cudaVideoChromaFormat_422: + video_output_format = cudaVideoSurfaceFormat_NV12; + } + + query_hardware(video_format); + + if (width && luma_height && chroma_height) { + // cuvidCreateDecoder() has been called before and now there's possible + // config change. + return reconfigure_decoder(video_format); + } + + cu_video_format = *video_format; + unsigned long decode_surface = video_format->min_num_decode_surfaces; + cudaVideoDeinterlaceMode deinterlace_mode = cudaVideoDeinterlaceMode_Adaptive; + + if (video_format->progressive_sequence) { + deinterlace_mode = cudaVideoDeinterlaceMode_Weave; + } + + CUVIDDECODECREATEINFO video_decode_create_info = {}; + video_decode_create_info.ulWidth = video_format->coded_width; + video_decode_create_info.ulHeight = video_format->coded_height; + video_decode_create_info.ulNumDecodeSurfaces = decode_surface; + video_decode_create_info.CodecType = video_format->codec; + video_decode_create_info.ChromaFormat = video_format->chroma_format; + // With PreferCUVID, JPEG is still decoded by CUDA while video is decoded + // by NVDEC hardware + video_decode_create_info.ulCreationFlags = cudaVideoCreate_PreferCUVID; + video_decode_create_info.bitDepthMinus8 = video_format->bit_depth_luma_minus8; + video_decode_create_info.OutputFormat = video_output_format; + video_decode_create_info.DeinterlaceMode = deinterlace_mode; + video_decode_create_info.ulNumOutputSurfaces = 2; + video_decode_create_info.vidLock = ctx_lock; + + // AV1 has max width/height of sequence in sequence header + if (video_format->codec == cudaVideoCodec_AV1 && + video_format->seqhdr_data_length > 0) { + CUVIDEOFORMATEX* video_format_ex = (CUVIDEOFORMATEX*)video_format; + max_width = video_format_ex->av1.max_width; + max_height = video_format_ex->av1.max_height; + } + if (max_width < video_format->coded_width) { + max_width = video_format->coded_width; + } + if (max_height < video_format->coded_height) { + max_height = video_format->coded_height; + } + video_decode_create_info.ulMaxWidth = max_width; + video_decode_create_info.ulMaxHeight = max_height; + width = video_format->display_area.right - video_format->display_area.left; + luma_height = + video_format->display_area.bottom - video_format->display_area.top; + video_decode_create_info.ulTargetWidth = video_format->coded_width; + video_decode_create_info.ulTargetHeight = video_format->coded_height; + chroma_height = + (int)(ceil(luma_height * chroma_height_factor(video_output_format))); + num_chroma_planes = chroma_plane_count(video_output_format); + surface_height = video_decode_create_info.ulTargetHeight; + surface_width = video_decode_create_info.ulTargetWidth; + display_rect.bottom = video_decode_create_info.display_area.bottom; + display_rect.top = video_decode_create_info.display_area.top; + display_rect.left = video_decode_create_info.display_area.left; + display_rect.right = video_decode_create_info.display_area.right; + + check_for_cuda_errors(cuCtxPushCurrent(cu_context), __LINE__, __FILE__); + check_for_cuda_errors( + cuvidCreateDecoder(&decoder, &video_decode_create_info), + __LINE__, + __FILE__); + check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__, __FILE__); + return decode_surface; +} + +int Decoder::reconfigure_decoder(CUVIDEOFORMAT* video_format) { + if (video_format->bit_depth_luma_minus8 != + cu_video_format.bit_depth_luma_minus8 || + video_format->bit_depth_chroma_minus8 != + cu_video_format.bit_depth_chroma_minus8) { + TORCH_CHECK(false, "Reconfigure not supported for bit depth change"); + } + if (video_format->chroma_format != cu_video_format.chroma_format) { + TORCH_CHECK(false, "Reconfigure not supported for chroma format change"); + } + + bool decode_res_change = + !(video_format->coded_width == cu_video_format.coded_width && + video_format->coded_height == cu_video_format.coded_height); + bool display_rect_change = + !(video_format->display_area.bottom == + cu_video_format.display_area.bottom && + video_format->display_area.top == cu_video_format.display_area.top && + video_format->display_area.left == cu_video_format.display_area.left && + video_format->display_area.right == cu_video_format.display_area.right); + + unsigned int decode_surface = video_format->min_num_decode_surfaces; + + if ((video_format->coded_width > max_width) || + (video_format->coded_height > max_height)) { + // For VP9, let driver handle the change if new width/height > + // maxwidth/maxheight + if (video_codec != cudaVideoCodec_VP9) { + TORCH_CHECK( + false, + "Reconfigure not supported when width/height > maxwidth/maxheight"); + } + return 1; + } + + if (!decode_res_change) { + // If the coded_width/coded_height hasn't changed but display resolution has + // changed, then need to update width/height for correct output without + // cropping. Example : 1920x1080 vs 1920x1088. + if (display_rect_change) { + width = + video_format->display_area.right - video_format->display_area.left; + luma_height = + video_format->display_area.bottom - video_format->display_area.top; + chroma_height = + (int)ceil(luma_height * chroma_height_factor(video_output_format)); + num_chroma_planes = chroma_plane_count(video_output_format); + } + return 1; + } + cu_video_format.coded_width = video_format->coded_width; + cu_video_format.coded_height = video_format->coded_height; + CUVIDRECONFIGUREDECODERINFO reconfig_params = {}; + reconfig_params.ulWidth = video_format->coded_width; + reconfig_params.ulHeight = video_format->coded_height; + reconfig_params.ulTargetWidth = surface_width; + reconfig_params.ulTargetHeight = surface_height; + reconfig_params.ulNumDecodeSurfaces = decode_surface; + reconfig_params.display_area.bottom = display_rect.bottom; + reconfig_params.display_area.top = display_rect.top; + reconfig_params.display_area.left = display_rect.left; + reconfig_params.display_area.right = display_rect.right; + + check_for_cuda_errors(cuCtxPushCurrent(cu_context), __LINE__, __FILE__); + check_for_cuda_errors( + cuvidReconfigureDecoder(decoder, &reconfig_params), __LINE__, __FILE__); + check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__, __FILE__); + + return decode_surface; +} + +/* Called from AV1 sequence header to get operating point of an AV1 bitstream. + */ +int Decoder::get_operating_point(CUVIDOPERATINGPOINTINFO* oper_point_info) { + return oper_point_info->codec == cudaVideoCodec_AV1 && + oper_point_info->av1.operating_points_cnt > 1 + ? 0 + : -1; +} diff --git a/torchvision/csrc/io/decoder/gpu/decoder.h b/torchvision/csrc/io/decoder/gpu/decoder.h new file mode 100644 index 00000000000..c3064eb1663 --- /dev/null +++ b/torchvision/csrc/io/decoder/gpu/decoder.h @@ -0,0 +1,99 @@ +#include +#include +#include +#include +#include +#include +#include + +static auto check_for_cuda_errors = + [](CUresult result, int line_num, std::string file_name) { + if (CUDA_SUCCESS != result) { + const char* error_name = nullptr; + + TORCH_CHECK( + CUDA_SUCCESS != cuGetErrorName(result, &error_name), + "CUDA error: ", + error_name, + " in ", + file_name, + " at line ", + line_num) + TORCH_CHECK( + false, "Error: ", result, " in ", file_name, " at line ", line_num); + } + }; + +struct Rect { + int left, top, right, bottom; +}; + +class Decoder { + public: + Decoder() {} + ~Decoder(); + void init(CUcontext, cudaVideoCodec); + void release(); + void decode(const uint8_t*, unsigned long); + torch::Tensor fetch_frame(); + int get_frame_size() const { + return get_width() * (luma_height + (chroma_height * num_chroma_planes)) * + bytes_per_pixel; + } + int get_width() const { + return (video_output_format == cudaVideoSurfaceFormat_NV12 || + video_output_format == cudaVideoSurfaceFormat_P016) + ? (width + 1) & ~1 + : width; + } + int get_height() const { + return luma_height; + } + + private: + unsigned int width = 0, luma_height = 0, chroma_height = 0; + unsigned int surface_height = 0, surface_width = 0; + unsigned int max_width = 0, max_height = 0; + unsigned int num_chroma_planes = 0; + int bit_depth_minus8 = 0, bytes_per_pixel = 1; + int decode_pic_count = 0, pic_num_in_decode_order[32]; + std::queue decoded_frames; + CUcontext cu_context = NULL; + CUvideoctxlock ctx_lock; + CUvideoparser parser = NULL; + CUvideodecoder decoder = NULL; + CUstream cuvidStream = 0; + cudaVideoCodec video_codec = cudaVideoCodec_NumCodecs; + cudaVideoChromaFormat video_chroma_format = cudaVideoChromaFormat_420; + cudaVideoSurfaceFormat video_output_format = cudaVideoSurfaceFormat_NV12; + CUVIDEOFORMAT cu_video_format = {}; + Rect display_rect = {}; + + static int video_sequence_handler( + void* user_data, + CUVIDEOFORMAT* video_format) { + return ((Decoder*)user_data)->handle_video_sequence(video_format); + } + static int picture_decode_handler( + void* user_data, + CUVIDPICPARAMS* pic_params) { + return ((Decoder*)user_data)->handle_picture_decode(pic_params); + } + static int picture_display_handler( + void* user_data, + CUVIDPARSERDISPINFO* disp_info) { + return ((Decoder*)user_data)->handle_picture_display(disp_info); + } + static int operating_point_handler( + void* user_data, + CUVIDOPERATINGPOINTINFO* operating_info) { + return ((Decoder*)user_data)->get_operating_point(operating_info); + } + + void query_hardware(CUVIDEOFORMAT*); + int reconfigure_decoder(CUVIDEOFORMAT*); + int handle_video_sequence(CUVIDEOFORMAT*); + int handle_picture_decode(CUVIDPICPARAMS*); + int handle_picture_display(CUVIDPARSERDISPINFO*); + int get_operating_point(CUVIDOPERATINGPOINTINFO*); +}; diff --git a/torchvision/csrc/io/decoder/gpu/demuxer.h b/torchvision/csrc/io/decoder/gpu/demuxer.h new file mode 100644 index 00000000000..75d4765dd79 --- /dev/null +++ b/torchvision/csrc/io/decoder/gpu/demuxer.h @@ -0,0 +1,241 @@ +extern "C" { +#include +#include +#include +#include +} + +class Demuxer { + private: + AVFormatContext* fmtCtx = NULL; + AVBSFContext* bsfCtx = NULL; + AVPacket pkt, pktFiltered; + AVCodecID eVideoCodec; + uint8_t* dataWithHeader = NULL; + bool bMp4H264, bMp4HEVC, bMp4MPEG4; + unsigned int frameCount = 0; + int iVideoStream; + int64_t userTimeScale = 0; + double timeBase = 0.0; + + public: + Demuxer(const char* filePath, int64_t timeScale = 1000 /*Hz*/) { + avformat_network_init(); + TORCH_CHECK( + 0 <= avformat_open_input(&fmtCtx, filePath, NULL, NULL), + "avformat_open_input() failed at line ", + __LINE__, + " in demuxer.h\n"); + if (!fmtCtx) { + TORCH_CHECK( + false, + "Encountered NULL AVFormatContext at line ", + __LINE__, + " in demuxer.h\n"); + } + + TORCH_CHECK( + 0 <= avformat_find_stream_info(fmtCtx, NULL), + "avformat_find_stream_info() failed at line ", + __LINE__, + " in demuxer.h\n"); + iVideoStream = + av_find_best_stream(fmtCtx, AVMEDIA_TYPE_VIDEO, -1, -1, NULL, 0); + if (iVideoStream < 0) { + TORCH_CHECK( + false, + "av_find_best_stream() failed at line ", + __LINE__, + " in demuxer.h\n"); + } + + eVideoCodec = fmtCtx->streams[iVideoStream]->codecpar->codec_id; + AVRational rTimeBase = fmtCtx->streams[iVideoStream]->time_base; + timeBase = av_q2d(rTimeBase); + userTimeScale = timeScale; + + bMp4H264 = eVideoCodec == AV_CODEC_ID_H264 && + (!strcmp(fmtCtx->iformat->long_name, "QuickTime / MOV") || + !strcmp(fmtCtx->iformat->long_name, "FLV (Flash Video)") || + !strcmp(fmtCtx->iformat->long_name, "Matroska / WebM")); + bMp4HEVC = eVideoCodec == AV_CODEC_ID_HEVC && + (!strcmp(fmtCtx->iformat->long_name, "QuickTime / MOV") || + !strcmp(fmtCtx->iformat->long_name, "FLV (Flash Video)") || + !strcmp(fmtCtx->iformat->long_name, "Matroska / WebM")); + bMp4MPEG4 = eVideoCodec == AV_CODEC_ID_MPEG4 && + (!strcmp(fmtCtx->iformat->long_name, "QuickTime / MOV") || + !strcmp(fmtCtx->iformat->long_name, "FLV (Flash Video)") || + !strcmp(fmtCtx->iformat->long_name, "Matroska / WebM")); + + av_init_packet(&pkt); + pkt.data = NULL; + pkt.size = 0; + av_init_packet(&pktFiltered); + pktFiltered.data = NULL; + pktFiltered.size = 0; + + if (bMp4H264) { + const AVBitStreamFilter* bsf = av_bsf_get_by_name("h264_mp4toannexb"); + if (!bsf) { + TORCH_CHECK( + false, + "av_bsf_get_by_name() failed at line ", + __LINE__, + " in demuxer.h\n"); + } + TORCH_CHECK( + 0 <= av_bsf_alloc(bsf, &bsfCtx), + "av_bsf_alloc() failed at line ", + __LINE__, + " in demuxer.h\n"); + avcodec_parameters_copy( + bsfCtx->par_in, fmtCtx->streams[iVideoStream]->codecpar); + TORCH_CHECK( + 0 <= av_bsf_init(bsfCtx), + "av_bsf_init() failed at line ", + __LINE__, + " in demuxer.h\n"); + } + if (bMp4HEVC) { + const AVBitStreamFilter* bsf = av_bsf_get_by_name("hevc_mp4toannexb"); + if (!bsf) { + TORCH_CHECK( + false, + "av_bsf_get_by_name() failed at line ", + __LINE__, + " in demuxer.h\n"); + } + TORCH_CHECK( + 0 <= av_bsf_alloc(bsf, &bsfCtx), + "av_bsf_alloc() failed at line ", + __LINE__, + " in demuxer.h\n"); + avcodec_parameters_copy( + bsfCtx->par_in, fmtCtx->streams[iVideoStream]->codecpar); + TORCH_CHECK( + 0 <= av_bsf_init(bsfCtx), + "av_bsf_init() failed at line ", + __LINE__, + " in demuxer.h\n"); + } + } + ~Demuxer() { + if (!fmtCtx) { + return; + } + if (pkt.data) { + av_packet_unref(&pkt); + } + if (pktFiltered.data) { + av_packet_unref(&pktFiltered); + } + if (bsfCtx) { + av_bsf_free(&bsfCtx); + } + avformat_close_input(&fmtCtx); + if (dataWithHeader) { + av_free(dataWithHeader); + } + } + + AVCodecID get_video_codec() { + return eVideoCodec; + } + + bool demux(uint8_t** video, unsigned long* videoBytes) { + if (!fmtCtx) { + return false; + } + *videoBytes = 0; + + if (pkt.data) { + av_packet_unref(&pkt); + } + int e = 0; + while ((e = av_read_frame(fmtCtx, &pkt)) >= 0 && + pkt.stream_index != iVideoStream) { + av_packet_unref(&pkt); + } + if (e < 0) { + return false; + } + + if (bMp4H264 || bMp4HEVC) { + if (pktFiltered.data) { + av_packet_unref(&pktFiltered); + } + TORCH_CHECK( + 0 <= av_bsf_send_packet(bsfCtx, &pkt), + "av_bsf_send_packet() failed at line ", + __LINE__, + " in demuxer.h\n"); + TORCH_CHECK( + 0 <= av_bsf_receive_packet(bsfCtx, &pktFiltered), + "av_bsf_receive_packet() failed at line ", + __LINE__, + " in demuxer.h\n"); + *video = pktFiltered.data; + *videoBytes = pktFiltered.size; + } else { + if (bMp4MPEG4 && (frameCount == 0)) { + int extraDataSize = + fmtCtx->streams[iVideoStream]->codecpar->extradata_size; + + if (extraDataSize > 0) { + dataWithHeader = (uint8_t*)av_malloc( + extraDataSize + pkt.size - 3 * sizeof(uint8_t)); + if (!dataWithHeader) { + TORCH_CHECK( + false, + "av_malloc() failed at line ", + __LINE__, + " in demuxer.h\n"); + } + memcpy( + dataWithHeader, + fmtCtx->streams[iVideoStream]->codecpar->extradata, + extraDataSize); + memcpy( + dataWithHeader + extraDataSize, + pkt.data + 3, + pkt.size - 3 * sizeof(uint8_t)); + *video = dataWithHeader; + *videoBytes = extraDataSize + pkt.size - 3 * sizeof(uint8_t); + } + } else { + *video = pkt.data; + *videoBytes = pkt.size; + } + } + frameCount++; + return true; + } +}; + +inline cudaVideoCodec ffmpeg_to_codec(AVCodecID id) { + switch (id) { + case AV_CODEC_ID_MPEG1VIDEO: + return cudaVideoCodec_MPEG1; + case AV_CODEC_ID_MPEG2VIDEO: + return cudaVideoCodec_MPEG2; + case AV_CODEC_ID_MPEG4: + return cudaVideoCodec_MPEG4; + case AV_CODEC_ID_WMV3: + case AV_CODEC_ID_VC1: + return cudaVideoCodec_VC1; + case AV_CODEC_ID_H264: + return cudaVideoCodec_H264; + case AV_CODEC_ID_HEVC: + return cudaVideoCodec_HEVC; + case AV_CODEC_ID_VP8: + return cudaVideoCodec_VP8; + case AV_CODEC_ID_VP9: + return cudaVideoCodec_VP9; + case AV_CODEC_ID_MJPEG: + return cudaVideoCodec_JPEG; + case AV_CODEC_ID_AV1: + return cudaVideoCodec_AV1; + default: + return cudaVideoCodec_NumCodecs; + } +} diff --git a/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp b/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp new file mode 100644 index 00000000000..e6255aab5aa --- /dev/null +++ b/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp @@ -0,0 +1,85 @@ +#include "gpu_decoder.h" +#include + +/* Set cuda device, create cuda context and initialise the demuxer and decoder. + */ +GPUDecoder::GPUDecoder(std::string src_file, int64_t dev) + : demuxer(src_file.c_str()), device(dev) { + at::cuda::CUDAGuard device_guard(device); + check_for_cuda_errors( + cuDevicePrimaryCtxRetain(&ctx, device), __LINE__, __FILE__); + decoder.init(ctx, ffmpeg_to_codec(demuxer.get_video_codec())); + initialised = true; +} + +GPUDecoder::~GPUDecoder() { + at::cuda::CUDAGuard device_guard(device); + decoder.release(); + if (initialised) { + check_for_cuda_errors( + cuDevicePrimaryCtxRelease(device), __LINE__, __FILE__); + } +} + +/* Fetch a decoded frame tensor after demuxing and decoding. + */ +torch::Tensor GPUDecoder::decode() { + torch::Tensor frameTensor; + unsigned long videoBytes = 0; + uint8_t* video = nullptr; + at::cuda::CUDAGuard device_guard(device); + auto options = torch::TensorOptions().dtype(torch::kU8).device(torch::kCUDA); + torch::Tensor frame = torch::zeros({0}, options); + do { + demuxer.demux(&video, &videoBytes); + decoder.decode(video, videoBytes); + frame = decoder.fetch_frame(); + } while (frame.numel() == 0 && videoBytes > 0); + return frame; +} + +/* Convert a tensor with data in NV12 format to a tensor with data in YUV420 + * format in-place. + */ +torch::Tensor GPUDecoder::nv12_to_yuv420(torch::Tensor frameTensor) { + int width = decoder.get_width(), height = decoder.get_height(); + int pitch = width; + uint8_t* frame = frameTensor.data_ptr(); + uint8_t* ptr = new uint8_t[((width + 1) / 2) * ((height + 1) / 2)]; + + // sizes of source surface plane + int sizePlaneY = pitch * height; + int sizePlaneU = ((pitch + 1) / 2) * ((height + 1) / 2); + int sizePlaneV = sizePlaneU; + + uint8_t* uv = frame + sizePlaneY; + uint8_t* u = uv; + uint8_t* v = uv + sizePlaneU; + + // split chroma from interleave to planar + for (int y = 0; y < (height + 1) / 2; y++) { + for (int x = 0; x < (width + 1) / 2; x++) { + u[y * ((pitch + 1) / 2) + x] = uv[y * pitch + x * 2]; + ptr[y * ((width + 1) / 2) + x] = uv[y * pitch + x * 2 + 1]; + } + } + if (pitch == width) { + memcpy(v, ptr, sizePlaneV * sizeof(uint8_t)); + } else { + for (int i = 0; i < (height + 1) / 2; i++) { + memcpy( + v + ((pitch + 1) / 2) * i, + ptr + ((width + 1) / 2) * i, + ((width + 1) / 2) * sizeof(uint8_t)); + } + } + delete[] ptr; + return frameTensor; +} + +TORCH_LIBRARY(torchvision, m) { + m.class_("GPUDecoder") + .def(torch::init()) + .def("next", &GPUDecoder::decode) + .def("reformat", &GPUDecoder::nv12_to_yuv420); +} diff --git a/torchvision/csrc/io/decoder/gpu/gpu_decoder.h b/torchvision/csrc/io/decoder/gpu/gpu_decoder.h new file mode 100644 index 00000000000..02b14fda99e --- /dev/null +++ b/torchvision/csrc/io/decoder/gpu/gpu_decoder.h @@ -0,0 +1,19 @@ +#include +#include +#include "decoder.h" +#include "demuxer.h" + +class GPUDecoder : public torch::CustomClassHolder { + public: + GPUDecoder(std::string, int64_t); + ~GPUDecoder(); + torch::Tensor decode(); + torch::Tensor nv12_to_yuv420(torch::Tensor); + + private: + Demuxer demuxer; + CUcontext ctx; + Decoder decoder; + int64_t device; + bool initialised = false; +}; diff --git a/torchvision/io/__init__.py b/torchvision/io/__init__.py index f2ae6dff51e..410ec5bfc2c 100644 --- a/torchvision/io/__init__.py +++ b/torchvision/io/__init__.py @@ -6,6 +6,7 @@ from ._video_opt import ( Timebase, VideoMetaData, + _HAS_VIDEO_DECODER, _HAS_VIDEO_OPT, _probe_video_from_file, _probe_video_from_memory, @@ -104,10 +105,23 @@ class VideoReader: num_threads (int, optional): number of threads used by the codec to decode video. Default value (0) enables multithreading with codec-dependent heuristic. The performance will depend on the version of FFMPEG codecs supported. + + device (str, optional): Device to be used for decoding. Defaults to ``"cpu"``. + """ - def __init__(self, path: str, stream: str = "video", num_threads: int = 0) -> None: + def __init__(self, path: str, stream: str = "video", num_threads: int = 0, device: str = "cpu") -> None: _log_api_usage_once(self) + self.is_cuda = False + device = torch.device(device) + if device.type == "cuda": + if not _HAS_VIDEO_DECODER: + raise RuntimeError("Not compiled with GPU decoder support.") + self.is_cuda = True + if device.index is None: + raise RuntimeError("Invalid cuda device!") + self._c = torch.classes.torchvision.GPUDecoder(path, device.index) + return if not _has_video_opt(): raise RuntimeError( "Not compiled with video_reader support, " @@ -115,6 +129,7 @@ def __init__(self, path: str, stream: str = "video", num_threads: int = 0) -> No + "ffmpeg (version 4.2 is currently supported) and " + "build torchvision from source." ) + self._c = torch.classes.torchvision.Video(path, stream, num_threads) def __next__(self) -> Dict[str, Any]: @@ -129,6 +144,11 @@ def __next__(self) -> Dict[str, Any]: and corresponding timestamp (``pts``) in seconds """ + if self.is_cuda: + frame = self._c.next() + if frame.numel() == 0: + raise StopIteration + return {"data": frame} frame, pts = self._c.next() if frame.numel() == 0: raise StopIteration @@ -150,6 +170,8 @@ def seek(self, time_s: float, keyframes_only: bool = False) -> "VideoReader": frame with the exact timestamp if it exists or the first frame with timestamp larger than ``time_s``. """ + if self.is_cuda: + raise RuntimeError("seek() not yet supported with GPU decoding.") self._c.seek(time_s, keyframes_only) return self @@ -159,6 +181,8 @@ def get_metadata(self) -> Dict[str, Any]: Returns: (dict): dictionary containing duration and frame rate for every stream """ + if self.is_cuda: + raise RuntimeError("get_metadata() not yet supported with GPU decoding.") return self._c.get_metadata() def set_current_stream(self, stream: str) -> bool: @@ -178,8 +202,20 @@ def set_current_stream(self, stream: str) -> bool: Returns: (bool): True on succes, False otherwise """ + if self.is_cuda: + print("GPU decoding only works with video stream.") return self._c.set_current_stream(stream) + def _reformat(self, tensor, output_format: str = "yuv420"): + supported_formats = [ + "yuv420", + ] + if output_format not in supported_formats: + raise RuntimeError(f"{output_format} not supported, please use one of {', '.join(supported_formats)}") + if not isinstance(tensor, torch.Tensor): + raise RuntimeError("Expected tensor as input parameter!") + return self._c.reformat(tensor.cpu()) + __all__ = [ "write_video", @@ -192,6 +228,7 @@ def set_current_stream(self, stream: str) -> bool: "_read_video_timestamps_from_memory", "_probe_video_from_memory", "_HAS_VIDEO_OPT", + "_HAS_VIDEO_DECODER", "_read_video_clip_from_memory", "_read_video_meta_data", "VideoMetaData", diff --git a/torchvision/io/_video_opt.py b/torchvision/io/_video_opt.py index de4b25bb7b5..5ef975e3586 100644 --- a/torchvision/io/_video_opt.py +++ b/torchvision/io/_video_opt.py @@ -14,6 +14,12 @@ except (ImportError, OSError): _HAS_VIDEO_OPT = False +try: + _load_library("Decoder") + _HAS_VIDEO_DECODER = True +except (ImportError, OSError): + _HAS_VIDEO_DECODER = False + default_timebase = Fraction(0, 1)