Skip to content

Commit 2936245

Browse files
authored
Merge pull request #3 from SeanNaren/save
Added ability to save tensors
2 parents 3510919 + 1fdb6ea commit 2936245

File tree

5 files changed

+105
-19
lines changed

5 files changed

+105
-19
lines changed

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ Quick Usage
3636
```python
3737
import torchaudio
3838
sound, sample_rate = torchaudio.load('foo.mp3')
39+
torchaudio.save('foo_save.mp3', sound, sample_rate) # saves tensor to file
3940
```
4041

4142
API Reference
@@ -49,3 +50,13 @@ audio.load(
4950
)
5051
```
5152

53+
torchaudio.save
54+
```
55+
saves a tensor into an audio file. The extension of the given path is used as the saving format.
56+
audio.save(
57+
string, # path to file
58+
tensor, # NSamples x NChannels 2D tensor
59+
number, # sample_rate of the audio to be saved as
60+
)
61+
```
62+

torchaudio/__init__.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,40 @@
1+
import os
2+
13
import torch
24

35
from cffi import FFI
6+
47
ffi = FFI()
58
from ._ext import th_sox
69

10+
11+
def check_input(src):
12+
if not torch.is_tensor(src):
13+
raise TypeError('Expected a tensor, got %s' % type(src))
14+
if not src.__module__ == 'torch':
15+
raise TypeError('Expected a CPU based tensor, got %s' % type(src))
16+
17+
718
def load(filename, out=None):
819
if out is not None:
9-
assert torch.is_tensor(out)
10-
assert not out.is_cuda
20+
check_input(out)
1121
else:
1222
out = torch.FloatTensor()
13-
14-
if isinstance(out, torch.FloatTensor):
15-
func = th_sox.libthsox_Float_read_audio_file
16-
elif isinstance(out, torch.DoubleTensor):
17-
func = th_sox.libthsox_Double_read_audio_file
18-
elif isinstance(out, torch.ByteTensor):
19-
func = th_sox.libthsox_Byte_read_audio_file
20-
elif isinstance(out, torch.CharTensor):
21-
func = th_sox.libthsox_Char_read_audio_file
22-
elif isinstance(out, torch.ShortTensor):
23-
func = th_sox.libthsox_Short_read_audio_file
24-
elif isinstance(out, torch.IntTensor):
25-
func = th_sox.libthsox_Int_read_audio_file
26-
elif isinstance(out, torch.LongTensor):
27-
func = th_sox.libthsox_Long_read_audio_file
28-
29-
sample_rate_p = ffi.new('int*')
23+
typename = type(out).__name__.replace('Tensor', '')
24+
func = getattr(th_sox, 'libthsox_{}_read_audio_file'.format(typename))
25+
sample_rate_p = ffi.new('int*')
3026
func(bytes(filename), out, sample_rate_p)
3127
sample_rate = sample_rate_p[0]
3228
return out, sample_rate
29+
30+
31+
def save(filepath, src, sample_rate):
32+
filename, extension = os.path.splitext(filepath)
33+
if type(sample_rate) != int:
34+
raise TypeError('Sample rate should be a integer')
35+
36+
check_input(src)
37+
typename = type(src).__name__.replace('Tensor', '')
38+
func = getattr(th_sox, 'libthsox_{}_write_audio_file'.format(typename))
39+
40+
func(bytes(filepath), src, extension[1:], sample_rate)

torchaudio/src/generic/th_sox.c

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,55 @@ void libthsox_(read_audio_file)(const char *file_name, THTensor* tensor, int* sa
4444
sox_close(fd);
4545
}
4646

47+
void libthsox_(write_audio)(sox_format_t *fd, THTensor* src,
48+
const char *extension, int sample_rate)
49+
{
50+
long nchannels = src->size[1];
51+
long nsamples = src->size[0];
52+
real* data = THTensor_(data)(src);
53+
54+
// convert audio to dest tensor
55+
int x,k;
56+
for (x=0; x<nsamples; x++) {
57+
for (k=0; k<nchannels; k++) {
58+
int32_t sample = (int32_t)(data[x*nchannels+k]);
59+
size_t samples_written = sox_write(fd, &sample, 1);
60+
if (samples_written != 1)
61+
THError("[write_audio_file] write failed in sox_write");
62+
}
63+
}
64+
}
65+
66+
void libthsox_(write_audio_file)(const char *file_name, THTensor* src,
67+
const char *extension, int sample_rate)
68+
{
69+
if (THTensor_(isContiguous)(src) == 0)
70+
THError("[write_audio_file] Input should be contiguous tensors");
71+
72+
long nchannels = src->size[1];
73+
long nsamples = src->size[0];
74+
75+
sox_format_t *fd;
76+
77+
// Create sox objects and write into int32_t buffer
78+
sox_signalinfo_t sinfo;
79+
sinfo.rate = sample_rate;
80+
sinfo.channels = nchannels;
81+
sinfo.length = nsamples * nchannels;
82+
sinfo.precision = sizeof(int32_t) * 8; /* precision in bits */
83+
#if SOX_LIB_VERSION_CODE >= 918272 // >= 14.3.0
84+
sinfo.mult = NULL;
85+
#endif
86+
fd = sox_open_write(file_name, &sinfo, NULL, extension, NULL, NULL);
87+
if (fd == NULL)
88+
THError("[write_audio_file] Failure to open file for writing");
89+
90+
libthsox_(write_audio)(fd, src, extension, sample_rate);
91+
92+
// free buffer and sox structures
93+
sox_close(fd);
94+
95+
return;
96+
}
97+
4798
#endif

torchaudio/src/generic/th_sox.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@
33
#else
44

55
void libthsox_(read_audio_file)(const char *file_name, THTensor* tensor, int* sample_rate);
6+
void libthsox_(write_audio_file)(const char *file_name, THTensor* src, const char *extension, int sample_rate);
67
#endif

torchaudio/src/th_sox.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,18 @@ void libthsox_Char_read_audio_file(const char *file_name, THCharTensor* tensor,
1515
void libthsox_Short_read_audio_file(const char *file_name, THShortTensor* tensor, int* sample_rate);
1616
void libthsox_Int_read_audio_file(const char *file_name, THIntTensor* tensor, int* sample_rate);
1717
void libthsox_Long_read_audio_file(const char *file_name, THLongTensor* tensor, int* sample_rate);
18+
19+
void libthsox_Float_write_audio_file(const char *file_name, THFloatTensor* tensor, const char *extension,
20+
int sample_rate);
21+
void libthsox_Double_write_audio_file(const char *file_name, THDoubleTensor* tensor, const char *extension,
22+
int sample_rate);
23+
void libthsox_Byte_write_audio_file(const char *file_name, THByteTensor* tensor, const char *extension,
24+
int sample_rate);
25+
void libthsox_Char_write_audio_file(const char *file_name, THCharTensor* tensor, const char *extension,
26+
int sample_rate);
27+
void libthsox_Short_write_audio_file(const char *file_name, THShortTensor* tensor, const char *extension,
28+
int sample_rate);
29+
void libthsox_Int_write_audio_file(const char *file_name, THIntTensor* tensor, const char *extension,
30+
int sample_rate);
31+
void libthsox_Long_write_audio_file(const char *file_name, THLongTensor* tensor, const char *extension,
32+
int sample_rate);

0 commit comments

Comments
 (0)