Skip to content

Commit 2a4a33b

Browse files
committed
Clean up common_utils
1 parent b4284de commit 2a4a33b

File tree

3 files changed

+16
-45
lines changed

3 files changed

+16
-45
lines changed

test/test_librosa_compatibility.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
from torch.testing._internal.common_utils import TestCase
88
import torchaudio
99
import torchaudio.functional as F
10-
from torchaudio.common_utils import IMPORT_LIBROSA
10+
from torchaudio.common_utils import _check_module_exists
1111

12-
if IMPORT_LIBROSA:
12+
LIBROSA_AVAILABLE = _check_module_exists('librosa')
13+
14+
if LIBROSA_AVAILABLE:
1315
import numpy as np
1416
import librosa
1517
import scipy
@@ -19,7 +21,7 @@
1921
from . import common_utils
2022

2123

22-
@unittest.skipIf(not IMPORT_LIBROSA, "Librosa not available")
24+
@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
2325
class TestFunctional(TestCase):
2426
"""Test suite for functions in `functional` module."""
2527
def test_griffinlim(self):
@@ -115,12 +117,8 @@ def test_amplitude_to_DB(self):
115117
])
116118
@pytest.mark.parametrize('rate', [0.5, 1.01, 1.3])
117119
@pytest.mark.parametrize('hop_length', [256])
120+
@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
118121
def test_phase_vocoder(complex_specgrams, rate, hop_length):
119-
120-
# Using a decorator here causes parametrize to fail on Python 2
121-
if not IMPORT_LIBROSA:
122-
raise unittest.SkipTest('Librosa is not available')
123-
124122
# Due to cummulative sum, numerical error in using torch.float32 will
125123
# result in bottom right values of the stretched sectrogram to not
126124
# match with librosa.
@@ -158,7 +156,7 @@ def _load_audio_asset(*asset_paths, **kwargs):
158156
return sound, sample_rate
159157

160158

161-
@unittest.skipIf(not IMPORT_LIBROSA, "Librosa not available")
159+
@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
162160
class TestTransforms(TestCase):
163161
"""Test suite for functions in `transforms` module."""
164162
def assert_compatibilities(self, n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):

torchaudio/common_utils.py

Lines changed: 3 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,11 @@
1-
import sys
1+
import importlib.util
22

3-
PY3 = sys.version_info > (3, 0)
4-
PY34 = sys.version_info >= (3, 4)
53

6-
7-
def _check_module_exists(name: str) -> bool:
4+
def _check_module_exists(*modules: str) -> bool:
85
r"""Returns if a top-level module with :attr:`name` exists *without**
96
importing it. This is generally safer than try-catch block around a
107
`import X`. It avoids third party libraries breaking assumptions of some of
118
our tests, e.g., setting multiprocessing start method when imported
129
(see librosa/#747, torchvision/#544).
1310
"""
14-
if not PY3: # Python 2
15-
import imp
16-
try:
17-
imp.find_module(name)
18-
return True
19-
except ImportError:
20-
return False
21-
elif not PY34: # Python [3, 3.4)
22-
import importlib
23-
loader = importlib.find_loader(name)
24-
return loader is not None
25-
else: # Python >= 3.4
26-
import importlib
27-
import importlib.util
28-
spec = importlib.util.find_spec(name)
29-
return spec is not None
30-
31-
IMPORT_NUMPY = _check_module_exists('numpy')
32-
IMPORT_KALDI_IO = _check_module_exists('kaldi_io')
33-
IMPORT_SCIPY = _check_module_exists('scipy')
34-
35-
# On Py2, importing librosa 0.6.1 triggers a TypeError (if using newest joblib)
36-
# see librosa/librosa#729.
37-
# TODO: allow Py2 when librosa 0.6.2 releases
38-
IMPORT_LIBROSA = _check_module_exists('librosa') and PY3
11+
return all(importlib.util.find_spec(m) is not None for m in modules)

torchaudio/kaldi_io.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
# To use this file, the dependency (https://github.com/vesis84/kaldi-io-for-python)
22
# needs to be installed. This is a light wrapper around kaldi_io that returns
33
# torch.Tensors.
4-
from typing import Any, Callable, Iterable, Tuple, Union
4+
from typing import Any, Callable, Iterable, Tuple
55

66
import torch
77
from torch import Tensor
8-
from torchaudio.common_utils import IMPORT_KALDI_IO, IMPORT_NUMPY
8+
from torchaudio.common_utils import _check_module_exists
99

10-
if IMPORT_NUMPY:
11-
import numpy as np
10+
_KALDI_IO_AVAILABLE = _check_module_exists('kaldi_io', 'numpy')
1211

13-
if IMPORT_KALDI_IO:
12+
if _KALDI_IO_AVAILABLE:
13+
import numpy as np
1414
import kaldi_io
1515

1616

@@ -38,7 +38,7 @@ def _convert_method_output_to_tensor(file_or_fd: Any,
3838
Returns:
3939
Iterable[Tuple[str, Tensor]]: The string is the key and the tensor is vec/mat
4040
"""
41-
if not IMPORT_KALDI_IO:
41+
if not _KALDI_IO_AVAILABLE:
4242
raise ImportError('Could not import kaldi_io. Did you install it?')
4343

4444
for key, np_arr in fn(file_or_fd):

0 commit comments

Comments
 (0)