Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
36def9c
Make `amplitude_to_DB` clamp items separately
jcaw Dec 20, 2020
10b3301
Don't pack batches in MFCC transform
jcaw Dec 20, 2020
d3ad3cf
Remove range->tuple conversion for torchscript
jcaw Dec 21, 2020
6e43af7
Split basic amplitude_to_DB test
jcaw Jan 6, 2021
f668caf
Test `amplitude_to_DB` without channel dim
jcaw Jan 6, 2021
eba14b0
Always clamp amplitude as a batch
jcaw Jan 6, 2021
18d57bd
Remove `_make_spectrogram` helper method
jcaw Jan 6, 2021
381ee07
Redefine test constants inside each test
jcaw Jan 6, 2021
156f183
Use leftmost dim for test name
jcaw Jan 6, 2021
5275442
Split `test_top_db` into separate tests
jcaw Jan 6, 2021
670fe97
Capitalise test constants
jcaw Jan 6, 2021
fd33b46
Inline channels variable
jcaw Jan 6, 2021
44c2259
Also specify minimum value for tests
jcaw Jan 6, 2021
c9c3a93
Reword comment
jcaw Jan 6, 2021
3bff1fe
Replace `assert_allclose` with `self.assertEqual`
jcaw Jan 6, 2021
d5070f9
Remove unused power constants
jcaw Jan 7, 2021
f68eb8c
Fix indentation
jcaw Jan 7, 2021
50ab54b
Update torchaudio/functional/functional.py
vincentqb Jan 7, 2021
2f07717
Update torchaudio/functional/functional.py
vincentqb Jan 7, 2021
0c0976c
Change upcase local variables to lowercase
jcaw Jan 7, 2021
2d7746d
Set seed when `rand` is called in tests
jcaw Jan 7, 2021
ea02627
Make decibel limit test fail more informatively
jcaw Jan 7, 2021
258b8c7
More descriptive test names, plus overt docstrings
jcaw Jan 7, 2021
a363968
Reference original MFCC clamping issue in test doc
jcaw Jan 7, 2021
da3b808
Move docstring to correct function
jcaw Jan 7, 2021
f8e1943
Pass correct number of item dimensions for a batch
jcaw Jan 20, 2021
ba7e457
Move dim tests to `test_batch_consistency.py`
jcaw Jan 20, 2021
6b8d6b7
Add generic batch test for `amplitude_to_DB`
jcaw Jan 20, 2021
dda8744
Expand test docstring
jcaw Jan 20, 2021
8ba7992
Parameterize `amplitude_to_DB` reversibility tests
jcaw Feb 1, 2021
62efc26
Parameterize `amplitude_to_DB` `top_db` tests
jcaw Feb 1, 2021
f4c7552
Check `top_db` doesn't over-clamp
jcaw Feb 1, 2021
5992235
Clearer test docstring
jcaw Feb 1, 2021
6a14029
Clearer description of scaling operation
jcaw Feb 1, 2021
27d23a9
Correct description of `top_db` behaviour
jcaw Feb 1, 2021
1c75e86
Merge branch 'master' of https://github.com/pytorch/audio into amplit…
jcaw Feb 2, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 75 additions & 1 deletion test/torchaudio_unittest/batch_consistency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import itertools
from parameterized import parameterized

import math

import torch
import torchaudio
import torchaudio.functional as F
Expand Down Expand Up @@ -59,6 +61,78 @@ def test_detect_pitch_frequency(self, frequency, sample_rate, n_channels):
n_channels=n_channels, duration=5)
self.assert_batch_consistencies(F.detect_pitch_frequency, waveform, sample_rate)

def test_amplitude_to_DB(self):
torch.manual_seed(0)
spec = torch.rand(2, 100, 100) * 200

amplitude_mult = 20.
amin = 1e-10
ref = 1.0
db_mult = math.log10(max(amin, ref))

# Test with & without a `top_db` clamp
self.assert_batch_consistencies(F.amplitude_to_DB, spec, amplitude_mult,
amin, db_mult, top_db=None)
self.assert_batch_consistencies(F.amplitude_to_DB, spec, amplitude_mult,
amin, db_mult, top_db=40.)

def test_amplitude_to_DB_itemwise_clamps(self):
"""Ensure that the clamps are separate for each spectrogram in a batch.

The clamp was determined per-batch in a prior implementation, which
meant it was determined by the loudest item, thus items weren't
independent. See:

https://github.com/pytorch/audio/issues/994

"""
amplitude_mult = 20.
amin = 1e-10
ref = 1.0
db_mult = math.log10(max(amin, ref))
top_db = 20.

