diff --git a/test/test_io.py b/test/test_io.py index 66f91cae779..9bfc2aa403e 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -87,6 +87,22 @@ def test_write_read_video(self): self.assertTrue(data.equal(lv)) self.assertEqual(info["video_fps"], 5) + @unittest.skipIf(not io._HAS_VIDEO_OPT, "video_reader backend is not chosen") + def test_probe_video_from_file(self): + with temp_video(10, 300, 300, 5) as (f_name, data): + video_info = io._probe_video_from_file(f_name) + self.assertAlmostEqual(video_info["video_duration"], 2, delta=0.1) + self.assertAlmostEqual(video_info["video_fps"], 5, delta=0.1) + + @unittest.skipIf(not io._HAS_VIDEO_OPT, "video_reader backend is not chosen") + def test_probe_video_from_memory(self): + with temp_video(10, 300, 300, 5) as (f_name, data): + with open(f_name, "rb") as fp: + filebuffer = fp.read() + video_info = io._probe_video_from_memory(filebuffer) + self.assertAlmostEqual(video_info["video_duration"], 2, delta=0.1) + self.assertAlmostEqual(video_info["video_fps"], 5, delta=0.1) + def test_read_timestamps(self): with temp_video(10, 300, 300, 5) as (f_name, data): if _video_backend == "pyav": diff --git a/test/test_video_reader.py b/test/test_video_reader.py index b9ed3ebcf46..ffefe40840d 100644 --- a/test/test_video_reader.py +++ b/test/test_video_reader.py @@ -31,6 +31,7 @@ VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos") CheckerConfig = [ + "duration", "video_fps", "audio_sample_rate", # We find for some videos (e.g. HMDB51 videos), the decoded audio frames and pts are @@ -44,6 +45,7 @@ ) all_check_config = GroundTruth( + duration=0, video_fps=0, audio_sample_rate=0, check_aframes=True, @@ -52,36 +54,42 @@ test_videos = { "RATRACE_wave_f_nm_np1_fr_goo_37.avi": GroundTruth( + duration=2.0, video_fps=30.0, audio_sample_rate=None, check_aframes=True, check_aframe_pts=True, ), "SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi": GroundTruth( + duration=2.0, video_fps=30.0, audio_sample_rate=None, check_aframes=True, check_aframe_pts=True, ), "TrumanShow_wave_f_nm_np1_fr_med_26.avi": GroundTruth( + duration=2.0, video_fps=30.0, audio_sample_rate=None, check_aframes=True, check_aframe_pts=True, ), "v_SoccerJuggling_g23_c01.avi": GroundTruth( + duration=8.0, video_fps=29.97, audio_sample_rate=None, check_aframes=True, check_aframe_pts=True, ), "v_SoccerJuggling_g24_c01.avi": GroundTruth( + duration=8.0, video_fps=29.97, audio_sample_rate=None, check_aframes=True, check_aframe_pts=True, ), "R6llTwEh07w.mp4": GroundTruth( + duration=10.0, video_fps=30.0, audio_sample_rate=44100, # PyAv miss one audio frame at the beginning (pts=0) @@ -89,6 +97,7 @@ check_aframe_pts=False, ), "SOX5yA1l24A.mp4": GroundTruth( + duration=11.0, video_fps=29.97, audio_sample_rate=48000, # PyAv miss one audio frame at the beginning (pts=0) @@ -96,6 +105,7 @@ check_aframe_pts=False, ), "WUzgd7C1pWA.mp4": GroundTruth( + duration=11.0, video_fps=29.97, audio_sample_rate=48000, # PyAv miss one audio frame at the beginning (pts=0) @@ -272,13 +282,22 @@ class TestVideoReader(unittest.TestCase): def check_separate_decoding_result(self, tv_result, config): """check the decoding results from TorchVision decoder """ - vframes, vframe_pts, vtimebase, vfps, aframes, aframe_pts, atimebase, asample_rate = ( - tv_result + vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \ + atimebase, asample_rate, aduration = tv_result + + video_duration = vduration.item() * Fraction( + vtimebase[0].item(), vtimebase[1].item() ) + self.assertAlmostEqual(video_duration, config.duration, delta=0.5) self.assertAlmostEqual(vfps.item(), config.video_fps, delta=0.5) if asample_rate.numel() > 0: self.assertEqual(asample_rate.item(), config.audio_sample_rate) + audio_duration = aduration.item() * Fraction( + atimebase[0].item(), atimebase[1].item() + ) + self.assertAlmostEqual(audio_duration, config.duration, delta=0.5) + # check if pts of video frames are sorted in ascending order for i in range(len(vframe_pts) - 1): self.assertEqual(vframe_pts[i] < vframe_pts[i + 1], True) @@ -288,6 +307,20 @@ def check_separate_decoding_result(self, tv_result, config): for i in range(len(aframe_pts) - 1): self.assertEqual(aframe_pts[i] < aframe_pts[i + 1], True) + def check_probe_result(self, result, config): + vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result + video_duration = vduration.item() * Fraction( + vtimebase[0].item(), vtimebase[1].item() + ) + self.assertAlmostEqual(video_duration, config.duration, delta=0.5) + self.assertAlmostEqual(vfps.item(), config.video_fps, delta=0.5) + if asample_rate.numel() > 0: + self.assertEqual(asample_rate.item(), config.audio_sample_rate) + audio_duration = aduration.item() * Fraction( + atimebase[0].item(), atimebase[1].item() + ) + self.assertAlmostEqual(audio_duration, config.duration, delta=0.5) + def compare_decoding_result(self, tv_result, ref_result, config=all_check_config): """ Compare decoding results from two sources. @@ -297,18 +330,17 @@ def compare_decoding_result(self, tv_result, ref_result, config=all_check_config decoder or TorchVision decoder with getPtsOnly = 1 config: config of decoding results checker """ - vframes, vframe_pts, vtimebase, _vfps, aframes, aframe_pts, atimebase, _asample_rate = ( - tv_result - ) + vframes, vframe_pts, vtimebase, _vfps, _vduration, aframes, aframe_pts, \ + atimebase, _asample_rate, _aduration = tv_result if isinstance(ref_result, list): # the ref_result is from new video_reader decoder ref_result = DecoderResult( vframes=ref_result[0], vframe_pts=ref_result[1], vtimebase=ref_result[2], - aframes=ref_result[4], - aframe_pts=ref_result[5], - atimebase=ref_result[6], + aframes=ref_result[5], + aframe_pts=ref_result[6], + atimebase=ref_result[7], ) if vframes.numel() > 0 and ref_result.vframes.numel() > 0: @@ -351,12 +383,12 @@ def test_stress_test_read_video_from_file(self): audio_start_pts, audio_end_pts = 0, -1 audio_timebase_num, audio_timebase_den = 0, 1 - for i in range(num_iter): - for test_video, config in test_videos.items(): + for _i in range(num_iter): + for test_video, _config in test_videos.items(): full_path = os.path.join(VIDEO_DIR, test_video) # pass 1: decode all frames using new decoder - _ = torch.ops.video_reader.read_video_from_file( + torch.ops.video_reader.read_video_from_file( full_path, seek_frame_margin, 0, # getPtsOnly @@ -460,9 +492,8 @@ def test_read_video_from_file_read_single_stream_only(self): audio_timebase_den, ) - vframes, vframe_pts, vtimebase, vfps, aframes, aframe_pts, atimebase, asample_rate = ( - tv_result - ) + vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \ + atimebase, asample_rate, aduration = tv_result self.assertEqual(vframes.numel() > 0, readVideoStream) self.assertEqual(vframe_pts.numel() > 0, readVideoStream) @@ -489,7 +520,7 @@ def test_read_video_from_file_rescale_min_dimension(self): audio_start_pts, audio_end_pts = 0, -1 audio_timebase_num, audio_timebase_den = 0, 1 - for test_video, config in test_videos.items(): + for test_video, _config in test_videos.items(): full_path = os.path.join(VIDEO_DIR, test_video) tv_result = torch.ops.video_reader.read_video_from_file( @@ -528,7 +559,7 @@ def test_read_video_from_file_rescale_width(self): audio_start_pts, audio_end_pts = 0, -1 audio_timebase_num, audio_timebase_den = 0, 1 - for test_video, config in test_videos.items(): + for test_video, _config in test_videos.items(): full_path = os.path.join(VIDEO_DIR, test_video) tv_result = torch.ops.video_reader.read_video_from_file( @@ -567,7 +598,7 @@ def test_read_video_from_file_rescale_height(self): audio_start_pts, audio_end_pts = 0, -1 audio_timebase_num, audio_timebase_den = 0, 1 - for test_video, config in test_videos.items(): + for test_video, _config in test_videos.items(): full_path = os.path.join(VIDEO_DIR, test_video) tv_result = torch.ops.video_reader.read_video_from_file( @@ -606,7 +637,7 @@ def test_read_video_from_file_rescale_width_and_height(self): audio_start_pts, audio_end_pts = 0, -1 audio_timebase_num, audio_timebase_den = 0, 1 - for test_video, config in test_videos.items(): + for test_video, _config in test_videos.items(): full_path = os.path.join(VIDEO_DIR, test_video) tv_result = torch.ops.video_reader.read_video_from_file( @@ -651,7 +682,7 @@ def test_read_video_from_file_audio_resampling(self): audio_start_pts, audio_end_pts = 0, -1 audio_timebase_num, audio_timebase_den = 0, 1 - for test_video, config in test_videos.items(): + for test_video, _config in test_videos.items(): full_path = os.path.join(VIDEO_DIR, test_video) tv_result = torch.ops.video_reader.read_video_from_file( @@ -674,18 +705,17 @@ def test_read_video_from_file_audio_resampling(self): audio_timebase_num, audio_timebase_den, ) - vframes, vframe_pts, vtimebase, vfps, aframes, aframe_pts, atimebase, a_sample_rate = ( - tv_result - ) + vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \ + atimebase, asample_rate, aduration = tv_result if aframes.numel() > 0: - self.assertEqual(samples, a_sample_rate.item()) + self.assertEqual(samples, asample_rate.item()) self.assertEqual(1, aframes.size(1)) # when audio stream is found duration = float(aframe_pts[-1]) * float(atimebase[0]) / float(atimebase[1]) self.assertAlmostEqual( aframes.size(0), - int(duration * a_sample_rate.item()), - delta=0.1 * a_sample_rate.item(), + int(duration * asample_rate.item()), + delta=0.1 * asample_rate.item(), ) def test_compare_read_video_from_memory_and_file(self): @@ -859,7 +889,7 @@ def test_read_video_from_memory_get_pts_only(self): ) self.assertEqual(tv_result_pts_only[0].numel(), 0) - self.assertEqual(tv_result_pts_only[4].numel(), 0) + self.assertEqual(tv_result_pts_only[5].numel(), 0) self.compare_decoding_result(tv_result, tv_result_pts_only) def test_read_video_in_range_from_memory(self): @@ -899,9 +929,8 @@ def test_read_video_in_range_from_memory(self): audio_timebase_num, audio_timebase_den, ) - vframes, vframe_pts, vtimebase, vfps, aframes, aframe_pts, atimebase, asample_rate = ( - tv_result - ) + vframes, vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \ + atimebase, asample_rate, aduration = tv_result self.assertAlmostEqual(config.video_fps, vfps.item(), delta=0.01) for num_frames in [4, 8, 16, 32, 64, 128]: @@ -997,6 +1026,24 @@ def test_read_video_in_range_from_memory(self): # and PyAv self.compare_decoding_result(tv_result, pyav_result, config) + def test_probe_video_from_file(self): + """ + Test the case when decoder probes a video file + """ + for test_video, config in test_videos.items(): + full_path = os.path.join(VIDEO_DIR, test_video) + probe_result = torch.ops.video_reader.probe_video_from_file(full_path) + self.check_probe_result(probe_result, config) + + def test_probe_video_from_memory(self): + """ + Test the case when decoder probes a video in memory + """ + for test_video, config in test_videos.items(): + full_path, video_tensor = _get_video_tensor(VIDEO_DIR, test_video) + probe_result = torch.ops.video_reader.probe_video_from_memory(video_tensor) + self.check_probe_result(probe_result, config) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/csrc/cpu/video_reader/FfmpegAudioStream.cpp b/torchvision/csrc/cpu/video_reader/FfmpegAudioStream.cpp index 4862e0c2c7e..b5b1e2fbda5 100644 --- a/torchvision/csrc/cpu/video_reader/FfmpegAudioStream.cpp +++ b/torchvision/csrc/cpu/video_reader/FfmpegAudioStream.cpp @@ -49,6 +49,7 @@ void FfmpegAudioStream::updateStreamDecodeParams() { mediaFormat_.format.audio.timeBaseDen = inputCtx_->streams[index_]->time_base.den; } + mediaFormat_.format.audio.duration = inputCtx_->streams[index_]->duration; } int FfmpegAudioStream::initFormat() { diff --git a/torchvision/csrc/cpu/video_reader/FfmpegDecoder.cpp b/torchvision/csrc/cpu/video_reader/FfmpegDecoder.cpp index c2d7f809ec1..fb4d302cc03 100644 --- a/torchvision/csrc/cpu/video_reader/FfmpegDecoder.cpp +++ b/torchvision/csrc/cpu/video_reader/FfmpegDecoder.cpp @@ -220,6 +220,30 @@ int FfmpegDecoder::decodeMemory( return ret; } +int FfmpegDecoder::probeFile( + unique_ptr params, + const string& fileName, + DecoderOutput& decoderOutput) { + VLOG(1) << "probe file: " << fileName; + FfmpegAvioContext ioctx; + return probeVideo(std::move(params), fileName, true, ioctx, decoderOutput); +} + +int FfmpegDecoder::probeMemory( + unique_ptr params, + const uint8_t* buffer, + int64_t size, + DecoderOutput& decoderOutput) { + VLOG(1) << "probe video data in memory"; + FfmpegAvioContext ioctx; + int ret = ioctx.initAVIOContext(buffer, size); + if (ret == 0) { + ret = + probeVideo(std::move(params), string(""), false, ioctx, decoderOutput); + } + return ret; +} + void FfmpegDecoder::cleanUp() { if (formatCtx_) { for (auto& stream : streams_) { @@ -320,6 +344,16 @@ int FfmpegDecoder::decodeLoop( return ret; } +int FfmpegDecoder::probeVideo( + unique_ptr params, + const std::string& filename, + bool isDecodeFile, + FfmpegAvioContext& ioctx, + DecoderOutput& decoderOutput) { + params_ = std::move(params); + return init(filename, isDecodeFile, ioctx, decoderOutput); +} + bool FfmpegDecoder::initStreams() { for (auto it = params_->formats.begin(); it != params_->formats.end(); ++it) { AVMediaType mediaType; diff --git a/torchvision/csrc/cpu/video_reader/FfmpegDecoder.h b/torchvision/csrc/cpu/video_reader/FfmpegDecoder.h index 5dc6f0c47e8..a0a564a4214 100644 --- a/torchvision/csrc/cpu/video_reader/FfmpegDecoder.h +++ b/torchvision/csrc/cpu/video_reader/FfmpegDecoder.h @@ -75,6 +75,19 @@ class FfmpegDecoder { const uint8_t* buffer, int64_t size, DecoderOutput& decoderOutput); + // return 0 on success + // return negative number on failure + int probeFile( + std::unique_ptr params, + const std::string& filename, + DecoderOutput& decoderOutput); + // return 0 on success + // return negative number on failure + int probeMemory( + std::unique_ptr params, + const uint8_t* buffer, + int64_t size, + DecoderOutput& decoderOutput); void cleanUp(); @@ -95,6 +108,13 @@ class FfmpegDecoder { FfmpegAvioContext& ioctx, DecoderOutput& decoderOutput); + int probeVideo( + std::unique_ptr params, + const std::string& filename, + bool isDecodeFile, + FfmpegAvioContext& ioctx, + DecoderOutput& decoderOutput); + bool initStreams(); void flushStreams(DecoderOutput& decoderOutput); diff --git a/torchvision/csrc/cpu/video_reader/FfmpegVideoStream.cpp b/torchvision/csrc/cpu/video_reader/FfmpegVideoStream.cpp index b0b11a683db..7a429249a71 100644 --- a/torchvision/csrc/cpu/video_reader/FfmpegVideoStream.cpp +++ b/torchvision/csrc/cpu/video_reader/FfmpegVideoStream.cpp @@ -48,6 +48,7 @@ void FfmpegVideoStream::updateStreamDecodeParams() { mediaFormat_.format.video.timeBaseDen = inputCtx_->streams[index_]->time_base.den; } + mediaFormat_.format.video.duration = inputCtx_->streams[index_]->duration; } int FfmpegVideoStream::initFormat() { diff --git a/torchvision/csrc/cpu/video_reader/Interface.h b/torchvision/csrc/cpu/video_reader/Interface.h index 98c1dfee517..e137008ce7b 100644 --- a/torchvision/csrc/cpu/video_reader/Interface.h +++ b/torchvision/csrc/cpu/video_reader/Interface.h @@ -48,6 +48,7 @@ struct VideoFormat { int timeBaseNum{0}; int timeBaseDen{1}; // numerator and denominator of time base float fps{0.0}; + int64_t duration{0}; // duration of the stream, in stream time base }; struct AudioFormat { @@ -60,6 +61,7 @@ struct AudioFormat { int64_t startPts{0}, endPts{0}; // Start and end presentation timestamp int timeBaseNum{0}; int timeBaseDen{1}; // numerator and denominator of time base + int64_t duration{0}; // duration of the stream, in stream time base }; union FormatUnion { diff --git a/torchvision/csrc/cpu/video_reader/VideoReader.cpp b/torchvision/csrc/cpu/video_reader/VideoReader.cpp index a300340ccba..dfe7f46bf39 100644 --- a/torchvision/csrc/cpu/video_reader/VideoReader.cpp +++ b/torchvision/csrc/cpu/video_reader/VideoReader.cpp @@ -27,8 +27,6 @@ PyMODINIT_FUNC PyInit_video_reader(void) { namespace video_reader { -bool glog_initialized = false; - class UnknownPixelFormatException : public exception { const char* what() const throw() override { return "Unknown pixel format"; @@ -167,11 +165,6 @@ torch::List readVideo( int64_t audioEndPts, int64_t audioTimeBaseNum, int64_t audioTimeBaseDen) { - if (!glog_initialized) { - glog_initialized = true; - // google::InitGoogleLogging("VideoReader"); - } - unique_ptr params = util::getDecoderParams( seekFrameMargin, getPtsOnly, @@ -209,6 +202,8 @@ torch::List readVideo( torch::Tensor videoFramePts = torch::zeros({0}, torch::kLong); torch::Tensor videoTimeBase = torch::zeros({0}, torch::kInt); torch::Tensor videoFps = torch::zeros({0}, torch::kFloat); + torch::Tensor videoDuration = torch::zeros({0}, torch::kLong); + if (readVideoStream == 1) { auto it = decoderOutput.media_data_.find(TYPE_VIDEO); if (it != decoderOutput.media_data_.end()) { @@ -236,6 +231,10 @@ torch::List readVideo( videoFps = torch::zeros({1}, torch::kFloat); float* videoFpsData = videoFps.data_ptr(); videoFpsData[0] = it->second.format_.video.fps; + + videoDuration = torch::zeros({1}, torch::kLong); + int64_t* videoDurationData = videoDuration.data_ptr(); + videoDurationData[0] = it->second.format_.video.duration; } else { VLOG(1) << "Miss video stream"; } @@ -246,6 +245,7 @@ torch::List readVideo( torch::Tensor audioFramePts = torch::zeros({0}, torch::kLong); torch::Tensor audioTimeBase = torch::zeros({0}, torch::kInt); torch::Tensor audioSampleRate = torch::zeros({0}, torch::kInt); + torch::Tensor audioDuration = torch::zeros({0}, torch::kLong); if (readAudioStream == 1) { auto it = decoderOutput.media_data_.find(TYPE_AUDIO); if (it != decoderOutput.media_data_.end()) { @@ -275,6 +275,10 @@ torch::List readVideo( audioSampleRate = torch::zeros({1}, torch::kInt); int* audioSampleRateData = audioSampleRate.data_ptr(); audioSampleRateData[0] = it->second.format_.audio.samples; + + audioDuration = torch::zeros({1}, torch::kLong); + int64_t* audioDurationData = audioDuration.data_ptr(); + audioDurationData[0] = it->second.format_.audio.duration; } else { VLOG(1) << "Miss audio stream"; } @@ -285,10 +289,12 @@ torch::List readVideo( result.push_back(std::move(videoFramePts)); result.push_back(std::move(videoTimeBase)); result.push_back(std::move(videoFps)); + result.push_back(std::move(videoDuration)); result.push_back(std::move(audioFrame)); result.push_back(std::move(audioFramePts)); result.push_back(std::move(audioTimeBase)); result.push_back(std::move(audioSampleRate)); + result.push_back(std::move(audioDuration)); return result; } @@ -378,10 +384,117 @@ torch::List readVideoFromFile( audioTimeBaseDen); } +torch::List probeVideo( + bool isReadFile, + const torch::Tensor& input_video, + std::string videoPath) { + unique_ptr params = util::getDecoderParams( + 0, // seekFrameMargin + 0, // getPtsOnly + 1, // readVideoStream + 0, // width + 0, // height + 0, // minDimension + 0, // videoStartPts + 0, // videoEndPts + 0, // videoTimeBaseNum + 1, // videoTimeBaseDen + 1, // readAudioStream + 0, // audioSamples + 0, // audioChannels + 0, // audioStartPts + 0, // audioEndPts + 0, // audioTimeBaseNum + 1 // audioTimeBaseDen + ); + + FfmpegDecoder decoder; + DecoderOutput decoderOutput; + if (isReadFile) { + decoder.probeFile(std::move(params), videoPath, decoderOutput); + } else { + decoder.probeMemory( + std::move(params), + input_video.data_ptr(), + input_video.size(0), + decoderOutput); + } + // video section + torch::Tensor videoTimeBase = torch::zeros({0}, torch::kInt); + torch::Tensor videoFps = torch::zeros({0}, torch::kFloat); + torch::Tensor videoDuration = torch::zeros({0}, torch::kLong); + + auto it = decoderOutput.media_data_.find(TYPE_VIDEO); + if (it != decoderOutput.media_data_.end()) { + VLOG(1) << "Find video stream"; + videoTimeBase = torch::zeros({2}, torch::kInt); + int* videoTimeBaseData = videoTimeBase.data_ptr(); + videoTimeBaseData[0] = it->second.format_.video.timeBaseNum; + videoTimeBaseData[1] = it->second.format_.video.timeBaseDen; + + videoFps = torch::zeros({1}, torch::kFloat); + float* videoFpsData = videoFps.data_ptr(); + videoFpsData[0] = it->second.format_.video.fps; + + videoDuration = torch::zeros({1}, torch::kLong); + int64_t* videoDurationData = videoDuration.data_ptr(); + videoDurationData[0] = it->second.format_.video.duration; + } else { + VLOG(1) << "Miss video stream"; + } + + // audio section + torch::Tensor audioTimeBase = torch::zeros({0}, torch::kInt); + torch::Tensor audioSampleRate = torch::zeros({0}, torch::kInt); + torch::Tensor audioDuration = torch::zeros({0}, torch::kLong); + + it = decoderOutput.media_data_.find(TYPE_AUDIO); + if (it != decoderOutput.media_data_.end()) { + VLOG(1) << "Find audio stream"; + audioTimeBase = torch::zeros({2}, torch::kInt); + int* audioTimeBaseData = audioTimeBase.data_ptr(); + audioTimeBaseData[0] = it->second.format_.audio.timeBaseNum; + audioTimeBaseData[1] = it->second.format_.audio.timeBaseDen; + + audioSampleRate = torch::zeros({1}, torch::kInt); + int* audioSampleRateData = audioSampleRate.data_ptr(); + audioSampleRateData[0] = it->second.format_.audio.samples; + + audioDuration = torch::zeros({1}, torch::kLong); + int64_t* audioDurationData = audioDuration.data_ptr(); + audioDurationData[0] = it->second.format_.audio.duration; + } else { + VLOG(1) << "Miss audio stream"; + } + + torch::List result; + result.push_back(std::move(videoTimeBase)); + result.push_back(std::move(videoFps)); + result.push_back(std::move(videoDuration)); + result.push_back(std::move(audioTimeBase)); + result.push_back(std::move(audioSampleRate)); + result.push_back(std::move(audioDuration)); + + return result; +} + +torch::List probeVideoFromMemory(torch::Tensor input_video) { + return probeVideo(false, input_video, ""); +} + +torch::List probeVideoFromFile(std::string videoPath) { + torch::Tensor dummy_input_video = torch::ones({0}); + return probeVideo(true, dummy_input_video, videoPath); +} + } // namespace video_reader static auto registry = torch::RegisterOperators() .op("video_reader::read_video_from_memory", &video_reader::readVideoFromMemory) .op("video_reader::read_video_from_file", - &video_reader::readVideoFromFile); + &video_reader::readVideoFromFile) + .op("video_reader::probe_video_from_memory", + &video_reader::probeVideoFromMemory) + .op("video_reader::probe_video_from_file", + &video_reader::probeVideoFromFile); diff --git a/torchvision/io/__init__.py b/torchvision/io/__init__.py index 13958fb9ab4..768befde412 100644 --- a/torchvision/io/__init__.py +++ b/torchvision/io/__init__.py @@ -2,15 +2,17 @@ from ._video_opt import ( _read_video_from_file, _read_video_timestamps_from_file, + _probe_video_from_file, _read_video_from_memory, _read_video_timestamps_from_memory, + _probe_video_from_memory, _HAS_VIDEO_OPT, ) __all__ = [ 'write_video', 'read_video', 'read_video_timestamps', - '_read_video_from_file', '_read_video_timestamps_from_file', - '_read_video_from_memory', '_read_video_timestamps_from_memory', + '_read_video_from_file', '_read_video_timestamps_from_file', '_probe_video_from_file', + '_read_video_from_memory', '_read_video_timestamps_from_memory', '_probe_video_from_memory', '_HAS_VIDEO_OPT', ] diff --git a/torchvision/io/_video_opt.py b/torchvision/io/_video_opt.py index f3edab1a957..5971f23c9c0 100644 --- a/torchvision/io/_video_opt.py +++ b/torchvision/io/_video_opt.py @@ -26,14 +26,20 @@ def _validate_pts(pts_range): start pts: %d and end pts: %d""" % (pts_range[0], pts_range[1]) -def _fill_info(vtimebase, vfps, atimebase, asample_rate): +def _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration): info = {} if vtimebase.numel() > 0: info["video_timebase"] = Fraction(vtimebase[0].item(), vtimebase[1].item()) + if vduration.numel() > 0: + video_duration = vduration.item() * info["video_timebase"] + info["video_duration"] = video_duration if vfps.numel() > 0: info["video_fps"] = vfps.item() if atimebase.numel() > 0: info["audio_timebase"] = Fraction(atimebase[0].item(), atimebase[1].item()) + if aduration.numel() > 0: + audio_duration = aduration.item() * info["audio_timebase"] + info["audio_duration"] = audio_duration if asample_rate.numel() > 0: info["audio_sample_rate"] = asample_rate.item() @@ -141,8 +147,9 @@ def _read_video_from_file( audio_timebase.numerator, audio_timebase.denominator, ) - vframes, _vframe_pts, vtimebase, vfps, aframes, aframe_pts, atimebase, asample_rate = result - info = _fill_info(vtimebase, vfps, atimebase, asample_rate) + vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, \ + asample_rate, aduration = result + info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration) if aframes.numel() > 0: # when audio stream is found aframes = _align_audio_frames(aframes, aframe_pts, audio_pts_range) @@ -175,16 +182,30 @@ def _read_video_timestamps_from_file(filename): 0, # audio_timebase_num 1, # audio_timebase_den ) - _vframes, vframe_pts, vtimebase, vfps, _aframes, aframe_pts, atimebase, asample_rate = result - info = _fill_info(vtimebase, vfps, atimebase, asample_rate) + _vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, atimebase, \ + asample_rate, aduration = result + info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration) vframe_pts = vframe_pts.numpy().tolist() aframe_pts = aframe_pts.numpy().tolist() return vframe_pts, aframe_pts, info +def _probe_video_from_file(filename): + """ + Probe a video file. + Return: + info [dict]: contain video meta information, including video_timebase, + video_duration, video_fps, audio_timebase, audio_duration, audio_sample_rate + """ + result = torch.ops.video_reader.probe_video_from_file(filename) + vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result + info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration) + return info + + def _read_video_from_memory( - file_buffer, + video_data, seek_frame_margin=0.25, read_video_stream=1, video_width=0, @@ -204,8 +225,8 @@ def _read_video_from_memory( Args ---------- - file_buffer : buffer - buffer of compressed video content + video_data : data type could be 1) torch.Tensor, dtype=torch.int8 or 2) python bytes + compressed video content stored in either 1) torch.Tensor 2) python bytes seek_frame_margin: double, optional seeking frame in the stream is imprecise. Thus, when video_start_pts is specified, we seek the pts earlier by seek_frame_margin seconds @@ -252,10 +273,11 @@ def _read_video_from_memory( _validate_pts(video_pts_range) _validate_pts(audio_pts_range) - video_tensor = torch.from_numpy(np.frombuffer(file_buffer, dtype=np.uint8)) + if not isinstance(video_data, torch.Tensor): + video_data = torch.from_numpy(np.frombuffer(video_data, dtype=np.uint8)) result = torch.ops.video_reader.read_video_from_memory( - video_tensor, + video_data, seek_frame_margin, 0, # getPtsOnly read_video_stream, @@ -275,24 +297,25 @@ def _read_video_from_memory( audio_timebase.denominator, ) - vframes, _vframe_pts, vtimebase, vfps, aframes, aframe_pts, atimebase, asample_rate = result - info = _fill_info(vtimebase, vfps, atimebase, asample_rate) + vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, \ + atimebase, asample_rate, aduration = result + info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration) if aframes.numel() > 0: # when audio stream is found aframes = _align_audio_frames(aframes, aframe_pts, audio_pts_range) return vframes, aframes, info -def _read_video_timestamps_from_memory(file_buffer): +def _read_video_timestamps_from_memory(video_data): """ Decode all frames in the video. Only pts (presentation timestamp) is returned. The actual frame pixel data is not copied. Thus, read_video_timestamps(...) is much faster than read_video(...) """ - - video_tensor = torch.from_numpy(np.frombuffer(file_buffer, dtype=np.uint8)) + if not isinstance(video_data, torch.Tensor): + video_data = torch.from_numpy(np.frombuffer(video_data, dtype=np.uint8)) result = torch.ops.video_reader.read_video_from_memory( - video_tensor, + video_data, 0, # seek_frame_margin 1, # getPtsOnly 1, # read_video_stream @@ -311,9 +334,25 @@ def _read_video_timestamps_from_memory(file_buffer): 0, # audio_timebase_num 1, # audio_timebase_den ) - _vframes, vframe_pts, vtimebase, vfps, _aframes, aframe_pts, atimebase, asample_rate = result - info = _fill_info(vtimebase, vfps, atimebase, asample_rate) + _vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, \ + atimebase, asample_rate, aduration = result + info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration) vframe_pts = vframe_pts.numpy().tolist() aframe_pts = aframe_pts.numpy().tolist() return vframe_pts, aframe_pts, info + + +def _probe_video_from_memory(video_data): + """ + Probe a video in memory. + Return: + info [dict]: contain video meta information, including video_timebase, + video_duration, video_fps, audio_timebase, audio_duration, audio_sample_rate + """ + if not isinstance(video_data, torch.Tensor): + video_data = torch.from_numpy(np.frombuffer(video_data, dtype=np.uint8)) + result = torch.ops.video_reader.probe_video_from_memory(video_data) + vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result + info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration) + return info