@@ -53,7 +53,9 @@ def fate(name, path="."):
5353class TestVideoApi :
5454 @pytest .mark .skipif (av is None , reason = "PyAV unavailable" )
5555 @pytest .mark .parametrize ("test_video" , test_videos .keys ())
56- def test_frame_reading (self , test_video ):
56+ @pytest .mark .parametrize ("backend" , ["video_reader" , "pyav" ])
57+ def test_frame_reading (self , test_video , backend ):
58+ torchvision .set_video_backend (backend )
5759 full_path = os .path .join (VIDEO_DIR , test_video )
5860 with av .open (full_path ) as av_reader :
5961 if av_reader .streams .video :
@@ -117,58 +119,70 @@ def test_frame_reading(self, test_video):
117119
118120 @pytest .mark .parametrize ("stream" , ["video" , "audio" ])
119121 @pytest .mark .parametrize ("test_video" , test_videos .keys ())
120- def test_frame_reading_mem_vs_file (self , test_video , stream ):
122+ @pytest .mark .parametrize ("backend" , ["video_reader" , "pyav" ])
123+ def test_frame_reading_mem_vs_file (self , test_video , stream , backend ):
124+ torchvision .set_video_backend (backend )
121125 full_path = os .path .join (VIDEO_DIR , test_video )
122126
123- # Test video reading from file vs from memory
124- vr_frames , vr_frames_mem = [], []
125- vr_pts , vr_pts_mem = [], []
126- # get vr frames
127- video_reader = VideoReader (full_path , stream )
128- for vr_frame in video_reader :
129- vr_frames .append (vr_frame ["data" ])
130- vr_pts .append (vr_frame ["pts" ])
131-
132- # get vr frames = read from memory
133- f = open (full_path , "rb" )
134- fbytes = f .read ()
135- f .close ()
136- video_reader_from_mem = VideoReader (fbytes , stream )
137-
138- for vr_frame_from_mem in video_reader_from_mem :
139- vr_frames_mem .append (vr_frame_from_mem ["data" ])
140- vr_pts_mem .append (vr_frame_from_mem ["pts" ])
141-
142- # same number of frames
143- assert len (vr_frames ) == len (vr_frames_mem )
144- assert len (vr_pts ) == len (vr_pts_mem )
145-
146- # compare the frames and ptss
147- for i in range (len (vr_frames )):
148- assert vr_pts [i ] == vr_pts_mem [i ]
149- mean_delta = torch .mean (torch .abs (vr_frames [i ].float () - vr_frames_mem [i ].float ()))
150- # on average the difference is very small and caused
151- # by decoding (around 1%)
152- # TODO: asses empirically how to set this? atm it's 1%
153- # averaged over all frames
154- assert mean_delta .item () < 2.55
155-
156- del vr_frames , vr_pts , vr_frames_mem , vr_pts_mem
127+ reader = VideoReader (full_path )
128+ reader_md = reader .get_metadata ()
129+
130+ if stream in reader_md :
131+ # Test video reading from file vs from memory
132+ vr_frames , vr_frames_mem = [], []
133+ vr_pts , vr_pts_mem = [], []
134+ # get vr frames
135+ video_reader = VideoReader (full_path , stream )
136+ for vr_frame in video_reader :
137+ vr_frames .append (vr_frame ["data" ])
138+ vr_pts .append (vr_frame ["pts" ])
139+
140+ # get vr frames = read from memory
141+ f = open (full_path , "rb" )
142+ fbytes = f .read ()
143+ f .close ()
144+ video_reader_from_mem = VideoReader (fbytes , stream )
145+
146+ for vr_frame_from_mem in video_reader_from_mem :
147+ vr_frames_mem .append (vr_frame_from_mem ["data" ])
148+ vr_pts_mem .append (vr_frame_from_mem ["pts" ])
149+
150+ # same number of frames
151+ assert len (vr_frames ) == len (vr_frames_mem )
152+ assert len (vr_pts ) == len (vr_pts_mem )
153+
154+ # compare the frames and ptss
155+ for i in range (len (vr_frames )):
156+ assert vr_pts [i ] == vr_pts_mem [i ]
157+ mean_delta = torch .mean (torch .abs (vr_frames [i ].float () - vr_frames_mem [i ].float ()))
158+ # on average the difference is very small and caused
159+ # by decoding (around 1%)
160+ # TODO: asses empirically how to set this? atm it's 1%
161+ # averaged over all frames
162+ assert mean_delta .item () < 2.55
163+
164+ del vr_frames , vr_pts , vr_frames_mem , vr_pts_mem
165+ else :
166+ del reader , reader_md
157167
158168 @pytest .mark .parametrize ("test_video,config" , test_videos .items ())
159- def test_metadata (self , test_video , config ):
169+ @pytest .mark .parametrize ("backend" , ["video_reader" , "pyav" ])
170+ def test_metadata (self , test_video , config , backend ):
160171 """
161172 Test that the metadata returned via pyav corresponds to the one returned
162173 by the new video decoder API
163174 """
175+ torchvision .set_video_backend (backend )
164176 full_path = os .path .join (VIDEO_DIR , test_video )
165177 reader = VideoReader (full_path , "video" )
166178 reader_md = reader .get_metadata ()
167179 assert config .video_fps == approx (reader_md ["video" ]["fps" ][0 ], abs = 0.0001 )
168180 assert config .duration == approx (reader_md ["video" ]["duration" ][0 ], abs = 0.5 )
169181
170182 @pytest .mark .parametrize ("test_video" , test_videos .keys ())
171- def test_seek_start (self , test_video ):
183+ @pytest .mark .parametrize ("backend" , ["video_reader" , "pyav" ])
184+ def test_seek_start (self , test_video , backend ):
185+ torchvision .set_video_backend (backend )
172186 full_path = os .path .join (VIDEO_DIR , test_video )
173187 video_reader = VideoReader (full_path , "video" )
174188 num_frames = 0
@@ -194,7 +208,9 @@ def test_seek_start(self, test_video):
194208 assert start_num_frames == num_frames
195209
196210 @pytest .mark .parametrize ("test_video" , test_videos .keys ())
197- def test_accurateseek_middle (self , test_video ):
211+ @pytest .mark .parametrize ("backend" , ["video_reader" ])
212+ def test_accurateseek_middle (self , test_video , backend ):
213+ torchvision .set_video_backend (backend )
198214 full_path = os .path .join (VIDEO_DIR , test_video )
199215 stream = "video"
200216 video_reader = VideoReader (full_path , stream )
@@ -233,7 +249,9 @@ def test_fate_suite(self):
233249
234250 @pytest .mark .skipif (av is None , reason = "PyAV unavailable" )
235251 @pytest .mark .parametrize ("test_video,config" , test_videos .items ())
236- def test_keyframe_reading (self , test_video , config ):
252+ @pytest .mark .parametrize ("backend" , ["pyav" , "video_reader" ])
253+ def test_keyframe_reading (self , test_video , config , backend ):
254+ torchvision .set_video_backend (backend )
237255 full_path = os .path .join (VIDEO_DIR , test_video )
238256
239257 av_reader = av .open (full_path )
0 commit comments