Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/torchcodec/_core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ function(make_torchcodec_libraries
AVIOContextHolder.cpp
AVIOTensorContext.cpp
FFMPEGCommon.cpp
FilterGraph.cpp
Frame.cpp
DeviceInterface.cpp
CpuDeviceInterface.cpp
Expand Down
204 changes: 55 additions & 149 deletions src/torchcodec/_core/CpuDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,6 @@

#include "src/torchcodec/_core/CpuDeviceInterface.h"

extern "C" {
#include <libavfilter/buffersink.h>
#include <libavfilter/buffersrc.h>
}

namespace facebook::torchcodec {
namespace {

Expand All @@ -20,17 +15,15 @@ static bool g_cpu = registerDeviceInterface(

} // namespace

bool CpuDeviceInterface::DecodedFrameContext::operator==(
const CpuDeviceInterface::DecodedFrameContext& other) {
return decodedWidth == other.decodedWidth &&
decodedHeight == other.decodedHeight &&
decodedFormat == other.decodedFormat &&
expectedWidth == other.expectedWidth &&
expectedHeight == other.expectedHeight;
bool CpuDeviceInterface::SwsFrameContext::operator==(
const CpuDeviceInterface::SwsFrameContext& other) const {
return inputWidth == other.inputWidth && inputHeight == other.inputHeight &&
inputFormat == other.inputFormat && outputWidth == other.outputWidth &&
outputHeight == other.outputHeight;
}

bool CpuDeviceInterface::DecodedFrameContext::operator!=(
const CpuDeviceInterface::DecodedFrameContext& other) {
bool CpuDeviceInterface::SwsFrameContext::operator!=(
const CpuDeviceInterface::SwsFrameContext& other) const {
return !(*this == other);
}

Expand Down Expand Up @@ -75,22 +68,8 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
}

torch::Tensor outputTensor;
// We need to compare the current frame context with our previous frame
// context. If they are different, then we need to re-create our colorspace
// conversion objects. We create our colorspace conversion objects late so
// that we don't have to depend on the unreliable metadata in the header.
// And we sometimes re-create them because it's possible for frame
// resolution to change mid-stream. Finally, we want to reuse the colorspace
// conversion objects as much as possible for performance reasons.
enum AVPixelFormat frameFormat =
static_cast<enum AVPixelFormat>(avFrame->format);
auto frameContext = DecodedFrameContext{
avFrame->width,
avFrame->height,
frameFormat,
avFrame->sample_aspect_ratio,
expectedOutputWidth,
expectedOutputHeight};

// By default, we want to use swscale for color conversion because it is
// faster. However, it has width requirements, so we may need to fall back
Expand All @@ -111,12 +90,27 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary);

if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
// We need to compare the current frame context with our previous frame
// context. If they are different, then we need to re-create our colorspace
// conversion objects. We create our colorspace conversion objects late so
// that we don't have to depend on the unreliable metadata in the header.
// And we sometimes re-create them because it's possible for frame
// resolution to change mid-stream. Finally, we want to reuse the colorspace
// conversion objects as much as possible for performance reasons.
SwsFrameContext swsFrameContext;

swsFrameContext.inputWidth = avFrame->width;
swsFrameContext.inputHeight = avFrame->height;
swsFrameContext.inputFormat = frameFormat;
swsFrameContext.outputWidth = expectedOutputWidth;
swsFrameContext.outputHeight = expectedOutputHeight;
Comment on lines +102 to +106
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious if we should prefer using aggregate initialization here, instead of setting each field individually. I think I have a preference for aggregate initialization because if we add a new field to the SwsFrameContext struct, we'd get a loud compilation error if we forget to initialize the field (which is good).

@scotts any pref?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I generally prefer objects to be fully specified on construction, rather than default-constructed and then members set one-by-one. But I don't think aggregate initialization will save use here from forgetting a field, as I think that for types with a default, it will just use that.

What I think we should do is actually just create a constructor for all of our structs. And to make sure we don't miss any members, we should remove the default constructor. (When possible. If we're creating an array of something we won't be able to do that.)


outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor(
expectedOutputHeight, expectedOutputWidth, torch::kCPU));

if (!swsContext_ || prevFrameContext_ != frameContext) {
createSwsContext(frameContext, avFrame->colorspace);
prevFrameContext_ = frameContext;
if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) {
createSwsContext(swsFrameContext, avFrame->colorspace);
prevSwsFrameContext_ = swsFrameContext;
}
int resultHeight =
convertAVFrameToTensorUsingSwsScale(avFrame, outputTensor);
Expand All @@ -132,9 +126,29 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(

frameOutput.data = outputTensor;
} else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
if (!filterGraphContext_.filterGraph || prevFrameContext_ != frameContext) {
createFilterGraph(frameContext, videoStreamOptions, timeBase);
prevFrameContext_ = frameContext;
// See comment above in swscale branch about the filterGraphContext_
// creation. creation
FiltersContext filtersContext;

filtersContext.inputWidth = avFrame->width;
filtersContext.inputHeight = avFrame->height;
filtersContext.inputFormat = frameFormat;
filtersContext.inputAspectRatio = avFrame->sample_aspect_ratio;
filtersContext.outputWidth = expectedOutputWidth;
filtersContext.outputHeight = expectedOutputHeight;
filtersContext.outputFormat = AV_PIX_FMT_RGB24;
filtersContext.timeBase = timeBase;

std::stringstream filters;
filters << "scale=" << expectedOutputWidth << ":" << expectedOutputHeight;
filters << ":sws_flags=bilinear";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unfortunate that we now have to re-create the string for every single frame in order to compare the filtersContext objects. Maybe eventually we should separate the concerns between filter-graph creation parameters and comparison operators. But I guess for now this is cheap enough.


filtersContext.filtergraphStr = filters.str();

if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) {
filterGraphContext_ =
std::make_unique<FilterGraph>(filtersContext, videoStreamOptions);
prevFiltersContext_ = std::move(filtersContext);
}
outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame);

Expand Down Expand Up @@ -187,14 +201,8 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale(

torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
const UniqueAVFrame& avFrame) {
int status = av_buffersrc_write_frame(
filterGraphContext_.sourceContext, avFrame.get());
TORCH_CHECK(
status >= AVSUCCESS, "Failed to add frame to buffer source context");
UniqueAVFrame filteredAVFrame = filterGraphContext_->convert(avFrame);

UniqueAVFrame filteredAVFrame(av_frame_alloc());
status = av_buffersink_get_frame(
filterGraphContext_.sinkContext, filteredAVFrame.get());
TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24);

auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get());
Expand All @@ -210,117 +218,15 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
}

