Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions torchaudio/backend/sox_io_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def save(
sample_rate: int,
channels_first: bool = True,
compression: Optional[float] = None,
frames_per_chunk: int = 65536,
):
"""Save audio data to file.

Expand Down Expand Up @@ -115,8 +114,6 @@ def save(
``8`` is default and highest compression.
- OGG/VORBIS: number from -1 to 10; -1 is the highest compression and lowest
quality. Default: ``3``.
frames_per_chunk: The number of frames to process (convert to ``int32`` internally
then write to file) at a time.
"""
if compression is None:
ext = str(filepath)[-3:].lower()
Expand All @@ -131,7 +128,7 @@ def save(
else:
raise RuntimeError(f'Unsupported file type: "{ext}"')
signal = torch.classes.torchaudio.TensorSignal(tensor, sample_rate, channels_first)
torch.ops.torchaudio.sox_io_save_audio_file(filepath, signal, compression, frames_per_chunk)
torch.ops.torchaudio.sox_io_save_audio_file(filepath, signal, compression)


load_wav = load
2 changes: 1 addition & 1 deletion torchaudio/csrc/register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ static auto registerLoadAudioFile = torch::RegisterOperators().op(
static auto registerSaveAudioFile = torch::RegisterOperators().op(
torch::RegisterOperators::options()
.schema(
"torchaudio::sox_io_save_audio_file(str path, __torch__.torch.classes.torchaudio.TensorSignal signal, float compression, int frames_per_chunk) -> ()")
"torchaudio::sox_io_save_audio_file(str path, __torch__.torch.classes.torchaudio.TensorSignal signal, float compression) -> ()")
.catchAllKernel<
decltype(sox_io::save_audio_file),
&sox_io::save_audio_file>());
Expand Down
4 changes: 2 additions & 2 deletions torchaudio/csrc/sox_io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,7 @@ c10::intrusive_ptr<TensorSignal> load_audio_file(
void save_audio_file(
const std::string& file_name,
const c10::intrusive_ptr<TensorSignal>& signal,
const double compression,
const int64_t frames_per_chunk) {
const double compression) {
const auto tensor = signal->getTensor();
const auto sample_rate = signal->getSampleRate();
const auto channels_first = signal->getChannelsFirst();
Expand Down Expand Up @@ -154,6 +153,7 @@ void save_audio_file(
tensor_ = tensor_.t();
}

const int64_t frames_per_chunk = 65536;
for (int64_t i = 0; i < tensor_.size(0); i += frames_per_chunk) {
auto chunk = tensor_.index({Slice(i, i + frames_per_chunk), Slice()});
chunk = unnormalize_wav(chunk).contiguous();
Expand Down
4 changes: 2 additions & 2 deletions torchaudio/csrc/sox_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> load_audio_file(
void save_audio_file(
const std::string& file_name,
const c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal>& signal,
const double compression = 0.,
const int64_t frames_per_chunk = 65536);
const double compression = 0.);

} // namespace sox_io
} // namespace torchaudio

Expand Down