-
Notifications
You must be signed in to change notification settings - Fork 58
Move filter graph to stand alone class #831
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7448d96
f150c6f
028b612
a0ecb95
72f6404
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,11 +6,6 @@ | |
|
||
#include "src/torchcodec/_core/CpuDeviceInterface.h" | ||
|
||
extern "C" { | ||
#include <libavfilter/buffersink.h> | ||
#include <libavfilter/buffersrc.h> | ||
} | ||
|
||
namespace facebook::torchcodec { | ||
namespace { | ||
|
||
|
@@ -20,17 +15,15 @@ static bool g_cpu = registerDeviceInterface( | |
|
||
} // namespace | ||
|
||
bool CpuDeviceInterface::DecodedFrameContext::operator==( | ||
const CpuDeviceInterface::DecodedFrameContext& other) { | ||
return decodedWidth == other.decodedWidth && | ||
decodedHeight == other.decodedHeight && | ||
decodedFormat == other.decodedFormat && | ||
expectedWidth == other.expectedWidth && | ||
expectedHeight == other.expectedHeight; | ||
bool CpuDeviceInterface::SwsFrameContext::operator==( | ||
const CpuDeviceInterface::SwsFrameContext& other) const { | ||
return inputWidth == other.inputWidth && inputHeight == other.inputHeight && | ||
inputFormat == other.inputFormat && outputWidth == other.outputWidth && | ||
outputHeight == other.outputHeight; | ||
} | ||
|
||
bool CpuDeviceInterface::DecodedFrameContext::operator!=( | ||
const CpuDeviceInterface::DecodedFrameContext& other) { | ||
bool CpuDeviceInterface::SwsFrameContext::operator!=( | ||
const CpuDeviceInterface::SwsFrameContext& other) const { | ||
return !(*this == other); | ||
} | ||
|
||
|
@@ -75,22 +68,8 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( | |
} | ||
|
||
torch::Tensor outputTensor; | ||
// We need to compare the current frame context with our previous frame | ||
// context. If they are different, then we need to re-create our colorspace | ||
// conversion objects. We create our colorspace conversion objects late so | ||
// that we don't have to depend on the unreliable metadata in the header. | ||
// And we sometimes re-create them because it's possible for frame | ||
// resolution to change mid-stream. Finally, we want to reuse the colorspace | ||
// conversion objects as much as possible for performance reasons. | ||
enum AVPixelFormat frameFormat = | ||
static_cast<enum AVPixelFormat>(avFrame->format); | ||
auto frameContext = DecodedFrameContext{ | ||
avFrame->width, | ||
avFrame->height, | ||
frameFormat, | ||
avFrame->sample_aspect_ratio, | ||
expectedOutputWidth, | ||
expectedOutputHeight}; | ||
|
||
// By default, we want to use swscale for color conversion because it is | ||
// faster. However, it has width requirements, so we may need to fall back | ||
|
@@ -111,12 +90,27 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( | |
videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary); | ||
|
||
if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) { | ||
// We need to compare the current frame context with our previous frame | ||
// context. If they are different, then we need to re-create our colorspace | ||
// conversion objects. We create our colorspace conversion objects late so | ||
// that we don't have to depend on the unreliable metadata in the header. | ||
// And we sometimes re-create them because it's possible for frame | ||
// resolution to change mid-stream. Finally, we want to reuse the colorspace | ||
// conversion objects as much as possible for performance reasons. | ||
SwsFrameContext swsFrameContext; | ||
|
||
swsFrameContext.inputWidth = avFrame->width; | ||
swsFrameContext.inputHeight = avFrame->height; | ||
swsFrameContext.inputFormat = frameFormat; | ||
swsFrameContext.outputWidth = expectedOutputWidth; | ||
swsFrameContext.outputHeight = expectedOutputHeight; | ||
|
||
outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor( | ||
expectedOutputHeight, expectedOutputWidth, torch::kCPU)); | ||
|
||
if (!swsContext_ || prevFrameContext_ != frameContext) { | ||
createSwsContext(frameContext, avFrame->colorspace); | ||
prevFrameContext_ = frameContext; | ||
if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) { | ||
createSwsContext(swsFrameContext, avFrame->colorspace); | ||
prevSwsFrameContext_ = swsFrameContext; | ||
} | ||
int resultHeight = | ||
convertAVFrameToTensorUsingSwsScale(avFrame, outputTensor); | ||
|
@@ -132,9 +126,29 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( | |
|
||
frameOutput.data = outputTensor; | ||
} else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { | ||
if (!filterGraphContext_.filterGraph || prevFrameContext_ != frameContext) { | ||
createFilterGraph(frameContext, videoStreamOptions, timeBase); | ||
prevFrameContext_ = frameContext; | ||
// See comment above in swscale branch about the filterGraphContext_ | ||
// creation. creation | ||
FiltersContext filtersContext; | ||
|
||
filtersContext.inputWidth = avFrame->width; | ||
filtersContext.inputHeight = avFrame->height; | ||
filtersContext.inputFormat = frameFormat; | ||
filtersContext.inputAspectRatio = avFrame->sample_aspect_ratio; | ||
filtersContext.outputWidth = expectedOutputWidth; | ||
filtersContext.outputHeight = expectedOutputHeight; | ||
filtersContext.outputFormat = AV_PIX_FMT_RGB24; | ||
filtersContext.timeBase = timeBase; | ||
|
||
std::stringstream filters; | ||
filters << "scale=" << expectedOutputWidth << ":" << expectedOutputHeight; | ||
filters << ":sws_flags=bilinear"; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's unfortunate that we now have to re-create the string for every single frame in order to compare the |
||
|
||
filtersContext.filtergraphStr = filters.str(); | ||
|
||
if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) { | ||
filterGraphContext_ = | ||
std::make_unique<FilterGraph>(filtersContext, videoStreamOptions); | ||
prevFiltersContext_ = std::move(filtersContext); | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame); | ||
|
||
|
@@ -187,14 +201,8 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale( | |
|
||
torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph( | ||
const UniqueAVFrame& avFrame) { | ||
int status = av_buffersrc_write_frame( | ||
filterGraphContext_.sourceContext, avFrame.get()); | ||
TORCH_CHECK( | ||
status >= AVSUCCESS, "Failed to add frame to buffer source context"); | ||
UniqueAVFrame filteredAVFrame = filterGraphContext_->convert(avFrame); | ||
|
||
UniqueAVFrame filteredAVFrame(av_frame_alloc()); | ||
status = av_buffersink_get_frame( | ||
filterGraphContext_.sinkContext, filteredAVFrame.get()); | ||
TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24); | ||
|
||
auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get()); | ||
|
@@ -210,117 +218,15 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph( | |
filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8}); | ||
} | ||
|
||
void CpuDeviceInterface::createFilterGraph( | ||
const DecodedFrameContext& frameContext, | ||
const VideoStreamOptions& videoStreamOptions, | ||
const AVRational& timeBase) { | ||
filterGraphContext_.filterGraph.reset(avfilter_graph_alloc()); | ||
TORCH_CHECK(filterGraphContext_.filterGraph.get() != nullptr); | ||
|
||
if (videoStreamOptions.ffmpegThreadCount.has_value()) { | ||
filterGraphContext_.filterGraph->nb_threads = | ||
videoStreamOptions.ffmpegThreadCount.value(); | ||
} | ||
|
||
const AVFilter* buffersrc = avfilter_get_by_name("buffer"); | ||
const AVFilter* buffersink = avfilter_get_by_name("buffersink"); | ||
|
||
std::stringstream filterArgs; | ||
filterArgs << "video_size=" << frameContext.decodedWidth << "x" | ||
<< frameContext.decodedHeight; | ||
filterArgs << ":pix_fmt=" << frameContext.decodedFormat; | ||
filterArgs << ":time_base=" << timeBase.num << "/" << timeBase.den; | ||
filterArgs << ":pixel_aspect=" << frameContext.decodedAspectRatio.num << "/" | ||
<< frameContext.decodedAspectRatio.den; | ||
|
||
int status = avfilter_graph_create_filter( | ||
&filterGraphContext_.sourceContext, | ||
buffersrc, | ||
"in", | ||
filterArgs.str().c_str(), | ||
nullptr, | ||
filterGraphContext_.filterGraph.get()); | ||
TORCH_CHECK( | ||
status >= 0, | ||
"Failed to create filter graph: ", | ||
filterArgs.str(), | ||
": ", | ||
getFFMPEGErrorStringFromErrorCode(status)); | ||
|
||
status = avfilter_graph_create_filter( | ||
&filterGraphContext_.sinkContext, | ||
buffersink, | ||
"out", | ||
nullptr, | ||
nullptr, | ||
filterGraphContext_.filterGraph.get()); | ||
TORCH_CHECK( | ||
status >= 0, | ||
"Failed to create filter graph: ", | ||
getFFMPEGErrorStringFromErrorCode(status)); | ||
|
||
enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE}; | ||
|
||
status = av_opt_set_int_list( | ||
filterGraphContext_.sinkContext, | ||
"pix_fmts", | ||
pix_fmts, | ||
AV_PIX_FMT_NONE, | ||
AV_OPT_SEARCH_CHILDREN); | ||
TORCH_CHECK( | ||
status >= 0, | ||
"Failed to set output pixel formats: ", | ||
getFFMPEGErrorStringFromErrorCode(status)); | ||
|
||
UniqueAVFilterInOut outputs(avfilter_inout_alloc()); | ||
UniqueAVFilterInOut inputs(avfilter_inout_alloc()); | ||
|
||
outputs->name = av_strdup("in"); | ||
outputs->filter_ctx = filterGraphContext_.sourceContext; | ||
outputs->pad_idx = 0; | ||
outputs->next = nullptr; | ||
inputs->name = av_strdup("out"); | ||
inputs->filter_ctx = filterGraphContext_.sinkContext; | ||
inputs->pad_idx = 0; | ||
inputs->next = nullptr; | ||
|
||
std::stringstream description; | ||
description << "scale=" << frameContext.expectedWidth << ":" | ||
<< frameContext.expectedHeight; | ||
description << ":sws_flags=bilinear"; | ||
|
||
AVFilterInOut* outputsTmp = outputs.release(); | ||
AVFilterInOut* inputsTmp = inputs.release(); | ||
status = avfilter_graph_parse_ptr( | ||
filterGraphContext_.filterGraph.get(), | ||
description.str().c_str(), | ||
&inputsTmp, | ||
&outputsTmp, | ||
nullptr); | ||
outputs.reset(outputsTmp); | ||
inputs.reset(inputsTmp); | ||
TORCH_CHECK( | ||
status >= 0, | ||
"Failed to parse filter description: ", | ||
getFFMPEGErrorStringFromErrorCode(status)); | ||
|
||
status = | ||
avfilter_graph_config(filterGraphContext_.filterGraph.get(), nullptr); | ||
TORCH_CHECK( | ||
status >= 0, | ||
"Failed to configure filter graph: ", | ||
getFFMPEGErrorStringFromErrorCode(status)); | ||
} | ||
|
||
void CpuDeviceInterface::createSwsContext( | ||
const DecodedFrameContext& frameContext, | ||
const SwsFrameContext& swsFrameContext, | ||
const enum AVColorSpace colorspace) { | ||
SwsContext* swsContext = sws_getContext( | ||
frameContext.decodedWidth, | ||
frameContext.decodedHeight, | ||
frameContext.decodedFormat, | ||
frameContext.expectedWidth, | ||
frameContext.expectedHeight, | ||
swsFrameContext.inputWidth, | ||
swsFrameContext.inputHeight, | ||
swsFrameContext.inputFormat, | ||
swsFrameContext.outputWidth, | ||
swsFrameContext.outputHeight, | ||
AV_PIX_FMT_RGB24, | ||
SWS_BILINEAR, | ||
nullptr, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious if we should prefer using aggregate initialization here, instead of setting each field individually. I think I have a preference for aggregate initialization because if we add a new field to the
SwsFrameContext
struct, we'd get a loud compilation error if we forget to initialize the field (which is good).@scotts any pref?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I generally prefer objects to be fully specified on construction, rather than default-constructed and then members set one-by-one. But I don't think aggregate initialization will save use here from forgetting a field, as I think that for types with a default, it will just use that.
What I think we should do is actually just create a constructor for all of our structs. And to make sure we don't miss any members, we should remove the default constructor. (When possible. If we're creating an array of something we won't be able to do that.)