Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
9bc0ca4
Added bits_per_sample field to sox_io backend
NicolasHug Jan 12, 2021
cd28304
Added tests for sox_io backend
NicolasHug Jan 13, 2021
c29834c
Added bits_per_sample support and tests to SoundFile backend
NicolasHug Jan 13, 2021
f99885b
fixed FLAC test
NicolasHug Jan 13, 2021
1f5781d
Fixed OGG test
NicolasHug Jan 13, 2021
e3f1f09
Merge branch 'master' of github.com:pytorch/audio into bits_per_sample
NicolasHug Jan 15, 2021
c80cfd4
Addressed comments: Added links to bit depths + handled unknown subtype
NicolasHug Jan 15, 2021
49d2c7c
Document bits_per_sample=0
NicolasHug Jan 15, 2021
c64ea12
fix amr-nb test
NicolasHug Jan 15, 2021
01c26f9
Addressed comments
NicolasHug Jan 18, 2021
17e1cc9
Merge branch 'master' of github.com:pytorch/audio into bits_per_sample
NicolasHug Jan 18, 2021
c91d0a3
Use f-strings instead of .format()
NicolasHug Jan 18, 2021
56a7536
Merge branch 'master' of github.com:pytorch/audio into bits_per_sample
NicolasHug Jan 19, 2021
04480b5
Addressed comments
NicolasHug Jan 19, 2021
04b1378
remove unused import
NicolasHug Jan 19, 2021
1d30db6
remove usage of pytest
NicolasHug Jan 20, 2021
2f955e8
Merge branch 'master' of github.com:pytorch/audio into bits_per_sample
NicolasHug Jan 21, 2021
41d2020
Addressed comments
NicolasHug Jan 21, 2021
ffee30a
use proper param name
NicolasHug Jan 21, 2021
532397f
expected bps is 0 for amr_nb?
NicolasHug Jan 21, 2021
dff1e04
Merge branch 'master' of github.com:pytorch/audio into bits_per_sample
NicolasHug Jan 22, 2021
9f86979
Merge branch 'master' of github.com:pytorch/audio into bits_per_sample
NicolasHug Jan 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/torchaudio_unittest/soundfile_backend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def skipIfFormatNotSupported(fmt):
import soundfile

fmts = soundfile.available_formats()
return skipIf(fmt not in fmts, f'"{fmt}" is not supported by sondfile')
return skipIf(fmt not in fmts, f'"{fmt}" is not supported by soundfile')
return skipIf(True, '"soundfile" not available.')


Expand Down
46 changes: 39 additions & 7 deletions test/torchaudio_unittest/soundfile_backend/info_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from unittest.mock import patch
import warnings

