|  | 
|  | 1 | +import io | 
| 1 | 2 | import itertools | 
| 2 | 3 | import unittest | 
| 3 | 4 | 
 | 
| @@ -85,3 +86,70 @@ def test_vorbis(self, sample_rate, num_channels, quality_level): | 
| 85 | 86 |     def test_flac(self, sample_rate, num_channels, compression_level): | 
| 86 | 87 |         """Run smoke test on flac format""" | 
| 87 | 88 |         self.run_smoke_test('flac', sample_rate, num_channels, compression=compression_level) | 
|  | 89 | + | 
|  | 90 | + | 
|  | 91 | +@skipIfNoExtension | 
|  | 92 | +class SmokeTestFileObj(TorchaudioTestCase): | 
|  | 93 | +    """Run smoke test on various audio format | 
|  | 94 | +
 | 
|  | 95 | +    The purpose of this test suite is to verify that sox_io_backend functionalities do not exhibit | 
|  | 96 | +    abnormal behaviors. | 
|  | 97 | +
 | 
|  | 98 | +    This test suite should be able to run without any additional tools (such as sox command), | 
|  | 99 | +    however without such tools, the correctness of each function cannot be verified. | 
|  | 100 | +    """ | 
|  | 101 | +    def run_smoke_test(self, ext, sample_rate, num_channels, *, compression=None, dtype='float32'): | 
|  | 102 | +        duration = 1 | 
|  | 103 | +        num_frames = sample_rate * duration | 
|  | 104 | +        original = get_wav_data(dtype, num_channels, normalize=False, num_frames=num_frames) | 
|  | 105 | + | 
|  | 106 | +        fileobj = io.BytesIO() | 
|  | 107 | +        # 1. run save | 
|  | 108 | +        sox_io_backend.save(fileobj, original, sample_rate, compression=compression, format=ext) | 
|  | 109 | +        # 2. run info | 
|  | 110 | +        fileobj.seek(0) | 
|  | 111 | +        info = sox_io_backend.info(fileobj, format=ext) | 
|  | 112 | +        assert info.sample_rate == sample_rate | 
|  | 113 | +        assert info.num_channels == num_channels | 
|  | 114 | +        # 3. run load | 
|  | 115 | +        fileobj.seek(0) | 
|  | 116 | +        loaded, sr = sox_io_backend.load(fileobj, normalize=False, format=ext) | 
|  | 117 | +        assert sr == sample_rate | 
|  | 118 | +        assert loaded.shape[0] == num_channels | 
|  | 119 | + | 
|  | 120 | +    @parameterized.expand(list(itertools.product( | 
|  | 121 | +        ['float32', 'int32', 'int16', 'uint8'], | 
|  | 122 | +        [8000, 16000], | 
|  | 123 | +        [1, 2], | 
|  | 124 | +    )), name_func=name_func) | 
|  | 125 | +    def test_wav(self, dtype, sample_rate, num_channels): | 
|  | 126 | +        """Run smoke test on wav format""" | 
|  | 127 | +        self.run_smoke_test('wav', sample_rate, num_channels, dtype=dtype) | 
|  | 128 | + | 
|  | 129 | +    @parameterized.expand(list(itertools.product( | 
|  | 130 | +        [8000, 16000], | 
|  | 131 | +        [1, 2], | 
|  | 132 | +        [-4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320], | 
|  | 133 | +    ))) | 
|  | 134 | +    @skipIfNoMP3 | 
|  | 135 | +    def test_mp3(self, sample_rate, num_channels, bit_rate): | 
|  | 136 | +        """Run smoke test on mp3 format""" | 
|  | 137 | +        self.run_smoke_test('mp3', sample_rate, num_channels, compression=bit_rate) | 
|  | 138 | + | 
|  | 139 | +    @parameterized.expand(list(itertools.product( | 
|  | 140 | +        [8000, 16000], | 
|  | 141 | +        [1, 2], | 
|  | 142 | +        [-1, 0, 1, 2, 3, 3.6, 5, 10], | 
|  | 143 | +    ))) | 
|  | 144 | +    def test_vorbis(self, sample_rate, num_channels, quality_level): | 
|  | 145 | +        """Run smoke test on vorbis format""" | 
|  | 146 | +        self.run_smoke_test('vorbis', sample_rate, num_channels, compression=quality_level) | 
|  | 147 | + | 
|  | 148 | +    @parameterized.expand(list(itertools.product( | 
|  | 149 | +        [8000, 16000], | 
|  | 150 | +        [1, 2], | 
|  | 151 | +        list(range(9)), | 
|  | 152 | +    )), name_func=name_func) | 
|  | 153 | +    def test_flac(self, sample_rate, num_channels, compression_level): | 
|  | 154 | +        """Run smoke test on flac format""" | 
|  | 155 | +        self.run_smoke_test('flac', sample_rate, num_channels, compression=compression_level) | 
0 commit comments