6
6
7
7
#include " src/torchcodec/_core/CpuDeviceInterface.h"
8
8
9
- extern " C" {
10
- #include < libavfilter/buffersink.h>
11
- #include < libavfilter/buffersrc.h>
12
- }
13
-
14
9
namespace facebook ::torchcodec {
15
10
namespace {
16
11
@@ -20,17 +15,15 @@ static bool g_cpu = registerDeviceInterface(
20
15
21
16
} // namespace
22
17
23
- bool CpuDeviceInterface::DecodedFrameContext::operator ==(
24
- const CpuDeviceInterface::DecodedFrameContext& other) {
25
- return decodedWidth == other.decodedWidth &&
26
- decodedHeight == other.decodedHeight &&
27
- decodedFormat == other.decodedFormat &&
28
- expectedWidth == other.expectedWidth &&
29
- expectedHeight == other.expectedHeight ;
18
+ bool CpuDeviceInterface::SwsFrameContext::operator ==(
19
+ const CpuDeviceInterface::SwsFrameContext& other) const {
20
+ return inputWidth == other.inputWidth && inputHeight == other.inputHeight &&
21
+ inputFormat == other.inputFormat && outputWidth == other.outputWidth &&
22
+ outputHeight == other.outputHeight ;
30
23
}
31
24
32
- bool CpuDeviceInterface::DecodedFrameContext ::operator !=(
33
- const CpuDeviceInterface::DecodedFrameContext & other) {
25
+ bool CpuDeviceInterface::SwsFrameContext ::operator !=(
26
+ const CpuDeviceInterface::SwsFrameContext & other) const {
34
27
return !(*this == other);
35
28
}
36
29
@@ -75,22 +68,8 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
75
68
}
76
69
77
70
torch::Tensor outputTensor;
78
- // We need to compare the current frame context with our previous frame
79
- // context. If they are different, then we need to re-create our colorspace
80
- // conversion objects. We create our colorspace conversion objects late so
81
- // that we don't have to depend on the unreliable metadata in the header.
82
- // And we sometimes re-create them because it's possible for frame
83
- // resolution to change mid-stream. Finally, we want to reuse the colorspace
84
- // conversion objects as much as possible for performance reasons.
85
71
enum AVPixelFormat frameFormat =
86
72
static_cast <enum AVPixelFormat>(avFrame->format );
87
- auto frameContext = DecodedFrameContext{
88
- avFrame->width ,
89
- avFrame->height ,
90
- frameFormat,
91
- avFrame->sample_aspect_ratio ,
92
- expectedOutputWidth,
93
- expectedOutputHeight};
94
73
95
74
// By default, we want to use swscale for color conversion because it is
96
75
// faster. However, it has width requirements, so we may need to fall back
@@ -111,12 +90,27 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
111
90
videoStreamOptions.colorConversionLibrary .value_or (defaultLibrary);
112
91
113
92
if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
93
+ // We need to compare the current frame context with our previous frame
94
+ // context. If they are different, then we need to re-create our colorspace
95
+ // conversion objects. We create our colorspace conversion objects late so
96
+ // that we don't have to depend on the unreliable metadata in the header.
97
+ // And we sometimes re-create them because it's possible for frame
98
+ // resolution to change mid-stream. Finally, we want to reuse the colorspace
99
+ // conversion objects as much as possible for performance reasons.
100
+ SwsFrameContext swsFrameContext;
101
+
102
+ swsFrameContext.inputWidth = avFrame->width ;
103
+ swsFrameContext.inputHeight = avFrame->height ;
104
+ swsFrameContext.inputFormat = frameFormat;
105
+ swsFrameContext.outputWidth = expectedOutputWidth;
106
+ swsFrameContext.outputHeight = expectedOutputHeight;
107
+
114
108
outputTensor = preAllocatedOutputTensor.value_or (allocateEmptyHWCTensor (
115
109
expectedOutputHeight, expectedOutputWidth, torch::kCPU ));
116
110
117
- if (!swsContext_ || prevFrameContext_ != frameContext ) {
118
- createSwsContext (frameContext , avFrame->colorspace );
119
- prevFrameContext_ = frameContext ;
111
+ if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext ) {
112
+ createSwsContext (swsFrameContext , avFrame->colorspace );
113
+ prevSwsFrameContext_ = swsFrameContext ;
120
114
}
121
115
int resultHeight =
122
116
convertAVFrameToTensorUsingSwsScale (avFrame, outputTensor);
@@ -132,9 +126,29 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
132
126
133
127
frameOutput.data = outputTensor;
134
128
} else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
135
- if (!filterGraphContext_.filterGraph || prevFrameContext_ != frameContext) {
136
- createFilterGraph (frameContext, videoStreamOptions, timeBase);
137
- prevFrameContext_ = frameContext;
129
+ // See comment above in swscale branch about the filterGraphContext_
130
+ // creation. creation
131
+ FiltersContext filtersContext;
132
+
133
+ filtersContext.inputWidth = avFrame->width ;
134
+ filtersContext.inputHeight = avFrame->height ;
135
+ filtersContext.inputFormat = frameFormat;
136
+ filtersContext.inputAspectRatio = avFrame->sample_aspect_ratio ;
137
+ filtersContext.outputWidth = expectedOutputWidth;
138
+ filtersContext.outputHeight = expectedOutputHeight;
139
+ filtersContext.outputFormat = AV_PIX_FMT_RGB24;
140
+ filtersContext.timeBase = timeBase;
141
+
142
+ std::stringstream filters;
143
+ filters << " scale=" << expectedOutputWidth << " :" << expectedOutputHeight;
144
+ filters << " :sws_flags=bilinear" ;
145
+
146
+ filtersContext.filtergraphStr = filters.str ();
147
+
148
+ if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) {
149
+ filterGraphContext_ =
150
+ std::make_unique<FilterGraph>(filtersContext, videoStreamOptions);
151
+ prevFiltersContext_ = std::move (filtersContext);
138
152
}
139
153
outputTensor = convertAVFrameToTensorUsingFilterGraph (avFrame);
140
154
@@ -187,14 +201,8 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale(
187
201
188
202
torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph (
189
203
const UniqueAVFrame& avFrame) {
190
- int status = av_buffersrc_write_frame (
191
- filterGraphContext_.sourceContext , avFrame.get ());
192
- TORCH_CHECK (
193
- status >= AVSUCCESS, " Failed to add frame to buffer source context" );
204
+ UniqueAVFrame filteredAVFrame = filterGraphContext_->convert (avFrame);
194
205
195
- UniqueAVFrame filteredAVFrame (av_frame_alloc ());
196
- status = av_buffersink_get_frame (
197
- filterGraphContext_.sinkContext , filteredAVFrame.get ());
198
206
TORCH_CHECK_EQ (filteredAVFrame->format , AV_PIX_FMT_RGB24);
199
207
200
208
auto frameDims = getHeightAndWidthFromResizedAVFrame (*filteredAVFrame.get ());
@@ -210,117 +218,15 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
210
218
filteredAVFramePtr->data [0 ], shape, strides, deleter, {torch::kUInt8 });
211
219
}
212
220
213
- void CpuDeviceInterface::createFilterGraph (
214
- const DecodedFrameContext& frameContext,
215
- const VideoStreamOptions& videoStreamOptions,
216
- const AVRational& timeBase) {
217
- filterGraphContext_.filterGraph .reset (avfilter_graph_alloc ());
218
- TORCH_CHECK (filterGraphContext_.filterGraph .get () != nullptr );
219
-
220
- if (videoStreamOptions.ffmpegThreadCount .has_value ()) {
221
- filterGraphContext_.filterGraph ->nb_threads =
222
- videoStreamOptions.ffmpegThreadCount .value ();
223
- }
224
-
225
- const AVFilter* buffersrc = avfilter_get_by_name (" buffer" );
226
- const AVFilter* buffersink = avfilter_get_by_name (" buffersink" );
227
-
228
- std::stringstream filterArgs;
229
- filterArgs << " video_size=" << frameContext.decodedWidth << " x"
230
- << frameContext.decodedHeight ;
231
- filterArgs << " :pix_fmt=" << frameContext.decodedFormat ;
232
- filterArgs << " :time_base=" << timeBase.num << " /" << timeBase.den ;
233
- filterArgs << " :pixel_aspect=" << frameContext.decodedAspectRatio .num << " /"
234
- << frameContext.decodedAspectRatio .den ;
235
-
236
- int status = avfilter_graph_create_filter (
237
- &filterGraphContext_.sourceContext ,
238
- buffersrc,
239
- " in" ,
240
- filterArgs.str ().c_str (),
241
- nullptr ,
242
- filterGraphContext_.filterGraph .get ());
243
- TORCH_CHECK (
244
- status >= 0 ,
245
- " Failed to create filter graph: " ,
246
- filterArgs.str (),
247
- " : " ,
248
- getFFMPEGErrorStringFromErrorCode (status));
249
-
250
- status = avfilter_graph_create_filter (
251
- &filterGraphContext_.sinkContext ,
252
- buffersink,
253
- " out" ,
254
- nullptr ,
255
- nullptr ,
256
- filterGraphContext_.filterGraph .get ());
257
- TORCH_CHECK (
258
- status >= 0 ,
259
- " Failed to create filter graph: " ,
260
- getFFMPEGErrorStringFromErrorCode (status));
261
-
262
- enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE};
263
-
264
- status = av_opt_set_int_list (
265
- filterGraphContext_.sinkContext ,
266
- " pix_fmts" ,
267
- pix_fmts,
268
- AV_PIX_FMT_NONE,
269
- AV_OPT_SEARCH_CHILDREN);
270
- TORCH_CHECK (
271
- status >= 0 ,
272
- " Failed to set output pixel formats: " ,
273
- getFFMPEGErrorStringFromErrorCode (status));
274
-
275
- UniqueAVFilterInOut outputs (avfilter_inout_alloc ());
276
- UniqueAVFilterInOut inputs (avfilter_inout_alloc ());
277
-
278
- outputs->name = av_strdup (" in" );
279
- outputs->filter_ctx = filterGraphContext_.sourceContext ;
280
- outputs->pad_idx = 0 ;
281
- outputs->next = nullptr ;
282
- inputs->name = av_strdup (" out" );
283
- inputs->filter_ctx = filterGraphContext_.sinkContext ;
284
- inputs->pad_idx = 0 ;
285
- inputs->next = nullptr ;
286
-
287
- std::stringstream description;
288
- description << " scale=" << frameContext.expectedWidth << " :"
289
- << frameContext.expectedHeight ;
290
- description << " :sws_flags=bilinear" ;
291
-
292
- AVFilterInOut* outputsTmp = outputs.release ();
293
- AVFilterInOut* inputsTmp = inputs.release ();
294
- status = avfilter_graph_parse_ptr (
295
- filterGraphContext_.filterGraph .get (),
296
- description.str ().c_str (),
297
- &inputsTmp,
298
- &outputsTmp,
299
- nullptr );
300
- outputs.reset (outputsTmp);
301
- inputs.reset (inputsTmp);
302
- TORCH_CHECK (
303
- status >= 0 ,
304
- " Failed to parse filter description: " ,
305
- getFFMPEGErrorStringFromErrorCode (status));
306
-
307
- status =
308
- avfilter_graph_config (filterGraphContext_.filterGraph .get (), nullptr );
309
- TORCH_CHECK (
310
- status >= 0 ,
311
- " Failed to configure filter graph: " ,
312
- getFFMPEGErrorStringFromErrorCode (status));
313
- }
314
-
315
221
void CpuDeviceInterface::createSwsContext (
316
- const DecodedFrameContext& frameContext ,
222
+ const SwsFrameContext& swsFrameContext ,
317
223
const enum AVColorSpace colorspace) {
318
224
SwsContext* swsContext = sws_getContext (
319
- frameContext. decodedWidth ,
320
- frameContext. decodedHeight ,
321
- frameContext. decodedFormat ,
322
- frameContext. expectedWidth ,
323
- frameContext. expectedHeight ,
225
+ swsFrameContext. inputWidth ,
226
+ swsFrameContext. inputHeight ,
227
+ swsFrameContext. inputFormat ,
228
+ swsFrameContext. outputWidth ,
229
+ swsFrameContext. outputHeight ,
324
230
AV_PIX_FMT_RGB24,
325
231
SWS_BILINEAR,
326
232
nullptr ,
0 commit comments