From faa4dfe6946947574eefdec912e26db54bc1f145 Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Thu, 18 Jun 2020 08:27:20 -0700 Subject: [PATCH 1/3] add unittest in test_models --- test/test_models.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index 7bd3f3819d..6f8d8e56ca 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -1,8 +1,10 @@ +import unittest + import torch from torchaudio.models import Wav2Letter, _MelResNet -class TestWav2Letter: +class TestWav2Letter(unittest.TestCase): def test_waveform(self): batch_size = 2 @@ -31,7 +33,7 @@ def test_mfcc(self): assert out.size() == (batch_size, num_classes, 2) -class TestMelResNet: +class TestMelResNet(unittest.TestCase): def test_waveform(self): @@ -49,3 +51,6 @@ def test_waveform(self): out = model(x) assert out.size() == (batch_size, output_dims, num_features - pad * 2) + +if __name__ == '__main__': + unittest.main() From 281a6666162c56936d93beab80535e9dfd6aef14 Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Thu, 18 Jun 2020 10:00:47 -0700 Subject: [PATCH 2/3] update test method --- test/test_models.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index 6f8d8e56ca..e087921c46 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -3,8 +3,10 @@ import torch from torchaudio.models import Wav2Letter, _MelResNet +from . import common_utils -class TestWav2Letter(unittest.TestCase): + +class TestWav2Letter(common_utils.TorchaudioTestCase): def test_waveform(self): batch_size = 2 @@ -33,7 +35,7 @@ def test_mfcc(self): assert out.size() == (batch_size, num_classes, 2) -class TestMelResNet(unittest.TestCase): +class TestMelResNet(common_utils.TorchaudioTestCase): def test_waveform(self): From af85e41d85eca8fca76c5fe9e755ca451643dbf5 Mon Sep 17 00:00:00 2001 From: Ji Chen Date: Tue, 23 Jun 2020 08:35:02 -0700 Subject: [PATCH 3/3] remove unittest main function --- test/test_models.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index e087921c46..57c86cc637 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -1,5 +1,3 @@ -import unittest - import torch from torchaudio.models import Wav2Letter, _MelResNet @@ -53,6 +51,3 @@ def test_waveform(self): out = model(x) assert out.size() == (batch_size, output_dims, num_features - pad * 2) - -if __name__ == '__main__': - unittest.main()