Skip to content

Commit 39bd543

Browse files
committed
Clean up common_utils
1 parent b56a27b commit 39bd543

File tree

3 files changed

+15
-23
lines changed

3 files changed

+15
-23
lines changed

test/test_librosa_compatibility.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
from torch.testing._internal.common_utils import TestCase
88
import torchaudio
99
import torchaudio.functional as F
10-
from torchaudio.common_utils import IMPORT_LIBROSA
10+
from torchaudio.common_utils import _check_module_exists
1111

12-
if IMPORT_LIBROSA:
12+
LIBROSA_AVAILABLE = _check_module_exists('librosa')
13+
14+
if LIBROSA_AVAILABLE:
1315
import numpy as np
1416
import librosa
1517
import scipy
@@ -19,7 +21,7 @@
1921
from . import common_utils
2022

2123

22-
@unittest.skipIf(not IMPORT_LIBROSA, "Librosa not available")
24+
@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
2325
class TestFunctional(TestCase):
2426
"""Test suite for functions in `functional` module."""
2527
def test_griffinlim(self):
@@ -115,12 +117,8 @@ def test_amplitude_to_DB(self):
115117
])
116118
@pytest.mark.parametrize('rate', [0.5, 1.01, 1.3])
117119
@pytest.mark.parametrize('hop_length', [256])
120+
@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
118121
def test_phase_vocoder(complex_specgrams, rate, hop_length):
119-
120-
# Using a decorator here causes parametrize to fail on Python 2
121-
if not IMPORT_LIBROSA:
122-
raise unittest.SkipTest('Librosa is not available')
123-
124122
# Due to cummulative sum, numerical error in using torch.float32 will
125123
# result in bottom right values of the stretched sectrogram to not
126124
# match with librosa.
@@ -158,7 +156,7 @@ def _load_audio_asset(*asset_paths, **kwargs):
158156
return sound, sample_rate
159157

160158

161-
@unittest.skipIf(not IMPORT_LIBROSA, "Librosa not available")
159+
@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
162160
class TestTransforms(TestCase):
163161
"""Test suite for functions in `transforms` module."""
164162
def assert_compatibilities(self, n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):

torchaudio/common_utils.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,11 @@
11
import importlib.util
22

33

4-
def _check_module_exists(name: str) -> bool:
4+
def _check_module_exists(*modules: str) -> bool:
55
r"""Returns if a top-level module with :attr:`name` exists *without**
66
importing it. This is generally safer than try-catch block around a
77
`import X`. It avoids third party libraries breaking assumptions of some of
88
our tests, e.g., setting multiprocessing start method when imported
99
(see librosa/#747, torchvision/#544).
1010
"""
11-
spec = importlib.util.find_spec(name)
12-
return spec is not None
13-
14-
IMPORT_NUMPY = _check_module_exists('numpy')
15-
IMPORT_KALDI_IO = _check_module_exists('kaldi_io')
16-
IMPORT_SCIPY = _check_module_exists('scipy')
17-
IMPORT_LIBROSA = _check_module_exists('librosa')
11+
return all(importlib.util.find_spec(m) is not None for m in modules)

torchaudio/kaldi_io.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
# To use this file, the dependency (https://github.com/vesis84/kaldi-io-for-python)
22
# needs to be installed. This is a light wrapper around kaldi_io that returns
33
# torch.Tensors.
4-
from typing import Any, Callable, Iterable, Tuple, Union
4+
from typing import Any, Callable, Iterable, Tuple
55

66
import torch
77
from torch import Tensor
8-
from torchaudio.common_utils import IMPORT_KALDI_IO, IMPORT_NUMPY
8+
from torchaudio.common_utils import _check_module_exists
99

10-
if IMPORT_NUMPY:
11-
import numpy as np
10+
_KALDI_IO_AVAILABLE = _check_module_exists('kaldi_io', 'numpy')
1211

13-
if IMPORT_KALDI_IO:
12+
if _KALDI_IO_AVAILABLE:
13+
import numpy as np
1414
import kaldi_io
1515

1616

@@ -38,7 +38,7 @@ def _convert_method_output_to_tensor(file_or_fd: Any,
3838
Returns:
3939
Iterable[Tuple[str, Tensor]]: The string is the key and the tensor is vec/mat
4040
"""
41-
if not IMPORT_KALDI_IO:
41+
if not _KALDI_IO_AVAILABLE:
4242
raise ImportError('Could not import kaldi_io. Did you install it?')
4343

4444
for key, np_arr in fn(file_or_fd):

0 commit comments

Comments
 (0)