22import gc
33import torch
44import numpy as np
5+ import math
6+ import warnings
57
68try :
79 import av
@@ -74,12 +76,20 @@ def write_video(filename, video_array, fps, video_codec='libx264', options=None)
7476 container .close ()
7577
7678
77- def _read_from_stream (container , start_offset , end_offset , stream , stream_name ):
79+ def _read_from_stream (container , start_offset , end_offset , pts_unit , stream , stream_name ):
7880 global _CALLED_TIMES , _GC_COLLECTION_INTERVAL
7981 _CALLED_TIMES += 1
8082 if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1 :
8183 gc .collect ()
8284
85+ if pts_unit == 'sec' :
86+ start_offset = int (math .floor (start_offset * (1 / stream .time_base )))
87+ if end_offset != float ("inf" ):
88+ end_offset = int (math .ceil (end_offset * (1 / stream .time_base )))
89+ else :
90+ warnings .warn ("The pts_unit 'pts' gives wrong results and will be removed in a " +
91+ "follow-up version. Please use pts_unit 'sec'." )
92+
8393 frames = {}
8494 should_buffer = False
8595 max_buffer_size = 5
@@ -145,7 +155,7 @@ def _align_audio_frames(aframes, audio_frames, ref_start, ref_end):
145155 return aframes [:, s_idx :e_idx ]
146156
147157
148- def read_video (filename , start_pts = 0 , end_pts = None ):
158+ def read_video (filename , start_pts = 0 , end_pts = None , pts_unit = 'pts' ):
149159 """
150160 Reads a video from a file, returning both the video frames as well as
151161 the audio frames
@@ -154,10 +164,14 @@ def read_video(filename, start_pts=0, end_pts=None):
154164 ----------
155165 filename : str
156166 path to the video file
157- start_pts : int, optional
167+ start_pts : int if pts_unit = 'pts', optional
168+ float / Fraction if pts_unit = 'sec', optional
158169 the start presentation time of the video
159- end_pts : int, optional
170+ end_pts : int if pts_unit = 'pts', optional
171+ float / Fraction if pts_unit = 'sec', optional
160172 the end presentation time
173+ pts_unit : str, optional
174+ unit in which start_pts and end_pts values will be interpreted, either 'pts' or 'sec'. Defaults to 'pts'.
161175
162176 Returns
163177 -------
@@ -184,12 +198,12 @@ def read_video(filename, start_pts=0, end_pts=None):
184198
185199 video_frames = []
186200 if container .streams .video :
187- video_frames = _read_from_stream (container , start_pts , end_pts ,
201+ video_frames = _read_from_stream (container , start_pts , end_pts , pts_unit ,
188202 container .streams .video [0 ], {'video' : 0 })
189203 info ["video_fps" ] = float (container .streams .video [0 ].average_rate )
190204 audio_frames = []
191205 if container .streams .audio :
192- audio_frames = _read_from_stream (container , start_pts , end_pts ,
206+ audio_frames = _read_from_stream (container , start_pts , end_pts , pts_unit ,
193207 container .streams .audio [0 ], {'audio' : 0 })
194208 info ["audio_fps" ] = container .streams .audio [0 ].rate
195209
@@ -217,7 +231,7 @@ def _can_read_timestamps_from_packets(container):
217231 return False
218232
219233
220- def read_video_timestamps (filename ):
234+ def read_video_timestamps (filename , pts_unit = 'pts' ):
221235 """
222236 List the video frames timestamps.
223237
@@ -227,27 +241,35 @@ def read_video_timestamps(filename):
227241 ----------
228242 filename : str
229243 path to the video file
244+ pts_unit : str, optional
245+ unit in which timestamp values will be returned either 'pts' or 'sec'. Defaults to 'pts'.
230246
231247 Returns
232248 -------
233- pts : List[int]
249+ pts : List[int] if pts_unit = 'pts'
250+ List[Fraction] if pts_unit = 'sec'
234251 presentation timestamps for each one of the frames in the video.
235252 video_fps : int
236253 the frame rate for the video
237254
238255 """
239256 _check_av_available ()
257+
240258 container = av .open (filename , metadata_errors = 'ignore' )
241259
242260 video_frames = []
243261 video_fps = None
244262 if container .streams .video :
263+ video_stream = container .streams .video [0 ]
264+ video_time_base = video_stream .time_base
245265 if _can_read_timestamps_from_packets (container ):
246266 # fast path
247267 video_frames = [x for x in container .demux (video = 0 ) if x .pts is not None ]
248268 else :
249- video_frames = _read_from_stream (container , 0 , float ("inf" ),
250- container . streams . video [ 0 ] , {'video' : 0 })
251- video_fps = float (container . streams . video [ 0 ] .average_rate )
269+ video_frames = _read_from_stream (container , 0 , float ("inf" ), pts_unit ,
270+ video_stream , {'video' : 0 })
271+ video_fps = float (video_stream .average_rate )
252272 container .close ()
273+ if pts_unit == 'sec' :
274+ return [x .pts * video_time_base for x in video_frames ], video_fps
253275 return [x .pts for x in video_frames ], video_fps
0 commit comments