Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Quick Usage
```python
import torchaudio
sound, sample_rate = torchaudio.load('foo.mp3')
torchaudio.save('foo_save.mp3', sound, sample_rate) # saves tensor to file
```

API Reference
Expand All @@ -49,3 +50,13 @@ audio.load(
)
```

torchaudio.save
```
saves a tensor into an audio file. The extension of the given path is used as the saving format.
audio.save(
string, # path to file
tensor, # NSamples x NChannels 2D tensor
number, # sample_rate of the audio to be saved as
)
```

46 changes: 27 additions & 19 deletions torchaudio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,40 @@
import os

import torch

from cffi import FFI

ffi = FFI()
from ._ext import th_sox


def check_input(src):
if not torch.is_tensor(src):
raise TypeError('Expected a tensor, got %s' % type(src))
if not src.__module__ == 'torch':
raise TypeError('Expected a CPU based tensor, got %s' % type(src))


def load(filename, out=None):
if out is not None:
assert torch.is_tensor(out)
assert not out.is_cuda
check_input(out)
else:
out = torch.FloatTensor()

if isinstance(out, torch.FloatTensor):
func = th_sox.libthsox_Float_read_audio_file
elif isinstance(out, torch.DoubleTensor):
func = th_sox.libthsox_Double_read_audio_file
elif isinstance(out, torch.ByteTensor):
func = th_sox.libthsox_Byte_read_audio_file
elif isinstance(out, torch.CharTensor):
func = th_sox.libthsox_Char_read_audio_file
elif isinstance(out, torch.ShortTensor):
func = th_sox.libthsox_Short_read_audio_file
elif isinstance(out, torch.IntTensor):
func = th_sox.libthsox_Int_read_audio_file
elif isinstance(out, torch.LongTensor):
func = th_sox.libthsox_Long_read_audio_file

sample_rate_p = ffi.new('int*')
typename = type(out).__name__.replace('Tensor', '')
func = getattr(th_sox, 'libthsox_{}_read_audio_file'.format(typename))
sample_rate_p = ffi.new('int*')
func(bytes(filename), out, sample_rate_p)
sample_rate = sample_rate_p[0]
return out, sample_rate


def save(filepath, src, sample_rate):
filename, extension = os.path.splitext(filepath)
if type(sample_rate) != int:
raise TypeError('Sample rate should be a integer')

check_input(src)
typename = type(src).__name__.replace('Tensor', '')
func = getattr(th_sox, 'libthsox_{}_write_audio_file'.format(typename))

func(bytes(filepath), src, extension[1:], sample_rate)
51 changes: 51 additions & 0 deletions torchaudio/src/generic/th_sox.c
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,55 @@ void libthsox_(read_audio_file)(const char *file_name, THTensor* tensor, int* sa
sox_close(fd);
}

void libthsox_(write_audio)(sox_format_t *fd, THTensor* src,
const char *extension, int sample_rate)
{
long nchannels = src->size[1];
long nsamples = src->size[0];
real* data = THTensor_(data)(src);

// convert audio to dest tensor
int x,k;
for (x=0; x<nsamples; x++) {
for (k=0; k<nchannels; k++) {
int32_t sample = (int32_t)(data[x*nchannels+k]);
size_t samples_written = sox_write(fd, &sample, 1);
if (samples_written != 1)
THError("[write_audio_file] write failed in sox_write");
}
}
}

void libthsox_(write_audio_file)(const char *file_name, THTensor* src,
const char *extension, int sample_rate)
{
if (THTensor_(isContiguous)(src) == 0)
THError("[write_audio_file] Input should be contiguous tensors");

long nchannels = src->size[1];
long nsamples = src->size[0];

sox_format_t *fd;

// Create sox objects and write into int32_t buffer
sox_signalinfo_t sinfo;
sinfo.rate = sample_rate;
sinfo.channels = nchannels;
sinfo.length = nsamples * nchannels;
sinfo.precision = sizeof(int32_t) * 8; /* precision in bits */
#if SOX_LIB_VERSION_CODE >= 918272 // >= 14.3.0
sinfo.mult = NULL;
#endif
fd = sox_open_write(file_name, &sinfo, NULL, extension, NULL, NULL);
if (fd == NULL)
THError("[write_audio_file] Failure to open file for writing");

libthsox_(write_audio)(fd, src, extension, sample_rate);

// free buffer and sox structures
sox_close(fd);

return;
}

#endif
1 change: 1 addition & 0 deletions torchaudio/src/generic/th_sox.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
#else

void libthsox_(read_audio_file)(const char *file_name, THTensor* tensor, int* sample_rate);
void libthsox_(write_audio_file)(const char *file_name, THTensor* src, const char *extension, int sample_rate);
#endif
15 changes: 15 additions & 0 deletions torchaudio/src/th_sox.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,18 @@ void libthsox_Char_read_audio_file(const char *file_name, THCharTensor* tensor,
void libthsox_Short_read_audio_file(const char *file_name, THShortTensor* tensor, int* sample_rate);
void libthsox_Int_read_audio_file(const char *file_name, THIntTensor* tensor, int* sample_rate);
void libthsox_Long_read_audio_file(const char *file_name, THLongTensor* tensor, int* sample_rate);

void libthsox_Float_write_audio_file(const char *file_name, THFloatTensor* tensor, const char *extension,
int sample_rate);
void libthsox_Double_write_audio_file(const char *file_name, THDoubleTensor* tensor, const char *extension,
int sample_rate);
void libthsox_Byte_write_audio_file(const char *file_name, THByteTensor* tensor, const char *extension,
int sample_rate);
void libthsox_Char_write_audio_file(const char *file_name, THCharTensor* tensor, const char *extension,
int sample_rate);
void libthsox_Short_write_audio_file(const char *file_name, THShortTensor* tensor, const char *extension,
int sample_rate);
void libthsox_Int_write_audio_file(const char *file_name, THIntTensor* tensor, const char *extension,
int sample_rate);
void libthsox_Long_write_audio_file(const char *file_name, THLongTensor* tensor, const char *extension,
int sample_rate);