# Make a batch of noise
torch.manual_seed(0)
spec = torch.rand([2, 2, 100, 100]) * 200
# Make one item blow out the other
spec[0] += 50

batchwise_dbs = F.amplitude_to_DB(spec, amplitude_mult, amin,
db_mult, top_db=top_db)
itemwise_dbs = torch.stack([
F.amplitude_to_DB(item, amplitude_mult, amin,
db_mult, top_db=top_db)
for item in spec
])

self.assertEqual(batchwise_dbs, itemwise_dbs)

def test_amplitude_to_DB_not_channelwise_clamps(self):
"""Check that clamps are applied per-item, not per channel."""
amplitude_mult = 20.
amin = 1e-10
ref = 1.0
db_mult = math.log10(max(amin, ref))
top_db = 40.

torch.manual_seed(0)
spec = torch.rand([1, 2, 100, 100]) * 200
# Make one channel blow out the other
spec[:, 0] += 50

specwise_dbs = F.amplitude_to_DB(spec, amplitude_mult, amin,
db_mult, top_db=top_db)
channelwise_dbs = torch.stack([
F.amplitude_to_DB(spec[:, i], amplitude_mult, amin,
db_mult, top_db=top_db)
for i in range(spec.size(-3))
])

# Just check channelwise gives a different answer.
difference = (specwise_dbs - channelwise_dbs).abs()
assert (difference >= 1e-5).any()

def test_contrast(self):
waveform = torch.rand(2, 100) - 0.5
self.assert_batch_consistencies(F.contrast, waveform, enhancement_amount=80.)
Expand Down Expand Up @@ -103,7 +177,7 @@ class TestTransforms(common_utils.TorchaudioTestCase):

"""Test suite for classes defined in `transforms` module"""
def test_batch_AmplitudeToDB(self):
spec = torch.rand((6, 201))
spec = torch.rand((2, 6, 201))

# Single then transform then batch
expected = torchaudio.transforms.AmplitudeToDB()(spec).repeat(3, 1, 1)
Expand Down
94 changes: 63 additions & 31 deletions test/torchaudio_unittest/functional/functional_cpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,46 +83,78 @@ def test_pitch(self, frequency):
self.assertFalse(s)


class TestDB_to_amplitude(common_utils.TorchaudioTestCase):
def test_DB_to_amplitude(self):
# Make some noise
x = torch.rand(1000)
spectrogram = torchaudio.transforms.Spectrogram()
spec = spectrogram(x)

class Testamplitude_to_DB(common_utils.TorchaudioTestCase):
@parameterized.expand([
([100, 100],),
([2, 100, 100],),
([2, 2, 100, 100],),
])
def test_reversible(self, shape):
"""Round trip between amplitude and db should return the original for various shape

This implicitly also tests `DB_to_amplitude`.

"""
amplitude_mult = 20.
power_mult = 10.
amin = 1e-10
ref = 1.0
db_multiplier = math.log10(max(amin, ref))

# Waveform amplitude -> DB -> amplitude
multiplier = 20.
power = 0.5

db = F.amplitude_to_DB(torch.abs(x), multiplier, amin, db_multiplier, top_db=None)
x2 = F.DB_to_amplitude(db, ref, power)
db_mult = math.log10(max(amin, ref))

self.assertEqual(x2, torch.abs(x), atol=5e-5, rtol=1e-5)
torch.manual_seed(0)
spec = torch.rand(*shape) * 200

# Spectrogram amplitude -> DB -> amplitude
db = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db=None)
x2 = F.DB_to_amplitude(db, ref, power)
db = F.amplitude_to_DB(spec, amplitude_mult, amin, db_mult, top_db=None)
x2 = F.DB_to_amplitude(db, ref, 0.5)

self.assertEqual(x2, spec, atol=5e-5, rtol=1e-5)

# Waveform power -> DB -> power
multiplier = 10.
power = 1.

db = F.amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None)
x2 = F.DB_to_amplitude(db, ref, power)

self.assertEqual(x2, torch.abs(x), atol=5e-5, rtol=1e-5)

# Spectrogram power -> DB -> power
db = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db=None)
x2 = F.DB_to_amplitude(db, ref, power)

