Skip to content

Commit 1ec5666

Browse files
authored
Support encoding to file-like object (#754)
1 parent 6b345aa commit 1ec5666

12 files changed

+330
-51
lines changed

src/torchcodec/_core/AVIOContextHolder.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ void AVIOContextHolder::createAVIOContext(
1414
AVIOWriteFunction write,
1515
AVIOSeekFunction seek,
1616
void* heldData,
17+
bool isForWriting,
1718
int bufferSize) {
1819
TORCH_CHECK(
1920
bufferSize > 0,
@@ -23,14 +24,18 @@ void AVIOContextHolder::createAVIOContext(
2324
buffer != nullptr,
2425
"Failed to allocate buffer of size " + std::to_string(bufferSize));
2526

26-
TORCH_CHECK(
27-
(seek != nullptr) && ((write != nullptr) ^ (read != nullptr)),
28-
"seek method must be defined, and either write or read must be defined. "
29-
"But not both!")
27+
TORCH_CHECK(seek != nullptr, "seek method must be defined");
28+
29+
if (isForWriting) {
30+
TORCH_CHECK(write != nullptr, "write method must be defined for writing");
31+
} else {
32+
TORCH_CHECK(read != nullptr, "read method must be defined for reading");
33+
}
34+
3035
avioContext_.reset(avioAllocContext(
3136
buffer,
3237
bufferSize,
33-
/*write_flag=*/write != nullptr,
38+
/*write_flag=*/isForWriting,
3439
heldData,
3540
read,
3641
write,

src/torchcodec/_core/AVIOContextHolder.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class AVIOContextHolder {
5151
AVIOWriteFunction write,
5252
AVIOSeekFunction seek,
5353
void* heldData,
54+
bool isForWriting,
5455
int bufferSize = defaultBufferSize);
5556

5657
private:

src/torchcodec/_core/AVIOFileLikeContext.cpp

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,29 @@
99

1010
namespace facebook::torchcodec {
1111

12-
AVIOFileLikeContext::AVIOFileLikeContext(py::object fileLike)
12+
AVIOFileLikeContext::AVIOFileLikeContext(py::object fileLike, bool isForWriting)
1313
: fileLike_{UniquePyObject(new py::object(fileLike))} {
1414
{
1515
// TODO: Is it necessary to acquire the GIL here? Is it maybe even
1616
// harmful? At the moment, this is only called from within a pybind
1717
// function, and pybind guarantees we have the GIL.
1818
py::gil_scoped_acquire gil;
19-
TORCH_CHECK(
20-
py::hasattr(fileLike, "read"),
21-
"File like object must implement a read method.");
19+
20+
if (isForWriting) {
21+
TORCH_CHECK(
22+
py::hasattr(fileLike, "write"),
23+
"File like object must implement a write method for writing.");
24+
} else {
25+
TORCH_CHECK(
26+
py::hasattr(fileLike, "read"),
27+
"File like object must implement a read method for reading.");
28+
}
29+
2230
TORCH_CHECK(
2331
py::hasattr(fileLike, "seek"),
2432
"File like object must implement a seek method.");
2533
}
26-
createAVIOContext(&read, nullptr, &seek, &fileLike_);
34+
createAVIOContext(&read, &write, &seek, &fileLike_, isForWriting);
2735
}
2836

2937
int AVIOFileLikeContext::read(void* opaque, uint8_t* buf, int buf_size) {
@@ -77,4 +85,12 @@ int64_t AVIOFileLikeContext::seek(void* opaque, int64_t offset, int whence) {
7785
return py::cast<int64_t>((*fileLike)->attr("seek")(offset, whence));
7886
}
7987

88+
int AVIOFileLikeContext::write(void* opaque, const uint8_t* buf, int buf_size) {
89+
auto fileLike = static_cast<UniquePyObject*>(opaque);
90+
py::gil_scoped_acquire gil;
91+
py::bytes bytes_obj(reinterpret_cast<const char*>(buf), buf_size);
92+
93+
return py::cast<int64_t>((*fileLike)->attr("write")(bytes_obj));
94+
}
95+
8096
} // namespace facebook::torchcodec

src/torchcodec/_core/AVIOFileLikeContext.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@ namespace facebook::torchcodec {
1919
// and seek calls back up to the methods on the Python object.
2020
class AVIOFileLikeContext : public AVIOContextHolder {
2121
public:
22-
explicit AVIOFileLikeContext(py::object fileLike);
22+
explicit AVIOFileLikeContext(py::object fileLike, bool isForWriting);
2323

2424
private:
2525
static int read(void* opaque, uint8_t* buf, int buf_size);
2626
static int64_t seek(void* opaque, int64_t offset, int whence);
27+
static int write(void* opaque, const uint8_t* buf, int buf_size);
2728

2829
// Note that we dynamically allocate the Python object because we need to
2930
// strictly control when its destructor is called. We must hold the GIL

src/torchcodec/_core/AVIOTensorContext.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,14 @@ AVIOFromTensorContext::AVIOFromTensorContext(torch::Tensor data)
105105
TORCH_CHECK(data.numel() > 0, "data must not be empty");
106106
TORCH_CHECK(data.is_contiguous(), "data must be contiguous");
107107
TORCH_CHECK(data.scalar_type() == torch::kUInt8, "data must be kUInt8");
108-
createAVIOContext(&read, nullptr, &seek, &tensorContext_);
108+
createAVIOContext(
109+
&read, nullptr, &seek, &tensorContext_, /*isForWriting=*/false);
109110
}
110111

111112
AVIOToTensorContext::AVIOToTensorContext()
112113
: tensorContext_{torch::empty({INITIAL_TENSOR_SIZE}, {torch::kUInt8}), 0} {
113-
createAVIOContext(nullptr, &write, &seek, &tensorContext_);
114+
createAVIOContext(
115+
nullptr, &write, &seek, &tensorContext_, /*isForWriting=*/true);
114116
}
115117

116118
torch::Tensor AVIOToTensorContext::getOutputTensor() {

src/torchcodec/_core/Encoder.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ AudioEncoder::AudioEncoder(
149149
const torch::Tensor& samples,
150150
int sampleRate,
151151
std::string_view formatName,
152-
std::unique_ptr<AVIOToTensorContext> avioContextHolder,
152+
std::unique_ptr<AVIOContextHolder> avioContextHolder,
153153
const AudioStreamOptions& audioStreamOptions)
154154
: samples_(validateSamples(samples)),
155155
inSampleRate_(sampleRate),
@@ -248,9 +248,12 @@ void AudioEncoder::initializeEncoder(
248248
torch::Tensor AudioEncoder::encodeToTensor() {
249249
TORCH_CHECK(
250250
avioContextHolder_ != nullptr,
251-
"Cannot encode to tensor, avio context doesn't exist.");
251+
"Cannot encode to tensor, avio tensor context doesn't exist.");
252252
encode();
253-
return avioContextHolder_->getOutputTensor();
253+
auto avioToTensorContext =
254+
dynamic_cast<AVIOToTensorContext*>(avioContextHolder_.get());
255+
TORCH_CHECK(avioToTensorContext != nullptr, "Invalid AVIO context holder.");
256+
return avioToTensorContext->getOutputTensor();
254257
}
255258

256259
void AudioEncoder::encode() {
@@ -501,6 +504,7 @@ void AudioEncoder::maybeFlushSwrBuffers(AutoAVPacket& autoAVPacket) {
501504
void AudioEncoder::flushBuffers() {
502505
AutoAVPacket autoAVPacket;
503506
maybeFlushSwrBuffers(autoAVPacket);
507+
504508
encodeFrame(autoAVPacket, UniqueAVFrame(nullptr));
505509
}
506510
} // namespace facebook::torchcodec

src/torchcodec/_core/Encoder.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
22
#include <torch/types.h>
3-
#include "src/torchcodec/_core/AVIOTensorContext.h"
3+
#include "src/torchcodec/_core/AVIOContextHolder.h"
44
#include "src/torchcodec/_core/FFMPEGCommon.h"
55
#include "src/torchcodec/_core/StreamOptions.h"
66

@@ -14,13 +14,16 @@ class AudioEncoder {
1414
int sampleRate,
1515
std::string_view fileName,
1616
const AudioStreamOptions& audioStreamOptions);
17+
1718
AudioEncoder(
1819
const torch::Tensor& samples,
1920
int sampleRate,
2021
std::string_view formatName,
21-
std::unique_ptr<AVIOToTensorContext> avioContextHolder,
22+
std::unique_ptr<AVIOContextHolder> avioContextHolder,
2223
const AudioStreamOptions& audioStreamOptions);
24+
2325
void encode();
26+
2427
torch::Tensor encodeToTensor();
2528

2629
private:
@@ -49,8 +52,7 @@ class AudioEncoder {
4952

5053
UniqueAVAudioFifo avAudioFifo_;
5154

52-
// Stores the AVIOContext for the output tensor buffer.
53-
std::unique_ptr<AVIOToTensorContext> avioContextHolder_;
55+
std::unique_ptr<AVIOContextHolder> avioContextHolder_;
5456

5557
bool encodeWasCalled_ = false;
5658
int64_t lastEncodedAVFramePts_ = 0;

src/torchcodec/_core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
create_from_file_like,
2424
create_from_tensor,
2525
encode_audio_to_file,
26+
encode_audio_to_file_like,
2627
encode_audio_to_tensor,
2728
get_ffmpeg_library_versions,
2829
get_frame_at_index,

src/torchcodec/_core/ops.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,62 @@ def create_from_file_like(
151151
return _convert_to_tensor(_pybind_ops.create_from_file_like(file_like, seek_mode))
152152

153153

154+
def encode_audio_to_file_like(
155+
samples: torch.Tensor,
156+
sample_rate: int,
157+
format: str,
158+
file_like: Union[io.RawIOBase, io.BufferedIOBase],
159+
bit_rate: Optional[int] = None,
160+
num_channels: Optional[int] = None,
161+
desired_sample_rate: Optional[int] = None,
162+
) -> None:
163+
"""Encode audio samples to a file-like object.
164+
165+
Args:
166+
samples: Audio samples tensor
167+
sample_rate: Sample rate in Hz
168+
format: Audio format (e.g., "wav", "mp3", "flac")
169+
file_like: File-like object that supports write() and seek() methods
170+
bit_rate: Optional bit rate for encoding
171+
num_channels: Optional number of output channels
172+
desired_sample_rate: Optional desired sample rate for the output.
173+
"""
174+
assert _pybind_ops is not None
175+
176+
if samples.dtype != torch.float32:
177+
raise ValueError(f"samples must have dtype torch.float32, got {samples.dtype}")
178+
179+
# We're having the same problem as with the decoder's create_from_file_like:
180+
# We should be able to pass a tensor directly, but this leads to a pybind
181+
# error. In order to work around this, we pass the pointer to the tensor's
182+
# data, and its shape, in order to re-construct it in C++. For this to work:
183+
# - the tensor must be float32
184+
# - the tensor must be contiguous, which is why we call contiguous().
185+
# In theory we could avoid this restriction by also passing the strides?
186+
# - IMPORTANT: the input samples tensor and its underlying data must be
187+
# alive during the call.
188+
#
189+
# A more elegant solution would be to cast the tensor into a py::object, but
190+
# casting the py::object backk to a tensor in C++ seems to lead to the same
191+
# pybing error.
192+
193+
samples = samples.contiguous()
194+
_pybind_ops.encode_audio_to_file_like(
195+
samples.data_ptr(),
196+
list(samples.shape),
197+
sample_rate,
198+
format,
199+
file_like,
200+
bit_rate,
201+
num_channels,
202+
desired_sample_rate,
203+
)
204+
205+
# This check is useless but it's critical to keep it to ensures that samples
206+
# is still alive during the call to encode_audio_to_file_like.
207+
assert samples.is_contiguous()
208+
209+
154210
# ==============================
155211
# Abstract impl for the operators. Needed by torch.compile.
156212
# ==============================

src/torchcodec/_core/pybind_ops.cpp

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
#include <string>
1111

1212
#include "src/torchcodec/_core/AVIOFileLikeContext.h"
13+
#include "src/torchcodec/_core/Encoder.h"
1314
#include "src/torchcodec/_core/SingleStreamDecoder.h"
15+
#include "src/torchcodec/_core/StreamOptions.h"
1416

1517
namespace py = pybind11;
1618

@@ -31,19 +33,55 @@ int64_t create_from_file_like(
3133
realSeek = seekModeFromString(seek_mode.value());
3234
}
3335

34-
auto avioContextHolder = std::make_unique<AVIOFileLikeContext>(file_like);
36+
auto avioContextHolder =
37+
std::make_unique<AVIOFileLikeContext>(file_like, /*isForWriting=*/false);
3538

3639
SingleStreamDecoder* decoder =
3740
new SingleStreamDecoder(std::move(avioContextHolder), realSeek);
3841
return reinterpret_cast<int64_t>(decoder);
3942
}
4043

44+
void encode_audio_to_file_like(
45+
int64_t data_ptr,
46+
const std::vector<int64_t>& shape,
47+
int64_t sample_rate,
48+
std::string_view format,
49+
py::object file_like,
50+
std::optional<int64_t> bit_rate = std::nullopt,
51+
std::optional<int64_t> num_channels = std::nullopt,
52+
std::optional<int64_t> desired_sample_rate = std::nullopt) {
53+
// We assume float32 *and* contiguity, this must be enforced by the caller.
54+
auto tensor_options = torch::TensorOptions().dtype(torch::kFloat32);
55+
auto samples = torch::from_blob(
56+
reinterpret_cast<void*>(data_ptr), shape, tensor_options);
57+
58+
// TODO Fix implicit int conversion:
59+
// https://github.com/pytorch/torchcodec/issues/679
60+
// same for sample_rate parameter below
61+
AudioStreamOptions audioStreamOptions;
62+
audioStreamOptions.bitRate = bit_rate;
63+
audioStreamOptions.numChannels = num_channels;
64+
audioStreamOptions.sampleRate = desired_sample_rate;
65+
66+
auto avioContextHolder =
67+
std::make_unique<AVIOFileLikeContext>(file_like, /*isForWriting=*/true);
68+
69+
AudioEncoder encoder(
70+
samples,
71+
static_cast<int>(sample_rate),
72+
format,
73+
std::move(avioContextHolder),
74+
audioStreamOptions);
75+
encoder.encode();
76+
}
77+
4178
#ifndef PYBIND_OPS_MODULE_NAME
4279
#error PYBIND_OPS_MODULE_NAME must be defined!
4380
#endif
4481

4582
PYBIND11_MODULE(PYBIND_OPS_MODULE_NAME, m) {
4683
m.def("create_from_file_like", &create_from_file_like);
84+
m.def("encode_audio_to_file_like", &encode_audio_to_file_like);
4785
}
4886

4987
} // namespace facebook::torchcodec

0 commit comments

Comments
 (0)