Skip to content

Commit d131232

Browse files
committed
initial commit and tests
1 parent 4d2fa19 commit d131232

File tree

7 files changed

+154
-0
lines changed

7 files changed

+154
-0
lines changed

test/torchaudio_unittest/functional/functional_impl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from scipy import signal
1111
from torchaudio_unittest.common_utils import (
1212
beamform_utils,
13+
get_asset_path,
14+
load_wav,
1315
get_sinusoid,
1416
get_whitenoise,
1517
nested_params,
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import torch
2+
from torchaudio_unittest import common_utils
3+
4+
from .loudness_compliance_test_impl import Loudness
5+
6+
7+
class TestLoudnessCPU(Loudness, common_utils.PytorchTestCase):
8+
device = torch.device("cpu")
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import torch
2+
from torchaudio_unittest import common_utils
3+
4+
from .loudness_compliance_test_impl import Loudness
5+
6+
7+
@common_utils.skipIfNoCuda
8+
class TestLoudnessCUDA(Loudness, common_utils.PytorchTestCase):
9+
device = torch.device("cuda")
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""Test suite for compliance with the ITU-R BS.1770-4 recommendation"""
2+
import torch
3+
import os.path
4+
import zipfile
5+
import torchaudio.functional as F
6+
from torchaudio_unittest.common_utils import (
7+
load_wav,
8+
TempDirMixin,
9+
TestBaseMixin,
10+
)
11+
12+
# Test files linked in https://www.itu.int/dms_pub/itu-r/opb/rep/R-REP-BS.2217-2-2016-PDF-E.pdf
13+
_COMPLIANCE_FILE_URLS = {
14+
"1770-2_Comp_RelGateTest": "http://www.itu.int/dms_pub/itu-r/oth/11/02/R11020000010030ZIPM.zip",
15+
"1770-2_Comp_AbsGateTest": "http://www.itu.int/dms_pub/itu-r/oth/11/02/R11020000010029ZIPM.zip",
16+
"1770-2_Comp_24LKFS_500Hz_2ch": "http://www.itu.int/dms_pub/itu-r/oth/11/02/R11020000010018ZIPM.zip",
17+
"1770-2 Conf Mono Voice+Music-24LKFS": "http://www.itu.int/dms_pub/itu-r/oth/11/02/R11020000010038ZIPM.zip",
18+
}
19+
20+
21+
class Loudness(TempDirMixin, TestBaseMixin):
22+
def download_and_extract_file(self, filename):
23+
zippath = self.get_temp_path(filename + ".zip")
24+
torch.hub.download_url_to_file(_COMPLIANCE_FILE_URLS[filename], zippath, progress=False)
25+
with zipfile.ZipFile(zippath) as file:
26+
file.extractall(os.path.dirname(zippath))
27+
return self.get_temp_path(filename + ".wav")
28+
29+
def test_measure_loudness_relative_gate(self):
30+
filepath = self.download_and_extract_file("1770-2_Comp_RelGateTest")
31+
waveform, sample_rate = load_wav(filepath)
32+
waveform = waveform.to(self.device)
33+
34+
loudness = F.measure_loudness(waveform, sample_rate)
35+
expected = torch.tensor(-10.0, dtype=loudness.dtype, device=self.device)
36+
self.assertEqual(loudness, expected, rtol=0.01, atol=0.1)
37+
38+
def test_measure_loudness_absolute_gate(self):
39+
filepath = self.download_and_extract_file("1770-2_Comp_AbsGateTest")
40+
waveform, sample_rate = load_wav(filepath)
41+
waveform = waveform.to(self.device)
42+
43+
loudness = F.measure_loudness(waveform, sample_rate)
44+
expected = torch.tensor(-69.5, dtype=loudness.dtype, device=self.device)
45+
self.assertEqual(loudness, expected, rtol=0.01, atol=0.1)
46+
47+
def test_measure_loudness_two_channels(self):
48+
filepath = filepath = self.download_and_extract_file("1770-2_Comp_24LKFS_500Hz_2ch")
49+
waveform, sample_rate = load_wav(filepath)
50+
waveform = waveform.to(self.device)
51+
52+
loudness = F.measure_loudness(waveform, sample_rate)
53+
expected = torch.tensor(-24.0, dtype=loudness.dtype, device=self.device)
54+
self.assertEqual(loudness, expected, rtol=0.01, atol=0.1)
55+
56+
def test_measure_loudness_mono_voice_music(self):
57+
filepath = self.download_and_extract_file("1770-2 Conf Mono Voice+Music-24LKFS")
58+
waveform, sample_rate = load_wav(filepath)
59+
waveform = waveform.to(self.device)
60+
61+
loudness = F.measure_loudness(waveform, sample_rate)
62+
expected = torch.tensor(-24.0, dtype=loudness.dtype, device=self.device)
63+
self.assertEqual(loudness, expected, rtol=0.01, atol=0.1)

test/torchaudio_unittest/functional/torchscript_consistency_impl.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,14 @@ def func(tensor):
111111

112112
self._assert_consistency(func, (waveform,))
113113

114+
def test_measure_loudness(self):
115+
if self.dtype == torch.float64:
116+
raise unittest.SkipTest("This test is known to fail for float64")
117+
118+
sample_rate = 44100
119+
waveform = common_utils.get_sinusoid(sample_rate=sample_rate, device=self.device)
120+
self._assert_consistency(F.measure_loudness, (waveform, sample_rate))
121+
114122
def test_melscale_fbanks(self):
115123
if self.device != torch.device("cpu"):
116124
raise unittest.SkipTest("No need to perform test on device other than CPU")

torchaudio/functional/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
compute_kaldi_pitch,
3131
create_dct,
3232
DB_to_amplitude,
33+
measure_loudness,
3334
detect_pitch_frequency,
3435
edit_distance,
3536
griffinlim,
@@ -62,6 +63,7 @@
6263
"melscale_fbanks",
6364
"linear_fbanks",
6465
"DB_to_amplitude",
66+
"measure_loudness",
6567
"detect_pitch_frequency",
6668
"griffinlim",
6769
"mask_along_axis",

torchaudio/functional/functional.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torchaudio
1111
from torch import Tensor
1212
from torchaudio._internal import module_utils as _mod_utils
13+
from .filtering import highpass_biquad, treble_biquad
1314

1415
__all__ = [
1516
"spectrogram",
@@ -35,6 +36,7 @@
3536
"apply_codec",
3637
"resample",
3738
"edit_distance",
39+
"measure_loudness",
3840
"pitch_shift",
3941
"rnnt_loss",
4042
"psd",
@@ -1602,6 +1604,66 @@ def edit_distance(seq1: Sequence, seq2: Sequence) -> int:
16021604
return int(dold[-1])
16031605

16041606

1607+
def measure_loudness(waveform: Tensor, sample_rate: int):
1608+
r"""Measure audio loudness according to the ITU-R BS.1770-4 recommendation.
1609+
1610+
.. devices:: CPU CUDA
1611+
1612+
.. properties:: TorchScript
1613+
1614+
Args:
1615+
waveform(torch.Tensor): audio waveform of dimension of `(..., channels, time)`
1616+
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
1617+
1618+
Returns:
1619+
Tensor: loudness estimates (LKFS)
1620+
1621+
Reference:
1622+
- https://www.itu.int/rec/R-REC-BS.1770-4-201510-I/en
1623+
"""
1624+
1625+
if waveform.size(-2) > 5:
1626+
raise ValueError("Only up to 5 channels are supported.")
1627+
1628+
gate_duration: float = 0.4
1629+
overlap: float = 0.75
1630+
gamma_abs: float = -70.0
1631+
gate_samples = int(round(gate_duration * sample_rate))
1632+
step = int(round(gate_samples * (1 - overlap)))
1633+
1634+
# Apply K-weighting
1635+
waveform = treble_biquad(waveform, sample_rate, 4.0, 1500.0, 1 / math.sqrt(2))
1636+
waveform = highpass_biquad(waveform, sample_rate, 38.0, 0.5)
1637+
1638+
# Compute the energy for each block
1639+
energy = torch.square(waveform).unfold(-1, gate_samples, step)
1640+
energy = torch.mean(energy, dim=-1)
1641+
1642+
# Compute channel-weighted summation
1643+
g = torch.tensor([1.0, 1.0, 1.0, 1.41, 1.41], dtype=waveform.dtype, device=waveform.device)
1644+
g = g[: energy.size(-2)]
1645+
1646+
energy_weighted = torch.sum(g.unsqueeze(-1) * energy, dim=-2)
1647+
loudness = -0.691 + 10 * torch.log10(energy_weighted)
1648+
1649+
# Apply absolute gating of the blocks
1650+
gated_blocks = loudness > gamma_abs
1651+
gated_blocks = gated_blocks.unsqueeze(-2)
1652+
1653+
energy_filtered = torch.sum(gated_blocks * energy, dim=-1) / torch.count_nonzero(gated_blocks, dim=-1)
1654+
energy_weighted = torch.sum(g * energy_filtered, dim=-1)
1655+
gamma_rel = -0.691 + 10 * torch.log10(energy_weighted) - 10
1656+
1657+
# Apply relative gating of the blocks
1658+
gated_blocks = torch.logical_and(gated_blocks.squeeze(-2), loudness > gamma_rel.unsqueeze(-1))
1659+
gated_blocks = gated_blocks.unsqueeze(-2)
1660+
1661+
energy_filtered = torch.sum(gated_blocks * energy, dim=-1) / torch.count_nonzero(gated_blocks, dim=-1)
1662+
energy_weighted = torch.sum(g * energy_filtered, dim=-1)
1663+
LKFS = -0.691 + 10 * torch.log10(energy_weighted)
1664+
return LKFS
1665+
1666+
16051667
def pitch_shift(
16061668
waveform: Tensor,
16071669
sample_rate: int,

0 commit comments

Comments
 (0)