Skip to content

Commit ec878eb

Browse files
dvrogozhDmitry RogozhkinNicolasHug
authored
Move filter graph to stand alone class (#831)
Co-authored-by: Dmitry Rogozhkin <[email protected]> Co-authored-by: Nicolas Hug <[email protected]>
1 parent d082a71 commit ec878eb

File tree

6 files changed

+269
-172
lines changed

6 files changed

+269
-172
lines changed

src/torchcodec/_core/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ function(make_torchcodec_libraries
8888
AVIOContextHolder.cpp
8989
AVIOTensorContext.cpp
9090
FFMPEGCommon.cpp
91+
FilterGraph.cpp
9192
Frame.cpp
9293
DeviceInterface.cpp
9394
CpuDeviceInterface.cpp

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 55 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,6 @@
66

77
#include "src/torchcodec/_core/CpuDeviceInterface.h"
88

9-
extern "C" {
10-
#include <libavfilter/buffersink.h>
11-
#include <libavfilter/buffersrc.h>
12-
}
13-
149
namespace facebook::torchcodec {
1510
namespace {
1611

@@ -20,17 +15,15 @@ static bool g_cpu = registerDeviceInterface(
2015

2116
} // namespace
2217

23-
bool CpuDeviceInterface::DecodedFrameContext::operator==(
24-
const CpuDeviceInterface::DecodedFrameContext& other) {
25-
return decodedWidth == other.decodedWidth &&
26-
decodedHeight == other.decodedHeight &&
27-
decodedFormat == other.decodedFormat &&
28-
expectedWidth == other.expectedWidth &&
29-
expectedHeight == other.expectedHeight;
18+
bool CpuDeviceInterface::SwsFrameContext::operator==(
19+
const CpuDeviceInterface::SwsFrameContext& other) const {
20+
return inputWidth == other.inputWidth && inputHeight == other.inputHeight &&
21+
inputFormat == other.inputFormat && outputWidth == other.outputWidth &&
22+
outputHeight == other.outputHeight;
3023
}
3124

32-
bool CpuDeviceInterface::DecodedFrameContext::operator!=(
33-
const CpuDeviceInterface::DecodedFrameContext& other) {
25+
bool CpuDeviceInterface::SwsFrameContext::operator!=(
26+
const CpuDeviceInterface::SwsFrameContext& other) const {
3427
return !(*this == other);
3528
}
3629

@@ -75,22 +68,8 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
7568
}
7669

7770
torch::Tensor outputTensor;
78-
// We need to compare the current frame context with our previous frame
79-
// context. If they are different, then we need to re-create our colorspace
80-
// conversion objects. We create our colorspace conversion objects late so
81-
// that we don't have to depend on the unreliable metadata in the header.
82-
// And we sometimes re-create them because it's possible for frame
83-
// resolution to change mid-stream. Finally, we want to reuse the colorspace
84-
// conversion objects as much as possible for performance reasons.
8571
enum AVPixelFormat frameFormat =
8672
static_cast<enum AVPixelFormat>(avFrame->format);
87-
auto frameContext = DecodedFrameContext{
88-
avFrame->width,
89-
avFrame->height,
90-
frameFormat,
91-
avFrame->sample_aspect_ratio,
92-
expectedOutputWidth,
93-
expectedOutputHeight};
9473

9574
// By default, we want to use swscale for color conversion because it is
9675
// faster. However, it has width requirements, so we may need to fall back
@@ -111,12 +90,27 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
11190
videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary);
11291

11392
if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
93+
// We need to compare the current frame context with our previous frame
94+
// context. If they are different, then we need to re-create our colorspace
95+
// conversion objects. We create our colorspace conversion objects late so
96+
// that we don't have to depend on the unreliable metadata in the header.
97+
// And we sometimes re-create them because it's possible for frame
98+
// resolution to change mid-stream. Finally, we want to reuse the colorspace
99+
// conversion objects as much as possible for performance reasons.
100+
SwsFrameContext swsFrameContext;
101+
102+
swsFrameContext.inputWidth = avFrame->width;
103+
swsFrameContext.inputHeight = avFrame->height;
104+
swsFrameContext.inputFormat = frameFormat;
105+
swsFrameContext.outputWidth = expectedOutputWidth;
106+
swsFrameContext.outputHeight = expectedOutputHeight;
107+
114108
outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor(
115109
expectedOutputHeight, expectedOutputWidth, torch::kCPU));
116110

117-
if (!swsContext_ || prevFrameContext_ != frameContext) {
118-
createSwsContext(frameContext, avFrame->colorspace);
119-
prevFrameContext_ = frameContext;
111+
if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) {
112+
createSwsContext(swsFrameContext, avFrame->colorspace);
113+
prevSwsFrameContext_ = swsFrameContext;
120114
}
121115
int resultHeight =
122116
convertAVFrameToTensorUsingSwsScale(avFrame, outputTensor);
@@ -132,9 +126,29 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
132126

133127
frameOutput.data = outputTensor;
134128
} else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
135-
if (!filterGraphContext_.filterGraph || prevFrameContext_ != frameContext) {
136-
createFilterGraph(frameContext, videoStreamOptions, timeBase);
137-
prevFrameContext_ = frameContext;
129+
// See comment above in swscale branch about the filterGraphContext_
130+
// creation. creation
131+
FiltersContext filtersContext;
132+
133+
filtersContext.inputWidth = avFrame->width;
134+
filtersContext.inputHeight = avFrame->height;
135+
filtersContext.inputFormat = frameFormat;
136+
filtersContext.inputAspectRatio = avFrame->sample_aspect_ratio;
137+
filtersContext.outputWidth = expectedOutputWidth;
138+
filtersContext.outputHeight = expectedOutputHeight;
139+
filtersContext.outputFormat = AV_PIX_FMT_RGB24;
140+
filtersContext.timeBase = timeBase;
141+
142+
std::stringstream filters;
143+
filters << "scale=" << expectedOutputWidth << ":" << expectedOutputHeight;
144+
filters << ":sws_flags=bilinear";
145+
146+
filtersContext.filtergraphStr = filters.str();
147+
148+
if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) {
149+
filterGraphContext_ =
150+
std::make_unique<FilterGraph>(filtersContext, videoStreamOptions);
151+
prevFiltersContext_ = std::move(filtersContext);
138152
}
139153
outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame);
140154

@@ -187,14 +201,8 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale(
187201

188202
torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
189203
const UniqueAVFrame& avFrame) {
190-
int status = av_buffersrc_write_frame(
191-
filterGraphContext_.sourceContext, avFrame.get());
192-
TORCH_CHECK(
193-
status >= AVSUCCESS, "Failed to add frame to buffer source context");
204+
UniqueAVFrame filteredAVFrame = filterGraphContext_->convert(avFrame);
194205

195-
UniqueAVFrame filteredAVFrame(av_frame_alloc());
196-
status = av_buffersink_get_frame(
197-
filterGraphContext_.sinkContext, filteredAVFrame.get());
198206
TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24);
199207

200208
auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get());
@@ -210,117 +218,15 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
210218
filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
211219
}
212220

