diff --git a/test/test_models.py b/test/test_models.py index 57c86cc637..fa9f77e3da 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -1,3 +1,5 @@ +import unittest + import torch from torchaudio.models import Wav2Letter, _MelResNet @@ -51,3 +53,7 @@ def test_waveform(self): out = model(x) assert out.size() == (batch_size, output_dims, num_features - pad * 2) + + +if __name__ == "__main__": + unittest.main()