Skip to content

Commit ea0b9e0

Browse files
committed
Add test for sox effects
1 parent b6d1b86 commit ea0b9e0

File tree

10 files changed

+565
-16
lines changed

10 files changed

+565
-16
lines changed
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
{"effects": [["allpass", "300", "10"]]}
2+
{"effects": [["band", "300", "10"]]}
3+
{"effects": [["bandpass", "300", "10"]]}
4+
{"effects": [["bandreject", "300", "10"]]}
5+
{"effects": [["bass", "-10"]]}
6+
{"effects": [["bend", ".35,180,.25", ".15,740,.53", "0,-520,.3"]]}
7+
{"effects": [["biquad", "0.4", "0.2", "0.9", "0.7", "0.2", "0.6"]]}
8+
{"effects": [["chorus", "0.7", "0.9", "55", "0.4", "0.25", "2", "-t"]]}
9+
{"effects": [["chorus", "0.6", "0.9", "50", "0.4", "0.25", "2", "-t", "60", "0.32", "0.4", "1.3", "-s"]]}
10+
{"effects": [["chorus", "0.5", "0.9", "50", "0.4", "0.25", "2", "-t", "60", "0.32", "0.4", "2.3", "-t", "40", "0.3", "0.3", "1.3", "-s"]]}
11+
{"effects": [["channels", "1"]]}
12+
{"effects": [["channels", "2"]]}
13+
{"effects": [["channels", "3"]]}
14+
{"effects": [["compand", "0.3,1", "6:-70,-60,-20", "-5", "-90", "0.2"]]}
15+
{"effects": [["compand", ".1,.2", "-inf,-50.1,-inf,-50,-50", "0", "-90", ".1"]]}
16+
{"effects": [["compand", ".1,.1", "-45.1,-45,-inf,0,-inf", "45", "-90", ".1"]]}
17+
{"effects": [["contrast", "0"]]}
18+
{"effects": [["contrast", "25"]]}
19+
{"effects": [["contrast", "50"]]}
20+
{"effects": [["contrast", "75"]]}
21+
{"effects": [["contrast", "100"]]}
22+
{"effects": [["dcshift", "1.0"]]}
23+
{"effects": [["dcshift", "-1.0"]]}
24+
{"effects": [["deemph"]], "input_sample_rate": 44100}
25+
{"effects": [["delay", "1.5", "+1"]]}
26+
{"effects": [["dither", "-s"]]}
27+
{"effects": [["dither", "-S"]]}
28+
{"effects": [["divide"]]}
29+
{"effects": [["downsample", "2"]], "input_sample_rate": 8000, "output_sample_rate": 4000}
30+
{"effects": [["earwax"]], "input_sample_rate": 44100}
31+
{"effects": [["echo", "0.8", "0.88", "60", "0.4"]]}
32+
{"effects": [["echo", "0.8", "0.88", "6", "0.4"]]}
33+
{"effects": [["echo", "0.8", "0.9", "1000", "0.3"]]}
34+
{"effects": [["echo", "0.8", "0.9", "1000", "0.3", "1800", "0.25"]]}
35+
{"effects": [["echos", "0.8", "0.7", "700", "0.25", "700", "0.3"]]}
36+
{"effects": [["echos", "0.8", "0.7", "700", "0.25", "900", "0.3"]]}
37+
{"effects": [["echos", "0.8", "0.7", "40", "0.25", "63", "0.3"]]}
38+
{"effects": [["equalizer", "300", "10", "5"]]}
39+
{"effects": [["fade", "q", "3"]]}
40+
{"effects": [["fade", "h", "3"]]}
41+
{"effects": [["fade", "t", "3"]]}
42+
{"effects": [["fade", "l", "3"]]}
43+
{"effects": [["fade", "p", "3"]]}
44+
{"effects": [["fir", "0.0195", "-0.082", "0.234", "0.891", "-0.145", "0.043"]]}
45+
{"effects": [["fir", "test/assets/sox_effect_test_fir_coeffs.txt"]]}
46+
{"effects": [["flanger"]]}
47+
{"effects": [["gain", "-n"]]}
48+
{"effects": [["gain", "-n", "-3"]]}
49+
{"effects": [["gain", "-l", "-6"]]}
50+
{"effects": [["highpass", "-1", "300"]]}
51+
{"effects": [["highpass", "-2", "300"]]}
52+
{"effects": [["hilbert"]]}
53+
{"effects": [["loudness"]]}
54+
{"effects": [["lowpass", "-1", "300"]]}
55+
{"effects": [["lowpass", "-2", "300"]]}
56+
{"effects": [["mcompand", "0.005,0.1 -47,-40,-34,-34,-17,-33", "100", "0.003,0.05 -47,-40,-34,-34,-17,-33", "400", "0.000625,0.0125 -47,-40,-34,-34,-15,-33", "1600", "0.0001,0.025 -47,-40,-34,-34,-31,-31,-0,-30", "6400", "0,0.025 -38,-31,-28,-28,-0,-25"]], "input_sample_rate": 44100}
57+
{"effects": [["norm"]]}
58+
{"effects": [["oops"]]}
59+
{"effects": [["overdrive"]]}
60+
{"effects": [["pad"]]}
61+
{"effects": [["phaser"]]}
62+
{"effects": [["pitch", "6.48"], ["rate", "8030"]], "output_sample_rate": 8030}
63+
{"effects": [["pitch", "-6.50"], ["rate", "7970"]], "output_sample_rate": 7970}
64+
{"effects": [["rate", "4567"]], "output_sample_rate": 4567}
65+
{"effects": [["remix", "6", "7", "8", "0"]], "num_channels": 8}
66+
{"effects": [["remix", "1-3,7", "3"]], "num_channels": 8}
67+
{"effects": [["repeat"]]}
68+
{"effects": [["reverb"]]}
69+
{"effects": [["reverse"]]}
70+
{"effects": [["riaa"]], "input_sample_rate": 44100}
71+
{"effects": [["silence", "0"]]}
72+
{"effects": [["sinc", "3k"]]}
73+
{"effects": [["speed", "1.3"]], "input_sample_rate": 4000, "output_sample_rate": 5200}
74+
{"effects": [["speed", "0.7"]], "input_sample_rate": 4000, "output_sample_rate": 2800}
75+
{"effects": [["stat"]]}
76+
{"effects": [["stats"]]}
77+
{"effects": [["stretch"]]}
78+
{"effects": [["swap"]]}
79+
{"effects": [["synth"]]}
80+
{"effects": [["tempo", "0.9"]]}
81+
{"effects": [["tempo", "1.1"]]}
82+
{"effects": [["treble", "3"]]}
83+
{"effects": [["tremolo", "300", "40"]]}
84+
{"effects": [["tremolo", "300", "50"]]}
85+
{"effects": [["trim", "0", "0.1"]]}
86+
{"effects": [["upsample", "2"]], "input_sample_rate": 8000, "output_sample_rate": 16000}
87+
{"effects": [["vad"]]}
88+
{"effects": [["vol", "3"]]}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
0.0195 -0.082 0.234 0.891 -0.145 0.043

