diff --git a/test/test_models.py b/test/test_models.py index 7bd3f3819d..57c86cc637 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -1,8 +1,10 @@ import torch from torchaudio.models import Wav2Letter, _MelResNet +from . import common_utils -class TestWav2Letter: + +class TestWav2Letter(common_utils.TorchaudioTestCase): def test_waveform(self): batch_size = 2 @@ -31,7 +33,7 @@ def test_mfcc(self): assert out.size() == (batch_size, num_classes, 2) -class TestMelResNet: +class TestMelResNet(common_utils.TorchaudioTestCase): def test_waveform(self):