1- from typing import Any , Dict , Iterator
21import warnings
2+ from typing import Any , Dict , Iterator
33
44import torch
55
@@ -17,6 +17,7 @@ def _has_video_opt() -> bool:
1717 def _has_video_opt () -> bool :
1818 return False
1919
20+
2021try :
2122 import av
2223
@@ -37,7 +38,8 @@ def _has_video_opt() -> bool:
3738 PyAV is not installed, and is necessary for the video operations in torchvision.
3839See https://github.com/mikeboers/PyAV#installation for instructions on how to
3940install PyAV on your system.
40- """ )
41+ """
42+ )
4143
4244
4345class VideoReader :
@@ -108,19 +110,20 @@ class VideoReader:
108110 def __init__ (self , path : str , stream : str = "video" , num_threads : int = 0 ) -> None :
109111 _log_api_usage_once (self )
110112 from .. import get_video_backend
113+
111114 self .backend = get_video_backend ()
112115 print ("Initiated the backend" , self .backend )
113116 if self .backend == "cuda" :
114117 device = torch .device ("cuda" )
115118 self ._c = torch .classes .torchvision .GPUDecoder (path , device )
116119 return
117-
120+
118121 elif self .backend == "video_reader" :
119122 self ._c = torch .classes .torchvision .Video (path , stream , num_threads )
120-
123+
121124 elif self .backend == "pyav" :
122125 self .container = av .open (path , metadata_errors = "ignore" )
123- #TODO: load metadata
126+ # TODO: load metadata
124127 stream_type = stream .split (":" )[0 ]
125128 stream_id = 0 if len (stream .split (":" )) == 1 else int (stream .split (":" )[1 ])
126129 self .pyav_stream = {stream_type : stream_id }
@@ -129,8 +132,7 @@ def __init__(self, path: str, stream: str = "video", num_threads: int = 0) -> No
129132 # TODO: add extradata exception
130133
131134 else :
132- raise RuntimeError ("Unknown video backend: {}" .format (self .backend ))
133-
135+ raise RuntimeError ("Unknown video backend: {}" .format (self .backend ))
134136
135137 def __next__ (self ) -> Dict [str , Any ]:
136138 """Decodes and returns the next frame of the current stream.
@@ -163,7 +165,7 @@ def __next__(self) -> Dict[str, Any]:
163165 frame = None
164166 except av .error .EOFError :
165167 raise StopIteration
166-
168+
167169 if frame .numel () == 0 :
168170 raise StopIteration
169171
@@ -210,10 +212,10 @@ def get_metadata(self) -> Dict[str, Any]:
210212 for stream in self .container .streams :
211213 if stream .type not in metadata :
212214 metadata [stream .type ] = {"fps" : [], "duration" : []}
213-
215+
214216 rate = stream .average_rate if stream .average_rate is not None else stream .sample_rate
215-
216- metadata [stream .type ]["duration" ].append (float (stream .duration * stream .time_base ))
217+
218+ metadata [stream .type ]["duration" ].append (float (stream .duration * stream .time_base ))
217219 metadata [stream .type ]["fps" ].append (float (rate ))
218220 return metadata
219221 return self ._c .get_metadata ()
0 commit comments