self.assertEqual(x2, spec, atol=5e-5, rtol=1e-5)
db = F.amplitude_to_DB(spec, power_mult, amin, db_mult, top_db=None)
x2 = F.DB_to_amplitude(db, ref, 1.)

self.assertEqual(x2, spec)

@parameterized.expand([
([100, 100],),
([2, 100, 100],),
([2, 2, 100, 100],),
])
def test_top_db_clamp(self, shape):
"""Ensure values are properly clamped when `top_db` is supplied."""
amplitude_mult = 20.
amin = 1e-10
ref = 1.0
db_mult = math.log10(max(amin, ref))
top_db = 40.

torch.manual_seed(0)
# A random tensor is used for increased entropy, but the max and min for
# each spectrogram still need to be predictable. The max determines the
# decibel cutoff, and the distance from the min must be large enough
# that it triggers a clamp.
spec = torch.rand(*shape)
# Ensure each spectrogram has a min of 0 and a max of 1.
spec -= spec.amin([-2, -1])[..., None, None]
spec /= spec.amax([-2, -1])[..., None, None]
# Expand the range to (0, 200) - wide enough to properly test clamping.
spec *= 200

decibels = F.amplitude_to_DB(spec, amplitude_mult, amin,
db_mult, top_db=top_db)
# Ensure the clamp was applied
below_limit = decibels < 6.0205
assert not below_limit.any(), (
"{} decibel values were below the expected cutoff:\n{}".format(
below_limit.sum().item(), decibels
)
)
# Ensure it didn't over-clamp
close_to_limit = decibels < 6.0207
assert close_to_limit.any(), (
f"No values were close to the limit. Did it over-clamp?\n{decibels}"
)


class TestComplexNorm(common_utils.TorchaudioTestCase):
Expand Down
22 changes: 16 additions & 6 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,14 +237,16 @@ def amplitude_to_DB(
db_multiplier: float,
top_db: Optional[float] = None
) -> Tensor:
r"""Turn a tensor from the power/amplitude scale to the decibel scale.
r"""Turn a spectrogram from the power/amplitude scale to the decibel scale.

This output depends on the maximum value in the input tensor, and so
may return different values for an audio clip split into snippets vs. a
full clip.
The output of each tensor in a batch depends on the maximum value of that tensor,
and so may return different values for an audio clip split into snippets vs. a full clip.

Args:
x (Tensor): Input tensor before being converted to decibel scale

x (Tensor): Input spectrogram(s) before being converted to decibel scale. Input should take
the form `(..., freq, time)`. Batched inputs should include a channel dimension and
have the form `(batch, channel, freq, time)`.
multiplier (float): Use 10. for power and 20. for amplitude
amin (float): Number to clamp ``x``
db_multiplier (float): Log10(max(reference value and amin))
Expand All @@ -258,7 +260,15 @@ def amplitude_to_DB(
x_db -= multiplier * db_multiplier

if top_db is not None:
x_db = x_db.clamp(min=x_db.max().item() - top_db)
# Expand batch
shape = x_db.size()
packed_channels = shape[-3] if x_db.dim() > 2 else 1
x_db = x_db.reshape(-1, packed_channels, shape[-2], shape[-1])

x_db = torch.max(x_db, (x_db.amax(dim=(-3, -2, -1)) - top_db).view(-1, 1, 1, 1))

# Repack batch
x_db = x_db.reshape(shape)

return x_db

Expand Down
14 changes: 3 additions & 11 deletions torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,24 +530,16 @@ def forward(self, waveform: Tensor) -> Tensor:
Returns:
Tensor: specgram_mel_db of size (..., ``n_mfcc``, time).
"""

# pack batch
shape = waveform.size()
waveform = waveform.reshape(-1, shape[-1])

mel_specgram = self.MelSpectrogram(waveform)
if self.log_mels:
log_offset = 1e-6
mel_specgram = torch.log(mel_specgram + log_offset)
else:
mel_specgram = self.amplitude_to_DB(mel_specgram)
# (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
# -> (channel, time, n_mfcc).tranpose(...)
mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)

# unpack batch
mfcc = mfcc.reshape(shape[:-1] + mfcc.shape[-2:])

# (..., channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
# -> (..., channel, time, n_mfcc).tranpose(...)
mfcc = torch.matmul(mel_specgram.transpose(-2, -1), self.dct_mat).transpose(-2, -1)
return mfcc


Expand Down