diff --git a/test/torchaudio_unittest/soundfile_backend/common.py b/test/torchaudio_unittest/soundfile_backend/common.py index 4da0873fa1..8f991fb0f8 100644 --- a/test/torchaudio_unittest/soundfile_backend/common.py +++ b/test/torchaudio_unittest/soundfile_backend/common.py @@ -26,7 +26,7 @@ def skipIfFormatNotSupported(fmt): import soundfile fmts = soundfile.available_formats() - return skipIf(fmt not in fmts, f'"{fmt}" is not supported by sondfile') + return skipIf(fmt not in fmts, f'"{fmt}" is not supported by soundfile') return skipIf(True, '"soundfile" not available.') diff --git a/test/torchaudio_unittest/soundfile_backend/info_test.py b/test/torchaudio_unittest/soundfile_backend/info_test.py index 71acb20689..05c9ddc3ce 100644 --- a/test/torchaudio_unittest/soundfile_backend/info_test.py +++ b/test/torchaudio_unittest/soundfile_backend/info_test.py @@ -1,3 +1,6 @@ +from unittest.mock import patch +import warnings + import torch from torchaudio.backend import _soundfile_backend as soundfile_backend from torchaudio._internal import module_utils as _mod_utils @@ -18,10 +21,11 @@ @skipIfNoModule("soundfile") class TestInfo(TempDirMixin, PytorchTestCase): @parameterize( - ["float32", "int32", "int16", "uint8"], [8000, 16000], [1, 2], + [("float32", 32), ("int32", 32), ("int16", 16), ("uint8", 8)], [8000, 16000], [1, 2], ) - def test_wav(self, dtype, sample_rate, num_channels): + def test_wav(self, dtype_and_bit_depth, sample_rate, num_channels): """`soundfile_backend.info` can check wav file correctly""" + dtype, bits_per_sample = dtype_and_bit_depth duration = 1 path = self.get_temp_path("data.wav") data = get_wav_data( @@ -32,12 +36,14 @@ def test_wav(self, dtype, sample_rate, num_channels): assert info.sample_rate == sample_rate assert info.num_frames == sample_rate * duration assert info.num_channels == num_channels + assert info.bits_per_sample == bits_per_sample @parameterize( - ["float32", "int32", "int16", "uint8"], [8000, 16000], [4, 8, 16, 32], + [("float32", 32), ("int32", 32), ("int16", 16), ("uint8", 8)], [8000, 16000], [1, 2], ) - def test_wav_multiple_channels(self, dtype, sample_rate, num_channels): + def test_wav_multiple_channels(self, dtype_and_bit_depth, sample_rate, num_channels): """`soundfile_backend.info` can check wav file with channels more than 2 correctly""" + dtype, bits_per_sample = dtype_and_bit_depth duration = 1 path = self.get_temp_path("data.wav") data = get_wav_data( @@ -48,6 +54,7 @@ def test_wav_multiple_channels(self, dtype, sample_rate, num_channels): assert info.sample_rate == sample_rate assert info.num_frames == sample_rate * duration assert info.num_channels == num_channels + assert info.bits_per_sample == bits_per_sample @parameterize([8000, 16000], [1, 2]) @skipIfFormatNotSupported("FLAC") @@ -63,6 +70,7 @@ def test_flac(self, sample_rate, num_channels): assert info.sample_rate == sample_rate assert info.num_frames == num_frames assert info.num_channels == num_channels + assert info.bits_per_sample == 16 @parameterize([8000, 16000], [1, 2]) @skipIfFormatNotSupported("OGG") @@ -78,18 +86,42 @@ def test_ogg(self, sample_rate, num_channels): assert info.sample_rate == sample_rate assert info.num_frames == sample_rate * duration assert info.num_channels == num_channels + assert info.bits_per_sample == 0 - @parameterize([8000, 16000], [1, 2]) + @parameterize([8000, 16000], [1, 2], [('PCM_24', 24), ('PCM_32', 32)]) @skipIfFormatNotSupported("NIST") - def test_sphere(self, sample_rate, num_channels): + def test_sphere(self, sample_rate, num_channels, subtype_and_bit_depth): """`soundfile_backend.info` can check sph file correctly""" duration = 1 num_frames = sample_rate * duration data = torch.randn(num_frames, num_channels).numpy() path = self.get_temp_path("data.nist") - soundfile.write(path, data, sample_rate) + subtype, bits_per_sample = subtype_and_bit_depth + soundfile.write(path, data, sample_rate, subtype=subtype) info = soundfile_backend.info(path) assert info.sample_rate == sample_rate assert info.num_frames == sample_rate * duration assert info.num_channels == num_channels + assert info.bits_per_sample == bits_per_sample + + def test_unknown_subtype_warning(self): + """soundfile_backend.info issues a warning when the subtype is unknown + + This will happen if a new subtype is supported in SoundFile: the _SUBTYPE_TO_BITS_PER_SAMPLE + dict should be updated. + """ + def _mock_info_func(_): + class MockSoundFileInfo: + samplerate = 8000 + frames = 356 + channels = 2 + subtype = 'UNSEEN_SUBTYPE' + return MockSoundFileInfo() + + with patch("soundfile.info", _mock_info_func): + with warnings.catch_warnings(record=True) as w: + info = soundfile_backend.info("foo") + assert len(w) == 1 + assert "UNSEEN_SUBTYPE subtype is unknown to TorchAudio" in str(w[-1].message) + assert info.bits_per_sample == 0 diff --git a/test/torchaudio_unittest/sox_io_backend/info_test.py b/test/torchaudio_unittest/sox_io_backend/info_test.py index 49fc797354..2b3b8ffdb2 100644 --- a/test/torchaudio_unittest/sox_io_backend/info_test.py +++ b/test/torchaudio_unittest/sox_io_backend/info_test.py @@ -36,6 +36,7 @@ def test_wav(self, dtype, sample_rate, num_channels): assert info.sample_rate == sample_rate assert info.num_frames == sample_rate * duration assert info.num_channels == num_channels + assert info.bits_per_sample == sox_utils.get_bit_depth(dtype) @parameterized.expand(list(itertools.product( ['float32', 'int32', 'int16', 'uint8'], @@ -52,6 +53,7 @@ def test_wav_multiple_channels(self, dtype, sample_rate, num_channels): assert info.sample_rate == sample_rate assert info.num_frames == sample_rate * duration assert info.num_channels == num_channels + assert info.bits_per_sample == sox_utils.get_bit_depth(dtype) @parameterized.expand(list(itertools.product( [8000, 16000], @@ -71,6 +73,7 @@ def test_mp3(self, sample_rate, num_channels, bit_rate): # mp3 does not preserve the number of samples # assert info.num_frames == sample_rate * duration assert info.num_channels == num_channels + assert info.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats @parameterized.expand(list(itertools.product( [8000, 16000], @@ -89,6 +92,7 @@ def test_flac(self, sample_rate, num_channels, compression_level): assert info.sample_rate == sample_rate assert info.num_frames == sample_rate * duration assert info.num_channels == num_channels + assert info.bits_per_sample == 24 # FLAC standard @parameterized.expand(list(itertools.product( [8000, 16000], @@ -107,20 +111,23 @@ def test_vorbis(self, sample_rate, num_channels, quality_level): assert info.sample_rate == sample_rate assert info.num_frames == sample_rate * duration assert info.num_channels == num_channels + assert info.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats @parameterized.expand(list(itertools.product( [8000, 16000], [1, 2], + [16, 32], )), name_func=name_func) - def test_sphere(self, sample_rate, num_channels): + def test_sphere(self, sample_rate, num_channels, bits_per_sample): """`sox_io_backend.info` can check sph file correctly""" duration = 1 path = self.get_temp_path('data.sph') - sox_utils.gen_audio_file(path, sample_rate, num_channels, duration=duration) + sox_utils.gen_audio_file(path, sample_rate, num_channels, duration=duration, bit_depth=bits_per_sample) info = sox_io_backend.info(path) assert info.sample_rate == sample_rate assert info.num_frames == sample_rate * duration assert info.num_channels == num_channels + assert info.bits_per_sample == bits_per_sample @parameterized.expand(list(itertools.product( ['float32', 'int32', 'int16', 'uint8'], @@ -131,13 +138,15 @@ def test_amb(self, dtype, sample_rate, num_channels): """`sox_io_backend.info` can check amb file correctly""" duration = 1 path = self.get_temp_path('data.amb') + bits_per_sample = sox_utils.get_bit_depth(dtype) sox_utils.gen_audio_file( path, sample_rate, num_channels, - bit_depth=sox_utils.get_bit_depth(dtype), duration=duration) + bit_depth=bits_per_sample, duration=duration) info = sox_io_backend.info(path) assert info.sample_rate == sample_rate assert info.num_frames == sample_rate * duration assert info.num_channels == num_channels + assert info.bits_per_sample == bits_per_sample def test_amr_nb(self): """`sox_io_backend.info` can check amr-nb file correctly""" @@ -146,11 +155,13 @@ def test_amr_nb(self): sample_rate = 8000 path = self.get_temp_path('data.amr-nb') sox_utils.gen_audio_file( - path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=16, duration=duration) + path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=16, + duration=duration) info = sox_io_backend.info(path) assert info.sample_rate == sample_rate assert info.num_frames == sample_rate * duration assert info.num_channels == num_channels + assert info.bits_per_sample == 0 @skipIfNoExtension @@ -167,6 +178,7 @@ def test_opus(self, bitrate, num_channels, compression_level): assert info.sample_rate == 48000 assert info.num_frames == 32768 assert info.num_channels == num_channels + assert info.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats @skipIfNoExtension @@ -184,3 +196,4 @@ def test_mp3(self): path = get_asset_path("mp3_without_ext") sinfo = sox_io_backend.info(path, format="mp3") assert sinfo.sample_rate == 16000 + assert sinfo.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats diff --git a/torchaudio/backend/_soundfile_backend.py b/torchaudio/backend/_soundfile_backend.py index 3366780bdb..54d97547f0 100644 --- a/torchaudio/backend/_soundfile_backend.py +++ b/torchaudio/backend/_soundfile_backend.py @@ -11,6 +11,45 @@ import soundfile +# Mapping from soundfile subtype to number of bits per sample. +# This is mostly heuristical and the value is set to 0 when it is irrelevant +# (lossy formats) or when it can't be inferred. +# For ADPCM (and G72X) subtypes, it's hard to infer the bit depth because it's not part of the standard: +# According to https://en.wikipedia.org/wiki/Adaptive_differential_pulse-code_modulation#In_telephony, +# the default seems to be 8 bits but it can be compressed further to 4 bits. +# The dict is inspired from +# https://github.com/bastibe/python-soundfile/blob/744efb4b01abc72498a96b09115b42a4cabd85e4/soundfile.py#L66-L94 +_SUBTYPE_TO_BITS_PER_SAMPLE = { + 'PCM_S8': 8, # Signed 8 bit data + 'PCM_16': 16, # Signed 16 bit data + 'PCM_24': 24, # Signed 24 bit data + 'PCM_32': 32, # Signed 32 bit data + 'PCM_U8': 8, # Unsigned 8 bit data (WAV and RAW only) + 'FLOAT': 32, # 32 bit float data + 'DOUBLE': 64, # 64 bit float data + 'ULAW': 8, # U-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types + 'ALAW': 8, # A-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types + 'IMA_ADPCM': 0, # IMA ADPCM. + 'MS_ADPCM': 0, # Microsoft ADPCM. + 'GSM610': 0, # GSM 6.10 encoding. (Wikipedia says 1.625 bit depth?? https://en.wikipedia.org/wiki/Full_Rate) + 'VOX_ADPCM': 0, # OKI / Dialogix ADPCM + 'G721_32': 0, # 32kbs G721 ADPCM encoding. + 'G723_24': 0, # 24kbs G723 ADPCM encoding. + 'G723_40': 0, # 40kbs G723 ADPCM encoding. + 'DWVW_12': 12, # 12 bit Delta Width Variable Word encoding. + 'DWVW_16': 16, # 16 bit Delta Width Variable Word encoding. + 'DWVW_24': 24, # 24 bit Delta Width Variable Word encoding. + 'DWVW_N': 0, # N bit Delta Width Variable Word encoding. + 'DPCM_8': 8, # 8 bit differential PCM (XI only) + 'DPCM_16': 16, # 16 bit differential PCM (XI only) + 'VORBIS': 0, # Xiph Vorbis encoding. (lossy) + 'ALAC_16': 16, # Apple Lossless Audio Codec (16 bit). + 'ALAC_20': 20, # Apple Lossless Audio Codec (20 bit). + 'ALAC_24': 24, # Apple Lossless Audio Codec (24 bit). + 'ALAC_32': 32, # Apple Lossless Audio Codec (32 bit). +} + + @_mod_utils.requires_module("soundfile") def info(filepath: str, format: Optional[str] = None) -> AudioMetaData: """Get signal information of an audio file. @@ -27,7 +66,15 @@ def info(filepath: str, format: Optional[str] = None) -> AudioMetaData: AudioMetaData: meta data of the given audio. """ sinfo = soundfile.info(filepath) - return AudioMetaData(sinfo.samplerate, sinfo.frames, sinfo.channels) + if sinfo.subtype not in _SUBTYPE_TO_BITS_PER_SAMPLE: + warnings.warn( + f"The {sinfo.subtype} subtype is unknown to TorchAudio. As a result, the bits_per_sample " + "attribute will be set to 0. If you are seeing this warning, please " + "report by opening an issue on github (after checking for existing/closed ones). " + "You may otherwise ignore this warning." + ) + bits_per_sample = _SUBTYPE_TO_BITS_PER_SAMPLE.get(sinfo.subtype, 0) + return AudioMetaData(sinfo.samplerate, sinfo.frames, sinfo.channels, bits_per_sample=bits_per_sample) _SUBTYPE2DTYPE = { diff --git a/torchaudio/backend/common.py b/torchaudio/backend/common.py index 135a18caee..f9c6585bf8 100644 --- a/torchaudio/backend/common.py +++ b/torchaudio/backend/common.py @@ -12,11 +12,14 @@ class AudioMetaData: :ivar int sample_rate: Sample rate :ivar int num_frames: The number of frames :ivar int num_channels: The number of channels + :ivar int bits_per_sample: The number of bits per sample. This is 0 for lossy formats, + or when it cannot be accurately inferred. """ - def __init__(self, sample_rate: int, num_frames: int, num_channels: int): + def __init__(self, sample_rate: int, num_frames: int, num_channels: int, bits_per_sample: int): self.sample_rate = sample_rate self.num_frames = num_frames self.num_channels = num_channels + self.bits_per_sample = bits_per_sample @_mod_utils.deprecated('Please migrate to `AudioMetaData`.', '0.9.0') diff --git a/torchaudio/backend/sox_io_backend.py b/torchaudio/backend/sox_io_backend.py index 18296ef7e2..20dea70c52 100644 --- a/torchaudio/backend/sox_io_backend.py +++ b/torchaudio/backend/sox_io_backend.py @@ -32,7 +32,8 @@ def info( # Cast to str in case type is `pathlib.Path` filepath = str(filepath) sinfo = torch.ops.torchaudio.sox_io_get_info(filepath, format) - return AudioMetaData(sinfo.get_sample_rate(), sinfo.get_num_frames(), sinfo.get_num_channels()) + return AudioMetaData(sinfo.get_sample_rate(), sinfo.get_num_frames(), sinfo.get_num_channels(), + sinfo.get_bits_per_sample()) @_mod_utils.requires_module('torchaudio._torchaudio') diff --git a/torchaudio/csrc/sox/io.cpp b/torchaudio/csrc/sox/io.cpp index c6531eb77e..a71bed4450 100644 --- a/torchaudio/csrc/sox/io.cpp +++ b/torchaudio/csrc/sox/io.cpp @@ -13,10 +13,12 @@ namespace sox_io { SignalInfo::SignalInfo( const int64_t sample_rate_, const int64_t num_channels_, - const int64_t num_frames_) + const int64_t num_frames_, + const int64_t bits_per_sample_) : sample_rate(sample_rate_), num_channels(num_channels_), - num_frames(num_frames_){}; + num_frames(num_frames_), + bits_per_sample(bits_per_sample_){}; int64_t SignalInfo::getSampleRate() const { return sample_rate; @@ -30,6 +32,10 @@ int64_t SignalInfo::getNumFrames() const { return num_frames; } +int64_t SignalInfo::getBitsPerSample() const { + return bits_per_sample; +} + c10::intrusive_ptr get_info( const std::string& path, c10::optional& format) { @@ -46,7 +52,8 @@ c10::intrusive_ptr get_info( return c10::make_intrusive( static_cast(sf->signal.rate), static_cast(sf->signal.channels), - static_cast(sf->signal.length / sf->signal.channels)); + static_cast(sf->signal.length / sf->signal.channels), + static_cast(sf->encoding.bits_per_sample)); } namespace { diff --git a/torchaudio/csrc/sox/io.h b/torchaudio/csrc/sox/io.h index ac7191527f..0eefc793a5 100644 --- a/torchaudio/csrc/sox/io.h +++ b/torchaudio/csrc/sox/io.h @@ -15,14 +15,17 @@ struct SignalInfo : torch::CustomClassHolder { int64_t sample_rate; int64_t num_channels; int64_t num_frames; + int64_t bits_per_sample; SignalInfo( const int64_t sample_rate_, const int64_t num_channels_, - const int64_t num_frames_); + const int64_t num_frames_, + const int64_t bits_per_sample_); int64_t getSampleRate() const; int64_t getNumChannels() const; int64_t getNumFrames() const; + int64_t getBitsPerSample() const; }; c10::intrusive_ptr get_info( diff --git a/torchaudio/csrc/sox/register.cpp b/torchaudio/csrc/sox/register.cpp index 7c65bebe2d..0f46af76d5 100644 --- a/torchaudio/csrc/sox/register.cpp +++ b/torchaudio/csrc/sox/register.cpp @@ -42,7 +42,8 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) { m.class_("SignalInfo") .def("get_sample_rate", &torchaudio::sox_io::SignalInfo::getSampleRate) .def("get_num_channels", &torchaudio::sox_io::SignalInfo::getNumChannels) - .def("get_num_frames", &torchaudio::sox_io::SignalInfo::getNumFrames); + .def("get_num_frames", &torchaudio::sox_io::SignalInfo::getNumFrames) + .def("get_bits_per_sample", &torchaudio::sox_io::SignalInfo::getBitsPerSample); m.def("torchaudio::sox_io_get_info", &torchaudio::sox_io::get_info); m.def(