|
| 1 | +import io |
1 | 2 | import itertools |
2 | 3 |
|
3 | 4 | from torchaudio.backend import sox_io_backend |
@@ -417,3 +418,88 @@ def test_tensor_preserve(self, dtype): |
417 | 418 | sox_io_backend.save(path, data, 8000) |
418 | 419 |
|
419 | 420 | self.assertEqual(data, expected) |
| 421 | + |
| 422 | + |
| 423 | +@skipIfNoExtension |
| 424 | +@skipIfNoExec('sox') |
| 425 | +class TestFileObject(SaveTestBase): |
| 426 | + """ |
| 427 | + We campare the result of file-like object input against file path input because |
| 428 | + `save` function is rigrously tested for file path inputs to match libsox's result, |
| 429 | + """ |
| 430 | + @parameterized.expand([ |
| 431 | + ('wav', None), |
| 432 | + ('mp3', 128), |
| 433 | + ('mp3', 320), |
| 434 | + ('flac', 0), |
| 435 | + ('flac', 5), |
| 436 | + ('flac', 8), |
| 437 | + ('vorbis', -1), |
| 438 | + ('vorbis', 10), |
| 439 | + ('amb', None), |
| 440 | + ]) |
| 441 | + def test_fileobj(self, ext, compression): |
| 442 | + """Saving audio to file object returns the same result as via file path.""" |
| 443 | + sample_rate = 16000 |
| 444 | + dtype = 'float32' |
| 445 | + num_channels = 2 |
| 446 | + num_frames = 16000 |
| 447 | + channels_first = True |
| 448 | + |
| 449 | + data = get_wav_data(dtype, num_channels, num_frames=num_frames) |
| 450 | + |
| 451 | + ref_path = self.get_temp_path(f'reference.{ext}') |
| 452 | + res_path = self.get_temp_path(f'test.{ext}') |
| 453 | + sox_io_backend.save( |
| 454 | + ref_path, data, channels_first=channels_first, |
| 455 | + sample_rate=sample_rate, compression=compression) |
| 456 | + with open(res_path, 'wb') as fileobj: |
| 457 | + sox_io_backend.save( |
| 458 | + fileobj, data, channels_first=channels_first, |
| 459 | + sample_rate=sample_rate, compression=compression, format=ext) |
| 460 | + |
| 461 | + expected_data, _ = sox_io_backend.load(ref_path) |
| 462 | + data, sr = sox_io_backend.load(res_path) |
| 463 | + |
| 464 | + assert sample_rate == sr |
| 465 | + self.assertEqual(expected_data, data) |
| 466 | + |
| 467 | + @parameterized.expand([ |
| 468 | + ('wav', None), |
| 469 | + ('mp3', 128), |
| 470 | + ('mp3', 320), |
| 471 | + ('flac', 0), |
| 472 | + ('flac', 5), |
| 473 | + ('flac', 8), |
| 474 | + ('vorbis', -1), |
| 475 | + ('vorbis', 10), |
| 476 | + ('amb', None), |
| 477 | + ]) |
| 478 | + def test_bytesio(self, ext, compression): |
| 479 | + """Saving audio to BytesIO object returns the same result as via file path.""" |
| 480 | + sample_rate = 16000 |
| 481 | + dtype = 'float32' |
| 482 | + num_channels = 2 |
| 483 | + num_frames = 16000 |
| 484 | + channels_first = True |
| 485 | + |
| 486 | + data = get_wav_data(dtype, num_channels, num_frames=num_frames) |
| 487 | + |
| 488 | + ref_path = self.get_temp_path(f'reference.{ext}') |
| 489 | + res_path = self.get_temp_path(f'test.{ext}') |
| 490 | + sox_io_backend.save( |
| 491 | + ref_path, data, channels_first=channels_first, |
| 492 | + sample_rate=sample_rate, compression=compression) |
| 493 | + fileobj = io.BytesIO() |
| 494 | + sox_io_backend.save( |
| 495 | + fileobj, data, channels_first=channels_first, |
| 496 | + sample_rate=sample_rate, compression=compression, format=ext) |
| 497 | + fileobj.seek(0) |
| 498 | + with open(res_path, 'wb') as file_: |
| 499 | + file_.write(fileobj.read()) |
| 500 | + |
| 501 | + expected_data, _ = sox_io_backend.load(ref_path) |
| 502 | + data, sr = sox_io_backend.load(res_path) |
| 503 | + |
| 504 | + assert sample_rate == sr |
| 505 | + self.assertEqual(expected_data, data) |
0 commit comments