import torch
from torchaudio.backend import _soundfile_backend as soundfile_backend
from torchaudio._internal import module_utils as _mod_utils
Expand All @@ -18,10 +21,11 @@
@skipIfNoModule("soundfile")
class TestInfo(TempDirMixin, PytorchTestCase):
@parameterize(
["float32", "int32", "int16", "uint8"], [8000, 16000], [1, 2],
[("float32", 32), ("int32", 32), ("int16", 16), ("uint8", 8)], [8000, 16000], [1, 2],
)
def test_wav(self, dtype, sample_rate, num_channels):
def test_wav(self, dtype_and_bit_depth, sample_rate, num_channels):
"""`soundfile_backend.info` can check wav file correctly"""
dtype, bits_per_sample = dtype_and_bit_depth
duration = 1
path = self.get_temp_path("data.wav")
data = get_wav_data(
Expand All @@ -32,12 +36,14 @@ def test_wav(self, dtype, sample_rate, num_channels):
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample

@parameterize(
["float32", "int32", "int16", "uint8"], [8000, 16000], [4, 8, 16, 32],
[("float32", 32), ("int32", 32), ("int16", 16), ("uint8", 8)], [8000, 16000], [1, 2],
)
def test_wav_multiple_channels(self, dtype, sample_rate, num_channels):
def test_wav_multiple_channels(self, dtype_and_bit_depth, sample_rate, num_channels):
"""`soundfile_backend.info` can check wav file with channels more than 2 correctly"""
dtype, bits_per_sample = dtype_and_bit_depth
duration = 1
path = self.get_temp_path("data.wav")
data = get_wav_data(
Expand All @@ -48,6 +54,7 @@ def test_wav_multiple_channels(self, dtype, sample_rate, num_channels):
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample

@parameterize([8000, 16000], [1, 2])
@skipIfFormatNotSupported("FLAC")
Expand All @@ -63,6 +70,7 @@ def test_flac(self, sample_rate, num_channels):
assert info.sample_rate == sample_rate
assert info.num_frames == num_frames
assert info.num_channels == num_channels
assert info.bits_per_sample == 16

@parameterize([8000, 16000], [1, 2])
@skipIfFormatNotSupported("OGG")
Expand All @@ -78,18 +86,42 @@ def test_ogg(self, sample_rate, num_channels):
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 0

@parameterize([8000, 16000], [1, 2])
@parameterize([8000, 16000], [1, 2], [('PCM_24', 24), ('PCM_32', 32)])
@skipIfFormatNotSupported("NIST")
def test_sphere(self, sample_rate, num_channels):
def test_sphere(self, sample_rate, num_channels, subtype_and_bit_depth):
"""`soundfile_backend.info` can check sph file correctly"""
duration = 1
num_frames = sample_rate * duration
data = torch.randn(num_frames, num_channels).numpy()
path = self.get_temp_path("data.nist")
soundfile.write(path, data, sample_rate)
subtype, bits_per_sample = subtype_and_bit_depth
soundfile.write(path, data, sample_rate, subtype=subtype)

info = soundfile_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample

def test_unknown_subtype_warning(self):
"""soundfile_backend.info issues a warning when the subtype is unknown

This will happen if a new subtype is supported in SoundFile: the _SUBTYPE_TO_BITS_PER_SAMPLE
dict should be updated.
"""
def _mock_info_func(_):
class MockSoundFileInfo:
samplerate = 8000
frames = 356
channels = 2
subtype = 'UNSEEN_SUBTYPE'
return MockSoundFileInfo()

with patch("soundfile.info", _mock_info_func):
with warnings.catch_warnings(record=True) as w:
info = soundfile_backend.info("foo")
assert len(w) == 1
assert "UNSEEN_SUBTYPE subtype is unknown to TorchAudio" in str(w[-1].message)
assert info.bits_per_sample == 0
21 changes: 17 additions & 4 deletions test/torchaudio_unittest/sox_io_backend/info_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def test_wav(self, dtype, sample_rate, num_channels):
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == sox_utils.get_bit_depth(dtype)

@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
Expand All @@ -52,6 +53,7 @@ def test_wav_multiple_channels(self, dtype, sample_rate, num_channels):
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == sox_utils.get_bit_depth(dtype)

@parameterized.expand(list(itertools.product(
[8000, 16000],
Expand All @@ -71,6 +73,7 @@ def test_mp3(self, sample_rate, num_channels, bit_rate):
# mp3 does not preserve the number of samples
# assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats

@parameterized.expand(list(itertools.product(
[8000, 16000],
Expand All @@ -89,6 +92,7 @@ def test_flac(self, sample_rate, num_channels, compression_level):
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 24 # FLAC standard

@parameterized.expand(list(itertools.product(
[8000, 16000],
Expand All @@ -107,20 +111,23 @@ def test_vorbis(self, sample_rate, num_channels, quality_level):
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats

@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[16, 32],
)), name_func=name_func)
def test_sphere(self, sample_rate, num_channels):
def test_sphere(self, sample_rate, num_channels, bits_per_sample):
"""`sox_io_backend.info` can check sph file correctly"""
duration = 1
path = self.get_temp_path('data.sph')
sox_utils.gen_audio_file(path, sample_rate, num_channels, duration=duration)
sox_utils.gen_audio_file(path, sample_rate, num_channels, duration=duration, bit_depth=bits_per_sample)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample

@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
Expand All @@ -131,13 +138,15 @@ def test_amb(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check amb file correctly"""
duration = 1
path = self.get_temp_path('data.amb')
bits_per_sample = sox_utils.get_bit_depth(dtype)
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
bit_depth=sox_utils.get_bit_depth(dtype), duration=duration)
bit_depth=bits_per_sample, duration=duration)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == bits_per_sample

def test_amr_nb(self):
"""`sox_io_backend.info` can check amr-nb file correctly"""
Expand All @@ -146,11 +155,13 @@ def test_amr_nb(self):
sample_rate = 8000
path = self.get_temp_path('data.amr-nb')
sox_utils.gen_audio_file(
path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=16, duration=duration)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you try to parameterize bit_depth here?

path, sample_rate=sample_rate, num_channels=num_channels, bit_depth=16,
duration=duration)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 0


@skipIfNoExtension
Expand All @@ -167,6 +178,7 @@ def test_opus(self, bitrate, num_channels, compression_level):
assert info.sample_rate == 48000
assert info.num_frames == 32768
assert info.num_channels == num_channels
assert info.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats


@skipIfNoExtension
Expand All @@ -184,3 +196,4 @@ def test_mp3(self):
path = get_asset_path("mp3_without_ext")
sinfo = sox_io_backend.info(path, format="mp3")
assert sinfo.sample_rate == 16000
assert sinfo.bits_per_sample == 0 # bit_per_sample is irrelevant for compressed formats
49 changes: 48 additions & 1 deletion torchaudio/backend/_soundfile_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,45 @@
import soundfile


# Mapping from soundfile subtype to number of bits per sample.
# This is mostly heuristical and the value is set to 0 when it is irrelevant
# (lossy formats) or when it can't be inferred.
# For ADPCM (and G72X) subtypes, it's hard to infer the bit depth because it's not part of the standard:
# According to https://en.wikipedia.org/wiki/Adaptive_differential_pulse-code_modulation#In_telephony,
# the default seems to be 8 bits but it can be compressed further to 4 bits.
# The dict is inspired from
# https://github.com/bastibe/python-soundfile/blob/744efb4b01abc72498a96b09115b42a4cabd85e4/soundfile.py#L66-L94
_SUBTYPE_TO_BITS_PER_SAMPLE = {
'PCM_S8': 8, # Signed 8 bit data
'PCM_16': 16, # Signed 16 bit data
'PCM_24': 24, # Signed 24 bit data
'PCM_32': 32, # Signed 32 bit data
'PCM_U8': 8, # Unsigned 8 bit data (WAV and RAW only)
'FLOAT': 32, # 32 bit float data
'DOUBLE': 64, # 64 bit float data
'ULAW': 8, # U-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types
'ALAW': 8, # A-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types
'IMA_ADPCM': 0, # IMA ADPCM.
'MS_ADPCM': 0, # Microsoft ADPCM.
'GSM610': 0, # GSM 6.10 encoding. (Wikipedia says 1.625 bit depth?? https://en.wikipedia.org/wiki/Full_Rate)
'VOX_ADPCM': 0, # OKI / Dialogix ADPCM
'G721_32': 0, # 32kbs G721 ADPCM encoding.
'G723_24': 0, # 24kbs G723 ADPCM encoding.
'G723_40': 0, # 40kbs G723 ADPCM encoding.
'DWVW_12': 12, # 12 bit Delta Width Variable Word encoding.
'DWVW_16': 16, # 16 bit Delta Width Variable Word encoding.
'DWVW_24': 24, # 24 bit Delta Width Variable Word encoding.
'DWVW_N': 0, # N bit Delta Width Variable Word encoding.
'DPCM_8': 8, # 8 bit differential PCM (XI only)
'DPCM_16': 16, # 16 bit differential PCM (XI only)
'VORBIS': 0, # Xiph Vorbis encoding. (lossy)
'ALAC_16': 16, # Apple Lossless Audio Codec (16 bit).
'ALAC_20': 20, # Apple Lossless Audio Codec (20 bit).
'ALAC_24': 24, # Apple Lossless Audio Codec (24 bit).
'ALAC_32': 32, # Apple Lossless Audio Codec (32 bit).
}


@_mod_utils.requires_module("soundfile")
def info(filepath: str, format: Optional[str] = None) -> AudioMetaData:
"""Get signal information of an audio file.
Expand All @@ -27,7 +66,15 @@ def info(filepath: str, format: Optional[str] = None) -> AudioMetaData:
AudioMetaData: meta data of the given audio.
"""
sinfo = soundfile.info(filepath)
return AudioMetaData(sinfo.samplerate, sinfo.frames, sinfo.channels)
if sinfo.subtype not in _SUBTYPE_TO_BITS_PER_SAMPLE:
warnings.warn(
f"The {sinfo.subtype} subtype is unknown to TorchAudio. As a result, the bits_per_sample "
"attribute will be set to 0. If you are seeing this warning, please "
"report by opening an issue on github (after checking for existing/closed ones). "
"You may otherwise ignore this warning."
)
bits_per_sample = _SUBTYPE_TO_BITS_PER_SAMPLE.get(sinfo.subtype, 0)
return AudioMetaData(sinfo.samplerate, sinfo.frames, sinfo.channels, bits_per_sample=bits_per_sample)


_SUBTYPE2DTYPE = {
Expand Down
5 changes: 4 additions & 1 deletion torchaudio/backend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@ class AudioMetaData:
:ivar int sample_rate: Sample rate
:ivar int num_frames: The number of frames
:ivar int num_channels: The number of channels
:ivar int bits_per_sample: The number of bits per sample. This is 0 for lossy formats,
or when it cannot be accurately inferred.
"""
def __init__(self, sample_rate: int, num_frames: int, num_channels: int):
def __init__(self, sample_rate: int, num_frames: int, num_channels: int, bits_per_sample: int):
self.sample_rate = sample_rate
self.num_frames = num_frames
self.num_channels = num_channels
self.bits_per_sample = bits_per_sample


@_mod_utils.deprecated('Please migrate to `AudioMetaData`.', '0.9.0')
Expand Down
3 changes: 2 additions & 1 deletion torchaudio/backend/sox_io_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def info(
# Cast to str in case type is `pathlib.Path`
filepath = str(filepath)
sinfo = torch.ops.torchaudio.sox_io_get_info(filepath, format)
return AudioMetaData(sinfo.get_sample_rate(), sinfo.get_num_frames(), sinfo.get_num_channels())
return AudioMetaData(sinfo.get_sample_rate(), sinfo.get_num_frames(), sinfo.get_num_channels(),
sinfo.get_bits_per_sample())


@_mod_utils.requires_module('torchaudio._torchaudio')
Expand Down
13 changes: 10 additions & 3 deletions torchaudio/csrc/sox/io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ namespace sox_io {
SignalInfo::SignalInfo(
const int64_t sample_rate_,
const int64_t num_channels_,
const int64_t num_frames_)
const int64_t num_frames_,
const int64_t bits_per_sample_)
: sample_rate(sample_rate_),
num_channels(num_channels_),
num_frames(num_frames_){};
num_frames(num_frames_),
bits_per_sample(bits_per_sample_){};

int64_t SignalInfo::getSampleRate() const {
return sample_rate;
Expand All @@ -30,6 +32,10 @@ int64_t SignalInfo::getNumFrames() const {
return num_frames;
}

int64_t SignalInfo::getBitsPerSample() const {
return bits_per_sample;
}

c10::intrusive_ptr<SignalInfo> get_info(
const std::string& path,
c10::optional<std::string>& format) {
Expand All @@ -46,7 +52,8 @@ c10::intrusive_ptr<SignalInfo> get_info(
return c10::make_intrusive<SignalInfo>(
static_cast<int64_t>(sf->signal.rate),
static_cast<int64_t>(sf->signal.channels),
static_cast<int64_t>(sf->signal.length / sf->signal.channels));
static_cast<int64_t>(sf->signal.length / sf->signal.channels),
static_cast<int64_t>(sf->encoding.bits_per_sample));
}

namespace {
Expand Down
5 changes: 4 additions & 1 deletion torchaudio/csrc/sox/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,17 @@ struct SignalInfo : torch::CustomClassHolder {
int64_t sample_rate;
int64_t num_channels;
int64_t num_frames;
int64_t bits_per_sample;

SignalInfo(
const int64_t sample_rate_,
const int64_t num_channels_,
const int64_t num_frames_);
const int64_t num_frames_,
const int64_t bits_per_sample_);
int64_t getSampleRate() const;
int64_t getNumChannels() const;
int64_t getNumFrames() const;
int64_t getBitsPerSample() const;
};

c10::intrusive_ptr<SignalInfo> get_info(
Expand Down
3 changes: 2 additions & 1 deletion torchaudio/csrc/sox/register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.class_<torchaudio::sox_io::SignalInfo>("SignalInfo")
.def("get_sample_rate", &torchaudio::sox_io::SignalInfo::getSampleRate)
.def("get_num_channels", &torchaudio::sox_io::SignalInfo::getNumChannels)
.def("get_num_frames", &torchaudio::sox_io::SignalInfo::getNumFrames);
.def("get_num_frames", &torchaudio::sox_io::SignalInfo::getNumFrames)
.def("get_bits_per_sample", &torchaudio::sox_io::SignalInfo::getBitsPerSample);

m.def("torchaudio::sox_io_get_info", &torchaudio::sox_io::get_info);
m.def(
Expand Down