213-
void CpuDeviceInterface::createFilterGraph(
214-
const DecodedFrameContext& frameContext,
215-
const VideoStreamOptions& videoStreamOptions,
216-
const AVRational& timeBase) {
217-
filterGraphContext_.filterGraph.reset(avfilter_graph_alloc());
218-
TORCH_CHECK(filterGraphContext_.filterGraph.get() != nullptr);
219-
220-
if (videoStreamOptions.ffmpegThreadCount.has_value()) {
221-
filterGraphContext_.filterGraph->nb_threads =
222-
videoStreamOptions.ffmpegThreadCount.value();
223-
}
224-
225-
const AVFilter* buffersrc = avfilter_get_by_name("buffer");
226-
const AVFilter* buffersink = avfilter_get_by_name("buffersink");
227-
228-
std::stringstream filterArgs;
229-
filterArgs << "video_size=" << frameContext.decodedWidth << "x"
230-
<< frameContext.decodedHeight;
231-
filterArgs << ":pix_fmt=" << frameContext.decodedFormat;
232-
filterArgs << ":time_base=" << timeBase.num << "/" << timeBase.den;
233-
filterArgs << ":pixel_aspect=" << frameContext.decodedAspectRatio.num << "/"
234-
<< frameContext.decodedAspectRatio.den;
235-
236-
int status = avfilter_graph_create_filter(
237-
&filterGraphContext_.sourceContext,
238-
buffersrc,
239-
"in",
240-
filterArgs.str().c_str(),
241-
nullptr,
242-
filterGraphContext_.filterGraph.get());
243-
TORCH_CHECK(
244-
status >= 0,
245-
"Failed to create filter graph: ",
246-
filterArgs.str(),
247-
": ",
248-
getFFMPEGErrorStringFromErrorCode(status));
249-
250-
status = avfilter_graph_create_filter(
251-
&filterGraphContext_.sinkContext,
252-
buffersink,
253-
"out",
254-
nullptr,
255-
nullptr,
256-
filterGraphContext_.filterGraph.get());
257-
TORCH_CHECK(
258-
status >= 0,
259-
"Failed to create filter graph: ",
260-
getFFMPEGErrorStringFromErrorCode(status));
261-
262-
enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE};
263-
264-
status = av_opt_set_int_list(
265-
filterGraphContext_.sinkContext,
266-
"pix_fmts",
267-
pix_fmts,
268-
AV_PIX_FMT_NONE,
269-
AV_OPT_SEARCH_CHILDREN);
270-
TORCH_CHECK(
271-
status >= 0,
272-
"Failed to set output pixel formats: ",
273-
getFFMPEGErrorStringFromErrorCode(status));
274-
275-
UniqueAVFilterInOut outputs(avfilter_inout_alloc());
276-
UniqueAVFilterInOut inputs(avfilter_inout_alloc());
277-
278-
outputs->name = av_strdup("in");
279-
outputs->filter_ctx = filterGraphContext_.sourceContext;
280-
outputs->pad_idx = 0;
281-
outputs->next = nullptr;
282-
inputs->name = av_strdup("out");
283-
inputs->filter_ctx = filterGraphContext_.sinkContext;
284-
inputs->pad_idx = 0;
285-
inputs->next = nullptr;
286-
287-
std::stringstream description;
288-
description << "scale=" << frameContext.expectedWidth << ":"
289-
<< frameContext.expectedHeight;
290-
description << ":sws_flags=bilinear";
291-
292-
AVFilterInOut* outputsTmp = outputs.release();
293-
AVFilterInOut* inputsTmp = inputs.release();
294-
status = avfilter_graph_parse_ptr(
295-
filterGraphContext_.filterGraph.get(),
296-
description.str().c_str(),
297-
&inputsTmp,
298-
&outputsTmp,
299-
nullptr);
300-
outputs.reset(outputsTmp);
301-
inputs.reset(inputsTmp);
302-
TORCH_CHECK(
303-
status >= 0,
304-
"Failed to parse filter description: ",
305-
getFFMPEGErrorStringFromErrorCode(status));
306-
307-
status =
308-
avfilter_graph_config(filterGraphContext_.filterGraph.get(), nullptr);
309-
TORCH_CHECK(
310-
status >= 0,
311-
"Failed to configure filter graph: ",
312-
getFFMPEGErrorStringFromErrorCode(status));
313-
}
314-
315221
void CpuDeviceInterface::createSwsContext(
316-
const DecodedFrameContext& frameContext,
222+
const SwsFrameContext& swsFrameContext,
317223
const enum AVColorSpace colorspace) {
318224
SwsContext* swsContext = sws_getContext(
319-
frameContext.decodedWidth,
320-
frameContext.decodedHeight,
321-
frameContext.decodedFormat,
322-
frameContext.expectedWidth,
323-
frameContext.expectedHeight,
225+
swsFrameContext.inputWidth,
226+
swsFrameContext.inputHeight,
227+
swsFrameContext.inputFormat,
228+
swsFrameContext.outputWidth,
229+
swsFrameContext.outputHeight,
324230
AV_PIX_FMT_RGB24,
325231
SWS_BILINEAR,
326232
nullptr,

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "src/torchcodec/_core/DeviceInterface.h"
1010
#include "src/torchcodec/_core/FFMPEGCommon.h"
11+
#include "src/torchcodec/_core/FilterGraph.h"
1112

1213
namespace facebook::torchcodec {
1314

@@ -41,40 +42,29 @@ class CpuDeviceInterface : public DeviceInterface {
4142
torch::Tensor convertAVFrameToTensorUsingFilterGraph(
4243
const UniqueAVFrame& avFrame);
4344

44-
struct FilterGraphContext {
45-
UniqueAVFilterGraph filterGraph;
46-
AVFilterContext* sourceContext = nullptr;
47-
AVFilterContext* sinkContext = nullptr;
48-
};
49-
50-
struct DecodedFrameContext {
51-
int decodedWidth;
52-
int decodedHeight;
53-
AVPixelFormat decodedFormat;
54-
AVRational decodedAspectRatio;
55-
int expectedWidth;
56-
int expectedHeight;
57-
bool operator==(const DecodedFrameContext&);
58-
bool operator!=(const DecodedFrameContext&);
45+
struct SwsFrameContext {
46+
int inputWidth;
47+
int inputHeight;
48+
AVPixelFormat inputFormat;
49+
int outputWidth;
50+
int outputHeight;
51+
bool operator==(const SwsFrameContext&) const;
52+
bool operator!=(const SwsFrameContext&) const;
5953
};
6054

6155
void createSwsContext(
62-
const DecodedFrameContext& frameContext,
56+
const SwsFrameContext& swsFrameContext,
6357
const enum AVColorSpace colorspace);
6458

65-
void createFilterGraph(
66-
const DecodedFrameContext& frameContext,
67-
const VideoStreamOptions& videoStreamOptions,
68-
const AVRational& timeBase);
69-
7059
// color-conversion fields. Only one of FilterGraphContext and
7160
// UniqueSwsContext should be non-null.
72-
FilterGraphContext filterGraphContext_;
61+
std::unique_ptr<FilterGraph> filterGraphContext_;
7362
UniqueSwsContext swsContext_;
7463

7564
// Used to know whether a new FilterGraphContext or UniqueSwsContext should
7665
// be created before decoding a new frame.
77-
DecodedFrameContext prevFrameContext_;
66+
SwsFrameContext prevSwsFrameContext_;
67+
FiltersContext prevFiltersContext_;
7868
};
7969

8070
} // namespace facebook::torchcodec

src/torchcodec/_core/FFMPEGCommon.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
extern "C" {
1414
#include <libavcodec/avcodec.h>
1515
#include <libavfilter/avfilter.h>
16+
#include <libavfilter/buffersrc.h>
1617
#include <libavformat/avformat.h>
1718
#include <libavformat/avio.h>
1819
#include <libavutil/audio_fifo.h>
@@ -41,6 +42,15 @@ struct Deleterp {
4142
}
4243
};
4344

45+
template <typename T, typename R, R (*Fn)(void*)>
46+
struct Deleterv {
47+
inline void operator()(T* p) const {
48+
if (p) {
49+
Fn(&p);
50+
}
51+
}
52+
};
53+
4454
template <typename T, typename R, R (*Fn)(T*)>
4555
struct Deleter {
4656
inline void operator()(T* p) const {
@@ -78,6 +88,9 @@ using UniqueAVAudioFifo = std::
7888
unique_ptr<AVAudioFifo, Deleter<AVAudioFifo, void, av_audio_fifo_free>>;
7989
using UniqueAVBufferRef =
8090
std::unique_ptr<AVBufferRef, Deleterp<AVBufferRef, void, av_buffer_unref>>;
91+
using UniqueAVBufferSrcParameters = std::unique_ptr<
92+
AVBufferSrcParameters,
93+
Deleterv<AVBufferSrcParameters, void, av_freep>>;
8194

8295
// These 2 classes share the same underlying AVPacket object. They are meant to
8396
// be used in tandem, like so:

0 commit comments

Comments
 (0)