diff --git a/README.md b/README.md index d422728d27..c232a91d49 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,7 @@ Quick Usage ```python import torchaudio sound, sample_rate = torchaudio.load('foo.mp3') +torchaudio.save('foo_save.mp3', sound, sample_rate) # saves tensor to file ``` API Reference @@ -49,3 +50,13 @@ audio.load( ) ``` +torchaudio.save +``` +saves a tensor into an audio file. The extension of the given path is used as the saving format. +audio.save( + string, # path to file + tensor, # NSamples x NChannels 2D tensor + number, # sample_rate of the audio to be saved as +) +``` + diff --git a/torchaudio/__init__.py b/torchaudio/__init__.py index 4e7af32c2d..0ab34ae4c2 100644 --- a/torchaudio/__init__.py +++ b/torchaudio/__init__.py @@ -1,32 +1,40 @@ +import os + import torch from cffi import FFI + ffi = FFI() from ._ext import th_sox + +def check_input(src): + if not torch.is_tensor(src): + raise TypeError('Expected a tensor, got %s' % type(src)) + if not src.__module__ == 'torch': + raise TypeError('Expected a CPU based tensor, got %s' % type(src)) + + def load(filename, out=None): if out is not None: - assert torch.is_tensor(out) - assert not out.is_cuda + check_input(out) else: out = torch.FloatTensor() - - if isinstance(out, torch.FloatTensor): - func = th_sox.libthsox_Float_read_audio_file - elif isinstance(out, torch.DoubleTensor): - func = th_sox.libthsox_Double_read_audio_file - elif isinstance(out, torch.ByteTensor): - func = th_sox.libthsox_Byte_read_audio_file - elif isinstance(out, torch.CharTensor): - func = th_sox.libthsox_Char_read_audio_file - elif isinstance(out, torch.ShortTensor): - func = th_sox.libthsox_Short_read_audio_file - elif isinstance(out, torch.IntTensor): - func = th_sox.libthsox_Int_read_audio_file - elif isinstance(out, torch.LongTensor): - func = th_sox.libthsox_Long_read_audio_file - - sample_rate_p = ffi.new('int*') + typename = type(out).__name__.replace('Tensor', '') + func = getattr(th_sox, 'libthsox_{}_read_audio_file'.format(typename)) + sample_rate_p = ffi.new('int*') func(bytes(filename), out, sample_rate_p) sample_rate = sample_rate_p[0] return out, sample_rate + + +def save(filepath, src, sample_rate): + filename, extension = os.path.splitext(filepath) + if type(sample_rate) != int: + raise TypeError('Sample rate should be a integer') + + check_input(src) + typename = type(src).__name__.replace('Tensor', '') + func = getattr(th_sox, 'libthsox_{}_write_audio_file'.format(typename)) + + func(bytes(filepath), src, extension[1:], sample_rate) diff --git a/torchaudio/src/generic/th_sox.c b/torchaudio/src/generic/th_sox.c index d37032a87f..2aa03c7fcf 100644 --- a/torchaudio/src/generic/th_sox.c +++ b/torchaudio/src/generic/th_sox.c @@ -44,4 +44,55 @@ void libthsox_(read_audio_file)(const char *file_name, THTensor* tensor, int* sa sox_close(fd); } +void libthsox_(write_audio)(sox_format_t *fd, THTensor* src, + const char *extension, int sample_rate) +{ + long nchannels = src->size[1]; + long nsamples = src->size[0]; + real* data = THTensor_(data)(src); + + // convert audio to dest tensor + int x,k; + for (x=0; xsize[1]; + long nsamples = src->size[0]; + + sox_format_t *fd; + + // Create sox objects and write into int32_t buffer + sox_signalinfo_t sinfo; + sinfo.rate = sample_rate; + sinfo.channels = nchannels; + sinfo.length = nsamples * nchannels; + sinfo.precision = sizeof(int32_t) * 8; /* precision in bits */ +#if SOX_LIB_VERSION_CODE >= 918272 // >= 14.3.0 + sinfo.mult = NULL; +#endif + fd = sox_open_write(file_name, &sinfo, NULL, extension, NULL, NULL); + if (fd == NULL) + THError("[write_audio_file] Failure to open file for writing"); + + libthsox_(write_audio)(fd, src, extension, sample_rate); + + // free buffer and sox structures + sox_close(fd); + + return; +} + #endif diff --git a/torchaudio/src/generic/th_sox.h b/torchaudio/src/generic/th_sox.h index 21b471c943..5b329f9a3f 100644 --- a/torchaudio/src/generic/th_sox.h +++ b/torchaudio/src/generic/th_sox.h @@ -3,4 +3,5 @@ #else void libthsox_(read_audio_file)(const char *file_name, THTensor* tensor, int* sample_rate); +void libthsox_(write_audio_file)(const char *file_name, THTensor* src, const char *extension, int sample_rate); #endif diff --git a/torchaudio/src/th_sox.h b/torchaudio/src/th_sox.h index 9f80dcaa90..098291501d 100644 --- a/torchaudio/src/th_sox.h +++ b/torchaudio/src/th_sox.h @@ -15,3 +15,18 @@ void libthsox_Char_read_audio_file(const char *file_name, THCharTensor* tensor, void libthsox_Short_read_audio_file(const char *file_name, THShortTensor* tensor, int* sample_rate); void libthsox_Int_read_audio_file(const char *file_name, THIntTensor* tensor, int* sample_rate); void libthsox_Long_read_audio_file(const char *file_name, THLongTensor* tensor, int* sample_rate); + +void libthsox_Float_write_audio_file(const char *file_name, THFloatTensor* tensor, const char *extension, + int sample_rate); +void libthsox_Double_write_audio_file(const char *file_name, THDoubleTensor* tensor, const char *extension, + int sample_rate); +void libthsox_Byte_write_audio_file(const char *file_name, THByteTensor* tensor, const char *extension, + int sample_rate); +void libthsox_Char_write_audio_file(const char *file_name, THCharTensor* tensor, const char *extension, + int sample_rate); +void libthsox_Short_write_audio_file(const char *file_name, THShortTensor* tensor, const char *extension, + int sample_rate); +void libthsox_Int_write_audio_file(const char *file_name, THIntTensor* tensor, const char *extension, + int sample_rate); +void libthsox_Long_write_audio_file(const char *file_name, THLongTensor* tensor, const char *extension, + int sample_rate); \ No newline at end of file