From 97623fc164213edc894eecf069fbe1223ecc77ed Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Sat, 27 May 2017 13:15:40 +0100 Subject: [PATCH 1/3] Added ability to save tensors --- README.md | 11 +++++++ torchaudio/__init__.py | 26 +++++++++++++++++ torchaudio/src/generic/th_sox.c | 51 +++++++++++++++++++++++++++++++++ torchaudio/src/generic/th_sox.h | 1 + torchaudio/src/th_sox.h | 15 ++++++++++ 5 files changed, 104 insertions(+) diff --git a/README.md b/README.md index d422728d27..9bd2c9d9c1 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,7 @@ Quick Usage ```python import torchaudio sound, sample_rate = torchaudio.load('foo.mp3') +torchaudio.save('foo_save.mp3', sound, sample_rate) # saves tensor to file ``` API Reference @@ -49,3 +50,13 @@ audio.load( ) ``` +torchaudio.save +``` +saves a tensor into an audio file. The extension of the given path is used as the saving format. +audio.save( + string, # path to file + tensor, # NSamples x NChannels 2D tensor + number, # sample_rate of the audio to be saved as +) +``` + diff --git a/torchaudio/__init__.py b/torchaudio/__init__.py index 4e7af32c2d..92e1f6f8d7 100644 --- a/torchaudio/__init__.py +++ b/torchaudio/__init__.py @@ -1,3 +1,5 @@ +import os + import torch from cffi import FFI @@ -30,3 +32,27 @@ def load(filename, out=None): func(bytes(filename), out, sample_rate_p) sample_rate = sample_rate_p[0] return out, sample_rate + + +def save(filepath, src, sample_rate): + assert torch.is_tensor(src) + assert not src.is_cuda + + filename, extension = os.path.splitext(filepath) + assert type(sample_rate) == int + + if isinstance(src, torch.FloatTensor): + func = th_sox.libthsox_Float_write_audio_file + elif isinstance(src, torch.DoubleTensor): + func = th_sox.libthsox_Double_write_audio_file + elif isinstance(src, torch.ByteTensor): + func = th_sox.libthsox_Byte_write_audio_file + elif isinstance(src, torch.CharTensor): + func = th_sox.libthsox_Char_write_audio_file + elif isinstance(src, torch.ShortTensor): + func = th_sox.libthsox_Short_write_audio_file + elif isinstance(src, torch.IntTensor): + func = th_sox.libthsox_Int_write_audio_file + elif isinstance(src, torch.LongTensor): + func = th_sox.libthsox_Long_write_audio_file + func(bytes(filepath), src, extension.replace('.', ''), sample_rate) diff --git a/torchaudio/src/generic/th_sox.c b/torchaudio/src/generic/th_sox.c index d37032a87f..2aa03c7fcf 100644 --- a/torchaudio/src/generic/th_sox.c +++ b/torchaudio/src/generic/th_sox.c @@ -44,4 +44,55 @@ void libthsox_(read_audio_file)(const char *file_name, THTensor* tensor, int* sa sox_close(fd); } +void libthsox_(write_audio)(sox_format_t *fd, THTensor* src, + const char *extension, int sample_rate) +{ + long nchannels = src->size[1]; + long nsamples = src->size[0]; + real* data = THTensor_(data)(src); + + // convert audio to dest tensor + int x,k; + for (x=0; xsize[1]; + long nsamples = src->size[0]; + + sox_format_t *fd; + + // Create sox objects and write into int32_t buffer + sox_signalinfo_t sinfo; + sinfo.rate = sample_rate; + sinfo.channels = nchannels; + sinfo.length = nsamples * nchannels; + sinfo.precision = sizeof(int32_t) * 8; /* precision in bits */ +#if SOX_LIB_VERSION_CODE >= 918272 // >= 14.3.0 + sinfo.mult = NULL; +#endif + fd = sox_open_write(file_name, &sinfo, NULL, extension, NULL, NULL); + if (fd == NULL) + THError("[write_audio_file] Failure to open file for writing"); + + libthsox_(write_audio)(fd, src, extension, sample_rate); + + // free buffer and sox structures + sox_close(fd); + + return; +} + #endif diff --git a/torchaudio/src/generic/th_sox.h b/torchaudio/src/generic/th_sox.h index 21b471c943..5b329f9a3f 100644 --- a/torchaudio/src/generic/th_sox.h +++ b/torchaudio/src/generic/th_sox.h @@ -3,4 +3,5 @@ #else void libthsox_(read_audio_file)(const char *file_name, THTensor* tensor, int* sample_rate); +void libthsox_(write_audio_file)(const char *file_name, THTensor* src, const char *extension, int sample_rate); #endif diff --git a/torchaudio/src/th_sox.h b/torchaudio/src/th_sox.h index 9f80dcaa90..098291501d 100644 --- a/torchaudio/src/th_sox.h +++ b/torchaudio/src/th_sox.h @@ -15,3 +15,18 @@ void libthsox_Char_read_audio_file(const char *file_name, THCharTensor* tensor, void libthsox_Short_read_audio_file(const char *file_name, THShortTensor* tensor, int* sample_rate); void libthsox_Int_read_audio_file(const char *file_name, THIntTensor* tensor, int* sample_rate); void libthsox_Long_read_audio_file(const char *file_name, THLongTensor* tensor, int* sample_rate); + +void libthsox_Float_write_audio_file(const char *file_name, THFloatTensor* tensor, const char *extension, + int sample_rate); +void libthsox_Double_write_audio_file(const char *file_name, THDoubleTensor* tensor, const char *extension, + int sample_rate); +void libthsox_Byte_write_audio_file(const char *file_name, THByteTensor* tensor, const char *extension, + int sample_rate); +void libthsox_Char_write_audio_file(const char *file_name, THCharTensor* tensor, const char *extension, + int sample_rate); +void libthsox_Short_write_audio_file(const char *file_name, THShortTensor* tensor, const char *extension, + int sample_rate); +void libthsox_Int_write_audio_file(const char *file_name, THIntTensor* tensor, const char *extension, + int sample_rate); +void libthsox_Long_write_audio_file(const char *file_name, THLongTensor* tensor, const char *extension, + int sample_rate); \ No newline at end of file From d02dff76c1e8fcfcff277a4ea0a65f9c8259605e Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Sat, 27 May 2017 15:43:12 +0100 Subject: [PATCH 2/3] Refactors for better error checking/case checks --- torchaudio/__init__.py | 60 +++++++++++++++--------------------------- 1 file changed, 21 insertions(+), 39 deletions(-) diff --git a/torchaudio/__init__.py b/torchaudio/__init__.py index 92e1f6f8d7..0ab34ae4c2 100644 --- a/torchaudio/__init__.py +++ b/torchaudio/__init__.py @@ -3,56 +3,38 @@ import torch from cffi import FFI + ffi = FFI() from ._ext import th_sox + +def check_input(src): + if not torch.is_tensor(src): + raise TypeError('Expected a tensor, got %s' % type(src)) + if not src.__module__ == 'torch': + raise TypeError('Expected a CPU based tensor, got %s' % type(src)) + + def load(filename, out=None): if out is not None: - assert torch.is_tensor(out) - assert not out.is_cuda + check_input(out) else: out = torch.FloatTensor() - - if isinstance(out, torch.FloatTensor): - func = th_sox.libthsox_Float_read_audio_file - elif isinstance(out, torch.DoubleTensor): - func = th_sox.libthsox_Double_read_audio_file - elif isinstance(out, torch.ByteTensor): - func = th_sox.libthsox_Byte_read_audio_file - elif isinstance(out, torch.CharTensor): - func = th_sox.libthsox_Char_read_audio_file - elif isinstance(out, torch.ShortTensor): - func = th_sox.libthsox_Short_read_audio_file - elif isinstance(out, torch.IntTensor): - func = th_sox.libthsox_Int_read_audio_file - elif isinstance(out, torch.LongTensor): - func = th_sox.libthsox_Long_read_audio_file - - sample_rate_p = ffi.new('int*') + typename = type(out).__name__.replace('Tensor', '') + func = getattr(th_sox, 'libthsox_{}_read_audio_file'.format(typename)) + sample_rate_p = ffi.new('int*') func(bytes(filename), out, sample_rate_p) sample_rate = sample_rate_p[0] return out, sample_rate def save(filepath, src, sample_rate): - assert torch.is_tensor(src) - assert not src.is_cuda - filename, extension = os.path.splitext(filepath) - assert type(sample_rate) == int - - if isinstance(src, torch.FloatTensor): - func = th_sox.libthsox_Float_write_audio_file - elif isinstance(src, torch.DoubleTensor): - func = th_sox.libthsox_Double_write_audio_file - elif isinstance(src, torch.ByteTensor): - func = th_sox.libthsox_Byte_write_audio_file - elif isinstance(src, torch.CharTensor): - func = th_sox.libthsox_Char_write_audio_file - elif isinstance(src, torch.ShortTensor): - func = th_sox.libthsox_Short_write_audio_file - elif isinstance(src, torch.IntTensor): - func = th_sox.libthsox_Int_write_audio_file - elif isinstance(src, torch.LongTensor): - func = th_sox.libthsox_Long_write_audio_file - func(bytes(filepath), src, extension.replace('.', ''), sample_rate) + if type(sample_rate) != int: + raise TypeError('Sample rate should be a integer') + + check_input(src) + typename = type(src).__name__.replace('Tensor', '') + func = getattr(th_sox, 'libthsox_{}_write_audio_file'.format(typename)) + + func(bytes(filepath), src, extension[1:], sample_rate) From 1fdb6eae3aae10d0a823c9710c23d60d77d960a1 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Sat, 27 May 2017 15:58:13 +0100 Subject: [PATCH 3/3] Fixed space indent --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 9bd2c9d9c1..c232a91d49 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ torchaudio.save ``` saves a tensor into an audio file. The extension of the given path is used as the saving format. audio.save( - string, # path to file + string, # path to file tensor, # NSamples x NChannels 2D tensor number, # sample_rate of the audio to be saved as )