diff --git a/test/test_librosa_compatibility.py b/test/test_librosa_compatibility.py index 665341d945..2011e5218f 100644 --- a/test/test_librosa_compatibility.py +++ b/test/test_librosa_compatibility.py @@ -7,9 +7,11 @@ from torch.testing._internal.common_utils import TestCase import torchaudio import torchaudio.functional as F -from torchaudio.common_utils import IMPORT_LIBROSA +from torchaudio.common_utils import _check_module_exists -if IMPORT_LIBROSA: +LIBROSA_AVAILABLE = _check_module_exists('librosa') + +if LIBROSA_AVAILABLE: import numpy as np import librosa import scipy @@ -19,7 +21,7 @@ from . import common_utils -@unittest.skipIf(not IMPORT_LIBROSA, "Librosa not available") +@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available") class TestFunctional(TestCase): """Test suite for functions in `functional` module.""" def test_griffinlim(self): @@ -115,12 +117,8 @@ def test_amplitude_to_DB(self): ]) @pytest.mark.parametrize('rate', [0.5, 1.01, 1.3]) @pytest.mark.parametrize('hop_length', [256]) +@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available") def test_phase_vocoder(complex_specgrams, rate, hop_length): - - # Using a decorator here causes parametrize to fail on Python 2 - if not IMPORT_LIBROSA: - raise unittest.SkipTest('Librosa is not available') - # Due to cummulative sum, numerical error in using torch.float32 will # result in bottom right values of the stretched sectrogram to not # match with librosa. @@ -158,7 +156,7 @@ def _load_audio_asset(*asset_paths, **kwargs): return sound, sample_rate -@unittest.skipIf(not IMPORT_LIBROSA, "Librosa not available") +@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available") class TestTransforms(TestCase): """Test suite for functions in `transforms` module.""" def assert_compatibilities(self, n_fft, hop_length, power, n_mels, n_mfcc, sample_rate): diff --git a/torchaudio/common_utils.py b/torchaudio/common_utils.py index 014f4392ce..a8d99cf0e9 100644 --- a/torchaudio/common_utils.py +++ b/torchaudio/common_utils.py @@ -1,17 +1,11 @@ import importlib.util -def _check_module_exists(name: str) -> bool: +def _check_module_exists(*modules: str) -> bool: r"""Returns if a top-level module with :attr:`name` exists *without** importing it. This is generally safer than try-catch block around a `import X`. It avoids third party libraries breaking assumptions of some of our tests, e.g., setting multiprocessing start method when imported (see librosa/#747, torchvision/#544). """ - spec = importlib.util.find_spec(name) - return spec is not None - -IMPORT_NUMPY = _check_module_exists('numpy') -IMPORT_KALDI_IO = _check_module_exists('kaldi_io') -IMPORT_SCIPY = _check_module_exists('scipy') -IMPORT_LIBROSA = _check_module_exists('librosa') + return all(importlib.util.find_spec(m) is not None for m in modules) diff --git a/torchaudio/kaldi_io.py b/torchaudio/kaldi_io.py index d32aaed357..25cce9ba1f 100644 --- a/torchaudio/kaldi_io.py +++ b/torchaudio/kaldi_io.py @@ -1,16 +1,16 @@ # To use this file, the dependency (https://github.com/vesis84/kaldi-io-for-python) # needs to be installed. This is a light wrapper around kaldi_io that returns # torch.Tensors. -from typing import Any, Callable, Iterable, Tuple, Union +from typing import Any, Callable, Iterable, Tuple import torch from torch import Tensor -from torchaudio.common_utils import IMPORT_KALDI_IO, IMPORT_NUMPY +from torchaudio.common_utils import _check_module_exists -if IMPORT_NUMPY: - import numpy as np +_KALDI_IO_AVAILABLE = _check_module_exists('kaldi_io', 'numpy') -if IMPORT_KALDI_IO: +if _KALDI_IO_AVAILABLE: + import numpy as np import kaldi_io @@ -38,7 +38,7 @@ def _convert_method_output_to_tensor(file_or_fd: Any, Returns: Iterable[Tuple[str, Tensor]]: The string is the key and the tensor is vec/mat """ - if not IMPORT_KALDI_IO: + if not _KALDI_IO_AVAILABLE: raise ImportError('Could not import kaldi_io. Did you install it?') for key, np_arr in fn(file_or_fd):