diff --git a/test/torchaudio_unittest/assets/RATRACE_wave_f_nm_np1_fr_goo_37.avi b/test/torchaudio_unittest/assets/RATRACE_wave_f_nm_np1_fr_goo_37.avi new file mode 100644 index 0000000000..6cccfb416f Binary files /dev/null and b/test/torchaudio_unittest/assets/RATRACE_wave_f_nm_np1_fr_goo_37.avi differ diff --git a/test/torchaudio_unittest/assets/README.md b/test/torchaudio_unittest/assets/README.md new file mode 100644 index 0000000000..34137615cc --- /dev/null +++ b/test/torchaudio_unittest/assets/README.md @@ -0,0 +1,5 @@ +* RATRACE_wave_f_nm_np1_fr_goo_37.avi + * Source: HMDB-51 dataset ("wave" subset) + https://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/#Downloads + * License: Creative Commons Attribution 4.0 International License. + * Note: This file does not have proper PTS values thus useful for testing seek for such files. diff --git a/test/torchaudio_unittest/io/stream_reader_test.py b/test/torchaudio_unittest/io/stream_reader_test.py index 1327250784..239e9550c2 100644 --- a/test/torchaudio_unittest/io/stream_reader_test.py +++ b/test/torchaudio_unittest/io/stream_reader_test.py @@ -446,35 +446,39 @@ def test_seek_invalid_mode(self): # Test keyframe seek # The source mp4 video has two key frames the first frame and 203rd frame at 8.08 second. # If the seek time stamp is smaller than 8.08, it will seek into the first frame at 0.0 second. - ("nasa_13013.mp4", "key", 0.2, (0, 0)), - ("nasa_13013.mp4", "key", 8.04, (0, 0)), - ("nasa_13013.mp4", "key", 8.08, (0, 202)), - ("nasa_13013.mp4", "key", 8.12, (0, 202)), + ("nasa_13013.mp4", "key", 0.2, (0, slice(None))), + ("nasa_13013.mp4", "key", 8.04, (0, slice(None))), + ("nasa_13013.mp4", "key", 8.08, (0, slice(202, None))), + ("nasa_13013.mp4", "key", 8.12, (0, slice(202, None))), # The source avi video has one keyframe every twelve frames 0, 12, 24,.. or every 0.4004 seconds. # if we seek to a time stamp smaller than 0.4004 it will seek into the first frame at 0.0 second. - ("nasa_13013.avi", "key", 0.2, (0, 0)), - ("nasa_13013.avi", "key", 1.01, (0, 24)), - ("nasa_13013.avi", "key", 7.37, (0, 216)), - ("nasa_13013.avi", "key", 7.7, (0, 216)), + ("nasa_13013.avi", "key", 0.2, (0, slice(None))), + ("nasa_13013.avi", "key", 1.01, (0, slice(24, None))), + ("nasa_13013.avi", "key", 7.37, (0, slice(216, None))), + ("nasa_13013.avi", "key", 7.7, (0, slice(216, None))), # Test precise seek - ("nasa_13013.mp4", "precise", 0.0, (0, 0)), - ("nasa_13013.mp4", "precise", 0.2, (0, 5)), - ("nasa_13013.mp4", "precise", 8.04, (0, 201)), - ("nasa_13013.mp4", "precise", 8.08, (0, 202)), - ("nasa_13013.mp4", "precise", 8.12, (0, 203)), - ("nasa_13013.avi", "precise", 0.0, (0, 0)), - ("nasa_13013.avi", "precise", 0.2, (0, 1)), - ("nasa_13013.avi", "precise", 8.1, (0, 238)), - ("nasa_13013.avi", "precise", 8.14, (0, 239)), - ("nasa_13013.avi", "precise", 8.17, (0, 240)), + ("nasa_13013.mp4", "precise", 0.0, (0, slice(None))), + ("nasa_13013.mp4", "precise", 0.2, (0, slice(5, None))), + ("nasa_13013.mp4", "precise", 8.04, (0, slice(201, None))), + ("nasa_13013.mp4", "precise", 8.08, (0, slice(202, None))), + ("nasa_13013.mp4", "precise", 8.12, (0, slice(203, None))), + ("nasa_13013.avi", "precise", 0.0, (0, slice(None))), + ("nasa_13013.avi", "precise", 0.2, (0, slice(1, None))), + ("nasa_13013.avi", "precise", 8.1, (0, slice(238, None))), + ("nasa_13013.avi", "precise", 8.14, (0, slice(239, None))), + ("nasa_13013.avi", "precise", 8.17, (0, slice(240, None))), + # Test precise seek on video with invalid PTS + ("RATRACE_wave_f_nm_np1_fr_goo_37.avi", "precise", 0.0, (0, slice(None))), + ("RATRACE_wave_f_nm_np1_fr_goo_37.avi", "precise", 0.2, (0, slice(4, -1))), + ("RATRACE_wave_f_nm_np1_fr_goo_37.avi", "precise", 0.3, (0, slice(7, -1))), # Test any seek # The source avi video has one keyframe every twelve frames 0, 12, 24,.. or every 0.4004 seconds. - ("nasa_13013.avi", "any", 0.0, (0, 0)), - ("nasa_13013.avi", "any", 0.56, (0, 12)), - ("nasa_13013.avi", "any", 7.77, (0, 228)), - ("nasa_13013.avi", "any", 0.2002, (11, 12)), - ("nasa_13013.avi", "any", 0.233567, (10, 12)), - ("nasa_13013.avi", "any", 0.266933, (9, 12)), + ("nasa_13013.avi", "any", 0.0, (0, slice(None))), + ("nasa_13013.avi", "any", 0.56, (0, slice(12, None))), + ("nasa_13013.avi", "any", 7.77, (0, slice(228, None))), + ("nasa_13013.avi", "any", 0.2002, (11, slice(12, None))), + ("nasa_13013.avi", "any", 0.233567, (10, slice(12, None))), + ("nasa_13013.avi", "any", 0.266933, (9, slice(12, None))), ] ) def test_seek_modes(self, src, mode, seek_time, ref_indices): @@ -506,7 +510,9 @@ def test_seek_modes(self, src, mode, seek_time, ref_indices): hyp_index, ref_index = ref_indices - self.assertEqual(frame[hyp_index:], ref_frames[ref_index:]) + hyp, ref = frame[hyp_index:], ref_frames[ref_index] + print(hyp.shape, ref.shape) + self.assertEqual(hyp, ref) def _to_fltp(original): diff --git a/torchaudio/csrc/ffmpeg/stream_reader/stream_processor.cpp b/torchaudio/csrc/ffmpeg/stream_reader/stream_processor.cpp index f89cbb53a8..3458f5f5ad 100644 --- a/torchaudio/csrc/ffmpeg/stream_reader/stream_processor.cpp +++ b/torchaudio/csrc/ffmpeg/stream_reader/stream_processor.cpp @@ -102,7 +102,8 @@ int StreamProcessor::process_packet(AVPacket* packet) { // and just not discard any. // // Note: discard_before_pts < 0 is UB. - if (discard_before_pts <= 0 || pFrame1->pts >= discard_before_pts) { + if (discard_before_pts <= 0 || pFrame1->pts >= discard_before_pts || + pFrame1->best_effort_timestamp >= discard_before_pts) { send_frame(pFrame1); }