3131VIDEO_DIR  =  os .path .join (os .path .dirname (os .path .abspath (__file__ )), "assets" , "videos" )
3232
3333CheckerConfig  =  [
34+     "duration" ,
3435    "video_fps" ,
3536    "audio_sample_rate" ,
3637    # We find for some videos (e.g. HMDB51 videos), the decoded audio frames and pts are 
4445)
4546
4647all_check_config  =  GroundTruth (
48+     duration = 0 ,
4749    video_fps = 0 ,
4850    audio_sample_rate = 0 ,
4951    check_aframes = True ,
5254
5355test_videos  =  {
5456    "RATRACE_wave_f_nm_np1_fr_goo_37.avi" : GroundTruth (
57+         duration = 2.0 ,
5558        video_fps = 30.0 ,
5659        audio_sample_rate = None ,
5760        check_aframes = True ,
5861        check_aframe_pts = True ,
5962    ),
6063    "SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi" : GroundTruth (
64+         duration = 2.0 ,
6165        video_fps = 30.0 ,
6266        audio_sample_rate = None ,
6367        check_aframes = True ,
6468        check_aframe_pts = True ,
6569    ),
6670    "TrumanShow_wave_f_nm_np1_fr_med_26.avi" : GroundTruth (
71+         duration = 2.0 ,
6772        video_fps = 30.0 ,
6873        audio_sample_rate = None ,
6974        check_aframes = True ,
7075        check_aframe_pts = True ,
7176    ),
7277    "v_SoccerJuggling_g23_c01.avi" : GroundTruth (
78+         duration = 8.0 ,
7379        video_fps = 29.97 ,
7480        audio_sample_rate = None ,
7581        check_aframes = True ,
7682        check_aframe_pts = True ,
7783    ),
7884    "v_SoccerJuggling_g24_c01.avi" : GroundTruth (
85+         duration = 8.0 ,
7986        video_fps = 29.97 ,
8087        audio_sample_rate = None ,
8188        check_aframes = True ,
8289        check_aframe_pts = True ,
8390    ),
8491    "R6llTwEh07w.mp4" : GroundTruth (
92+         duration = 10.0 ,
8593        video_fps = 30.0 ,
8694        audio_sample_rate = 44100 ,
8795        # PyAv miss one audio frame at the beginning (pts=0) 
8896        check_aframes = False ,
8997        check_aframe_pts = False ,
9098    ),
9199    "SOX5yA1l24A.mp4" : GroundTruth (
100+         duration = 11.0 ,
92101        video_fps = 29.97 ,
93102        audio_sample_rate = 48000 ,
94103        # PyAv miss one audio frame at the beginning (pts=0) 
95104        check_aframes = False ,
96105        check_aframe_pts = False ,
97106    ),
98107    "WUzgd7C1pWA.mp4" : GroundTruth (
108+         duration = 11.0 ,
99109        video_fps = 29.97 ,
100110        audio_sample_rate = 48000 ,
101111        # PyAv miss one audio frame at the beginning (pts=0) 
@@ -272,13 +282,22 @@ class TestVideoReader(unittest.TestCase):
272282    def  check_separate_decoding_result (self , tv_result , config ):
273283        """check the decoding results from TorchVision decoder 
274284        """ 
275-         vframes , vframe_pts , vtimebase , vfps , aframes , aframe_pts , atimebase , asample_rate  =  (
276-             tv_result 
285+         vframes , vframe_pts , vtimebase , vfps , vduration , aframes , aframe_pts , \
286+             atimebase , asample_rate , aduration  =  tv_result 
287+ 
288+         video_duration  =  vduration .item () *  Fraction (
289+             vtimebase [0 ].item (), vtimebase [1 ].item ()
277290        )
291+         self .assertAlmostEqual (video_duration , config .duration , delta = 0.5 )
278292
279293        self .assertAlmostEqual (vfps .item (), config .video_fps , delta = 0.5 )
280294        if  asample_rate .numel () >  0 :
281295            self .assertEqual (asample_rate .item (), config .audio_sample_rate )
296+             audio_duration  =  aduration .item () *  Fraction (
297+                 atimebase [0 ].item (), atimebase [1 ].item ()
298+             )
299+             self .assertAlmostEqual (audio_duration , config .duration , delta = 0.5 )
300+ 
282301        # check if pts of video frames are sorted in ascending order 
283302        for  i  in  range (len (vframe_pts ) -  1 ):
284303            self .assertEqual (vframe_pts [i ] <  vframe_pts [i  +  1 ], True )
@@ -288,6 +307,20 @@ def check_separate_decoding_result(self, tv_result, config):
288307            for  i  in  range (len (aframe_pts ) -  1 ):
289308                self .assertEqual (aframe_pts [i ] <  aframe_pts [i  +  1 ], True )
290309
310+     def  check_probe_result (self , result , config ):
311+         vtimebase , vfps , vduration , atimebase , asample_rate , aduration  =  result 
312+         video_duration  =  vduration .item () *  Fraction (
313+             vtimebase [0 ].item (), vtimebase [1 ].item ()
314+         )
315+         self .assertAlmostEqual (video_duration , config .duration , delta = 0.5 )
316+         self .assertAlmostEqual (vfps .item (), config .video_fps , delta = 0.5 )
317+         if  asample_rate .numel () >  0 :
318+             self .assertEqual (asample_rate .item (), config .audio_sample_rate )
319+             audio_duration  =  aduration .item () *  Fraction (
320+                 atimebase [0 ].item (), atimebase [1 ].item ()
321+             )
322+             self .assertAlmostEqual (audio_duration , config .duration , delta = 0.5 )
323+ 
291324    def  compare_decoding_result (self , tv_result , ref_result , config = all_check_config ):
292325        """ 
293326        Compare decoding results from two sources. 
@@ -297,18 +330,17 @@ def compare_decoding_result(self, tv_result, ref_result, config=all_check_config
297330                        decoder or TorchVision decoder with getPtsOnly = 1 
298331            config: config of decoding results checker 
299332        """ 
300-         vframes , vframe_pts , vtimebase , _vfps , aframes , aframe_pts , atimebase , _asample_rate  =  (
301-             tv_result 
302-         )
333+         vframes , vframe_pts , vtimebase , _vfps , _vduration , aframes , aframe_pts , \
334+             atimebase , _asample_rate , _aduration  =  tv_result 
303335        if  isinstance (ref_result , list ):
304336            # the ref_result is from new video_reader decoder 
305337            ref_result  =  DecoderResult (
306338                vframes = ref_result [0 ],
307339                vframe_pts = ref_result [1 ],
308340                vtimebase = ref_result [2 ],
309-                 aframes = ref_result [4 ],
310-                 aframe_pts = ref_result [5 ],
311-                 atimebase = ref_result [6 ],
341+                 aframes = ref_result [5 ],
342+                 aframe_pts = ref_result [6 ],
343+                 atimebase = ref_result [7 ],
312344            )
313345
314346        if  vframes .numel () >  0  and  ref_result .vframes .numel () >  0 :
@@ -351,12 +383,12 @@ def test_stress_test_read_video_from_file(self):
351383        audio_start_pts , audio_end_pts  =  0 , - 1 
352384        audio_timebase_num , audio_timebase_den  =  0 , 1 
353385
354-         for  i  in  range (num_iter ):
355-             for  test_video , config  in  test_videos .items ():
386+         for  _i  in  range (num_iter ):
387+             for  test_video , _config  in  test_videos .items ():
356388                full_path  =  os .path .join (VIDEO_DIR , test_video )
357389
358390                # pass 1: decode all frames using new decoder 
359-                 _   =   torch .ops .video_reader .read_video_from_file (
391+                 torch .ops .video_reader .read_video_from_file (
360392                    full_path ,
361393                    seek_frame_margin ,
362394                    0 ,  # getPtsOnly 
@@ -460,9 +492,8 @@ def test_read_video_from_file_read_single_stream_only(self):
460492                    audio_timebase_den ,
461493                )
462494
463-                 vframes , vframe_pts , vtimebase , vfps , aframes , aframe_pts , atimebase , asample_rate  =  (
464-                     tv_result 
465-                 )
495+                 vframes , vframe_pts , vtimebase , vfps , vduration , aframes , aframe_pts , \
496+                     atimebase , asample_rate , aduration  =  tv_result 
466497
467498                self .assertEqual (vframes .numel () >  0 , readVideoStream )
468499                self .assertEqual (vframe_pts .numel () >  0 , readVideoStream )
@@ -489,7 +520,7 @@ def test_read_video_from_file_rescale_min_dimension(self):
489520        audio_start_pts , audio_end_pts  =  0 , - 1 
490521        audio_timebase_num , audio_timebase_den  =  0 , 1 
491522
492-         for  test_video , config  in  test_videos .items ():
523+         for  test_video , _config  in  test_videos .items ():
493524            full_path  =  os .path .join (VIDEO_DIR , test_video )
494525
495526            tv_result  =  torch .ops .video_reader .read_video_from_file (
@@ -528,7 +559,7 @@ def test_read_video_from_file_rescale_width(self):
528559        audio_start_pts , audio_end_pts  =  0 , - 1 
529560        audio_timebase_num , audio_timebase_den  =  0 , 1 
530561
531-         for  test_video , config  in  test_videos .items ():
562+         for  test_video , _config  in  test_videos .items ():
532563            full_path  =  os .path .join (VIDEO_DIR , test_video )
533564
534565            tv_result  =  torch .ops .video_reader .read_video_from_file (
@@ -567,7 +598,7 @@ def test_read_video_from_file_rescale_height(self):
567598        audio_start_pts , audio_end_pts  =  0 , - 1 
568599        audio_timebase_num , audio_timebase_den  =  0 , 1 
569600
570-         for  test_video , config  in  test_videos .items ():
601+         for  test_video , _config  in  test_videos .items ():
571602            full_path  =  os .path .join (VIDEO_DIR , test_video )
572603
573604            tv_result  =  torch .ops .video_reader .read_video_from_file (
@@ -606,7 +637,7 @@ def test_read_video_from_file_rescale_width_and_height(self):
606637        audio_start_pts , audio_end_pts  =  0 , - 1 
607638        audio_timebase_num , audio_timebase_den  =  0 , 1 
608639
609-         for  test_video , config  in  test_videos .items ():
640+         for  test_video , _config  in  test_videos .items ():
610641            full_path  =  os .path .join (VIDEO_DIR , test_video )
611642
612643            tv_result  =  torch .ops .video_reader .read_video_from_file (
@@ -651,7 +682,7 @@ def test_read_video_from_file_audio_resampling(self):
651682            audio_start_pts , audio_end_pts  =  0 , - 1 
652683            audio_timebase_num , audio_timebase_den  =  0 , 1 
653684
654-             for  test_video , config  in  test_videos .items ():
685+             for  test_video , _config  in  test_videos .items ():
655686                full_path  =  os .path .join (VIDEO_DIR , test_video )
656687
657688                tv_result  =  torch .ops .video_reader .read_video_from_file (
@@ -674,18 +705,17 @@ def test_read_video_from_file_audio_resampling(self):
674705                    audio_timebase_num ,
675706                    audio_timebase_den ,
676707                )
677-                 vframes , vframe_pts , vtimebase , vfps , aframes , aframe_pts , atimebase , a_sample_rate  =  (
678-                     tv_result 
679-                 )
708+                 vframes , vframe_pts , vtimebase , vfps , vduration , aframes , aframe_pts , \
709+                     atimebase , asample_rate , aduration  =  tv_result 
680710                if  aframes .numel () >  0 :
681-                     self .assertEqual (samples , a_sample_rate .item ())
711+                     self .assertEqual (samples , asample_rate .item ())
682712                    self .assertEqual (1 , aframes .size (1 ))
683713                    # when audio stream is found 
684714                    duration  =  float (aframe_pts [- 1 ]) *  float (atimebase [0 ]) /  float (atimebase [1 ])
685715                    self .assertAlmostEqual (
686716                        aframes .size (0 ),
687-                         int (duration  *  a_sample_rate .item ()),
688-                         delta = 0.1  *  a_sample_rate .item (),
717+                         int (duration  *  asample_rate .item ()),
718+                         delta = 0.1  *  asample_rate .item (),
689719                    )
690720
691721    def  test_compare_read_video_from_memory_and_file (self ):
@@ -859,7 +889,7 @@ def test_read_video_from_memory_get_pts_only(self):
859889            )
860890
861891            self .assertEqual (tv_result_pts_only [0 ].numel (), 0 )
862-             self .assertEqual (tv_result_pts_only [4 ].numel (), 0 )
892+             self .assertEqual (tv_result_pts_only [5 ].numel (), 0 )
863893            self .compare_decoding_result (tv_result , tv_result_pts_only )
864894
865895    def  test_read_video_in_range_from_memory (self ):
@@ -899,9 +929,8 @@ def test_read_video_in_range_from_memory(self):
899929                audio_timebase_num ,
900930                audio_timebase_den ,
901931            )
902-             vframes , vframe_pts , vtimebase , vfps , aframes , aframe_pts , atimebase , asample_rate  =  (
903-                 tv_result 
904-             )
932+             vframes , vframe_pts , vtimebase , vfps , vduration , aframes , aframe_pts , \
933+                 atimebase , asample_rate , aduration  =  tv_result 
905934            self .assertAlmostEqual (config .video_fps , vfps .item (), delta = 0.01 )
906935
907936            for  num_frames  in  [4 , 8 , 16 , 32 , 64 , 128 ]:
@@ -997,6 +1026,24 @@ def test_read_video_in_range_from_memory(self):
9971026                    # and PyAv 
9981027                    self .compare_decoding_result (tv_result , pyav_result , config )
9991028
1029+     def  test_probe_video_from_file (self ):
1030+         """ 
1031+         Test the case when decoder probes a video file 
1032+         """ 
1033+         for  test_video , config  in  test_videos .items ():
1034+             full_path  =  os .path .join (VIDEO_DIR , test_video )
1035+             probe_result  =  torch .ops .video_reader .probe_video_from_file (full_path )
1036+             self .check_probe_result (probe_result , config )
1037+ 
1038+     def  test_probe_video_from_memory (self ):
1039+         """ 
1040+         Test the case when decoder probes a video in memory 
1041+         """ 
1042+         for  test_video , config  in  test_videos .items ():
1043+             full_path , video_tensor  =  _get_video_tensor (VIDEO_DIR , test_video )
1044+             probe_result  =  torch .ops .video_reader .probe_video_from_memory (video_tensor )
1045+             self .check_probe_result (probe_result , config )
1046+ 
10001047
10011048if  __name__  ==  '__main__' :
10021049    unittest .main ()
0 commit comments