Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions test/test_librosa_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from torch.testing._internal.common_utils import TestCase
import torchaudio
import torchaudio.functional as F
from torchaudio.common_utils import IMPORT_LIBROSA
from torchaudio.common_utils import _check_module_exists

if IMPORT_LIBROSA:
LIBROSA_AVAILABLE = _check_module_exists('librosa')

if LIBROSA_AVAILABLE:
import numpy as np
import librosa
import scipy
Expand All @@ -19,7 +21,7 @@
from . import common_utils


@unittest.skipIf(not IMPORT_LIBROSA, "Librosa not available")
@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
class TestFunctional(TestCase):
"""Test suite for functions in `functional` module."""
def test_griffinlim(self):
Expand Down Expand Up @@ -115,12 +117,8 @@ def test_amplitude_to_DB(self):
])
@pytest.mark.parametrize('rate', [0.5, 1.01, 1.3])
@pytest.mark.parametrize('hop_length', [256])
@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing comment

def test_phase_vocoder(complex_specgrams, rate, hop_length):

# Using a decorator here causes parametrize to fail on Python 2
if not IMPORT_LIBROSA:
raise unittest.SkipTest('Librosa is not available')

# Due to cummulative sum, numerical error in using torch.float32 will
# result in bottom right values of the stretched sectrogram to not
# match with librosa.
Expand Down Expand Up @@ -158,7 +156,7 @@ def _load_audio_asset(*asset_paths, **kwargs):
return sound, sample_rate


@unittest.skipIf(not IMPORT_LIBROSA, "Librosa not available")
@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
class TestTransforms(TestCase):
"""Test suite for functions in `transforms` module."""
def assert_compatibilities(self, n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):
Expand Down
10 changes: 2 additions & 8 deletions torchaudio/common_utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
import importlib.util


def _check_module_exists(name: str) -> bool:
def _check_module_exists(*modules: str) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the typing for this works? If so, I like it :)

r"""Returns if a top-level module with :attr:`name` exists *without**
importing it. This is generally safer than try-catch block around a
`import X`. It avoids third party libraries breaking assumptions of some of
our tests, e.g., setting multiprocessing start method when imported
(see librosa/#747, torchvision/#544).
"""
spec = importlib.util.find_spec(name)
return spec is not None

IMPORT_NUMPY = _check_module_exists('numpy')
IMPORT_KALDI_IO = _check_module_exists('kaldi_io')
IMPORT_SCIPY = _check_module_exists('scipy')
IMPORT_LIBROSA = _check_module_exists('librosa')
return all(importlib.util.find_spec(m) is not None for m in modules)
12 changes: 6 additions & 6 deletions torchaudio/kaldi_io.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# To use this file, the dependency (https://github.com/vesis84/kaldi-io-for-python)
# needs to be installed. This is a light wrapper around kaldi_io that returns
# torch.Tensors.
from typing import Any, Callable, Iterable, Tuple, Union
from typing import Any, Callable, Iterable, Tuple

import torch
from torch import Tensor
from torchaudio.common_utils import IMPORT_KALDI_IO, IMPORT_NUMPY
from torchaudio.common_utils import _check_module_exists

if IMPORT_NUMPY:
import numpy as np
_KALDI_IO_AVAILABLE = _check_module_exists('kaldi_io', 'numpy')

if IMPORT_KALDI_IO:
if _KALDI_IO_AVAILABLE:
import numpy as np
import kaldi_io


Expand Down Expand Up @@ -38,7 +38,7 @@ def _convert_method_output_to_tensor(file_or_fd: Any,
Returns:
Iterable[Tuple[str, Tensor]]: The string is the key and the tensor is vec/mat
"""
if not IMPORT_KALDI_IO:
if not _KALDI_IO_AVAILABLE:
raise ImportError('Could not import kaldi_io. Did you install it?')

for key, np_arr in fn(file_or_fd):
Expand Down