test/common_utils/data_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def get_sinusoid(
7272
n_channels: int = 1,
7373
dtype: Union[str, torch.dtype] = "float32",
7474
device: Union[str, torch.device] = "cpu",
75+
channels_first: bool = True,
7576
):
7677
"""Generate pseudo audio data with sine wave.
7778
@@ -91,4 +92,7 @@ def get_sinusoid(
9192
pie2 = 2 * 3.141592653589793
9293
end = pie2 * frequency * duration
9394
theta = torch.linspace(0, end, sample_rate * duration, dtype=dtype, device=device)
94-
return torch.sin(theta, out=None).repeat([n_channels, 1])
95+
sin = torch.sin(theta, out=None).repeat([n_channels, 1])
96+
if not channels_first:
97+
sin = sin.t()
98+
return sin

test/common_utils/sox_utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,24 @@ def convert_audio_file(
7777
command += [dst_path]
7878
print(' '.join(command))
7979
subprocess.run(command, check=True)
80+
81+
82+
def _flattern(effects):
83+
if not effects:
84+
return effects
85+
if isinstance(effects[0], str):
86+
return effects
87+
return [item for sublist in effects for item in sublist]
88+
89+
90+
def run_sox_effect(input_file, output_file, effect, *, output_sample_rate=None, output_bitdepth=None):
91+
"""Run sox effects"""
92+
effect = _flattern(effect)
93+
command = ['sox', '-V', '--no-dither', input_file]
94+
if output_bitdepth:
95+
command += ['--bits', str(output_bitdepth)]
96+
command += [output_file] + effect
97+
if output_sample_rate:
98+
command += ['rate', str(output_sample_rate)]
99+
print(' '.join(command))
100+
subprocess.run(command, check=True)

test/common_utils/test_case_utils.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,33 +14,28 @@
1414
class TempDirMixin:
1515
"""Mixin to provide easy access to temp dir"""
1616
temp_dir_ = None
17-
base_temp_dir = None
18-
temp_dir = None
1917

20-
@classmethod
21-
def setUpClass(cls):
22-
super().setUpClass()
18+
@property
19+
def base_temp_dir(self):
2320
# If TORCHAUDIO_TEST_TEMP_DIR is set, use it instead of temporary directory.
2421
# this is handy for debugging.
2522
key = 'TORCHAUDIO_TEST_TEMP_DIR'
2623
if key in os.environ:
27-
cls.base_temp_dir = os.environ[key]
28-
else:
29-
cls.temp_dir_ = tempfile.TemporaryDirectory()
30-
cls.base_temp_dir = cls.temp_dir_.name
24+
return os.environ[key]
25+
if self.__class__.temp_dir_ is None:
26+
self.__class__.temp_dir_ = tempfile.TemporaryDirectory()
27+
return self.__class__.temp_dir_.name
3128

3229
@classmethod
3330
def tearDownClass(cls):
3431
super().tearDownClass()
35-
if isinstance(cls.temp_dir_, tempfile.TemporaryDirectory):
32+
if cls.temp_dir_ is not None:
3633
cls.temp_dir_.cleanup()
37-
38-
def setUp(self):
39-
super().setUp()
40-
self.temp_dir = os.path.join(self.base_temp_dir, self.id())
34+
cls.temp_dir_ = None
4135

4236
def get_temp_path(self, *paths):
43-
path = os.path.join(self.temp_dir, *paths)
37+
temp_dir = os.path.join(self.base_temp_dir, self.id())
38+
path = os.path.join(temp_dir, *paths)
4439
os.makedirs(os.path.dirname(path), exist_ok=True)
4540
return path
4641

test/sox_effect/__init__.py

Whitespace-only changes.

test/sox_effect/common.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
def name_func(func, _, params):
2+
if isinstance(params.args[0], str):
3+
args = "_".join([str(arg) for arg in params.args])
4+
else:
5+
args = "_".join([str(arg) for arg in params.args[0]])
6+
return f'{func.__name__}_{args}'

test/sox_effect/test_dataset.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from typing import List, Tuple
2+
3+
import numpy as np
4+
import torch
5+
import torchaudio
6+
7+
from ..common_utils import (
8+
TempDirMixin,
9+
PytorchTestCase,
10+
skipIfNoExtension,
11+
get_whitenoise,
12+
load_wav,
13+
save_wav,
14+
)
15+
16+
17+
class RandomPerturbationFile(torch.utils.data.Dataset):
18+
"""Given flist, apply random speed perturbation"""
19+
def __init__(self, flist: List[str], sample_rate: int):
20+
super().__init__()
21+
self.flist = flist
22+
self.sample_rate = sample_rate
23+
self.rng = None
24+
25+
def __getitem__(self, index):
26+
speed = self.rng.uniform(0.5, 2.0)
27+
effects = [
28+
['gain', '-n', '-10'],
29+
['speed', f'{speed:.5f}'], # duration of data is 0.5 ~ 2.0 seconds.
30+
['rate', f'{self.sample_rate}'],
31+
['pad', '0', '1.5'], # add 1.5 seconds silence at the end
32+
['trim', '0', '2'], # get the first 2 seconds
33+
]
34+
data, _ = torchaudio.sox_effects.apply_effects_file(self.flist[index], effects)
35+
return data
36+
37+
def __len__(self):
38+
return len(self.flist)
39+
40+
41+
class RandomPerturbationTensor(torch.utils.data.Dataset):
42+
"""Apply speed purturbation to (synthetic) Tensor data"""
43+
def __init__(self, signals: List[Tuple[torch.Tensor, int]], sample_rate: int):
44+
super().__init__()
45+
self.signals = signals
46+
self.sample_rate = sample_rate
47+
self.rng = None
48+
49+
def __getitem__(self, index):
50+
speed = self.rng.uniform(0.5, 2.0)
51+
effects = [
52+
['gain', '-n', '-10'],
53+
['speed', f'{speed:.5f}'], # duration of data is 0.5 ~ 2.0 seconds.
54+
['rate', f'{self.sample_rate}'],
55+
['pad', '0', '1.5'], # add 1.5 seconds silence at the end
56+
['trim', '0', '2'], # get the first 2 seconds
57+
]
58+
tensor, sample_rate = self.signals[index]
59+
data, _ = torchaudio.sox_effects.apply_effects_tensor(tensor, sample_rate, effects)
60+
return data
61+
62+
def __len__(self):
63+
return len(self.signals)
64+
65+
66+
def init_random_seed(worker_id):
67+
dataset = torch.utils.data.get_worker_info().dataset
68+
dataset.rng = np.random.RandomState(worker_id)
69+
70+
71+
@skipIfNoExtension
72+
class TestSoxEffectsDataset(TempDirMixin, PytorchTestCase):
73+
"""Test `apply_effects_file` in multi-process dataloader setting"""
74+
75+
def _generate_dataset(self, num_samples=128):
76+
flist = []
77+
for i in range(num_samples):
78+
sample_rate = np.random.choice([8000, 16000, 44100])
79+
dtype = np.random.choice(['float32', 'int32', 'int16', 'uint8'])
80+
data = get_whitenoise(n_channels=2, sample_rate=sample_rate, duration=1, dtype=dtype)
81+
path = self.get_temp_path(f'{i:03d}_{dtype}_{sample_rate}.wav')
82+
save_wav(path, data, sample_rate)
83+
flist.append(path)
84+
return flist
85+
86+
def test_apply_effects_file(self):
87+
sample_rate = 12000
88+
flist = self._generate_dataset()
89+
dataset = RandomPerturbationFile(flist, sample_rate)
90+
loader = torch.utils.data.DataLoader(
91+
dataset, batch_size=32, num_workers=16,
92+
worker_init_fn=init_random_seed,
93+
)
94+
for batch in loader:
95+
assert batch.shape == (32, 2, 2 * sample_rate)
96+
97+
def _generate_signals(self, num_samples=128):
98+
signals = []
99+
for _ in range(num_samples):
100+
sample_rate = np.random.choice([8000, 16000, 44100])
101+
data = get_whitenoise(
102+
n_channels=2, sample_rate=sample_rate, duration=1, dtype='float32')
103+
signals.append((data, sample_rate))
104+
return signals
105+
106+
def test_apply_effects_tensor(self):
107+
sample_rate = 12000
108+
signals = self._generate_signals()
109+
dataset = RandomPerturbationTensor(signals, sample_rate)
110+
loader = torch.utils.data.DataLoader(
111+
dataset, batch_size=32, num_workers=16,
112+
worker_init_fn=init_random_seed,
113+
)
114+
for batch in loader:
115+
assert batch.shape == (32, 2, 2 * sample_rate)

0 commit comments

Comments
 (0)