Skip to content

Commit 2e3bc3d

Browse files
committed
Add smoke test for sox_io fileobj
1 parent 99ed718 commit 2e3bc3d

File tree

1 file changed

+68
-0
lines changed

1 file changed

+68
-0
lines changed

test/torchaudio_unittest/sox_io_backend/smoke_test.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import io
12
import itertools
23
import unittest
34

@@ -85,3 +86,70 @@ def test_vorbis(self, sample_rate, num_channels, quality_level):
8586
def test_flac(self, sample_rate, num_channels, compression_level):
8687
"""Run smoke test on flac format"""
8788
self.run_smoke_test('flac', sample_rate, num_channels, compression=compression_level)
89+
90+
91+
@skipIfNoExtension
92+
class SmokeTestFileObj(TorchaudioTestCase):
93+
"""Run smoke test on various audio format
94+
95+
The purpose of this test suite is to verify that sox_io_backend functionalities do not exhibit
96+
abnormal behaviors.
97+
98+
This test suite should be able to run without any additional tools (such as sox command),
99+
however without such tools, the correctness of each function cannot be verified.
100+
"""
101+
def run_smoke_test(self, ext, sample_rate, num_channels, *, compression=None, dtype='float32'):
102+
duration = 1
103+
num_frames = sample_rate * duration
104+
original = get_wav_data(dtype, num_channels, normalize=False, num_frames=num_frames)
105+
106+
fileobj = io.BytesIO()
107+
# 1. run save
108+
sox_io_backend.save(fileobj, original, sample_rate, compression=compression, format=ext)
109+
# 2. run info
110+
fileobj.seek(0)
111+
info = sox_io_backend.info(fileobj, format=ext)
112+
assert info.sample_rate == sample_rate
113+
assert info.num_channels == num_channels
114+
# 3. run load
115+
fileobj.seek(0)
116+
loaded, sr = sox_io_backend.load(fileobj, normalize=False, format=ext)
117+
assert sr == sample_rate
118+
assert loaded.shape[0] == num_channels
119+
120+
@parameterized.expand(list(itertools.product(
121+
['float32', 'int32', 'int16', 'uint8'],
122+
[8000, 16000],
123+
[1, 2],
124+
)), name_func=name_func)
125+
def test_wav(self, dtype, sample_rate, num_channels):
126+
"""Run smoke test on wav format"""
127+
self.run_smoke_test('wav', sample_rate, num_channels, dtype=dtype)
128+
129+
@parameterized.expand(list(itertools.product(
130+
[8000, 16000],
131+
[1, 2],
132+
[-4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320],
133+
)))
134+
@skipIfNoMP3
135+
def test_mp3(self, sample_rate, num_channels, bit_rate):
136+
"""Run smoke test on mp3 format"""
137+
self.run_smoke_test('mp3', sample_rate, num_channels, compression=bit_rate)
138+
139+
@parameterized.expand(list(itertools.product(
140+
[8000, 16000],
141+
[1, 2],
142+
[-1, 0, 1, 2, 3, 3.6, 5, 10],
143+
)))
144+
def test_vorbis(self, sample_rate, num_channels, quality_level):
145+
"""Run smoke test on vorbis format"""
146+
self.run_smoke_test('vorbis', sample_rate, num_channels, compression=quality_level)
147+
148+
@parameterized.expand(list(itertools.product(
149+
[8000, 16000],
150+
[1, 2],
151+
list(range(9)),
152+
)), name_func=name_func)
153+
def test_flac(self, sample_rate, num_channels, compression_level):
154+
"""Run smoke test on flac format"""
155+
self.run_smoke_test('flac', sample_rate, num_channels, compression=compression_level)

0 commit comments

Comments
 (0)