From bc2cda3dd8b51a2c3cbc8ec043eabd634d88f4e3 Mon Sep 17 00:00:00 2001 From: bshall Date: Sun, 12 Jun 2022 18:46:08 +0200 Subject: [PATCH 01/14] initial commit and tests --- .../loudness_compliance_cpu_test.py | 8 +++ .../loudness_compliance_cuda_test.py | 9 +++ .../loudness_compliance_test_impl.py | 60 ++++++++++++++++++ .../torchscript_consistency_impl.py | 8 +++ torchaudio/functional/__init__.py | 6 ++ torchaudio/functional/functional.py | 63 +++++++++++++++++++ 6 files changed, 154 insertions(+) create mode 100644 test/torchaudio_unittest/functional/loudness_compliance_cpu_test.py create mode 100644 test/torchaudio_unittest/functional/loudness_compliance_cuda_test.py create mode 100644 test/torchaudio_unittest/functional/loudness_compliance_test_impl.py diff --git a/test/torchaudio_unittest/functional/loudness_compliance_cpu_test.py b/test/torchaudio_unittest/functional/loudness_compliance_cpu_test.py new file mode 100644 index 0000000000..f8dbaeef28 --- /dev/null +++ b/test/torchaudio_unittest/functional/loudness_compliance_cpu_test.py @@ -0,0 +1,8 @@ +import torch +from torchaudio_unittest import common_utils + +from .loudness_compliance_test_impl import Loudness + + +class TestLoudnessCPU(Loudness, common_utils.PytorchTestCase): + device = torch.device("cpu") diff --git a/test/torchaudio_unittest/functional/loudness_compliance_cuda_test.py b/test/torchaudio_unittest/functional/loudness_compliance_cuda_test.py new file mode 100644 index 0000000000..cd310e215f --- /dev/null +++ b/test/torchaudio_unittest/functional/loudness_compliance_cuda_test.py @@ -0,0 +1,9 @@ +import torch +from torchaudio_unittest import common_utils + +from .loudness_compliance_test_impl import Loudness + + +@common_utils.skipIfNoCuda +class TestLoudnessCUDA(Loudness, common_utils.PytorchTestCase): + device = torch.device("cuda") diff --git a/test/torchaudio_unittest/functional/loudness_compliance_test_impl.py b/test/torchaudio_unittest/functional/loudness_compliance_test_impl.py new file mode 100644 index 0000000000..4adddb32f2 --- /dev/null +++ b/test/torchaudio_unittest/functional/loudness_compliance_test_impl.py @@ -0,0 +1,60 @@ +"""Test suite for compliance with the ITU-R BS.1770-4 recommendation""" +import os.path +import zipfile + +import torch +import torchaudio.functional as F +from torchaudio_unittest.common_utils import load_wav, TempDirMixin, TestBaseMixin + +# Test files linked in https://www.itu.int/dms_pub/itu-r/opb/rep/R-REP-BS.2217-2-2016-PDF-E.pdf +_COMPLIANCE_FILE_URLS = { + "1770-2_Comp_RelGateTest": "http://www.itu.int/dms_pub/itu-r/oth/11/02/R11020000010030ZIPM.zip", + "1770-2_Comp_AbsGateTest": "http://www.itu.int/dms_pub/itu-r/oth/11/02/R11020000010029ZIPM.zip", + "1770-2_Comp_24LKFS_500Hz_2ch": "http://www.itu.int/dms_pub/itu-r/oth/11/02/R11020000010018ZIPM.zip", + "1770-2 Conf Mono Voice+Music-24LKFS": "http://www.itu.int/dms_pub/itu-r/oth/11/02/R11020000010038ZIPM.zip", +} + + +class Loudness(TempDirMixin, TestBaseMixin): + def download_and_extract_file(self, filename): + zippath = self.get_temp_path(filename + ".zip") + torch.hub.download_url_to_file(_COMPLIANCE_FILE_URLS[filename], zippath, progress=False) + with zipfile.ZipFile(zippath) as file: + file.extractall(os.path.dirname(zippath)) + return self.get_temp_path(filename + ".wav") + + def test_measure_loudness_relative_gate(self): + filepath = self.download_and_extract_file("1770-2_Comp_RelGateTest") + waveform, sample_rate = load_wav(filepath) + waveform = waveform.to(self.device) + + loudness = F.measure_loudness(waveform, sample_rate) + expected = torch.tensor(-10.0, dtype=loudness.dtype, device=self.device) + self.assertEqual(loudness, expected, rtol=0.01, atol=0.1) + + def test_measure_loudness_absolute_gate(self): + filepath = self.download_and_extract_file("1770-2_Comp_AbsGateTest") + waveform, sample_rate = load_wav(filepath) + waveform = waveform.to(self.device) + + loudness = F.measure_loudness(waveform, sample_rate) + expected = torch.tensor(-69.5, dtype=loudness.dtype, device=self.device) + self.assertEqual(loudness, expected, rtol=0.01, atol=0.1) + + def test_measure_loudness_two_channels(self): + filepath = filepath = self.download_and_extract_file("1770-2_Comp_24LKFS_500Hz_2ch") + waveform, sample_rate = load_wav(filepath) + waveform = waveform.to(self.device) + + loudness = F.measure_loudness(waveform, sample_rate) + expected = torch.tensor(-24.0, dtype=loudness.dtype, device=self.device) + self.assertEqual(loudness, expected, rtol=0.01, atol=0.1) + + def test_measure_loudness_mono_voice_music(self): + filepath = self.download_and_extract_file("1770-2 Conf Mono Voice+Music-24LKFS") + waveform, sample_rate = load_wav(filepath) + waveform = waveform.to(self.device) + + loudness = F.measure_loudness(waveform, sample_rate) + expected = torch.tensor(-24.0, dtype=loudness.dtype, device=self.device) + self.assertEqual(loudness, expected, rtol=0.01, atol=0.1) diff --git a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py index 99c1894725..5ac2decd3c 100644 --- a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py @@ -120,6 +120,14 @@ def func(tensor): self._assert_consistency(func, (waveform,)) + def test_measure_loudness(self): + if self.dtype == torch.float64: + raise unittest.SkipTest("This test is known to fail for float64") + + sample_rate = 44100 + waveform = common_utils.get_sinusoid(sample_rate=sample_rate, device=self.device) + self._assert_consistency(F.measure_loudness, (waveform, sample_rate)) + def test_melscale_fbanks(self): if self.device != torch.device("cpu"): raise unittest.SkipTest("No need to perform test on device other than CPU") diff --git a/torchaudio/functional/__init__.py b/torchaudio/functional/__init__.py index 06325da3fe..6d9617162d 100644 --- a/torchaudio/functional/__init__.py +++ b/torchaudio/functional/__init__.py @@ -30,11 +30,16 @@ compute_kaldi_pitch, create_dct, DB_to_amplitude, +<<<<<<< HEAD + measure_loudness, +======= +>>>>>>> 42509325... Fixed linting issues detect_pitch_frequency, edit_distance, griffinlim, inverse_spectrogram, linear_fbanks, + loudness, mask_along_axis, mask_along_axis_iid, melscale_fbanks, @@ -62,6 +67,7 @@ "melscale_fbanks", "linear_fbanks", "DB_to_amplitude", + "measure_loudness", "detect_pitch_frequency", "griffinlim", "mask_along_axis", diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index 067cbc9330..211056a34e 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -11,6 +11,8 @@ from torch import Tensor from torchaudio._internal import module_utils as _mod_utils +from .filtering import highpass_biquad, treble_biquad + __all__ = [ "spectrogram", "inverse_spectrogram", @@ -35,6 +37,7 @@ "apply_codec", "resample", "edit_distance", + "measure_loudness", "pitch_shift", "rnnt_loss", "psd", @@ -1640,6 +1643,66 @@ def edit_distance(seq1: Sequence, seq2: Sequence) -> int: return int(dold[-1]) +def measure_loudness(waveform: Tensor, sample_rate: int): + r"""Measure audio loudness according to the ITU-R BS.1770-4 recommendation. + + .. devices:: CPU CUDA + + .. properties:: TorchScript + + Args: + waveform(torch.Tensor): audio waveform of dimension of `(..., channels, time)` + sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) + + Returns: + Tensor: loudness estimates (LKFS) + + Reference: + - https://www.itu.int/rec/R-REC-BS.1770-4-201510-I/en + """ + + if waveform.size(-2) > 5: + raise ValueError("Only up to 5 channels are supported.") + + gate_duration: float = 0.4 + overlap: float = 0.75 + gamma_abs: float = -70.0 + gate_samples = int(round(gate_duration * sample_rate)) + step = int(round(gate_samples * (1 - overlap))) + + # Apply K-weighting + waveform = treble_biquad(waveform, sample_rate, 4.0, 1500.0, 1 / math.sqrt(2)) + waveform = highpass_biquad(waveform, sample_rate, 38.0, 0.5) + + # Compute the energy for each block + energy = torch.square(waveform).unfold(-1, gate_samples, step) + energy = torch.mean(energy, dim=-1) + + # Compute channel-weighted summation + g = torch.tensor([1.0, 1.0, 1.0, 1.41, 1.41], dtype=waveform.dtype, device=waveform.device) + g = g[: energy.size(-2)] + + energy_weighted = torch.sum(g.unsqueeze(-1) * energy, dim=-2) + loudness = -0.691 + 10 * torch.log10(energy_weighted) + + # Apply absolute gating of the blocks + gated_blocks = loudness > gamma_abs + gated_blocks = gated_blocks.unsqueeze(-2) + + energy_filtered = torch.sum(gated_blocks * energy, dim=-1) / torch.count_nonzero(gated_blocks, dim=-1) + energy_weighted = torch.sum(g * energy_filtered, dim=-1) + gamma_rel = -0.691 + 10 * torch.log10(energy_weighted) - 10 + + # Apply relative gating of the blocks + gated_blocks = torch.logical_and(gated_blocks.squeeze(-2), loudness > gamma_rel.unsqueeze(-1)) + gated_blocks = gated_blocks.unsqueeze(-2) + + energy_filtered = torch.sum(gated_blocks * energy, dim=-1) / torch.count_nonzero(gated_blocks, dim=-1) + energy_weighted = torch.sum(g * energy_filtered, dim=-1) + LKFS = -0.691 + 10 * torch.log10(energy_weighted) + return LKFS + + def pitch_shift( waveform: Tensor, sample_rate: int, From 05fe1623e8671dddf70db42fb622dc4eb7985932 Mon Sep 17 00:00:00 2001 From: bshall Date: Fri, 22 Jul 2022 15:26:02 +0200 Subject: [PATCH 02/14] Fixed doc string and renamed function to --- .../functional/loudness_compliance_test_impl.py | 8 ++++---- .../functional/torchscript_consistency_impl.py | 2 +- torchaudio/functional/__init__.py | 6 +----- torchaudio/functional/functional.py | 8 ++++---- 4 files changed, 10 insertions(+), 14 deletions(-) diff --git a/test/torchaudio_unittest/functional/loudness_compliance_test_impl.py b/test/torchaudio_unittest/functional/loudness_compliance_test_impl.py index 4adddb32f2..0fd27128cf 100644 --- a/test/torchaudio_unittest/functional/loudness_compliance_test_impl.py +++ b/test/torchaudio_unittest/functional/loudness_compliance_test_impl.py @@ -28,7 +28,7 @@ def test_measure_loudness_relative_gate(self): waveform, sample_rate = load_wav(filepath) waveform = waveform.to(self.device) - loudness = F.measure_loudness(waveform, sample_rate) + loudness = F.loudness(waveform, sample_rate) expected = torch.tensor(-10.0, dtype=loudness.dtype, device=self.device) self.assertEqual(loudness, expected, rtol=0.01, atol=0.1) @@ -37,7 +37,7 @@ def test_measure_loudness_absolute_gate(self): waveform, sample_rate = load_wav(filepath) waveform = waveform.to(self.device) - loudness = F.measure_loudness(waveform, sample_rate) + loudness = F.loudness(waveform, sample_rate) expected = torch.tensor(-69.5, dtype=loudness.dtype, device=self.device) self.assertEqual(loudness, expected, rtol=0.01, atol=0.1) @@ -46,7 +46,7 @@ def test_measure_loudness_two_channels(self): waveform, sample_rate = load_wav(filepath) waveform = waveform.to(self.device) - loudness = F.measure_loudness(waveform, sample_rate) + loudness = F.loudness(waveform, sample_rate) expected = torch.tensor(-24.0, dtype=loudness.dtype, device=self.device) self.assertEqual(loudness, expected, rtol=0.01, atol=0.1) @@ -55,6 +55,6 @@ def test_measure_loudness_mono_voice_music(self): waveform, sample_rate = load_wav(filepath) waveform = waveform.to(self.device) - loudness = F.measure_loudness(waveform, sample_rate) + loudness = F.loudness(waveform, sample_rate) expected = torch.tensor(-24.0, dtype=loudness.dtype, device=self.device) self.assertEqual(loudness, expected, rtol=0.01, atol=0.1) diff --git a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py index 5ac2decd3c..83d77ec01d 100644 --- a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py @@ -126,7 +126,7 @@ def test_measure_loudness(self): sample_rate = 44100 waveform = common_utils.get_sinusoid(sample_rate=sample_rate, device=self.device) - self._assert_consistency(F.measure_loudness, (waveform, sample_rate)) + self._assert_consistency(F.loudness, (waveform, sample_rate)) def test_melscale_fbanks(self): if self.device != torch.device("cpu"): diff --git a/torchaudio/functional/__init__.py b/torchaudio/functional/__init__.py index 6d9617162d..49f8c686e9 100644 --- a/torchaudio/functional/__init__.py +++ b/torchaudio/functional/__init__.py @@ -30,10 +30,6 @@ compute_kaldi_pitch, create_dct, DB_to_amplitude, -<<<<<<< HEAD - measure_loudness, -======= ->>>>>>> 42509325... Fixed linting issues detect_pitch_frequency, edit_distance, griffinlim, @@ -67,7 +63,7 @@ "melscale_fbanks", "linear_fbanks", "DB_to_amplitude", - "measure_loudness", + "loudness", "detect_pitch_frequency", "griffinlim", "mask_along_axis", diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index 211056a34e..7e7c93a581 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -37,7 +37,7 @@ "apply_codec", "resample", "edit_distance", - "measure_loudness", + "loudness", "pitch_shift", "rnnt_loss", "psd", @@ -1643,7 +1643,7 @@ def edit_distance(seq1: Sequence, seq2: Sequence) -> int: return int(dold[-1]) -def measure_loudness(waveform: Tensor, sample_rate: int): +def loudness(waveform: Tensor, sample_rate: int): r"""Measure audio loudness according to the ITU-R BS.1770-4 recommendation. .. devices:: CPU CUDA @@ -1651,8 +1651,8 @@ def measure_loudness(waveform: Tensor, sample_rate: int): .. properties:: TorchScript Args: - waveform(torch.Tensor): audio waveform of dimension of `(..., channels, time)` - sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) + waveform(torch.Tensor): audio waveform of dimension `(..., channels, time)` + sample_rate (int): sampling rate of the waveform Returns: Tensor: loudness estimates (LKFS) From 2d91fa2cf1dafd556ad7fe89fc1240ffba8a9f7c Mon Sep 17 00:00:00 2001 From: bshall Date: Fri, 22 Jul 2022 15:50:46 +0200 Subject: [PATCH 03/14] Moved loudness test out of unittest directory --- .../loudness_compliance_cpu_test.py | 0 .../loudness_compliance_cuda_test.py | 0 .../loudness_compliance_test_impl.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename test/{torchaudio_unittest/functional => integration_tests}/loudness_compliance_cpu_test.py (100%) rename test/{torchaudio_unittest/functional => integration_tests}/loudness_compliance_cuda_test.py (100%) rename test/{torchaudio_unittest/functional => integration_tests}/loudness_compliance_test_impl.py (100%) diff --git a/test/torchaudio_unittest/functional/loudness_compliance_cpu_test.py b/test/integration_tests/loudness_compliance_cpu_test.py similarity index 100% rename from test/torchaudio_unittest/functional/loudness_compliance_cpu_test.py rename to test/integration_tests/loudness_compliance_cpu_test.py diff --git a/test/torchaudio_unittest/functional/loudness_compliance_cuda_test.py b/test/integration_tests/loudness_compliance_cuda_test.py similarity index 100% rename from test/torchaudio_unittest/functional/loudness_compliance_cuda_test.py rename to test/integration_tests/loudness_compliance_cuda_test.py diff --git a/test/torchaudio_unittest/functional/loudness_compliance_test_impl.py b/test/integration_tests/loudness_compliance_test_impl.py similarity index 100% rename from test/torchaudio_unittest/functional/loudness_compliance_test_impl.py rename to test/integration_tests/loudness_compliance_test_impl.py From 4243b5a79f0f813b66126e17e6a5f74804dade2c Mon Sep 17 00:00:00 2001 From: bshall Date: Fri, 22 Jul 2022 16:00:27 +0200 Subject: [PATCH 04/14] Added Loudness transform --- torchaudio/transforms/__init__.py | 2 ++ torchaudio/transforms/_transforms.py | 26 ++++++++++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/torchaudio/transforms/__init__.py b/torchaudio/transforms/__init__.py index 527da5c7d2..2dc5bcf555 100644 --- a/torchaudio/transforms/__init__.py +++ b/torchaudio/transforms/__init__.py @@ -8,6 +8,7 @@ InverseMelScale, InverseSpectrogram, LFCC, + Loudness, MelScale, MelSpectrogram, MFCC, @@ -35,6 +36,7 @@ "InverseMelScale", "InverseSpectrogram", "LFCC", + "Loudness", "MFCC", "MVDR", "MelScale", diff --git a/torchaudio/transforms/_transforms.py b/torchaudio/transforms/_transforms.py index bb5144d713..f89c47a030 100644 --- a/torchaudio/transforms/_transforms.py +++ b/torchaudio/transforms/_transforms.py @@ -1251,6 +1251,32 @@ def __init__(self, time_mask_param: int, iid_masks: bool = False, p: float = 1.0 super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks, p=p) +class Loudness(torch.nn.Module): + r"""Measure audio loudness according to the ITU-R BS.1770-4 recommendation. + + .. devices:: CPU CUDA + + .. properties:: TorchScript + + Reference: + - https://www.itu.int/rec/R-REC-BS.1770-4-201510-I/en + """ + + def __init__(self): + super(Loudness, self).__init__() + + def forward(self, wavefrom: Tensor, sample_rate: int): + """ + Args: + waveform(torch.Tensor): audio waveform of dimension `(..., channels, time)` + sample_rate (int): sampling rate of the waveform + + Returns: + Tensor: loudness estimates (LKFS) + """ + return F.loudness(wavefrom, sample_rate) + + class Vol(torch.nn.Module): r"""Add a volume to an waveform. From a02ccbc054e38be325c2d656f17f39c003290a71 Mon Sep 17 00:00:00 2001 From: bshall Date: Fri, 22 Jul 2022 16:17:46 +0200 Subject: [PATCH 05/14] Removed float types for contants --- torchaudio/functional/functional.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index 7e7c93a581..0dded5d2dc 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -1664,9 +1664,9 @@ def loudness(waveform: Tensor, sample_rate: int): if waveform.size(-2) > 5: raise ValueError("Only up to 5 channels are supported.") - gate_duration: float = 0.4 - overlap: float = 0.75 - gamma_abs: float = -70.0 + gate_duration = 0.4 + overlap = 0.75 + gamma_abs = -70.0 gate_samples = int(round(gate_duration * sample_rate)) step = int(round(gate_samples * (1 - overlap))) From f968789308d3220324df80835be4a6880cf4c950 Mon Sep 17 00:00:00 2001 From: bshall Date: Fri, 22 Jul 2022 17:06:08 +0200 Subject: [PATCH 06/14] Moved sample_rate to __init__ in Loudness transform --- torchaudio/transforms/_transforms.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/torchaudio/transforms/_transforms.py b/torchaudio/transforms/_transforms.py index f89c47a030..89927e3b82 100644 --- a/torchaudio/transforms/_transforms.py +++ b/torchaudio/transforms/_transforms.py @@ -1258,23 +1258,27 @@ class Loudness(torch.nn.Module): .. properties:: TorchScript + Args: + sample_rate (int, optional): sampling rate of the waveform + Reference: - https://www.itu.int/rec/R-REC-BS.1770-4-201510-I/en """ + __constants__ = ["sample_rate"] - def __init__(self): + def __init__(self, sample_rate: int = 16000): super(Loudness, self).__init__() + self.sample_rate = sample_rate - def forward(self, wavefrom: Tensor, sample_rate: int): - """ + def forward(self, wavefrom: Tensor): + r""" Args: waveform(torch.Tensor): audio waveform of dimension `(..., channels, time)` - sample_rate (int): sampling rate of the waveform Returns: Tensor: loudness estimates (LKFS) """ - return F.loudness(wavefrom, sample_rate) + return F.loudness(wavefrom, self.sample_rate) class Vol(torch.nn.Module): From d8c24afcf08468f6a2101bb50b516984fb47a2e3 Mon Sep 17 00:00:00 2001 From: bshall Date: Tue, 26 Jul 2022 10:20:39 +0200 Subject: [PATCH 07/14] Fixed doc string --- torchaudio/transforms/_transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchaudio/transforms/_transforms.py b/torchaudio/transforms/_transforms.py index 89927e3b82..f6a7e73b0e 100644 --- a/torchaudio/transforms/_transforms.py +++ b/torchaudio/transforms/_transforms.py @@ -1259,14 +1259,14 @@ class Loudness(torch.nn.Module): .. properties:: TorchScript Args: - sample_rate (int, optional): sampling rate of the waveform + sample_rate (int): Sample rate of audio signal. Reference: - https://www.itu.int/rec/R-REC-BS.1770-4-201510-I/en """ __constants__ = ["sample_rate"] - def __init__(self, sample_rate: int = 16000): + def __init__(self, sample_rate: int): super(Loudness, self).__init__() self.sample_rate = sample_rate From 7e97e153b84333dfa0236bd4341cb9e71f35f710 Mon Sep 17 00:00:00 2001 From: bshall Date: Tue, 26 Jul 2022 10:29:51 +0200 Subject: [PATCH 08/14] Refactored loudness tests --- .../loudness_compliance_test_impl.py | 42 ++++++------------- 1 file changed, 12 insertions(+), 30 deletions(-) diff --git a/test/integration_tests/loudness_compliance_test_impl.py b/test/integration_tests/loudness_compliance_test_impl.py index 0fd27128cf..67b5b2f602 100644 --- a/test/integration_tests/loudness_compliance_test_impl.py +++ b/test/integration_tests/loudness_compliance_test_impl.py @@ -4,6 +4,7 @@ import torch import torchaudio.functional as F +from parameterized import parameterized from torchaudio_unittest.common_utils import load_wav, TempDirMixin, TestBaseMixin # Test files linked in https://www.itu.int/dms_pub/itu-r/opb/rep/R-REP-BS.2217-2-2016-PDF-E.pdf @@ -23,38 +24,19 @@ def download_and_extract_file(self, filename): file.extractall(os.path.dirname(zippath)) return self.get_temp_path(filename + ".wav") - def test_measure_loudness_relative_gate(self): - filepath = self.download_and_extract_file("1770-2_Comp_RelGateTest") + @parameterized.expand( + [ + ("1770-2_Comp_RelGateTest", -10.0), + ("1770-2_Comp_AbsGateTest", -69.5), + ("1770-2_Comp_24LKFS_500Hz_2ch", -24.0), + ("1770-2 Conf Mono Voice+Music-24LKFS", -24.0), + ] + ) + def test_loudness(self, filename, expected): + filepath = self.download_and_extract_file(filename) waveform, sample_rate = load_wav(filepath) waveform = waveform.to(self.device) loudness = F.loudness(waveform, sample_rate) - expected = torch.tensor(-10.0, dtype=loudness.dtype, device=self.device) - self.assertEqual(loudness, expected, rtol=0.01, atol=0.1) - - def test_measure_loudness_absolute_gate(self): - filepath = self.download_and_extract_file("1770-2_Comp_AbsGateTest") - waveform, sample_rate = load_wav(filepath) - waveform = waveform.to(self.device) - - loudness = F.loudness(waveform, sample_rate) - expected = torch.tensor(-69.5, dtype=loudness.dtype, device=self.device) - self.assertEqual(loudness, expected, rtol=0.01, atol=0.1) - - def test_measure_loudness_two_channels(self): - filepath = filepath = self.download_and_extract_file("1770-2_Comp_24LKFS_500Hz_2ch") - waveform, sample_rate = load_wav(filepath) - waveform = waveform.to(self.device) - - loudness = F.loudness(waveform, sample_rate) - expected = torch.tensor(-24.0, dtype=loudness.dtype, device=self.device) - self.assertEqual(loudness, expected, rtol=0.01, atol=0.1) - - def test_measure_loudness_mono_voice_music(self): - filepath = self.download_and_extract_file("1770-2 Conf Mono Voice+Music-24LKFS") - waveform, sample_rate = load_wav(filepath) - waveform = waveform.to(self.device) - - loudness = F.loudness(waveform, sample_rate) - expected = torch.tensor(-24.0, dtype=loudness.dtype, device=self.device) + expected = torch.tensor(expected, dtype=loudness.dtype, device=self.device) self.assertEqual(loudness, expected, rtol=0.01, atol=0.1) From 4e72bcab238db4d2b45fd92d03beb19fffd956fc Mon Sep 17 00:00:00 2001 From: bshall Date: Tue, 26 Jul 2022 10:35:40 +0200 Subject: [PATCH 09/14] Added loudness entries to functional.rst and transforms.rst --- docs/source/functional.rst | 5 +++++ docs/source/transforms.rst | 7 +++++++ 2 files changed, 12 insertions(+) diff --git a/docs/source/functional.rst b/docs/source/functional.rst index edec942731..2945f5cb2c 100644 --- a/docs/source/functional.rst +++ b/docs/source/functional.rst @@ -69,6 +69,11 @@ resample .. autofunction:: resample +loudness +-------- + +.. autofunction:: loudness + :hidden:`Filtering` ~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 4fe5acce26..e2d046fc39 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -90,6 +90,13 @@ Transforms are common audio transforms. They can be chained together using :clas .. automethod:: forward +:hidden:`Loudness` +------------- + +.. autoclass:: Loudness + + .. automethod:: forward + :hidden:`Feature Extractions` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ From 93f350fd07b0697a32d54099d3b2c9bd782fadc4 Mon Sep 17 00:00:00 2001 From: bshall Date: Tue, 26 Jul 2022 10:38:15 +0200 Subject: [PATCH 10/14] Replaced constant bias with variable --- torchaudio/functional/functional.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index 0dded5d2dc..b179979185 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -1667,6 +1667,7 @@ def loudness(waveform: Tensor, sample_rate: int): gate_duration = 0.4 overlap = 0.75 gamma_abs = -70.0 + kweight_bias = -0.691 gate_samples = int(round(gate_duration * sample_rate)) step = int(round(gate_samples * (1 - overlap))) @@ -1691,7 +1692,7 @@ def loudness(waveform: Tensor, sample_rate: int): energy_filtered = torch.sum(gated_blocks * energy, dim=-1) / torch.count_nonzero(gated_blocks, dim=-1) energy_weighted = torch.sum(g * energy_filtered, dim=-1) - gamma_rel = -0.691 + 10 * torch.log10(energy_weighted) - 10 + gamma_rel = kweight_bias + 10 * torch.log10(energy_weighted) - 10 # Apply relative gating of the blocks gated_blocks = torch.logical_and(gated_blocks.squeeze(-2), loudness > gamma_rel.unsqueeze(-1)) @@ -1699,7 +1700,7 @@ def loudness(waveform: Tensor, sample_rate: int): energy_filtered = torch.sum(gated_blocks * energy, dim=-1) / torch.count_nonzero(gated_blocks, dim=-1) energy_weighted = torch.sum(g * energy_filtered, dim=-1) - LKFS = -0.691 + 10 * torch.log10(energy_weighted) + LKFS = kweight_bias + 10 * torch.log10(energy_weighted) return LKFS From 01f979f1d8651677ff43b17fe86b4fd851e841da Mon Sep 17 00:00:00 2001 From: bshall Date: Thu, 28 Jul 2022 11:13:17 +0200 Subject: [PATCH 11/14] Added expecttest and parameterized to integration test workflow --- .github/workflows/integration-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 4eef95dddb..13e82ab73d 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -29,7 +29,7 @@ jobs: run: | python -m pip install --quiet --upgrade pip python -m pip install --quiet --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html - python -m pip install --quiet pytest requests cmake ninja deep-phonemizer sentencepiece + python -m pip install --quiet pytest requests expecttest parameterized cmake ninja deep-phonemizer sentencepiece python setup.py install env: USE_FFMPEG: true From b173b58883f2874a6fb4498e2cfb4a38256cae02 Mon Sep 17 00:00:00 2001 From: bshall Date: Thu, 28 Jul 2022 17:20:30 +0200 Subject: [PATCH 12/14] Added scipy to integration test workflow --- .github/workflows/integration-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 13e82ab73d..cb777cb4eb 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -29,7 +29,7 @@ jobs: run: | python -m pip install --quiet --upgrade pip python -m pip install --quiet --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html - python -m pip install --quiet pytest requests expecttest parameterized cmake ninja deep-phonemizer sentencepiece + python -m pip install --quiet pytest requests expecttest parameterized scipy cmake ninja deep-phonemizer sentencepiece python setup.py install env: USE_FFMPEG: true From 9c54cd7a67053457475b8fdbba66919e449e536e Mon Sep 17 00:00:00 2001 From: bshall Date: Tue, 2 Aug 2022 15:40:10 +0200 Subject: [PATCH 13/14] Removed torchaudio_unittest module from loudness test --- .github/workflows/integration-test.yml | 2 +- .../loudness_compliance_cpu_test.py | 8 ---- .../loudness_compliance_cuda_test.py | 9 ---- .../loudness_compliance_test.py | 45 +++++++++++++++++++ .../loudness_compliance_test_impl.py | 42 ----------------- 5 files changed, 46 insertions(+), 60 deletions(-) delete mode 100644 test/integration_tests/loudness_compliance_cpu_test.py delete mode 100644 test/integration_tests/loudness_compliance_cuda_test.py create mode 100644 test/integration_tests/loudness_compliance_test.py delete mode 100644 test/integration_tests/loudness_compliance_test_impl.py diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index cb777cb4eb..4eef95dddb 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -29,7 +29,7 @@ jobs: run: | python -m pip install --quiet --upgrade pip python -m pip install --quiet --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html - python -m pip install --quiet pytest requests expecttest parameterized scipy cmake ninja deep-phonemizer sentencepiece + python -m pip install --quiet pytest requests cmake ninja deep-phonemizer sentencepiece python setup.py install env: USE_FFMPEG: true diff --git a/test/integration_tests/loudness_compliance_cpu_test.py b/test/integration_tests/loudness_compliance_cpu_test.py deleted file mode 100644 index f8dbaeef28..0000000000 --- a/test/integration_tests/loudness_compliance_cpu_test.py +++ /dev/null @@ -1,8 +0,0 @@ -import torch -from torchaudio_unittest import common_utils - -from .loudness_compliance_test_impl import Loudness - - -class TestLoudnessCPU(Loudness, common_utils.PytorchTestCase): - device = torch.device("cpu") diff --git a/test/integration_tests/loudness_compliance_cuda_test.py b/test/integration_tests/loudness_compliance_cuda_test.py deleted file mode 100644 index cd310e215f..0000000000 --- a/test/integration_tests/loudness_compliance_cuda_test.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from torchaudio_unittest import common_utils - -from .loudness_compliance_test_impl import Loudness - - -@common_utils.skipIfNoCuda -class TestLoudnessCUDA(Loudness, common_utils.PytorchTestCase): - device = torch.device("cuda") diff --git a/test/integration_tests/loudness_compliance_test.py b/test/integration_tests/loudness_compliance_test.py new file mode 100644 index 0000000000..7ac8dba9e3 --- /dev/null +++ b/test/integration_tests/loudness_compliance_test.py @@ -0,0 +1,45 @@ +"""Test suite for compliance with the ITU-R BS.1770-4 recommendation""" +import zipfile + +import pytest + +import torch +import torchaudio +import torchaudio.functional as F + + +# Test files linked in https://www.itu.int/dms_pub/itu-r/opb/rep/R-REP-BS.2217-2-2016-PDF-E.pdf +@pytest.mark.parametrize( + "filename,url,expected", + [ + ( + "1770-2_Comp_RelGateTest", + "http://www.itu.int/dms_pub/itu-r/oth/11/02/R11020000010030ZIPM.zip", + -10.0, + ), + ( + "1770-2_Comp_AbsGateTest", + "http://www.itu.int/dms_pub/itu-r/oth/11/02/R11020000010029ZIPM.zip", + -69.5, + ), + ( + "1770-2_Comp_24LKFS_500Hz_2ch", + "http://www.itu.int/dms_pub/itu-r/oth/11/02/R11020000010018ZIPM.zip", + -24.0, + ), + ( + "1770-2 Conf Mono Voice+Music-24LKFS", + "http://www.itu.int/dms_pub/itu-r/oth/11/02/R11020000010038ZIPM.zip", + -24.0, + ), + ], +) +def test_loudness(tmp_path, filename, url, expected): + zippath = tmp_path / filename + torch.hub.download_url_to_file(url, zippath, progress=False) + with zipfile.ZipFile(zippath) as file: + file.extractall(zippath.parent) + + waveform, sample_rate = torchaudio.load(zippath.with_suffix(".wav")) + loudness = F.loudness(waveform, sample_rate) + assert pytest.approx(loudness.item(), rel=0.01, abs=0.1) == expected diff --git a/test/integration_tests/loudness_compliance_test_impl.py b/test/integration_tests/loudness_compliance_test_impl.py deleted file mode 100644 index 67b5b2f602..0000000000 --- a/test/integration_tests/loudness_compliance_test_impl.py +++ /dev/null @@ -1,42 +0,0 @@ -"""Test suite for compliance with the ITU-R BS.1770-4 recommendation""" -import os.path -import zipfile - -import torch -import torchaudio.functional as F -from parameterized import parameterized -from torchaudio_unittest.common_utils import load_wav, TempDirMixin, TestBaseMixin - -# Test files linked in https://www.itu.int/dms_pub/itu-r/opb/rep/R-REP-BS.2217-2-2016-PDF-E.pdf -_COMPLIANCE_FILE_URLS = { - "1770-2_Comp_RelGateTest": "http://www.itu.int/dms_pub/itu-r/oth/11/02/R11020000010030ZIPM.zip", - "1770-2_Comp_AbsGateTest": "http://www.itu.int/dms_pub/itu-r/oth/11/02/R11020000010029ZIPM.zip", - "1770-2_Comp_24LKFS_500Hz_2ch": "http://www.itu.int/dms_pub/itu-r/oth/11/02/R11020000010018ZIPM.zip", - "1770-2 Conf Mono Voice+Music-24LKFS": "http://www.itu.int/dms_pub/itu-r/oth/11/02/R11020000010038ZIPM.zip", -} - - -class Loudness(TempDirMixin, TestBaseMixin): - def download_and_extract_file(self, filename): - zippath = self.get_temp_path(filename + ".zip") - torch.hub.download_url_to_file(_COMPLIANCE_FILE_URLS[filename], zippath, progress=False) - with zipfile.ZipFile(zippath) as file: - file.extractall(os.path.dirname(zippath)) - return self.get_temp_path(filename + ".wav") - - @parameterized.expand( - [ - ("1770-2_Comp_RelGateTest", -10.0), - ("1770-2_Comp_AbsGateTest", -69.5), - ("1770-2_Comp_24LKFS_500Hz_2ch", -24.0), - ("1770-2 Conf Mono Voice+Music-24LKFS", -24.0), - ] - ) - def test_loudness(self, filename, expected): - filepath = self.download_and_extract_file(filename) - waveform, sample_rate = load_wav(filepath) - waveform = waveform.to(self.device) - - loudness = F.loudness(waveform, sample_rate) - expected = torch.tensor(expected, dtype=loudness.dtype, device=self.device) - self.assertEqual(loudness, expected, rtol=0.01, atol=0.1) From 2dbeab0b0eb850d9d4503d2b5df43829d0b82e30 Mon Sep 17 00:00:00 2001 From: Caroline Chen Date: Wed, 3 Aug 2022 11:44:27 -0400 Subject: [PATCH 14/14] nits --- docs/source/transforms.rst | 2 +- test/integration_tests/loudness_compliance_test.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index e2d046fc39..9077616805 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -91,7 +91,7 @@ Transforms are common audio transforms. They can be chained together using :clas .. automethod:: forward :hidden:`Loudness` -------------- +----------------- .. autoclass:: Loudness diff --git a/test/integration_tests/loudness_compliance_test.py b/test/integration_tests/loudness_compliance_test.py index 7ac8dba9e3..d9473cfa50 100644 --- a/test/integration_tests/loudness_compliance_test.py +++ b/test/integration_tests/loudness_compliance_test.py @@ -42,4 +42,5 @@ def test_loudness(tmp_path, filename, url, expected): waveform, sample_rate = torchaudio.load(zippath.with_suffix(".wav")) loudness = F.loudness(waveform, sample_rate) - assert pytest.approx(loudness.item(), rel=0.01, abs=0.1) == expected + expected = torch.tensor(expected, dtype=loudness.dtype, device=loudness.device) + assert torch.allclose(loudness, expected, rtol=0.01, atol=0.1)