Skip to content

Commit babc24a

Browse files
mthrokjaeyeun97
andauthored
Add test for InverseMelScale (#448)
* Inverse Mel Scale Implementation * Inverse Mel Scale Docs * Better working version. * GPU fix * These shouldn't go on git.. * Even better one, but does not support JITability. * Remove JITability test * Flake8 * n_stft is a must * minor clean up of initialization * Add librosa consistency test This PR follows up #366 and adds test for `InverseMelScale` (and `MelScale`) for librosa compatibility. For `MelScale` compatibility test; 1. Generate spectrogram 2. Feed the spectrogram to `torchaudio.transforms.MelScale` instance 3. Feed the spectrogram to `librosa.feature.melspectrogram` function. 4. Compare the result from 2 and 3 elementwise. Element-wise numerical comparison is possible because under the hood their implementations use the same algorith. For `InverseMelScale` compatibility test, it is more elaborated than that. 1. Generate the original spectrogram 2. Convert the original spectrogram to Mel scale using `torchaudio.transforms.MelScale` instance 3. Reconstruct spectrogram using torchaudio implementation 3.1. Feed the Mel spectrogram to `torchaudio.transforms.InverseMelScale` instance and get reconstructed spectrogram. 3.2. Compute the sum of element-wise P1 distance of the original spectrogram and that from 3.1. 4. Reconstruct spectrogram using librosa 4.1. Feed the Mel spectrogram to `librosa.feature.inverse.mel_to_stft` function and get reconstructed spectrogram. 4.2. Compute the sum of element-wise P1 distance of the original spectrogram and that from 4.1. (this is the reference.) 5. Check that resulting P1 distance are in a roughly same value range. Element-wise numerical comparison is not possible due to the difference algorithms used to compute the inverse. The reconstructed spectrograms can have some values vary in magnitude. Therefore the strategy here is to check that P1 distance (reconstruction loss) is not that different from the value obtained using `librosa`. For this purpose, threshold was empirically chosen ``` print('p1 dist (orig <-> ta):', torch.dist(spec_orig, spec_ta, p=1)) print('p1 dist (orig <-> lr):', torch.dist(spec_orig, spec_lr, p=1)) >>> p1 dist (orig <-> ta): tensor(1482.1917) >>> p1 dist (orig <-> lr): tensor(1420.7103) ``` This value can vary based on the length and the kind of the signal being processed, so it was handpicked. * Address review feedbacks * Support arbitrary batch dimensions. * Add batch test * Use view for batch * fix sgd * Use negative indices and update docstring * Update threshold Co-authored-by: Charles J.Y. Yoon <[email protected]>
1 parent 2cf59c4 commit babc24a

File tree

3 files changed

+204
-0
lines changed

3 files changed

+204
-0
lines changed

docs/source/transforms.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@ Transforms are common audio transforms. They can be chained together using :clas
3737

3838
.. automethod:: forward
3939

40+
:hidden:`InverseMelScale`
41+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
42+
43+
.. autoclass:: InverseMelScale
44+
45+
.. automethod:: forward
46+
47+
4048
:hidden:`MelSpectrogram`
4149
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4250

test/test_transforms.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,25 @@ def test_batch_MelScale(self):
410410
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
411411
self.assertTrue(torch.allclose(computed, expected))
412412

413+
def test_batch_InverseMelScale(self):
414+
n_fft = 8
415+
n_mels = 32
416+
n_stft = 5
417+
mel_spec = torch.randn(2, n_mels, 32) ** 2
418+
419+
# Single then transform then batch
420+
expected = transforms.InverseMelScale(n_stft, n_mels)(mel_spec).repeat(3, 1, 1, 1)
421+
422+
# Batch then transform
423+
computed = transforms.InverseMelScale(n_stft, n_mels)(mel_spec.repeat(3, 1, 1, 1))
424+
425+
# shape = (3, 2, n_mels, 32)
426+
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
427+
428+
# Because InverseMelScale runs SGD on randomly initialized values so they do not yield
429+
# exactly same result. For this reason, tolerance is very relaxed here.
430+
self.assertTrue(torch.allclose(computed, expected, atol=1.0))
431+
413432
def test_batch_compute_deltas(self):
414433
specgram = torch.randn(2, 31, 2786)
415434

@@ -509,5 +528,97 @@ def test_scriptmodule_TimeMasking(self):
509528
_test_script_module(transforms.TimeMasking, tensor, time_mask_param=30, iid_masks=False)
510529

511530

531+
class TestLibrosaConsistency(unittest.TestCase):
532+
test_dirpath = None
533+
test_dir = None
534+
535+
@classmethod
536+
def setUpClass(cls):
537+
cls.test_dirpath, cls.test_dir = common_utils.create_temp_assets_dir()
538+
539+
def _to_librosa(self, sound):
540+
return sound.cpu().numpy().squeeze()
541+
542+
def _get_sample_data(self, *asset_paths, **kwargs):
543+
file_path = os.path.join(self.test_dirpath, 'assets', *asset_paths)
544+
545+
sound, sample_rate = torchaudio.load(file_path, **kwargs)
546+
return sound.mean(dim=0, keepdim=True), sample_rate
547+
548+
@unittest.skipIf(not IMPORT_LIBROSA, 'Librosa is not available')
549+
def test_MelScale(self):
550+
"""MelScale transform is comparable to that of librosa"""
551+
n_fft = 2048
552+
n_mels = 256
553+
hop_length = n_fft // 4
554+
555+
# Prepare spectrogram input. We use torchaudio to compute one.
556+
sound, sample_rate = self._get_sample_data('whitenoise_1min.mp3')
557+
spec_ta = F.spectrogram(
558+
sound, pad=0, window=torch.hann_window(n_fft), n_fft=n_fft,
559+
hop_length=hop_length, win_length=n_fft, power=2, normalized=False)
560+
spec_lr = spec_ta.cpu().numpy().squeeze()
561+
# Perform MelScale with torchaudio and librosa
562+
melspec_ta = transforms.MelScale(n_mels=n_mels, sample_rate=sample_rate)(spec_ta)
563+
melspec_lr = librosa.feature.melspectrogram(
564+
S=spec_lr, sr=sample_rate, n_fft=n_fft, hop_length=hop_length,
565+
win_length=n_fft, center=True, window='hann', n_mels=n_mels, htk=True, norm=None)
566+
# Note: Using relaxed rtol instead of atol
567+
assert torch.allclose(melspec_ta, torch.from_numpy(melspec_lr[None, ...]), rtol=1e-3)
568+
569+
@unittest.skipIf(not IMPORT_LIBROSA, 'Librosa is not available')
570+
def test_InverseMelScale(self):
571+
"""InverseMelScale transform is comparable to that of librosa"""
572+
n_fft = 2048
573+
n_mels = 256
574+
n_stft = n_fft // 2 + 1
575+
hop_length = n_fft // 4
576+
577+
# Prepare mel spectrogram input. We use torchaudio to compute one.
578+
sound, sample_rate = self._get_sample_data(
579+
'steam-train-whistle-daniel_simon.wav', offset=2**10, num_frames=2**14)
580+
spec_orig = F.spectrogram(
581+
sound, pad=0, window=torch.hann_window(n_fft), n_fft=n_fft,
582+
hop_length=hop_length, win_length=n_fft, power=2, normalized=False)
583+
melspec_ta = transforms.MelScale(n_mels=n_mels, sample_rate=sample_rate)(spec_orig)
584+
melspec_lr = melspec_ta.cpu().numpy().squeeze()
585+
# Perform InverseMelScale with torch audio and librosa
586+
spec_ta = transforms.InverseMelScale(
587+
n_stft, n_mels=n_mels, sample_rate=sample_rate)(melspec_ta)
588+
spec_lr = librosa.feature.inverse.mel_to_stft(
589+
melspec_lr, sr=sample_rate, n_fft=n_fft, power=2.0, htk=True, norm=None)
590+
spec_lr = torch.from_numpy(spec_lr[None, ...])
591+
592+
# Align dimensions
593+
# librosa does not return power spectrogram while torchaudio returns power spectrogram
594+
spec_orig = spec_orig.sqrt()
595+
spec_ta = spec_ta.sqrt()
596+
597+
threshold = 2.0
598+
# This threshold was choosen empirically, based on the following observation
599+
#
600+
# torch.dist(spec_lr, spec_ta, p=float('inf'))
601+
# >>> tensor(1.9666)
602+
#
603+
# The spectrograms reconstructed by librosa and torchaudio are not very comparable elementwise.
604+
# This is because they use different approximation algorithms and resulting values can live
605+
# in different magnitude. (although most of them are very close)
606+
# See https://github.com/pytorch/audio/pull/366 for the discussion of the choice of algorithm
607+
# See https://github.com/pytorch/audio/pull/448/files#r385747021 for the distribution of P-inf
608+
# distance over frequencies.
609+
assert torch.allclose(spec_ta, spec_lr, atol=threshold)
610+
611+
threshold = 1700.0
612+
# This threshold was choosen empirically, based on the following observations
613+
#
614+
# torch.dist(spec_orig, spec_ta, p=1)
615+
# >>> tensor(1644.3516)
616+
# torch.dist(spec_orig, spec_lr, p=1)
617+
# >>> tensor(1420.7103)
618+
# torch.dist(spec_lr, spec_ta, p=1)
619+
# >>> tensor(943.2759)
620+
assert torch.dist(spec_orig, spec_ta, p=1) < threshold
621+
622+
512623
if __name__ == '__main__':
513624
unittest.main()

torchaudio/transforms.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
'GriffinLim',
1515
'AmplitudeToDB',
1616
'MelScale',
17+
'InverseMelScale',
1718
'MelSpectrogram',
1819
'MFCC',
1920
'MuLawEncoding',
@@ -233,6 +234,90 @@ def forward(self, specgram):
233234
return mel_specgram
234235

235236

237+
class InverseMelScale(torch.nn.Module):
238+
r"""Solve for a normal STFT from a mel frequency STFT, using a conversion
239+
matrix. This uses triangular filter banks.
240+
241+
It minimizes the euclidian norm between the input mel-spectrogram and the product between
242+
the estimated spectrogram and the filter banks using SGD.
243+
244+
Args:
245+
n_stft (int): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`.
246+
n_mels (int): Number of mel filterbanks. (Default: ``128``)
247+
sample_rate (int): Sample rate of audio signal. (Default: ``16000``)
248+
f_min (float): Minimum frequency. (Default: ``0.``)
249+
f_max (float, optional): Maximum frequency. (Default: ``sample_rate // 2``)
250+
max_iter (int): Maximum number of optimization iterations.
251+
tolerance_loss (float): Value of loss to stop optimization at.
252+
tolerance_change (float): Difference in losses to stop optimization at.
253+
sgdargs (dict): Arguments for the SGD optimizer.
254+
"""
255+
__constants__ = ['n_stft', 'n_mels', 'sample_rate', 'f_min', 'f_max', 'max_iter', 'tolerance_loss',
256+
'tolerance_change', 'sgdargs']
257+
258+
def __init__(self, n_stft, n_mels=128, sample_rate=16000, f_min=0., f_max=None, max_iter=100000,
259+
tolerance_loss=1e-5, tolerance_change=1e-8, sgdargs=None):
260+
super(InverseMelScale, self).__init__()
261+
self.n_mels = n_mels
262+
self.sample_rate = sample_rate
263+
self.f_max = f_max or float(sample_rate // 2)
264+
self.f_min = f_min
265+
self.max_iter = max_iter
266+
self.tolerance_loss = tolerance_loss
267+
self.tolerance_change = tolerance_change
268+
self.sgdargs = sgdargs or {'lr': 0.1, 'momentum': 0.9}
269+
270+
assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max)
271+
272+
fb = F.create_fb_matrix(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate)
273+
self.register_buffer('fb', fb)
274+
275+
def forward(self, melspec):
276+
r"""
277+
Args:
278+
melspec (torch.Tensor): A Mel frequency spectrogram of dimension (..., ``n_mels``, time)
279+
280+
Returns:
281+
torch.Tensor: Linear scale spectrogram of size (..., freq, time)
282+
"""
283+
# pack batch
284+
shape = melspec.size()
285+
melspec = melspec.view(-1, shape[-2], shape[-1])
286+
287+
n_mels, time = shape[-2], shape[-1]
288+
freq, _ = self.fb.size() # (freq, n_mels)
289+
melspec = melspec.transpose(-1, -2)
290+
assert self.n_mels == n_mels
291+
292+
specgram = torch.rand(melspec.size()[0], time, freq, requires_grad=True,
293+
dtype=melspec.dtype, device=melspec.device)
294+
295+
optim = torch.optim.SGD([specgram], **self.sgdargs)
296+
297+
loss = float('inf')
298+
for _ in range(self.max_iter):
299+
optim.zero_grad()
300+
diff = melspec - specgram.matmul(self.fb)
301+
new_loss = diff.pow(2).sum(axis=-1).mean()
302+
# take sum over mel-frequency then average over other dimensions
303+
# so that loss threshold is applied par unit timeframe
304+
new_loss.backward()
305+
optim.step()
306+
specgram.data = specgram.data.clamp(min=0)
307+
308+
new_loss = new_loss.item()
309+
if new_loss < self.tolerance_loss or abs(loss - new_loss) < self.tolerance_change:
310+
break
311+
loss = new_loss
312+
313+
specgram.requires_grad_(False)
314+
specgram = specgram.clamp(min=0).transpose(-1, -2)
315+
316+
# unpack batch
317+
specgram = specgram.view(shape[:-2] + (freq, time))
318+
return specgram
319+
320+
236321
class MelSpectrogram(torch.nn.Module):
237322
r"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram
238323
and MelScale.

0 commit comments

Comments
 (0)