void CpuDeviceInterface::createFilterGraph(
const DecodedFrameContext& frameContext,
const VideoStreamOptions& videoStreamOptions,
const AVRational& timeBase) {
filterGraphContext_.filterGraph.reset(avfilter_graph_alloc());
TORCH_CHECK(filterGraphContext_.filterGraph.get() != nullptr);

if (videoStreamOptions.ffmpegThreadCount.has_value()) {
filterGraphContext_.filterGraph->nb_threads =
videoStreamOptions.ffmpegThreadCount.value();
}

const AVFilter* buffersrc = avfilter_get_by_name("buffer");
const AVFilter* buffersink = avfilter_get_by_name("buffersink");

std::stringstream filterArgs;
filterArgs << "video_size=" << frameContext.decodedWidth << "x"
<< frameContext.decodedHeight;
filterArgs << ":pix_fmt=" << frameContext.decodedFormat;
filterArgs << ":time_base=" << timeBase.num << "/" << timeBase.den;
filterArgs << ":pixel_aspect=" << frameContext.decodedAspectRatio.num << "/"
<< frameContext.decodedAspectRatio.den;

int status = avfilter_graph_create_filter(
&filterGraphContext_.sourceContext,
buffersrc,
"in",
filterArgs.str().c_str(),
nullptr,
filterGraphContext_.filterGraph.get());
TORCH_CHECK(
status >= 0,
"Failed to create filter graph: ",
filterArgs.str(),
": ",
getFFMPEGErrorStringFromErrorCode(status));

status = avfilter_graph_create_filter(
&filterGraphContext_.sinkContext,
buffersink,
"out",
nullptr,
nullptr,
filterGraphContext_.filterGraph.get());
TORCH_CHECK(
status >= 0,
"Failed to create filter graph: ",
getFFMPEGErrorStringFromErrorCode(status));

enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE};

status = av_opt_set_int_list(
filterGraphContext_.sinkContext,
"pix_fmts",
pix_fmts,
AV_PIX_FMT_NONE,
AV_OPT_SEARCH_CHILDREN);
TORCH_CHECK(
status >= 0,
"Failed to set output pixel formats: ",
getFFMPEGErrorStringFromErrorCode(status));

UniqueAVFilterInOut outputs(avfilter_inout_alloc());
UniqueAVFilterInOut inputs(avfilter_inout_alloc());

outputs->name = av_strdup("in");
outputs->filter_ctx = filterGraphContext_.sourceContext;
outputs->pad_idx = 0;
outputs->next = nullptr;
inputs->name = av_strdup("out");
inputs->filter_ctx = filterGraphContext_.sinkContext;
inputs->pad_idx = 0;
inputs->next = nullptr;

std::stringstream description;
description << "scale=" << frameContext.expectedWidth << ":"
<< frameContext.expectedHeight;
description << ":sws_flags=bilinear";

