|
| 1 | +import io |
1 | 2 | import itertools |
| 3 | +import tarfile |
2 | 4 |
|
3 | | -from torchaudio.backend import sox_io_backend |
4 | 5 | from parameterized import parameterized |
| 6 | +from torchaudio.backend import sox_io_backend |
| 7 | +from torchaudio._internal import module_utils as _mod_utils |
5 | 8 |
|
6 | 9 | from torchaudio_unittest.common_utils import ( |
7 | 10 | TempDirMixin, |
| 11 | + HttpServerMixin, |
8 | 12 | PytorchTestCase, |
9 | 13 | skipIfNoExec, |
10 | 14 | skipIfNoExtension, |
| 15 | + skipIfNoModule, |
11 | 16 | get_asset_path, |
12 | 17 | get_wav_data, |
13 | 18 | load_wav, |
|
19 | 24 | ) |
20 | 25 |
|
21 | 26 |
|
| 27 | +if _mod_utils.is_module_available("requests"): |
| 28 | + import requests |
| 29 | + |
| 30 | + |
22 | 31 | class LoadTestBase(TempDirMixin, PytorchTestCase): |
23 | 32 | def assert_wav(self, dtype, sample_rate, num_channels, normalize, duration): |
24 | 33 | """`sox_io_backend.load` can load wav format correctly. |
@@ -369,3 +378,156 @@ def test_mp3(self): |
369 | 378 | path = get_asset_path("mp3_without_ext") |
370 | 379 | _, sr = sox_io_backend.load(path, format="mp3") |
371 | 380 | assert sr == 16000 |
| 381 | + |
| 382 | + |
| 383 | +@skipIfNoExtension |
| 384 | +@skipIfNoExec('sox') |
| 385 | +class TestFileObject(TempDirMixin, PytorchTestCase): |
| 386 | + """ |
| 387 | + In this test suite, the result of file-like object input is compared against file path input, |
| 388 | + because `load` function is rigrously tested for file path inputs to match libsox's result, |
| 389 | + """ |
| 390 | + @parameterized.expand([ |
| 391 | + ('wav', None), |
| 392 | + ('mp3', 128), |
| 393 | + ('mp3', 320), |
| 394 | + ('flac', 0), |
| 395 | + ('flac', 5), |
| 396 | + ('flac', 8), |
| 397 | + ('vorbis', -1), |
| 398 | + ('vorbis', 10), |
| 399 | + ('amb', None), |
| 400 | + ]) |
| 401 | + def test_fileobj(self, ext, compression): |
| 402 | + """Loading audio via file object returns the same result as via file path.""" |
| 403 | + sample_rate = 16000 |
| 404 | + format_ = ext if ext in ['mp3'] else None |
| 405 | + path = self.get_temp_path(f'test.{ext}') |
| 406 | + |
| 407 | + sox_utils.gen_audio_file( |
| 408 | + path, sample_rate, num_channels=2, |
| 409 | + compression=compression) |
| 410 | + expected, _ = sox_io_backend.load(path) |
| 411 | + |
| 412 | + with open(path, 'rb') as fileobj: |
| 413 | + found, sr = sox_io_backend.load(fileobj, format=format_) |
| 414 | + |
| 415 | + assert sr == sample_rate |
| 416 | + self.assertEqual(expected, found) |
| 417 | + |
| 418 | + @parameterized.expand([ |
| 419 | + ('wav', None), |
| 420 | + ('mp3', 128), |
| 421 | + ('mp3', 320), |
| 422 | + ('flac', 0), |
| 423 | + ('flac', 5), |
| 424 | + ('flac', 8), |
| 425 | + ('vorbis', -1), |
| 426 | + ('vorbis', 10), |
| 427 | + ('amb', None), |
| 428 | + ]) |
| 429 | + def test_bytesio(self, ext, compression): |
| 430 | + """Loading audio via BytesIO object returns the same result as via file path.""" |
| 431 | + sample_rate = 16000 |
| 432 | + format_ = ext if ext in ['mp3'] else None |
| 433 | + path = self.get_temp_path(f'test.{ext}') |
| 434 | + |
| 435 | + sox_utils.gen_audio_file( |
| 436 | + path, sample_rate, num_channels=2, |
| 437 | + compression=compression) |
| 438 | + expected, _ = sox_io_backend.load(path) |
| 439 | + |
| 440 | + with open(path, 'rb') as file_: |
| 441 | + fileobj = io.BytesIO(file_.read()) |
| 442 | + found, sr = sox_io_backend.load(fileobj, format=format_) |
| 443 | + |
| 444 | + assert sr == sample_rate |
| 445 | + self.assertEqual(expected, found) |
| 446 | + |
| 447 | + @parameterized.expand([ |
| 448 | + ('wav', None), |
| 449 | + ('mp3', 128), |
| 450 | + ('mp3', 320), |
| 451 | + ('flac', 0), |
| 452 | + ('flac', 5), |
| 453 | + ('flac', 8), |
| 454 | + ('vorbis', -1), |
| 455 | + ('vorbis', 10), |
| 456 | + ('amb', None), |
| 457 | + ]) |
| 458 | + def test_tarfile(self, ext, compression): |
| 459 | + """Loading compressed audio via file-like object returns the same result as via file path.""" |
| 460 | + sample_rate = 16000 |
| 461 | + format_ = ext if ext in ['mp3'] else None |
| 462 | + audio_file = f'test.{ext}' |
| 463 | + audio_path = self.get_temp_path(audio_file) |
| 464 | + archive_path = self.get_temp_path('archive.tar.gz') |
| 465 | + |
| 466 | + sox_utils.gen_audio_file( |
| 467 | + audio_path, sample_rate, num_channels=2, |
| 468 | + compression=compression) |
| 469 | + expected, _ = sox_io_backend.load(audio_path) |
| 470 | + |
| 471 | + with tarfile.TarFile(archive_path, 'w') as tarobj: |
| 472 | + tarobj.add(audio_path, arcname=audio_file) |
| 473 | + with tarfile.TarFile(archive_path, 'r') as tarobj: |
| 474 | + fileobj = tarobj.extractfile(audio_file) |
| 475 | + found, sr = sox_io_backend.load(fileobj, format=format_) |
| 476 | + |
| 477 | + assert sr == sample_rate |
| 478 | + self.assertEqual(expected, found) |
| 479 | + |
| 480 | + |
| 481 | +@skipIfNoExtension |
| 482 | +@skipIfNoExec('sox') |
| 483 | +@skipIfNoModule("requests") |
| 484 | +class TestFileObjectHttp(HttpServerMixin, PytorchTestCase): |
| 485 | + @parameterized.expand([ |
| 486 | + ('wav', None), |
| 487 | + ('mp3', 128), |
| 488 | + ('mp3', 320), |
| 489 | + ('flac', 0), |
| 490 | + ('flac', 5), |
| 491 | + ('flac', 8), |
| 492 | + ('vorbis', -1), |
| 493 | + ('vorbis', 10), |
| 494 | + ('amb', None), |
| 495 | + ]) |
| 496 | + def test_requests(self, ext, compression): |
| 497 | + sample_rate = 16000 |
| 498 | + format_ = ext if ext in ['mp3'] else None |
| 499 | + audio_file = f'test.{ext}' |
| 500 | + audio_path = self.get_temp_path(audio_file) |
| 501 | + |
| 502 | + sox_utils.gen_audio_file( |
| 503 | + audio_path, sample_rate, num_channels=2, compression=compression) |
| 504 | + expected, _ = sox_io_backend.load(audio_path) |
| 505 | + |
| 506 | + url = self.get_url(audio_file) |
| 507 | + with requests.get(url, stream=True) as resp: |
| 508 | + found, sr = sox_io_backend.load(resp.raw, format=format_) |
| 509 | + |
| 510 | + assert sr == sample_rate |
| 511 | + self.assertEqual(expected, found) |
| 512 | + |
| 513 | + @parameterized.expand(list(itertools.product( |
| 514 | + [0, 1, 10, 100, 1000], |
| 515 | + [-1, 1, 10, 100, 1000], |
| 516 | + )), name_func=name_func) |
| 517 | + def test_frame(self, frame_offset, num_frames): |
| 518 | + """num_frames and frame_offset correctly specify the region of data""" |
| 519 | + sample_rate = 8000 |
| 520 | + audio_file = 'test.wav' |
| 521 | + audio_path = self.get_temp_path(audio_file) |
| 522 | + |
| 523 | + original = get_wav_data('float32', num_channels=2) |
| 524 | + save_wav(audio_path, original, sample_rate) |
| 525 | + frame_end = None if num_frames == -1 else frame_offset + num_frames |
| 526 | + expected = original[:, frame_offset:frame_end] |
| 527 | + |
| 528 | + url = self.get_url(audio_file) |
| 529 | + with requests.get(url, stream=True) as resp: |
| 530 | + found, sr = sox_io_backend.load(resp.raw, frame_offset, num_frames) |
| 531 | + |
| 532 | + assert sr == sample_rate |
| 533 | + self.assertEqual(expected, found) |
0 commit comments