Skip to content

Commit a20da5e

Browse files
authored
Refactor test utilities (#756)
1 parent 6b15905 commit a20da5e

17 files changed

+421
-371
lines changed

test/common_utils.py

Lines changed: 0 additions & 216 deletions
This file was deleted.

test/common_utils/__init__.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from .data_utils import (
2+
get_asset_path,
3+
get_whitenoise,
4+
get_sinusoid,
5+
)
6+
from .backend_utils import (
7+
set_audio_backend,
8+
BACKENDS,
9+
BACKENDS_MP3,
10+
)
11+
from .test_case_utils import (
12+
TempDirMixin,
13+
TestBaseMixin,
14+
PytorchTestCase,
15+
TorchaudioTestCase,
16+
skipIfNoCuda,
17+
skipIfNoExec,
18+
skipIfNoModule,
19+
skipIfNoExtension,
20+
skipIfNoSoxBackend,
21+
)
22+
from .wav_utils import (
23+
get_wav_data,
24+
normalize_wav,
25+
load_wav,
26+
save_wav,
27+
)
28+
from .parameterized_utils import (
29+
load_params,
30+
)
31+
from . import sox_utils

test/common_utils/backend_utils.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import unittest
2+
3+
import torchaudio
4+
5+
from .import data_utils
6+
7+
8+
BACKENDS = torchaudio.list_audio_backends()
9+
10+
11+
def _filter_backends_with_mp3(backends):
12+
# Filter out backends that do not support mp3
13+
test_filepath = data_utils.get_asset_path('steam-train-whistle-daniel_simon.mp3')
14+
15+
def supports_mp3(backend):
16+
torchaudio.set_audio_backend(backend)
17+
try:
18+
torchaudio.load(test_filepath)
19+
return True
20+
except (RuntimeError, ImportError):
21+
return False
22+
23+
return [backend for backend in backends if supports_mp3(backend)]
24+
25+
26+
BACKENDS_MP3 = _filter_backends_with_mp3(BACKENDS)
27+
28+
29+
def set_audio_backend(backend):
30+
"""Allow additional backend value, 'default'"""
31+
if backend == 'default':
32+
if 'sox' in BACKENDS:
33+
be = 'sox'
34+
elif 'soundfile' in BACKENDS:
35+
be = 'soundfile'
36+
else:
37+
raise unittest.SkipTest('No default backend available')
38+
else:
39+
be = backend
40+
41+
torchaudio.set_audio_backend(be)

test/common_utils/data_utils.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import os.path
2+
from typing import Union
3+
4+
import torch
5+
6+
7+
_TEST_DIR_PATH = os.path.realpath(
8+
os.path.join(os.path.dirname(__file__), '..'))
9+
10+
11+
def get_asset_path(*paths):
12+
"""Return full path of a test asset"""
13+
return os.path.join(_TEST_DIR_PATH, 'assets', *paths)
14+
15+
16+
def get_whitenoise(
17+
*,
18+
sample_rate: int = 16000,
19+
duration: float = 1, # seconds
20+
n_channels: int = 1,
21+
seed: int = 0,
22+
dtype: Union[str, torch.dtype] = "float32",
23+
device: Union[str, torch.device] = "cpu",
24+
):
25+
"""Generate pseudo audio data with whitenoise
26+
27+
Args:
28+
sample_rate: Sampling rate
29+
duration: Length of the resulting Tensor in seconds.
30+
n_channels: Number of channels
31+
seed: Seed value used for random number generation.
32+
Note that this function does not modify global random generator state.
33+
dtype: Torch dtype
34+
device: device
35+
Returns:
36+
Tensor: shape of (n_channels, sample_rate * duration)
37+
"""
38+
if isinstance(dtype, str):
39+
dtype = getattr(torch, dtype)
40+
shape = [n_channels, sample_rate * duration]
41+
# According to the doc, folking rng on all CUDA devices is slow when there are many CUDA devices,
42+
# so we only folk on CPU, generate values and move the data to the given device
43+
with torch.random.fork_rng([]):
44+
torch.random.manual_seed(seed)
45+
tensor = torch.randn(shape, dtype=dtype, device='cpu')
46+
tensor /= 2.0
47+
tensor.clamp_(-1.0, 1.0)
48+
return tensor.to(device=device)
49+
50+
51+
def get_sinusoid(
52+
*,
53+
frequency: float = 300,
54+
sample_rate: int = 16000,
55+
duration: float = 1, # seconds
56+
n_channels: int = 1,
57+
dtype: Union[str, torch.dtype] = "float32",
58+
device: Union[str, torch.device] = "cpu",
59+
):
60+
"""Generate pseudo audio data with sine wave.
61+
62+
Args:
63+
frequency: Frequency of sine wave
64+
sample_rate: Sampling rate
65+
duration: Length of the resulting Tensor in seconds.
66+
n_channels: Number of channels
67+
dtype: Torch dtype
68+
device: device
69+
70+
Returns:
71+
Tensor: shape of (n_channels, sample_rate * duration)
72+
"""
73+
if isinstance(dtype, str):
74+
dtype = getattr(torch, dtype)
75+
pie2 = 2 * 3.141592653589793
76+
end = pie2 * frequency * duration
77+
theta = torch.linspace(0, end, sample_rate * duration, dtype=dtype, device=device)
78+
return torch.sin(theta, out=None).repeat([n_channels, 1])
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import json
2+
3+
from parameterized import param
4+
5+
from .data_utils import get_asset_path
6+
7+
8+
def load_params(*paths):
9+
with open(get_asset_path(*paths), 'r') as file:
10+
return [param(json.loads(line)) for line in file]
File renamed without changes.

0 commit comments

Comments
 (0)