Skip to content

Commit 06ad05f

Browse files
authored
Read video from memory newapi (#6771)
* add tensor as optional param * add init from memory * fix bug * fix bug * first working version * apply formatting and add tests * simplify tests * fix tests * fix wrong variable name * add path as optional parameter * add src as optional * address pr comments * Fix warning messages * address pr comments * make tests stricter * Revert "make tests stricter" This reverts commit 6c92e94.
1 parent 246de07 commit 06ad05f

File tree

5 files changed

+148
-19
lines changed

5 files changed

+148
-19
lines changed

test/test_videoapi.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def test_frame_reading(self, test_video):
7777
# compare the frames and ptss
7878
for i in range(len(vr_frames)):
7979
assert float(av_pts[i]) == approx(vr_pts[i], abs=0.1)
80+
8081
mean_delta = torch.mean(torch.abs(av_frames[i].float() - vr_frames[i].float()))
8182
# on average the difference is very small and caused
8283
# by decoding (around 1%)
@@ -114,6 +115,46 @@ def test_frame_reading(self, test_video):
114115
# we assure that there is never more than 1% difference in signal
115116
assert max_delta.item() < 0.001
116117

118+
@pytest.mark.parametrize("stream", ["video", "audio"])
119+
@pytest.mark.parametrize("test_video", test_videos.keys())
120+
def test_frame_reading_mem_vs_file(self, test_video, stream):
121+
full_path = os.path.join(VIDEO_DIR, test_video)
122+
123+
# Test video reading from file vs from memory
124+
vr_frames, vr_frames_mem = [], []
125+
vr_pts, vr_pts_mem = [], []
126+
# get vr frames
127+
video_reader = VideoReader(full_path, stream)
128+
for vr_frame in video_reader:
129+
vr_frames.append(vr_frame["data"])
130+
vr_pts.append(vr_frame["pts"])
131+
132+
# get vr frames = read from memory
133+
f = open(full_path, "rb")
134+
fbytes = f.read()
135+
f.close()
136+
video_reader_from_mem = VideoReader(fbytes, stream)
137+
138+
for vr_frame_from_mem in video_reader_from_mem:
139+
vr_frames_mem.append(vr_frame_from_mem["data"])
140+
vr_pts_mem.append(vr_frame_from_mem["pts"])
141+
142+
# same number of frames
143+
assert len(vr_frames) == len(vr_frames_mem)
144+
assert len(vr_pts) == len(vr_pts_mem)
145+
146+
# compare the frames and ptss
147+
for i in range(len(vr_frames)):
148+
assert vr_pts[i] == vr_pts_mem[i]
149+
mean_delta = torch.mean(torch.abs(vr_frames[i].float() - vr_frames_mem[i].float()))
150+
# on average the difference is very small and caused
151+
# by decoding (around 1%)
152+
# TODO: asses empirically how to set this? atm it's 1%
153+
# averaged over all frames
154+
assert mean_delta.item() < 2.55
155+
156+
del vr_frames, vr_pts, vr_frames_mem, vr_pts_mem
157+
117158
@pytest.mark.parametrize("test_video,config", test_videos.items())
118159
def test_metadata(self, test_video, config):
119160
"""

torchvision/csrc/io/decoder/defs.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ struct MediaFormat {
165165
struct DecoderParameters {
166166
// local file, remote file, http url, rtmp stream uri, etc. anything that
167167
// ffmpeg can recognize
168-
std::string uri;
168+
std::string uri{std::string()};
169169
// timeout on getting bytes for decoding
170170
size_t timeoutMs{1000};
171171
// logging level, default AV_LOG_PANIC

torchvision/csrc/io/video/video.cpp

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -156,14 +156,34 @@ void Video::_getDecoderParams(
156156

157157
} // _get decoder params
158158

159-
Video::Video(std::string videoPath, std::string stream, int64_t numThreads) {
160-
C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.video.video.Video");
159+
void Video::initFromFile(
160+
std::string videoPath,
161+
std::string stream,
162+
int64_t numThreads) {
163+
TORCH_CHECK(!initialized, "Video object can only be initialized once");
164+
initialized = true;
165+
params.uri = videoPath;
166+
_init(stream, numThreads);
167+
}
168+
169+
void Video::initFromMemory(
170+
torch::Tensor videoTensor,
171+
std::string stream,
172+
int64_t numThreads) {
173+
TORCH_CHECK(!initialized, "Video object can only be initialized once");
174+
initialized = true;
175+
callback = MemoryBuffer::getCallback(
176+
videoTensor.data_ptr<uint8_t>(), videoTensor.size(0));
177+
_init(stream, numThreads);
178+
}
179+
180+
void Video::_init(std::string stream, int64_t numThreads) {
161181
// set number of threads global
162182
numThreads_ = numThreads;
163183
// parse stream information
164184
current_stream = _parseStream(stream);
165185
// note that in the initial call we want to get all streams
166-
Video::_getDecoderParams(
186+
_getDecoderParams(
167187
0, // video start
168188
0, // headerOnly
169189
std::get<0>(current_stream), // stream info - remove that
@@ -175,11 +195,6 @@ Video::Video(std::string videoPath, std::string stream, int64_t numThreads) {
175195

176196
std::string logMessage, logType;
177197

178-
// TODO: add read from memory option
179-
params.uri = videoPath;
180-
logType = "file";
181-
logMessage = videoPath;
182-
183198
// locals
184199
std::vector<double> audioFPS, videoFPS;
185200
std::vector<double> audioDuration, videoDuration, ccDuration, subsDuration;
@@ -190,7 +205,8 @@ Video::Video(std::string videoPath, std::string stream, int64_t numThreads) {
190205
c10::Dict<std::string, std::vector<double>> subsMetadata;
191206

192207
// callback and metadata defined in struct
193-
succeeded = decoder.init(params, std::move(callback), &metadata);
208+
DecoderInCallback tmp_callback = callback;
209+
succeeded = decoder.init(params, std::move(tmp_callback), &metadata);
194210
if (succeeded) {
195211
for (const auto& header : metadata) {
196212
double fps = double(header.fps);
@@ -225,16 +241,24 @@ Video::Video(std::string videoPath, std::string stream, int64_t numThreads) {
225241
streamsMetadata.insert("subtitles", subsMetadata);
226242
streamsMetadata.insert("cc", ccMetadata);
227243

228-
succeeded = Video::setCurrentStream(stream);
244+
succeeded = setCurrentStream(stream);
229245
LOG(INFO) << "\nDecoder inited with: " << succeeded << "\n";
230246
if (std::get<1>(current_stream) != -1) {
231247
LOG(INFO)
232248
<< "Stream index set to " << std::get<1>(current_stream)
233249
<< ". If you encounter trouble, consider switching it to automatic stream discovery. \n";
234250
}
251+
}
252+
253+
Video::Video(std::string videoPath, std::string stream, int64_t numThreads) {
254+
C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.video.video.Video");
255+
if (!videoPath.empty()) {
256+
initFromFile(videoPath, stream, numThreads);
257+
}
235258
} // video
236259

237260
bool Video::setCurrentStream(std::string stream = "video") {
261+
TORCH_CHECK(initialized, "Video object has to be initialized first");
238262
if ((!stream.empty()) && (_parseStream(stream) != current_stream)) {
239263
current_stream = _parseStream(stream);
240264
}
@@ -256,19 +280,23 @@ bool Video::setCurrentStream(std::string stream = "video") {
256280
);
257281

258282
// callback and metadata defined in Video.h
259-
return (decoder.init(params, std::move(callback), &metadata));
283+
DecoderInCallback tmp_callback = callback;
284+
return (decoder.init(params, std::move(tmp_callback), &metadata));
260285
}
261286

262287
std::tuple<std::string, int64_t> Video::getCurrentStream() const {
288+
TORCH_CHECK(initialized, "Video object has to be initialized first");
263289
return current_stream;
264290
}
265291

266292
c10::Dict<std::string, c10::Dict<std::string, std::vector<double>>> Video::
267293
getStreamMetadata() const {
294+
TORCH_CHECK(initialized, "Video object has to be initialized first");
268295
return streamsMetadata;
269296
}
270297

271298
void Video::Seek(double ts, bool fastSeek = false) {
299+
TORCH_CHECK(initialized, "Video object has to be initialized first");
272300
// initialize the class variables used for seeking and retrurn
273301
_getDecoderParams(
274302
ts, // video start
@@ -282,11 +310,14 @@ void Video::Seek(double ts, bool fastSeek = false) {
282310
);
283311

284312
// callback and metadata defined in Video.h
285-
succeeded = decoder.init(params, std::move(callback), &metadata);
313+
DecoderInCallback tmp_callback = callback;
314+
succeeded = decoder.init(params, std::move(tmp_callback), &metadata);
315+
286316
LOG(INFO) << "Decoder init at seek " << succeeded << "\n";
287317
}
288318

289319
std::tuple<torch::Tensor, double> Video::Next() {
320+
TORCH_CHECK(initialized, "Video object has to be initialized first");
290321
// if failing to decode simply return a null tensor (note, should we
291322
// raise an exeption?)
292323
double frame_pts_s;
@@ -345,6 +376,8 @@ std::tuple<torch::Tensor, double> Video::Next() {
345376
static auto registerVideo =
346377
torch::class_<Video>("torchvision", "Video")
347378
.def(torch::init<std::string, std::string, int64_t>())
379+
.def("init_from_file", &Video::initFromFile)
380+
.def("init_from_memory", &Video::initFromMemory)
348381
.def("get_current_stream", &Video::getCurrentStream)
349382
.def("set_current_stream", &Video::setCurrentStream)
350383
.def("get_metadata", &Video::getStreamMetadata)

torchvision/csrc/io/video/video.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,19 @@ struct Video : torch::CustomClassHolder {
1919
int64_t numThreads_{0};
2020

2121
public:
22-
Video(std::string videoPath, std::string stream, int64_t numThreads);
22+
Video(
23+
std::string videoPath = std::string(),
24+
std::string stream = std::string("video"),
25+
int64_t numThreads = 0);
26+
void initFromFile(
27+
std::string videoPath,
28+
std::string stream,
29+
int64_t numThreads);
30+
void initFromMemory(
31+
torch::Tensor videoTensor,
32+
std::string stream,
33+
int64_t numThreads);
34+
2335
std::tuple<std::string, int64_t> getCurrentStream() const;
2436
c10::Dict<std::string, c10::Dict<std::string, std::vector<double>>>
2537
getStreamMetadata() const;
@@ -34,6 +46,12 @@ struct Video : torch::CustomClassHolder {
3446
// time in comination with any_frame settings
3547
double seekTS = -1;
3648

49+
bool initialized = false;
50+
51+
void _init(
52+
std::string stream,
53+
int64_t numThreads); // expects params.uri OR callback to be set
54+
3755
void _getDecoderParams(
3856
double videoStartS,
3957
int64_t getPtsOnly,

torchvision/io/video_reader.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Any, Dict, Iterator
1+
import warnings
2+
from typing import Any, Dict, Iterator, Optional
23

34
import torch
45

@@ -71,8 +72,13 @@ class VideoReader:
7172
If only stream type is passed, the decoder auto-detects first stream of that type.
7273
7374
Args:
75+
src (string, bytes object, or tensor): The media source.
76+
If string-type, it must be a file path supported by FFMPEG.
77+
If bytes shoud be an in memory representatin of a file supported by FFMPEG.
78+
If Tensor, it is interpreted internally as byte buffer.
79+
It must be one-dimensional, of type ``torch.uint8``.
80+
7481
75-
path (string): Path to the video file in supported format
7682
7783
stream (string, optional): descriptor of the required stream, followed by the stream id,
7884
in the format ``{stream_type}:{stream_id}``. Defaults to ``"video:0"``.
@@ -85,17 +91,31 @@ class VideoReader:
8591
device (str, optional): Device to be used for decoding. Defaults to ``"cpu"``.
8692
To use GPU decoding, pass ``device="cuda"``.
8793
94+
path (str, optional):
95+
.. warning:
96+
This parameter was deprecated in ``0.15`` and will be removed in ``0.17``.
97+
Please use ``src`` instead.
98+
99+
100+
88101
"""
89102

90-
def __init__(self, path: str, stream: str = "video", num_threads: int = 0, device: str = "cpu") -> None:
103+
def __init__(
104+
self,
105+
src: str = "",
106+
stream: str = "video",
107+
num_threads: int = 0,
108+
device: str = "cpu",
109+
path: Optional[str] = None,
110+
) -> None:
91111
_log_api_usage_once(self)
92112
self.is_cuda = False
93113
device = torch.device(device)
94114
if device.type == "cuda":
95115
if not _HAS_GPU_VIDEO_DECODER:
96116
raise RuntimeError("Not compiled with GPU decoder support.")
97117
self.is_cuda = True
98-
self._c = torch.classes.torchvision.GPUDecoder(path, device)
118+
self._c = torch.classes.torchvision.GPUDecoder(src, device)
99119
return
100120
if not _has_video_opt():
101121
raise RuntimeError(
@@ -105,7 +125,24 @@ def __init__(self, path: str, stream: str = "video", num_threads: int = 0, devic
105125
+ "build torchvision from source."
106126
)
107127

108-
self._c = torch.classes.torchvision.Video(path, stream, num_threads)
128+
if src == "":
129+
if path is None:
130+
raise TypeError("src cannot be empty")
131+
src = path
132+
warnings.warn("path is deprecated and will be removed in 0.17. Please use src instead")
133+
134+
elif isinstance(src, bytes):
135+
src = torch.frombuffer(src, dtype=torch.uint8)
136+
137+
if isinstance(src, str):
138+
self._c = torch.classes.torchvision.Video(src, stream, num_threads)
139+
elif isinstance(src, torch.Tensor):
140+
if self.is_cuda:
141+
raise RuntimeError("GPU VideoReader cannot be initialized from Tensor or bytes object.")
142+
self._c = torch.classes.torchvision.Video("", "", 0)
143+
self._c.init_from_memory(src, stream, num_threads)
144+
else:
145+
raise TypeError("`src` must be either string, Tensor or bytes object.")
109146

110147
def __next__(self) -> Dict[str, Any]:
111148
"""Decodes and returns the next frame of the current stream.

0 commit comments

Comments
 (0)