|
| 1 | +import os |
| 2 | + |
1 | 3 | import torch |
2 | 4 |
|
3 | 5 | from cffi import FFI |
| 6 | + |
4 | 7 | ffi = FFI() |
5 | 8 | from ._ext import th_sox |
6 | 9 |
|
| 10 | + |
| 11 | +def check_input(src): |
| 12 | + if not torch.is_tensor(src): |
| 13 | + raise TypeError('Expected a tensor, got %s' % type(src)) |
| 14 | + if not src.__module__ == 'torch': |
| 15 | + raise TypeError('Expected a CPU based tensor, got %s' % type(src)) |
| 16 | + |
| 17 | + |
7 | 18 | def load(filename, out=None): |
8 | 19 | if out is not None: |
9 | | - assert torch.is_tensor(out) |
10 | | - assert not out.is_cuda |
| 20 | + check_input(out) |
11 | 21 | else: |
12 | 22 | out = torch.FloatTensor() |
13 | | - |
14 | | - if isinstance(out, torch.FloatTensor): |
15 | | - func = th_sox.libthsox_Float_read_audio_file |
16 | | - elif isinstance(out, torch.DoubleTensor): |
17 | | - func = th_sox.libthsox_Double_read_audio_file |
18 | | - elif isinstance(out, torch.ByteTensor): |
19 | | - func = th_sox.libthsox_Byte_read_audio_file |
20 | | - elif isinstance(out, torch.CharTensor): |
21 | | - func = th_sox.libthsox_Char_read_audio_file |
22 | | - elif isinstance(out, torch.ShortTensor): |
23 | | - func = th_sox.libthsox_Short_read_audio_file |
24 | | - elif isinstance(out, torch.IntTensor): |
25 | | - func = th_sox.libthsox_Int_read_audio_file |
26 | | - elif isinstance(out, torch.LongTensor): |
27 | | - func = th_sox.libthsox_Long_read_audio_file |
28 | | - |
29 | | - sample_rate_p = ffi.new('int*') |
| 23 | + typename = type(out).__name__.replace('Tensor', '') |
| 24 | + func = getattr(th_sox, 'libthsox_{}_read_audio_file'.format(typename)) |
| 25 | + sample_rate_p = ffi.new('int*') |
30 | 26 | func(bytes(filename), out, sample_rate_p) |
31 | 27 | sample_rate = sample_rate_p[0] |
32 | 28 | return out, sample_rate |
| 29 | + |
| 30 | + |
| 31 | +def save(filepath, src, sample_rate): |
| 32 | + filename, extension = os.path.splitext(filepath) |
| 33 | + if type(sample_rate) != int: |
| 34 | + raise TypeError('Sample rate should be a integer') |
| 35 | + |
| 36 | + check_input(src) |
| 37 | + typename = type(src).__name__.replace('Tensor', '') |
| 38 | + func = getattr(th_sox, 'libthsox_{}_write_audio_file'.format(typename)) |
| 39 | + |
| 40 | + func(bytes(filepath), src, extension[1:], sample_rate) |
0 commit comments