Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ resample

.. autofunction:: resample

loudness
--------

.. autofunction:: loudness


:hidden:`Filtering`
~~~~~~~~~~~~~~~~~~~
Expand Down
7 changes: 7 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,13 @@ Transforms are common audio transforms. They can be chained together using :clas

.. automethod:: forward

:hidden:`Loudness`
-----------------

.. autoclass:: Loudness

.. automethod:: forward

:hidden:`Feature Extractions`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
46 changes: 46 additions & 0 deletions test/integration_tests/loudness_compliance_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Test suite for compliance with the ITU-R BS.1770-4 recommendation"""
import zipfile

import pytest

import torch
import torchaudio
import torchaudio.functional as F


# Test files linked in https://www.itu.int/dms_pub/itu-r/opb/rep/R-REP-BS.2217-2-2016-PDF-E.pdf
@pytest.mark.parametrize(
"filename,url,expected",
[
(
"1770-2_Comp_RelGateTest",
"http://www.itu.int/dms_pub/itu-r/oth/11/02/R11020000010030ZIPM.zip",
-10.0,
),
(
"1770-2_Comp_AbsGateTest",
"http://www.itu.int/dms_pub/itu-r/oth/11/02/R11020000010029ZIPM.zip",
-69.5,
),
(
"1770-2_Comp_24LKFS_500Hz_2ch",
"http://www.itu.int/dms_pub/itu-r/oth/11/02/R11020000010018ZIPM.zip",
-24.0,
),
(
"1770-2 Conf Mono Voice+Music-24LKFS",
"http://www.itu.int/dms_pub/itu-r/oth/11/02/R11020000010038ZIPM.zip",
-24.0,
),
],
)
def test_loudness(tmp_path, filename, url, expected):
zippath = tmp_path / filename
torch.hub.download_url_to_file(url, zippath, progress=False)
with zipfile.ZipFile(zippath) as file:
file.extractall(zippath.parent)

waveform, sample_rate = torchaudio.load(zippath.with_suffix(".wav"))
loudness = F.loudness(waveform, sample_rate)
expected = torch.tensor(expected, dtype=loudness.dtype, device=loudness.device)
assert torch.allclose(loudness, expected, rtol=0.01, atol=0.1)
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,14 @@ def func(tensor):

self._assert_consistency(func, (waveform,))

def test_measure_loudness(self):
if self.dtype == torch.float64:
raise unittest.SkipTest("This test is known to fail for float64")

sample_rate = 44100
waveform = common_utils.get_sinusoid(sample_rate=sample_rate, device=self.device)
self._assert_consistency(F.loudness, (waveform, sample_rate))

def test_melscale_fbanks(self):
if self.device != torch.device("cpu"):
raise unittest.SkipTest("No need to perform test on device other than CPU")
Expand Down
2 changes: 2 additions & 0 deletions torchaudio/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
griffinlim,
inverse_spectrogram,
linear_fbanks,
loudness,
mask_along_axis,
mask_along_axis_iid,
melscale_fbanks,
Expand Down Expand Up @@ -62,6 +63,7 @@
"melscale_fbanks",
"linear_fbanks",
"DB_to_amplitude",
"loudness",
"detect_pitch_frequency",
"griffinlim",
"mask_along_axis",
Expand Down
64 changes: 64 additions & 0 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from torch import Tensor
from torchaudio._internal import module_utils as _mod_utils

from .filtering import highpass_biquad, treble_biquad

__all__ = [
"spectrogram",
"inverse_spectrogram",
Expand All @@ -35,6 +37,7 @@
"apply_codec",
"resample",
"edit_distance",
"loudness",
"pitch_shift",
"rnnt_loss",
"psd",
Expand Down Expand Up @@ -1640,6 +1643,67 @@ def edit_distance(seq1: Sequence, seq2: Sequence) -> int:
return int(dold[-1])


def loudness(waveform: Tensor, sample_rate: int):
r"""Measure audio loudness according to the ITU-R BS.1770-4 recommendation.

