diff --git a/test/torchaudio_unittest/batch_consistency_test.py b/test/torchaudio_unittest/batch_consistency_test.py index 78e3cd62f9..851fc9c86c 100644 --- a/test/torchaudio_unittest/batch_consistency_test.py +++ b/test/torchaudio_unittest/batch_consistency_test.py @@ -3,6 +3,8 @@ import itertools from parameterized import parameterized +import math + import torch import torchaudio import torchaudio.functional as F @@ -59,6 +61,78 @@ def test_detect_pitch_frequency(self, frequency, sample_rate, n_channels): n_channels=n_channels, duration=5) self.assert_batch_consistencies(F.detect_pitch_frequency, waveform, sample_rate) + def test_amplitude_to_DB(self): + torch.manual_seed(0) + spec = torch.rand(2, 100, 100) * 200 + + amplitude_mult = 20. + amin = 1e-10 + ref = 1.0 + db_mult = math.log10(max(amin, ref)) + + # Test with & without a `top_db` clamp + self.assert_batch_consistencies(F.amplitude_to_DB, spec, amplitude_mult, + amin, db_mult, top_db=None) + self.assert_batch_consistencies(F.amplitude_to_DB, spec, amplitude_mult, + amin, db_mult, top_db=40.) + + def test_amplitude_to_DB_itemwise_clamps(self): + """Ensure that the clamps are separate for each spectrogram in a batch. + + The clamp was determined per-batch in a prior implementation, which + meant it was determined by the loudest item, thus items weren't + independent. See: + + https://github.com/pytorch/audio/issues/994 + + """ + amplitude_mult = 20. + amin = 1e-10 + ref = 1.0 + db_mult = math.log10(max(amin, ref)) + top_db = 20. + + # Make a batch of noise + torch.manual_seed(0) + spec = torch.rand([2, 2, 100, 100]) * 200 + # Make one item blow out the other + spec[0] += 50 + + batchwise_dbs = F.amplitude_to_DB(spec, amplitude_mult, amin, + db_mult, top_db=top_db) + itemwise_dbs = torch.stack([ + F.amplitude_to_DB(item, amplitude_mult, amin, + db_mult, top_db=top_db) + for item in spec + ]) + + self.assertEqual(batchwise_dbs, itemwise_dbs) + + def test_amplitude_to_DB_not_channelwise_clamps(self): + """Check that clamps are applied per-item, not per channel.""" + amplitude_mult = 20. + amin = 1e-10 + ref = 1.0 + db_mult = math.log10(max(amin, ref)) + top_db = 40. + + torch.manual_seed(0) + spec = torch.rand([1, 2, 100, 100]) * 200 + # Make one channel blow out the other + spec[:, 0] += 50 + + specwise_dbs = F.amplitude_to_DB(spec, amplitude_mult, amin, + db_mult, top_db=top_db) + channelwise_dbs = torch.stack([ + F.amplitude_to_DB(spec[:, i], amplitude_mult, amin, + db_mult, top_db=top_db) + for i in range(spec.size(-3)) + ]) + + # Just check channelwise gives a different answer. + difference = (specwise_dbs - channelwise_dbs).abs() + assert (difference >= 1e-5).any() + def test_contrast(self): waveform = torch.rand(2, 100) - 0.5 self.assert_batch_consistencies(F.contrast, waveform, enhancement_amount=80.) @@ -103,7 +177,7 @@ class TestTransforms(common_utils.TorchaudioTestCase): """Test suite for classes defined in `transforms` module""" def test_batch_AmplitudeToDB(self): - spec = torch.rand((6, 201)) + spec = torch.rand((2, 6, 201)) # Single then transform then batch expected = torchaudio.transforms.AmplitudeToDB()(spec).repeat(3, 1, 1) diff --git a/test/torchaudio_unittest/functional/functional_cpu_test.py b/test/torchaudio_unittest/functional/functional_cpu_test.py index c4ba42391d..21e5a3e1b7 100644 --- a/test/torchaudio_unittest/functional/functional_cpu_test.py +++ b/test/torchaudio_unittest/functional/functional_cpu_test.py @@ -83,46 +83,78 @@ def test_pitch(self, frequency): self.assertFalse(s) -class TestDB_to_amplitude(common_utils.TorchaudioTestCase): - def test_DB_to_amplitude(self): - # Make some noise - x = torch.rand(1000) - spectrogram = torchaudio.transforms.Spectrogram() - spec = spectrogram(x) - +class Testamplitude_to_DB(common_utils.TorchaudioTestCase): + @parameterized.expand([ + ([100, 100],), + ([2, 100, 100],), + ([2, 2, 100, 100],), + ]) + def test_reversible(self, shape): + """Round trip between amplitude and db should return the original for various shape + + This implicitly also tests `DB_to_amplitude`. + + """ + amplitude_mult = 20. + power_mult = 10. amin = 1e-10 ref = 1.0 - db_multiplier = math.log10(max(amin, ref)) - - # Waveform amplitude -> DB -> amplitude - multiplier = 20. - power = 0.5 - - db = F.amplitude_to_DB(torch.abs(x), multiplier, amin, db_multiplier, top_db=None) - x2 = F.DB_to_amplitude(db, ref, power) + db_mult = math.log10(max(amin, ref)) - self.assertEqual(x2, torch.abs(x), atol=5e-5, rtol=1e-5) + torch.manual_seed(0) + spec = torch.rand(*shape) * 200 # Spectrogram amplitude -> DB -> amplitude - db = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db=None) - x2 = F.DB_to_amplitude(db, ref, power) + db = F.amplitude_to_DB(spec, amplitude_mult, amin, db_mult, top_db=None) + x2 = F.DB_to_amplitude(db, ref, 0.5) self.assertEqual(x2, spec, atol=5e-5, rtol=1e-5) - # Waveform power -> DB -> power - multiplier = 10. - power = 1. - - db = F.amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None) - x2 = F.DB_to_amplitude(db, ref, power) - - self.assertEqual(x2, torch.abs(x), atol=5e-5, rtol=1e-5) - # Spectrogram power -> DB -> power - db = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db=None) - x2 = F.DB_to_amplitude(db, ref, power) - - self.assertEqual(x2, spec, atol=5e-5, rtol=1e-5) + db = F.amplitude_to_DB(spec, power_mult, amin, db_mult, top_db=None) + x2 = F.DB_to_amplitude(db, ref, 1.) + + self.assertEqual(x2, spec) + + @parameterized.expand([ + ([100, 100],), + ([2, 100, 100],), + ([2, 2, 100, 100],), + ]) + def test_top_db_clamp(self, shape): + """Ensure values are properly clamped when `top_db` is supplied.""" + amplitude_mult = 20. + amin = 1e-10 + ref = 1.0 + db_mult = math.log10(max(amin, ref)) + top_db = 40. + + torch.manual_seed(0) + # A random tensor is used for increased entropy, but the max and min for + # each spectrogram still need to be predictable. The max determines the + # decibel cutoff, and the distance from the min must be large enough + # that it triggers a clamp. + spec = torch.rand(*shape) + # Ensure each spectrogram has a min of 0 and a max of 1. + spec -= spec.amin([-2, -1])[..., None, None] + spec /= spec.amax([-2, -1])[..., None, None] + # Expand the range to (0, 200) - wide enough to properly test clamping. + spec *= 200 + + decibels = F.amplitude_to_DB(spec, amplitude_mult, amin, + db_mult, top_db=top_db) + # Ensure the clamp was applied + below_limit = decibels < 6.0205 + assert not below_limit.any(), ( + "{} decibel values were below the expected cutoff:\n{}".format( + below_limit.sum().item(), decibels + ) + ) + # Ensure it didn't over-clamp + close_to_limit = decibels < 6.0207 + assert close_to_limit.any(), ( + f"No values were close to the limit. Did it over-clamp?\n{decibels}" + ) class TestComplexNorm(common_utils.TorchaudioTestCase): diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index 73941946ec..4dde324a31 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -237,14 +237,16 @@ def amplitude_to_DB( db_multiplier: float, top_db: Optional[float] = None ) -> Tensor: - r"""Turn a tensor from the power/amplitude scale to the decibel scale. + r"""Turn a spectrogram from the power/amplitude scale to the decibel scale. - This output depends on the maximum value in the input tensor, and so - may return different values for an audio clip split into snippets vs. a - full clip. + The output of each tensor in a batch depends on the maximum value of that tensor, + and so may return different values for an audio clip split into snippets vs. a full clip. Args: - x (Tensor): Input tensor before being converted to decibel scale + + x (Tensor): Input spectrogram(s) before being converted to decibel scale. Input should take + the form `(..., freq, time)`. Batched inputs should include a channel dimension and + have the form `(batch, channel, freq, time)`. multiplier (float): Use 10. for power and 20. for amplitude amin (float): Number to clamp ``x`` db_multiplier (float): Log10(max(reference value and amin)) @@ -258,7 +260,15 @@ def amplitude_to_DB( x_db -= multiplier * db_multiplier if top_db is not None: - x_db = x_db.clamp(min=x_db.max().item() - top_db) + # Expand batch + shape = x_db.size() + packed_channels = shape[-3] if x_db.dim() > 2 else 1 + x_db = x_db.reshape(-1, packed_channels, shape[-2], shape[-1]) + + x_db = torch.max(x_db, (x_db.amax(dim=(-3, -2, -1)) - top_db).view(-1, 1, 1, 1)) + + # Repack batch + x_db = x_db.reshape(shape) return x_db diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index dc7e2ddfe1..cc27b3a64a 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -530,24 +530,16 @@ def forward(self, waveform: Tensor) -> Tensor: Returns: Tensor: specgram_mel_db of size (..., ``n_mfcc``, time). """ - - # pack batch - shape = waveform.size() - waveform = waveform.reshape(-1, shape[-1]) - mel_specgram = self.MelSpectrogram(waveform) if self.log_mels: log_offset = 1e-6 mel_specgram = torch.log(mel_specgram + log_offset) else: mel_specgram = self.amplitude_to_DB(mel_specgram) - # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc) - # -> (channel, time, n_mfcc).tranpose(...) - mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2) - - # unpack batch - mfcc = mfcc.reshape(shape[:-1] + mfcc.shape[-2:]) + # (..., channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc) + # -> (..., channel, time, n_mfcc).tranpose(...) + mfcc = torch.matmul(mel_specgram.transpose(-2, -1), self.dct_mat).transpose(-2, -1) return mfcc