diff --git a/docs/source/functional.rst b/docs/source/functional.rst index edec942731..2945f5cb2c 100644 --- a/docs/source/functional.rst +++ b/docs/source/functional.rst @@ -69,6 +69,11 @@ resample .. autofunction:: resample +loudness +-------- + +.. autofunction:: loudness + :hidden:`Filtering` ~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 4fe5acce26..9077616805 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -90,6 +90,13 @@ Transforms are common audio transforms. They can be chained together using :clas .. automethod:: forward +:hidden:`Loudness` +----------------- + +.. autoclass:: Loudness + + .. automethod:: forward + :hidden:`Feature Extractions` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/test/integration_tests/loudness_compliance_test.py b/test/integration_tests/loudness_compliance_test.py new file mode 100644 index 0000000000..d9473cfa50 --- /dev/null +++ b/test/integration_tests/loudness_compliance_test.py @@ -0,0 +1,46 @@ +"""Test suite for compliance with the ITU-R BS.1770-4 recommendation""" +import zipfile + +import pytest + +import torch +import torchaudio +import torchaudio.functional as F + + +# Test files linked in https://www.itu.int/dms_pub/itu-r/opb/rep/R-REP-BS.2217-2-2016-PDF-E.pdf +@pytest.mark.parametrize( + "filename,url,expected", + [ + ( + "1770-2_Comp_RelGateTest", + "http://www.itu.int/dms_pub/itu-r/oth/11/02/R11020000010030ZIPM.zip", + -10.0, + ), + ( + "1770-2_Comp_AbsGateTest", + "http://www.itu.int/dms_pub/itu-r/oth/11/02/R11020000010029ZIPM.zip", + -69.5, + ), + ( + "1770-2_Comp_24LKFS_500Hz_2ch", + "http://www.itu.int/dms_pub/itu-r/oth/11/02/R11020000010018ZIPM.zip", + -24.0, + ), + ( + "1770-2 Conf Mono Voice+Music-24LKFS", + "http://www.itu.int/dms_pub/itu-r/oth/11/02/R11020000010038ZIPM.zip", + -24.0, + ), + ], +) +def test_loudness(tmp_path, filename, url, expected): + zippath = tmp_path / filename + torch.hub.download_url_to_file(url, zippath, progress=False) + with zipfile.ZipFile(zippath) as file: + file.extractall(zippath.parent) + + waveform, sample_rate = torchaudio.load(zippath.with_suffix(".wav")) + loudness = F.loudness(waveform, sample_rate) + expected = torch.tensor(expected, dtype=loudness.dtype, device=loudness.device) + assert torch.allclose(loudness, expected, rtol=0.01, atol=0.1) diff --git a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py index 99c1894725..83d77ec01d 100644 --- a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py @@ -120,6 +120,14 @@ def func(tensor): self._assert_consistency(func, (waveform,)) + def test_measure_loudness(self): + if self.dtype == torch.float64: + raise unittest.SkipTest("This test is known to fail for float64") + + sample_rate = 44100 + waveform = common_utils.get_sinusoid(sample_rate=sample_rate, device=self.device) + self._assert_consistency(F.loudness, (waveform, sample_rate)) + def test_melscale_fbanks(self): if self.device != torch.device("cpu"): raise unittest.SkipTest("No need to perform test on device other than CPU") diff --git a/torchaudio/functional/__init__.py b/torchaudio/functional/__init__.py index 06325da3fe..49f8c686e9 100644 --- a/torchaudio/functional/__init__.py +++ b/torchaudio/functional/__init__.py @@ -35,6 +35,7 @@ griffinlim, inverse_spectrogram, linear_fbanks, + loudness, mask_along_axis, mask_along_axis_iid, melscale_fbanks, @@ -62,6 +63,7 @@ "melscale_fbanks", "linear_fbanks", "DB_to_amplitude", + "loudness", "detect_pitch_frequency", "griffinlim", "mask_along_axis", diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index 067cbc9330..b179979185 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -11,6 +11,8 @@ from torch import Tensor from torchaudio._internal import module_utils as _mod_utils +from .filtering import highpass_biquad, treble_biquad + __all__ = [ "spectrogram", "inverse_spectrogram", @@ -35,6 +37,7 @@ "apply_codec", "resample", "edit_distance", + "loudness", "pitch_shift", "rnnt_loss", "psd", @@ -1640,6 +1643,67 @@ def edit_distance(seq1: Sequence, seq2: Sequence) -> int: return int(dold[-1]) +def loudness(waveform: Tensor, sample_rate: int): + r"""Measure audio loudness according to the ITU-R BS.1770-4 recommendation. + + .. devices:: CPU CUDA + + .. properties:: TorchScript + + Args: + waveform(torch.Tensor): audio waveform of dimension `(..., channels, time)` + sample_rate (int): sampling rate of the waveform + + Returns: + Tensor: loudness estimates (LKFS) + + Reference: + - https://www.itu.int/rec/R-REC-BS.1770-4-201510-I/en + """ + + if waveform.size(-2) > 5: + raise ValueError("Only up to 5 channels are supported.") + + gate_duration = 0.4 + overlap = 0.75 + gamma_abs = -70.0 + kweight_bias = -0.691 + gate_samples = int(round(gate_duration * sample_rate)) + step = int(round(gate_samples * (1 - overlap))) + + # Apply K-weighting + waveform = treble_biquad(waveform, sample_rate, 4.0, 1500.0, 1 / math.sqrt(2)) + waveform = highpass_biquad(waveform, sample_rate, 38.0, 0.5) + + # Compute the energy for each block + energy = torch.square(waveform).unfold(-1, gate_samples, step) + energy = torch.mean(energy, dim=-1) + + # Compute channel-weighted summation + g = torch.tensor([1.0, 1.0, 1.0, 1.41, 1.41], dtype=waveform.dtype, device=waveform.device) + g = g[: energy.size(-2)] + + energy_weighted = torch.sum(g.unsqueeze(-1) * energy, dim=-2) + loudness = -0.691 + 10 * torch.log10(energy_weighted) + + # Apply absolute gating of the blocks + gated_blocks = loudness > gamma_abs + gated_blocks = gated_blocks.unsqueeze(-2) + + energy_filtered = torch.sum(gated_blocks * energy, dim=-1) / torch.count_nonzero(gated_blocks, dim=-1) + energy_weighted = torch.sum(g * energy_filtered, dim=-1) + gamma_rel = kweight_bias + 10 * torch.log10(energy_weighted) - 10 + + # Apply relative gating of the blocks + gated_blocks = torch.logical_and(gated_blocks.squeeze(-2), loudness > gamma_rel.unsqueeze(-1)) + gated_blocks = gated_blocks.unsqueeze(-2) + + energy_filtered = torch.sum(gated_blocks * energy, dim=-1) / torch.count_nonzero(gated_blocks, dim=-1) + energy_weighted = torch.sum(g * energy_filtered, dim=-1) + LKFS = kweight_bias + 10 * torch.log10(energy_weighted) + return LKFS + + def pitch_shift( waveform: Tensor, sample_rate: int, diff --git a/torchaudio/transforms/__init__.py b/torchaudio/transforms/__init__.py index 527da5c7d2..2dc5bcf555 100644 --- a/torchaudio/transforms/__init__.py +++ b/torchaudio/transforms/__init__.py @@ -8,6 +8,7 @@ InverseMelScale, InverseSpectrogram, LFCC, + Loudness, MelScale, MelSpectrogram, MFCC, @@ -35,6 +36,7 @@ "InverseMelScale", "InverseSpectrogram", "LFCC", + "Loudness", "MFCC", "MVDR", "MelScale", diff --git a/torchaudio/transforms/_transforms.py b/torchaudio/transforms/_transforms.py index bb5144d713..f6a7e73b0e 100644 --- a/torchaudio/transforms/_transforms.py +++ b/torchaudio/transforms/_transforms.py @@ -1251,6 +1251,36 @@ def __init__(self, time_mask_param: int, iid_masks: bool = False, p: float = 1.0 super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks, p=p) +class Loudness(torch.nn.Module): + r"""Measure audio loudness according to the ITU-R BS.1770-4 recommendation. + + .. devices:: CPU CUDA + + .. properties:: TorchScript + + Args: + sample_rate (int): Sample rate of audio signal. + + Reference: + - https://www.itu.int/rec/R-REC-BS.1770-4-201510-I/en + """ + __constants__ = ["sample_rate"] + + def __init__(self, sample_rate: int): + super(Loudness, self).__init__() + self.sample_rate = sample_rate + + def forward(self, wavefrom: Tensor): + r""" + Args: + waveform(torch.Tensor): audio waveform of dimension `(..., channels, time)` + + Returns: + Tensor: loudness estimates (LKFS) + """ + return F.loudness(wavefrom, self.sample_rate) + + class Vol(torch.nn.Module): r"""Add a volume to an waveform.