AVFilterInOut* outputsTmp = outputs.release();
AVFilterInOut* inputsTmp = inputs.release();
status = avfilter_graph_parse_ptr(
filterGraphContext_.filterGraph.get(),
description.str().c_str(),
&inputsTmp,
&outputsTmp,
nullptr);
outputs.reset(outputsTmp);
inputs.reset(inputsTmp);
TORCH_CHECK(
status >= 0,
"Failed to parse filter description: ",
getFFMPEGErrorStringFromErrorCode(status));

status =
avfilter_graph_config(filterGraphContext_.filterGraph.get(), nullptr);
TORCH_CHECK(
status >= 0,
"Failed to configure filter graph: ",
getFFMPEGErrorStringFromErrorCode(status));
}

void CpuDeviceInterface::createSwsContext(
const DecodedFrameContext& frameContext,
const SwsFrameContext& swsFrameContext,
const enum AVColorSpace colorspace) {
SwsContext* swsContext = sws_getContext(
frameContext.decodedWidth,
frameContext.decodedHeight,
frameContext.decodedFormat,
frameContext.expectedWidth,
frameContext.expectedHeight,
swsFrameContext.inputWidth,
swsFrameContext.inputHeight,
swsFrameContext.inputFormat,
swsFrameContext.outputWidth,
swsFrameContext.outputHeight,
AV_PIX_FMT_RGB24,
SWS_BILINEAR,
nullptr,
Expand Down
36 changes: 13 additions & 23 deletions src/torchcodec/_core/CpuDeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "src/torchcodec/_core/DeviceInterface.h"
#include "src/torchcodec/_core/FFMPEGCommon.h"
#include "src/torchcodec/_core/FilterGraph.h"

namespace facebook::torchcodec {

Expand Down Expand Up @@ -41,40 +42,29 @@ class CpuDeviceInterface : public DeviceInterface {
torch::Tensor convertAVFrameToTensorUsingFilterGraph(
const UniqueAVFrame& avFrame);

struct FilterGraphContext {
UniqueAVFilterGraph filterGraph;
AVFilterContext* sourceContext = nullptr;
AVFilterContext* sinkContext = nullptr;
};

struct DecodedFrameContext {
int decodedWidth;
int decodedHeight;
AVPixelFormat decodedFormat;
AVRational decodedAspectRatio;
int expectedWidth;
int expectedHeight;
bool operator==(const DecodedFrameContext&);
bool operator!=(const DecodedFrameContext&);
struct SwsFrameContext {
int inputWidth;
int inputHeight;
AVPixelFormat inputFormat;
int outputWidth;
int outputHeight;
bool operator==(const SwsFrameContext&) const;
bool operator!=(const SwsFrameContext&) const;
};

void createSwsContext(
const DecodedFrameContext& frameContext,
const SwsFrameContext& swsFrameContext,
const enum AVColorSpace colorspace);

void createFilterGraph(
const DecodedFrameContext& frameContext,
const VideoStreamOptions& videoStreamOptions,
const AVRational& timeBase);

// color-conversion fields. Only one of FilterGraphContext and
// UniqueSwsContext should be non-null.
FilterGraphContext filterGraphContext_;
std::unique_ptr<FilterGraph> filterGraphContext_;
UniqueSwsContext swsContext_;

// Used to know whether a new FilterGraphContext or UniqueSwsContext should
// be created before decoding a new frame.
DecodedFrameContext prevFrameContext_;
SwsFrameContext prevSwsFrameContext_;
FiltersContext prevFiltersContext_;
};

} // namespace facebook::torchcodec
13 changes: 13 additions & 0 deletions src/torchcodec/_core/FFMPEGCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
extern "C" {
#include <libavcodec/avcodec.h>
#include <libavfilter/avfilter.h>
#include <libavfilter/buffersrc.h>
#include <libavformat/avformat.h>
#include <libavformat/avio.h>
#include <libavutil/audio_fifo.h>
Expand Down Expand Up @@ -41,6 +42,15 @@ struct Deleterp {
}
};

template <typename T, typename R, R (*Fn)(void*)>
struct Deleterv {
inline void operator()(T* p) const {
if (p) {
Fn(&p);
}
}
};

template <typename T, typename R, R (*Fn)(T*)>
struct Deleter {
inline void operator()(T* p) const {
Expand Down Expand Up @@ -78,6 +88,9 @@ using UniqueAVAudioFifo = std::
unique_ptr<AVAudioFifo, Deleter<AVAudioFifo, void, av_audio_fifo_free>>;
using UniqueAVBufferRef =
std::unique_ptr<AVBufferRef, Deleterp<AVBufferRef, void, av_buffer_unref>>;
using UniqueAVBufferSrcParameters = std::unique_ptr<
AVBufferSrcParameters,
Deleterv<AVBufferSrcParameters, void, av_freep>>;

// These 2 classes share the same underlying AVPacket object. They are meant to
// be used in tandem, like so:
Expand Down
Loading
Loading