.. devices:: CPU CUDA

.. properties:: TorchScript

Args:
waveform(torch.Tensor): audio waveform of dimension `(..., channels, time)`
sample_rate (int): sampling rate of the waveform

Returns:
Tensor: loudness estimates (LKFS)

Reference:
- https://www.itu.int/rec/R-REC-BS.1770-4-201510-I/en
"""

if waveform.size(-2) > 5:
raise ValueError("Only up to 5 channels are supported.")

gate_duration = 0.4
overlap = 0.75
gamma_abs = -70.0
kweight_bias = -0.691
gate_samples = int(round(gate_duration * sample_rate))
step = int(round(gate_samples * (1 - overlap)))

# Apply K-weighting
waveform = treble_biquad(waveform, sample_rate, 4.0, 1500.0, 1 / math.sqrt(2))
waveform = highpass_biquad(waveform, sample_rate, 38.0, 0.5)

# Compute the energy for each block
energy = torch.square(waveform).unfold(-1, gate_samples, step)
energy = torch.mean(energy, dim=-1)

# Compute channel-weighted summation
g = torch.tensor([1.0, 1.0, 1.0, 1.41, 1.41], dtype=waveform.dtype, device=waveform.device)
g = g[: energy.size(-2)]

energy_weighted = torch.sum(g.unsqueeze(-1) * energy, dim=-2)
loudness = -0.691 + 10 * torch.log10(energy_weighted)

# Apply absolute gating of the blocks
gated_blocks = loudness > gamma_abs
gated_blocks = gated_blocks.unsqueeze(-2)

energy_filtered = torch.sum(gated_blocks * energy, dim=-1) / torch.count_nonzero(gated_blocks, dim=-1)
energy_weighted = torch.sum(g * energy_filtered, dim=-1)
gamma_rel = kweight_bias + 10 * torch.log10(energy_weighted) - 10

# Apply relative gating of the blocks
gated_blocks = torch.logical_and(gated_blocks.squeeze(-2), loudness > gamma_rel.unsqueeze(-1))
gated_blocks = gated_blocks.unsqueeze(-2)

energy_filtered = torch.sum(gated_blocks * energy, dim=-1) / torch.count_nonzero(gated_blocks, dim=-1)
energy_weighted = torch.sum(g * energy_filtered, dim=-1)
LKFS = kweight_bias + 10 * torch.log10(energy_weighted)
return LKFS


def pitch_shift(
waveform: Tensor,
sample_rate: int,
Expand Down
2 changes: 2 additions & 0 deletions torchaudio/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
InverseMelScale,
InverseSpectrogram,
LFCC,
Loudness,
MelScale,
MelSpectrogram,
MFCC,
Expand Down Expand Up @@ -35,6 +36,7 @@
"InverseMelScale",
"InverseSpectrogram",
"LFCC",
"Loudness",
"MFCC",
"MVDR",
"MelScale",
Expand Down
30 changes: 30 additions & 0 deletions torchaudio/transforms/_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1251,6 +1251,36 @@ def __init__(self, time_mask_param: int, iid_masks: bool = False, p: float = 1.0
super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks, p=p)


class Loudness(torch.nn.Module):
r"""Measure audio loudness according to the ITU-R BS.1770-4 recommendation.

.. devices:: CPU CUDA

.. properties:: TorchScript

Args:
sample_rate (int): Sample rate of audio signal.

Reference:
- https://www.itu.int/rec/R-REC-BS.1770-4-201510-I/en
"""
__constants__ = ["sample_rate"]

def __init__(self, sample_rate: int):
super(Loudness, self).__init__()
self.sample_rate = sample_rate

def forward(self, wavefrom: Tensor):
r"""
Args:
waveform(torch.Tensor): audio waveform of dimension `(..., channels, time)`

Returns:
Tensor: loudness estimates (LKFS)
"""
return F.loudness(wavefrom, self.sample_rate)


class Vol(torch.nn.Module):
r"""Add a volume to an waveform.

Expand Down