From cfc5b1f2d61ae464851c33070118b7341ba48f43 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 5 Sep 2019 14:42:00 -0400 Subject: [PATCH 01/22] compute deltas. --- test/test_functional.py | 20 ++++++++++++++++++++ torchaudio/functional.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/test/test_functional.py b/test/test_functional.py index 8f4f84942d..9405d28fce 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -192,6 +192,26 @@ def test_istft_of_sine(self): self._test_istft_of_sine(amplitude=99, L=10, n=7) +class TestDeltas(unittest.TestCase): + waveform = torch.tensor([1.,2.,3.,4.]).unsqueeze(0) + + def _test(self, waveform, expected, n_diff=1, atol=1e-6, rtol=1e-8): + computed = F.compute_deltas(waveform, n_diff=1) + self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) + self.assertTrue(torch.allclose(computed, expected, atol=atol, rtol=rtol)) + + def test_onechannel(self): + waveform = self.waveform + expected = torch.tensor([[ 1.0, 1.0, 1.0, -1.5]]) + self._test(waveform, expected) + + def test_twochannel(self): + waveform = torch.cat([self.waveform, self.waveform], dim=0) + expected = torch.tensor([[ 1.0, 1.0, 1.0, -1.5], + [ 1., 1.0, 1.0, -1.5]]) + self._test(waveform, expected) + + def _num_stft_bins(signal_len, fft_len, hop_length, pad): return (signal_len + 2 * pad - fft_len + hop_length) // hop_length diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 842d8495bf..da4da21fed 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -652,3 +652,36 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): a1 = -2 * math.cos(w0) a2 = 1 - alpha return biquad(waveform, b0, b1, b2, a0, a1, a2) + + +def compute_deltas(waveform, n_diff=2): + r"""Compute delta coefficients. + + Args: + waveform (torch.Tensor): Tensor of audio of dimension (channel, time) + n_diff (int): Number of differences to consider + + Returns: + waveform (torch.Tensor): Tensor of audio of dimension (channel, time) + + Example + >>> waveform = torch.randn(2, 100) + >>> deltas = compute_deltas(waveform) + """ + + assert waveform.dim() == 2 + + kernel = ( + torch + .tensor(range(-n_diff, n_diff+1, 1), device=waveform.device, dtype=waveform.dtype) + .repeat(waveform.shape[0], 1) + .unsqueeze(1) + ) + waveform = waveform.unsqueeze(0) + deltas = torch.nn.functional.conv1d(waveform, kernel, padding=n_diff, groups=waveform.shape[1]) + + # twice sum of integer squared + denom = n_diff * (n_diff+1) * (2*n_diff+1) / 3 + + deltas /= denom + return deltas.squeeze(0) From 642cd39fb15708206c882771fdbb786c030279c4 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 5 Sep 2019 15:00:19 -0400 Subject: [PATCH 02/22] flake8. --- test/test_functional.py | 8 ++++---- torchaudio/functional.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/test/test_functional.py b/test/test_functional.py index 9405d28fce..3d2887eb52 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -193,7 +193,7 @@ def test_istft_of_sine(self): class TestDeltas(unittest.TestCase): - waveform = torch.tensor([1.,2.,3.,4.]).unsqueeze(0) + waveform = torch.tensor([1., 2., 3., 4.]).unsqueeze(0) def _test(self, waveform, expected, n_diff=1, atol=1e-6, rtol=1e-8): computed = F.compute_deltas(waveform, n_diff=1) @@ -202,13 +202,13 @@ def _test(self, waveform, expected, n_diff=1, atol=1e-6, rtol=1e-8): def test_onechannel(self): waveform = self.waveform - expected = torch.tensor([[ 1.0, 1.0, 1.0, -1.5]]) + expected = torch.tensor([[1.0, 1.0, 1.0, -1.5]]) self._test(waveform, expected) def test_twochannel(self): waveform = torch.cat([self.waveform, self.waveform], dim=0) - expected = torch.tensor([[ 1.0, 1.0, 1.0, -1.5], - [ 1., 1.0, 1.0, -1.5]]) + expected = torch.tensor([[1.0, 1.0, 1.0, -1.5], + [1.0, 1.0, 1.0, -1.5]]) self._test(waveform, expected) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index da4da21fed..424aa7a1ff 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -673,7 +673,7 @@ def compute_deltas(waveform, n_diff=2): kernel = ( torch - .tensor(range(-n_diff, n_diff+1, 1), device=waveform.device, dtype=waveform.dtype) + .tensor(range(-n_diff, n_diff + 1, 1), device=waveform.device, dtype=waveform.dtype) .repeat(waveform.shape[0], 1) .unsqueeze(1) ) @@ -681,7 +681,7 @@ def compute_deltas(waveform, n_diff=2): deltas = torch.nn.functional.conv1d(waveform, kernel, padding=n_diff, groups=waveform.shape[1]) # twice sum of integer squared - denom = n_diff * (n_diff+1) * (2*n_diff+1) / 3 + denom = n_diff * (n_diff + 1) * (2 * n_diff + 1) / 3 deltas /= denom return deltas.squeeze(0) From 5601fc7119ac45404508123dc4b8d903a5bc7114 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 5 Sep 2019 15:25:39 -0400 Subject: [PATCH 03/22] specgram format. --- test/test_functional.py | 20 ++++++++++---------- torchaudio/functional.py | 33 +++++++++++++++------------------ 2 files changed, 25 insertions(+), 28 deletions(-) diff --git a/test/test_functional.py b/test/test_functional.py index 3d2887eb52..4385181598 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -193,23 +193,23 @@ def test_istft_of_sine(self): class TestDeltas(unittest.TestCase): - waveform = torch.tensor([1., 2., 3., 4.]).unsqueeze(0) + specgram = torch.tensor([1., 2., 3., 4.]) - def _test(self, waveform, expected, n_diff=1, atol=1e-6, rtol=1e-8): - computed = F.compute_deltas(waveform, n_diff=1) + def _test(self, specgram, expected, n_diff=1, atol=1e-6, rtol=1e-8): + computed = F.compute_deltas(specgram, n_diff=1) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(torch.allclose(computed, expected, atol=atol, rtol=rtol)) def test_onechannel(self): - waveform = self.waveform - expected = torch.tensor([[1.0, 1.0, 1.0, -1.5]]) - self._test(waveform, expected) + specgram = self.specgram.unsqueeze(0).unsqueeze(0) + expected = torch.tensor([[[1.0, 1.0, 1.0, -1.5]]]) + self._test(specgram, expected) def test_twochannel(self): - waveform = torch.cat([self.waveform, self.waveform], dim=0) - expected = torch.tensor([[1.0, 1.0, 1.0, -1.5], - [1.0, 1.0, 1.0, -1.5]]) - self._test(waveform, expected) + specgram = self.specgram.repeat(1, 2, 1) + expected = torch.tensor([[[1.0, 1.0, 1.0, -1.5], + [1.0, 1.0, 1.0, -1.5]]]) + self._test(specgram, expected) def _num_stft_bins(signal_len, fft_len, hop_length, pad): diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 424aa7a1ff..5cea46c185 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -654,34 +654,31 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): return biquad(waveform, b0, b1, b2, a0, a1, a2) -def compute_deltas(waveform, n_diff=2): - r"""Compute delta coefficients. +def compute_deltas(specgram, n_diff=2): + r"""Compute delta coefficients of a spectogram. Args: - waveform (torch.Tensor): Tensor of audio of dimension (channel, time) + specgram (torch.Tensor): Tensor of audio of dimension (channel, time) n_diff (int): Number of differences to consider Returns: - waveform (torch.Tensor): Tensor of audio of dimension (channel, time) + deltas (torch.Tensor): Tensor of audio of dimension (channel, time) Example - >>> waveform = torch.randn(2, 100) - >>> deltas = compute_deltas(waveform) + >>> specgram = torch.randn(2, 100) + >>> deltas = compute_deltas(specgram) """ - assert waveform.dim() == 2 - - kernel = ( - torch - .tensor(range(-n_diff, n_diff + 1, 1), device=waveform.device, dtype=waveform.dtype) - .repeat(waveform.shape[0], 1) - .unsqueeze(1) - ) - waveform = waveform.unsqueeze(0) - deltas = torch.nn.functional.conv1d(waveform, kernel, padding=n_diff, groups=waveform.shape[1]) + assert specgram.dim() == 3 # twice sum of integer squared denom = n_diff * (n_diff + 1) * (2 * n_diff + 1) / 3 - deltas /= denom - return deltas.squeeze(0) + kernel = ( + torch + .tensor(range(-n_diff, n_diff + 1, 1), device=specgram.device, dtype=specgram.dtype) + .repeat(specgram.shape[1], specgram.shape[0], 1) + ) + return torch.nn.functional.conv1d( + specgram, kernel, padding=n_diff, groups=specgram.shape[1] + ) / denom From 792f5362d006ecc6537d1e2b4a95baa78f5e8e56 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 5 Sep 2019 15:26:46 -0400 Subject: [PATCH 04/22] update doc. --- torchaudio/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 5cea46c185..47aab8cb03 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -665,7 +665,7 @@ def compute_deltas(specgram, n_diff=2): deltas (torch.Tensor): Tensor of audio of dimension (channel, time) Example - >>> specgram = torch.randn(2, 100) + >>> specgram = torch.randn(1, 40, 1000) >>> deltas = compute_deltas(specgram) """ From c9895676138ef2bfb53dc509016bf846c17878de Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 5 Sep 2019 15:46:22 -0400 Subject: [PATCH 05/22] multichannel, and random test. --- test/test_functional.py | 8 ++++++++ torchaudio/functional.py | 9 +++++---- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/test/test_functional.py b/test/test_functional.py index 4385181598..2e80746b1f 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -211,6 +211,14 @@ def test_twochannel(self): [1.0, 1.0, 1.0, -1.5]]]) self._test(specgram, expected) + def test_randn(self): + channel = 13 + n_mfcc = channel * 3 + time = 1021 + specgram = torch.randn(channel, n_mfcc, time) + computed = F.compute_deltas(specgram, n_diff=7) + self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) + def _num_stft_bins(signal_len, fft_len, hop_length, pad): return (signal_len + 2 * pad - fft_len + hop_length) // hop_length diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 47aab8cb03..b09dce8bf8 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -658,11 +658,11 @@ def compute_deltas(specgram, n_diff=2): r"""Compute delta coefficients of a spectogram. Args: - specgram (torch.Tensor): Tensor of audio of dimension (channel, time) - n_diff (int): Number of differences to consider + specgram (torch.Tensor): Tensor of audio of dimension (channel, n_mfcc, time) + n_diff (int): Number of differences to use in computing delta Returns: - deltas (torch.Tensor): Tensor of audio of dimension (channel, time) + deltas (torch.Tensor): Tensor of audio of dimension (channel, n_mfcc, time) Example >>> specgram = torch.randn(1, 40, 1000) @@ -670,6 +670,7 @@ def compute_deltas(specgram, n_diff=2): """ assert specgram.dim() == 3 + assert not specgram.shape[1] % specgram.shape[0] # twice sum of integer squared denom = n_diff * (n_diff + 1) * (2 * n_diff + 1) / 3 @@ -680,5 +681,5 @@ def compute_deltas(specgram, n_diff=2): .repeat(specgram.shape[1], specgram.shape[0], 1) ) return torch.nn.functional.conv1d( - specgram, kernel, padding=n_diff, groups=specgram.shape[1] + specgram, kernel, padding=n_diff, groups=specgram.shape[1]//specgram.shape[0] ) / denom From 4d3175973f51225f611e332800a548099416d23e Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 5 Sep 2019 16:18:14 -0400 Subject: [PATCH 06/22] documentation. --- torchaudio/functional.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index b09dce8bf8..420ec7bfe7 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -655,7 +655,14 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): def compute_deltas(specgram, n_diff=2): - r"""Compute delta coefficients of a spectogram. + r"""Compute delta coefficients of a spectogram: + + .. math:: + d_t = \frac{\sum_{n=1}^{N} n (c_{t+n} - c_{t-n})}{2 \sum_{n=1}^N n^2} + + where :math:`d_t` is the deltas at time :math:`t`, + :math:`N` is n_diff, + :math:`c_t` are the spectogram coeffcients at time :math:`t`, Args: specgram (torch.Tensor): Tensor of audio of dimension (channel, n_mfcc, time) @@ -671,6 +678,7 @@ def compute_deltas(specgram, n_diff=2): assert specgram.dim() == 3 assert not specgram.shape[1] % specgram.shape[0] + assert n_diff > 0 # twice sum of integer squared denom = n_diff * (n_diff + 1) * (2 * n_diff + 1) / 3 From c36710df999284ac557f7e9beb830b1f72ad9f14 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 5 Sep 2019 16:19:25 -0400 Subject: [PATCH 07/22] flake8. --- torchaudio/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 420ec7bfe7..01c1bd4f3a 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -689,5 +689,5 @@ def compute_deltas(specgram, n_diff=2): .repeat(specgram.shape[1], specgram.shape[0], 1) ) return torch.nn.functional.conv1d( - specgram, kernel, padding=n_diff, groups=specgram.shape[1]//specgram.shape[0] + specgram, kernel, padding=n_diff, groups=specgram.shape[1] // specgram.shape[0] ) / denom From ecc358cb009e84bff5696fe3b77d7a45113d5fd8 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 5 Sep 2019 16:31:44 -0400 Subject: [PATCH 08/22] phrasing in doc. --- torchaudio/functional.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 01c1bd4f3a..780d1028ea 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -661,12 +661,12 @@ def compute_deltas(specgram, n_diff=2): d_t = \frac{\sum_{n=1}^{N} n (c_{t+n} - c_{t-n})}{2 \sum_{n=1}^N n^2} where :math:`d_t` is the deltas at time :math:`t`, - :math:`N` is n_diff, - :math:`c_t` are the spectogram coeffcients at time :math:`t`, + :math:`c_t` is the spectogram coeffcients at time :math:`t`, + :math:`N` is n_diff. Args: specgram (torch.Tensor): Tensor of audio of dimension (channel, n_mfcc, time) - n_diff (int): Number of differences to use in computing delta + n_diff (int): A nonzero number of differences to use in computing delta Returns: deltas (torch.Tensor): Tensor of audio of dimension (channel, n_mfcc, time) @@ -676,9 +676,9 @@ def compute_deltas(specgram, n_diff=2): >>> deltas = compute_deltas(specgram) """ + assert n_diff > 0 assert specgram.dim() == 3 assert not specgram.shape[1] % specgram.shape[0] - assert n_diff > 0 # twice sum of integer squared denom = n_diff * (n_diff + 1) * (2 * n_diff + 1) / 3 From e2814cfecc486431ccd5dc5fec75ac1ec91df224 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Fri, 6 Sep 2019 11:46:12 -0400 Subject: [PATCH 09/22] follow kaldi's interface. --- test/test_functional.py | 12 ++++++------ torchaudio/functional.py | 24 +++++++++++++++--------- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/test/test_functional.py b/test/test_functional.py index 2e80746b1f..caed801399 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -195,20 +195,20 @@ def test_istft_of_sine(self): class TestDeltas(unittest.TestCase): specgram = torch.tensor([1., 2., 3., 4.]) - def _test(self, specgram, expected, n_diff=1, atol=1e-6, rtol=1e-8): - computed = F.compute_deltas(specgram, n_diff=1) + def _test(self, specgram, expected, window=1, atol=1e-6, rtol=1e-8): + computed = F.compute_deltas(specgram, window=1) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(torch.allclose(computed, expected, atol=atol, rtol=rtol)) def test_onechannel(self): specgram = self.specgram.unsqueeze(0).unsqueeze(0) - expected = torch.tensor([[[1.0, 1.0, 1.0, -1.5]]]) + expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]]) self._test(specgram, expected) def test_twochannel(self): specgram = self.specgram.repeat(1, 2, 1) - expected = torch.tensor([[[1.0, 1.0, 1.0, -1.5], - [1.0, 1.0, 1.0, -1.5]]]) + expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5], + [0.5, 1.0, 1.0, 0.5]]]) self._test(specgram, expected) def test_randn(self): @@ -216,7 +216,7 @@ def test_randn(self): n_mfcc = channel * 3 time = 1021 specgram = torch.randn(channel, n_mfcc, time) - computed = F.compute_deltas(specgram, n_diff=7) + computed = F.compute_deltas(specgram, window=7) self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 780d1028ea..86a82ae174 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -654,40 +654,46 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): return biquad(waveform, b0, b1, b2, a0, a1, a2) -def compute_deltas(specgram, n_diff=2): +def compute_deltas(specgram, window=2): r"""Compute delta coefficients of a spectogram: .. math:: - d_t = \frac{\sum_{n=1}^{N} n (c_{t+n} - c_{t-n})}{2 \sum_{n=1}^N n^2} + d_t = \frac{\sum_{n=1}^{\text{window}} n (c_{t+n} - c_{t-n})}{2 \sum_{n=1}^{\text{window} n^2} where :math:`d_t` is the deltas at time :math:`t`, :math:`c_t` is the spectogram coeffcients at time :math:`t`, - :math:`N` is n_diff. + `window` is the parameter given to the function (the actual window size is 2*window+1). + + The behavior at the edges is to replicate the boundaries. Args: specgram (torch.Tensor): Tensor of audio of dimension (channel, n_mfcc, time) - n_diff (int): A nonzero number of differences to use in computing delta + window (int): A nonzero number of differences to use in computing delta Returns: deltas (torch.Tensor): Tensor of audio of dimension (channel, n_mfcc, time) Example >>> specgram = torch.randn(1, 40, 1000) - >>> deltas = compute_deltas(specgram) + >>> delta = compute_deltas(specgram) + >>> delta2 = compute_deltas(delta) """ - assert n_diff > 0 + assert window > 0 assert specgram.dim() == 3 assert not specgram.shape[1] % specgram.shape[0] # twice sum of integer squared - denom = n_diff * (n_diff + 1) * (2 * n_diff + 1) / 3 + denom = window * (window + 1) * (2 * window + 1) / 3 + + specgram = torch.nn.functional.pad(specgram, (window, window), mode='replicate') kernel = ( torch - .tensor(range(-n_diff, n_diff + 1, 1), device=specgram.device, dtype=specgram.dtype) + .tensor(range(-window, window + 1, 1), device=specgram.device, dtype=specgram.dtype) .repeat(specgram.shape[1], specgram.shape[0], 1) ) + return torch.nn.functional.conv1d( - specgram, kernel, padding=n_diff, groups=specgram.shape[1] // specgram.shape[0] + specgram, kernel, groups=specgram.shape[1] // specgram.shape[0] ) / denom From e0433db1c229245829738758d99de7d8df4ca398 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Fri, 6 Sep 2019 15:18:58 -0400 Subject: [PATCH 10/22] adding compliance and transform. --- test/test_compliance_kaldi.py | 10 ++++++++++ test/test_functional.py | 3 ++- test/test_transforms.py | 10 ++++++++++ torchaudio/compliance/kaldi.py | 17 +++++++++++++++++ torchaudio/functional.py | 3 ++- torchaudio/transforms.py | 33 +++++++++++++++++++++++++++++++++ 6 files changed, 74 insertions(+), 2 deletions(-) diff --git a/test/test_compliance_kaldi.py b/test/test_compliance_kaldi.py index d640e1671f..789b273021 100644 --- a/test/test_compliance_kaldi.py +++ b/test/test_compliance_kaldi.py @@ -319,5 +319,15 @@ def test_resample_waveform_multi_channel(self): single_channel_sampled = kaldi.resample_waveform(single_channel, sample_rate, sample_rate // 2) self.assertTrue(torch.allclose(multi_sound_sampled[i, :], single_channel_sampled, rtol=1e-4)) + def test_compute_deltas(self): + channel = 13 + n_mfcc = channel * 3 + time = 1021 + window = 7 + order = 1 + specgram = torch.randn(channel, n_mfcc, time) + computed = kaldi.add_deltas(specgram, window=window, order=order) + self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) + if __name__ == '__main__': unittest.main() diff --git a/test/test_functional.py b/test/test_functional.py index caed801399..fa1ce958f6 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -215,8 +215,9 @@ def test_randn(self): channel = 13 n_mfcc = channel * 3 time = 1021 + window = 7 specgram = torch.randn(channel, n_mfcc, time) - computed = F.compute_deltas(specgram, window=7) + computed = F.compute_deltas(specgram, window=window) self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) diff --git a/test/test_transforms.py b/test/test_transforms.py index 91074c4177..2d5b52045a 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -281,5 +281,15 @@ def test_resample_size(self): # we expect the downsampled signal to have half as many samples self.assertTrue(down_sampled.size(-1) == waveform.size(-1) // 2) + def test_compute_deltas(self): + channel = 13 + n_mfcc = channel * 3 + time = 1021 + window = 7 + specgram = torch.randn(channel, n_mfcc, time) + transform = transforms.ComputeDeltas(window=window) + computed = transform(specgram) + self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) + if __name__ == '__main__': unittest.main() diff --git a/torchaudio/compliance/kaldi.py b/torchaudio/compliance/kaldi.py index a379c09ce5..d26174b432 100644 --- a/torchaudio/compliance/kaldi.py +++ b/torchaudio/compliance/kaldi.py @@ -866,3 +866,20 @@ def resample_waveform(waveform, orig_freq, new_freq, lowpass_filter_width=6): output += dilated_conv_wave return output + + +def add_deltas(specgram, order=1, window=2): + r"""Compute delta coefficients of given order of a spectogram. + + Args: + specgram (torch.Tensor): Tensor of audio of dimension (channel, n_mfcc, time) + order (int): A nonzero order of difference + window (int): A nonzero number of differences to use in computing delta + + Returns: + deltas (torch.Tensor): Tensor of audio of dimension (channel, n_mfcc, time) + """ + deltas = specgram + for _ in range(order): + deltas = torchaudio.functional.compute_deltas(deltas, window=window) + return deltas diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 86a82ae174..baa78b4816 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -655,6 +655,7 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): def compute_deltas(specgram, window=2): + # type: (Tensor, int) -> Tensor r"""Compute delta coefficients of a spectogram: .. math:: @@ -690,7 +691,7 @@ def compute_deltas(specgram, window=2): kernel = ( torch - .tensor(range(-window, window + 1, 1), device=specgram.device, dtype=specgram.dtype) + .arange(-window, window + 1, 1, device=specgram.device, dtype=specgram.dtype) .repeat(specgram.shape[1], specgram.shape[0], 1) ) diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 38e703b2b6..72352a9be8 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -365,3 +365,36 @@ def forward(self, waveform): return kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq) raise ValueError('Invalid resampling method: %s' % (self.resampling_method)) + + +class ComputeDeltas(torch.jit.ScriptModule): + r"""Compute delta coefficients of a spectogram: + + .. math:: + d_t = \frac{\sum_{n=1}^{\text{window}} n (c_{t+n} - c_{t-n})}{2 \sum_{n=1}^{\text{window} n^2} + + where :math:`d_t` is the deltas at time :math:`t`, + :math:`c_t` is the spectogram coeffcients at time :math:`t`, + `window` is the parameter given to the function (the actual window size is 2*window+1). + + The behavior at the edges is to replicate the boundaries. + + Args: + window (int): A nonzero number of differences to use in computing delta + """ + __constants__ = ['window'] + + def __init__(self, window=2): + super(ComputeDeltas, self).__init__() + self.window = window + + @torch.jit.script_method + def forward(self, specgram): + r""" + Args: + specgram (torch.Tensor): Tensor of audio of dimension (channel, n_mfcc, time) + + Returns: + deltas (torch.Tensor): Tensor of audio of dimension (channel, n_mfcc, time) + """ + return F.compute_deltas(specgram, window=self.window) From d08af7de6f752eb1068153830301ab5742275dd5 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Fri, 6 Sep 2019 15:23:08 -0400 Subject: [PATCH 11/22] flake8 --- test/test_transforms.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_transforms.py b/test/test_transforms.py index 2d5b52045a..10c325b10f 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -291,5 +291,6 @@ def test_compute_deltas(self): computed = transform(specgram) self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) + if __name__ == '__main__': unittest.main() From c4c3756e80fe396c29703c1f17f5154f2f56ef97 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Mon, 9 Sep 2019 18:24:59 -0400 Subject: [PATCH 12/22] feedback. changing name of window to win_length. --- test/test_functional.py | 55 +++++++++++++++++++--------------------- test/test_transforms.py | 12 +++++++-- torchaudio/functional.py | 18 +++++++------ torchaudio/transforms.py | 21 +++++---------- 4 files changed, 53 insertions(+), 53 deletions(-) diff --git a/test/test_functional.py b/test/test_functional.py index fa1ce958f6..dc5923fd69 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -18,6 +18,32 @@ class TestFunctional(unittest.TestCase): data_sizes = [(2, 20), (3, 15), (4, 10)] number_of_trials = 100 + specgram = torch.tensor([1., 2., 3., 4.]) + + def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8): + computed = F.compute_deltas(specgram, win_length=3) + self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) + torch.testing.assert_allclose(computed, expected, atol=atol, rtol=rtol) + + def test_compute_deltas_onechannel(self): + specgram = self.specgram.unsqueeze(0).unsqueeze(0) + expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]]) + self._test_compute_deltas(specgram, expected) + + def test_compute_deltas_twochannel(self): + specgram = self.specgram.repeat(1, 2, 1) + expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5], + [0.5, 1.0, 1.0, 0.5]]]) + self._test_compute_deltas(specgram, expected) + + def test_compute_deltas_randn(self): + channel = 13 + n_mfcc = channel * 3 + time = 1021 + win_length = 2*7+1 + specgram = torch.randn(channel, n_mfcc, time) + computed = F.compute_deltas(specgram, win_length=win_length) + self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8): # trim sound for case when constructed signal is shorter than original @@ -192,35 +218,6 @@ def test_istft_of_sine(self): self._test_istft_of_sine(amplitude=99, L=10, n=7) -class TestDeltas(unittest.TestCase): - specgram = torch.tensor([1., 2., 3., 4.]) - - def _test(self, specgram, expected, window=1, atol=1e-6, rtol=1e-8): - computed = F.compute_deltas(specgram, window=1) - self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) - self.assertTrue(torch.allclose(computed, expected, atol=atol, rtol=rtol)) - - def test_onechannel(self): - specgram = self.specgram.unsqueeze(0).unsqueeze(0) - expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]]) - self._test(specgram, expected) - - def test_twochannel(self): - specgram = self.specgram.repeat(1, 2, 1) - expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5], - [0.5, 1.0, 1.0, 0.5]]]) - self._test(specgram, expected) - - def test_randn(self): - channel = 13 - n_mfcc = channel * 3 - time = 1021 - window = 7 - specgram = torch.randn(channel, n_mfcc, time) - computed = F.compute_deltas(specgram, window=window) - self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) - - def _num_stft_bins(signal_len, fft_len, hop_length, pad): return (signal_len + 2 * pad - fft_len + hop_length) // hop_length diff --git a/test/test_transforms.py b/test/test_transforms.py index 10c325b10f..3e88bba1a4 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -285,9 +285,17 @@ def test_compute_deltas(self): channel = 13 n_mfcc = channel * 3 time = 1021 - window = 7 + win_length = 2 * 7 + 1 specgram = torch.randn(channel, n_mfcc, time) - transform = transforms.ComputeDeltas(window=window) + transform = transforms.ComputeDeltas(win_length=win_length) + computed = transform(specgram) + self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) + + def test_compute_deltas_twochannel(self): + specgram = torch.tensor([1., 2., 3., 4.]).repeat(1, 2, 1) + expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5], + [0.5, 1.0, 1.0, 0.5]]]) + transform = transforms.ComputeDeltas() computed = transform(specgram) self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index baa78b4816..df6a06d609 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -654,22 +654,22 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): return biquad(waveform, b0, b1, b2, a0, a1, a2) -def compute_deltas(specgram, window=2): +def compute_deltas(specgram, win_length=5): # type: (Tensor, int) -> Tensor r"""Compute delta coefficients of a spectogram: .. math:: - d_t = \frac{\sum_{n=1}^{\text{window}} n (c_{t+n} - c_{t-n})}{2 \sum_{n=1}^{\text{window} n^2} + d_t = \frac{\sum_{n=1}^{\text{N}} n (c_{t+n} - c_{t-n})}{2 \sum_{n=1}^{\text{N} n^2} where :math:`d_t` is the deltas at time :math:`t`, :math:`c_t` is the spectogram coeffcients at time :math:`t`, - `window` is the parameter given to the function (the actual window size is 2*window+1). + :math:`N` is (`win_length`-1)//2. The behavior at the edges is to replicate the boundaries. Args: specgram (torch.Tensor): Tensor of audio of dimension (channel, n_mfcc, time) - window (int): A nonzero number of differences to use in computing delta + win_length (int): A nonzero number of differences to use in computing delta Returns: deltas (torch.Tensor): Tensor of audio of dimension (channel, n_mfcc, time) @@ -680,18 +680,20 @@ def compute_deltas(specgram, window=2): >>> delta2 = compute_deltas(delta) """ - assert window > 0 + assert win_length >= 3 assert specgram.dim() == 3 assert not specgram.shape[1] % specgram.shape[0] + n = (win_length - 1) // 2 + # twice sum of integer squared - denom = window * (window + 1) * (2 * window + 1) / 3 + denom = n * (n + 1) * (2 * n + 1) / 3 - specgram = torch.nn.functional.pad(specgram, (window, window), mode='replicate') + specgram = torch.nn.functional.pad(specgram, (n, n), mode='replicate') kernel = ( torch - .arange(-window, window + 1, 1, device=specgram.device, dtype=specgram.dtype) + .arange(-n, n + 1, 1, device=specgram.device, dtype=specgram.dtype) .repeat(specgram.shape[1], specgram.shape[0], 1) ) diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 72352a9be8..04fd958ce4 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -368,25 +368,18 @@ def forward(self, waveform): class ComputeDeltas(torch.jit.ScriptModule): - r"""Compute delta coefficients of a spectogram: + r"""Compute delta coefficients of a spectogram. - .. math:: - d_t = \frac{\sum_{n=1}^{\text{window}} n (c_{t+n} - c_{t-n})}{2 \sum_{n=1}^{\text{window} n^2} - - where :math:`d_t` is the deltas at time :math:`t`, - :math:`c_t` is the spectogram coeffcients at time :math:`t`, - `window` is the parameter given to the function (the actual window size is 2*window+1). - - The behavior at the edges is to replicate the boundaries. + See `torchaudio.functional.compute_deltas` for more details. Args: - window (int): A nonzero number of differences to use in computing delta + win_length (int): The window length used for computing delta """ - __constants__ = ['window'] + __constants__ = ['win_length'] - def __init__(self, window=2): + def __init__(self, win_length=5): super(ComputeDeltas, self).__init__() - self.window = window + self.win_length = win_length @torch.jit.script_method def forward(self, specgram): @@ -397,4 +390,4 @@ def forward(self, specgram): Returns: deltas (torch.Tensor): Tensor of audio of dimension (channel, n_mfcc, time) """ - return F.compute_deltas(specgram, window=self.window) + return F.compute_deltas(specgram, win_length=self.win_length) From 4730dfaef211bd7f0d611e76ee8bbc40f1de3b2e Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Mon, 9 Sep 2019 18:52:37 -0400 Subject: [PATCH 13/22] compliance update. --- test/test_compliance_kaldi.py | 4 ++-- torchaudio/compliance/kaldi.py | 6 +++--- torchaudio/functional.py | 2 +- torchaudio/transforms.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/test/test_compliance_kaldi.py b/test/test_compliance_kaldi.py index 789b273021..ff54d58506 100644 --- a/test/test_compliance_kaldi.py +++ b/test/test_compliance_kaldi.py @@ -323,10 +323,10 @@ def test_compute_deltas(self): channel = 13 n_mfcc = channel * 3 time = 1021 - window = 7 + win_length = 7 order = 1 specgram = torch.randn(channel, n_mfcc, time) - computed = kaldi.add_deltas(specgram, window=window, order=order) + computed = kaldi.add_deltas(specgram, win_length=win_length, order=order) self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) if __name__ == '__main__': diff --git a/torchaudio/compliance/kaldi.py b/torchaudio/compliance/kaldi.py index d26174b432..a83d7a1d26 100644 --- a/torchaudio/compliance/kaldi.py +++ b/torchaudio/compliance/kaldi.py @@ -868,18 +868,18 @@ def resample_waveform(waveform, orig_freq, new_freq, lowpass_filter_width=6): return output -def add_deltas(specgram, order=1, window=2): +def add_deltas(specgram, order=1, win_length=2): r"""Compute delta coefficients of given order of a spectogram. Args: specgram (torch.Tensor): Tensor of audio of dimension (channel, n_mfcc, time) order (int): A nonzero order of difference - window (int): A nonzero number of differences to use in computing delta + win_length (int): The window length used for computing delta. Returns: deltas (torch.Tensor): Tensor of audio of dimension (channel, n_mfcc, time) """ deltas = specgram for _ in range(order): - deltas = torchaudio.functional.compute_deltas(deltas, window=window) + deltas = torchaudio.functional.compute_deltas(deltas, win_length=win_length) return deltas diff --git a/torchaudio/functional.py b/torchaudio/functional.py index df6a06d609..b25876e4de 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -669,7 +669,7 @@ def compute_deltas(specgram, win_length=5): Args: specgram (torch.Tensor): Tensor of audio of dimension (channel, n_mfcc, time) - win_length (int): A nonzero number of differences to use in computing delta + win_length (int): The window length used for computing delta. Returns: deltas (torch.Tensor): Tensor of audio of dimension (channel, n_mfcc, time) diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 04fd958ce4..8a7a3ef7c9 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -373,7 +373,7 @@ class ComputeDeltas(torch.jit.ScriptModule): See `torchaudio.functional.compute_deltas` for more details. Args: - win_length (int): The window length used for computing delta + win_length (int): The window length used for computing delta. """ __constants__ = ['win_length'] From b6dbbbc8a3ec59b96bbcdf7a2efa0b5731748bf7 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Mon, 9 Sep 2019 19:11:02 -0400 Subject: [PATCH 14/22] flake8. --- test/test_functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_functional.py b/test/test_functional.py index dc5923fd69..fd732274d0 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -40,7 +40,7 @@ def test_compute_deltas_randn(self): channel = 13 n_mfcc = channel * 3 time = 1021 - win_length = 2*7+1 + win_length = 2 * 7 + 1 specgram = torch.randn(channel, n_mfcc, time) computed = F.compute_deltas(specgram, win_length=win_length) self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) From 88b4d5cda061d0aa2a31fa788c172073bc5dc9ef Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Wed, 11 Sep 2019 16:25:48 -0400 Subject: [PATCH 15/22] assert same values. --- test/test_transforms.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 3e88bba1a4..d78c5b6000 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -4,8 +4,9 @@ import torch import torchaudio -from torchaudio.common_utils import IMPORT_LIBROSA, IMPORT_SCIPY import torchaudio.transforms as transforms +import torchaudio.functional as F +from torchaudio.common_utils import IMPORT_LIBROSA, IMPORT_SCIPY import unittest import common_utils @@ -291,6 +292,19 @@ def test_compute_deltas(self): computed = transform(specgram) self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) + def test_compute_deltas_transform_same_as_functional(self, atol=1e-6, rtol=1e-8): + channel = 13 + n_mfcc = channel * 3 + time = 1021 + win_length = 2 * 7 + 1 + specgram = torch.randn(channel, n_mfcc, time) + + transform = transforms.ComputeDeltas(win_length=win_length) + computed_transform = transform(specgram) + + computed_functional = F.compute_deltas(specgram, win_length=win_length) + torch.testing.assert_allclose(computed_functional, computed_transform, atol=atol, rtol=rtol) + def test_compute_deltas_twochannel(self): specgram = torch.tensor([1., 2., 3., 4.]).repeat(1, 2, 1) expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5], From 0501b9adf2e3c7ccf10f4e9fdb3b0736848aa297 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Wed, 18 Sep 2019 14:12:17 -0400 Subject: [PATCH 16/22] remove kaldi compliance for add_deltas. --- torchaudio/compliance/kaldi.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/torchaudio/compliance/kaldi.py b/torchaudio/compliance/kaldi.py index a83d7a1d26..a379c09ce5 100644 --- a/torchaudio/compliance/kaldi.py +++ b/torchaudio/compliance/kaldi.py @@ -866,20 +866,3 @@ def resample_waveform(waveform, orig_freq, new_freq, lowpass_filter_width=6): output += dilated_conv_wave return output - - -def add_deltas(specgram, order=1, win_length=2): - r"""Compute delta coefficients of given order of a spectogram. - - Args: - specgram (torch.Tensor): Tensor of audio of dimension (channel, n_mfcc, time) - order (int): A nonzero order of difference - win_length (int): The window length used for computing delta. - - Returns: - deltas (torch.Tensor): Tensor of audio of dimension (channel, n_mfcc, time) - """ - deltas = specgram - for _ in range(order): - deltas = torchaudio.functional.compute_deltas(deltas, win_length=win_length) - return deltas From 6c8d9d5da6e9675388182c1fc0d7984964ff5078 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Wed, 18 Sep 2019 14:13:26 -0400 Subject: [PATCH 17/22] remove kaldi compliance for add_deltas. --- test/test_compliance_kaldi.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/test/test_compliance_kaldi.py b/test/test_compliance_kaldi.py index ff54d58506..b01f5e3a7e 100644 --- a/test/test_compliance_kaldi.py +++ b/test/test_compliance_kaldi.py @@ -319,15 +319,6 @@ def test_resample_waveform_multi_channel(self): single_channel_sampled = kaldi.resample_waveform(single_channel, sample_rate, sample_rate // 2) self.assertTrue(torch.allclose(multi_sound_sampled[i, :], single_channel_sampled, rtol=1e-4)) - def test_compute_deltas(self): - channel = 13 - n_mfcc = channel * 3 - time = 1021 - win_length = 7 - order = 1 - specgram = torch.randn(channel, n_mfcc, time) - computed = kaldi.add_deltas(specgram, win_length=win_length, order=order) - self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) if __name__ == '__main__': unittest.main() From c8bc3c001f99e0fb7f88bc97d1e0c9c5293a6b47 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Wed, 18 Sep 2019 17:01:14 -0400 Subject: [PATCH 18/22] typo. --- torchaudio/functional.py | 5 +++-- torchaudio/transforms.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index b25876e4de..f97852b57c 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -20,6 +20,7 @@ "biquad", ] + # TODO: remove this once https://github.com/pytorch/pytorch/issues/21478 gets solved @torch.jit.ignore def _stft( @@ -656,13 +657,13 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): def compute_deltas(specgram, win_length=5): # type: (Tensor, int) -> Tensor - r"""Compute delta coefficients of a spectogram: + r"""Compute delta coefficients of a spectrogram: .. math:: d_t = \frac{\sum_{n=1}^{\text{N}} n (c_{t+n} - c_{t-n})}{2 \sum_{n=1}^{\text{N} n^2} where :math:`d_t` is the deltas at time :math:`t`, - :math:`c_t` is the spectogram coeffcients at time :math:`t`, + :math:`c_t` is the spectrogram coeffcients at time :math:`t`, :math:`N` is (`win_length`-1)//2. The behavior at the edges is to replicate the boundaries. diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 8a7a3ef7c9..d503959002 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -368,7 +368,7 @@ def forward(self, waveform): class ComputeDeltas(torch.jit.ScriptModule): - r"""Compute delta coefficients of a spectogram. + r"""Compute delta coefficients of a spectrogram. See `torchaudio.functional.compute_deltas` for more details. From 22e9fe05a2eeff4344aa86d68cb1a7760d9ec85b Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Wed, 18 Sep 2019 17:03:56 -0400 Subject: [PATCH 19/22] passing padding mode. --- torchaudio/functional.py | 11 +++++------ torchaudio/transforms.py | 7 ++++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index f97852b57c..a05513eba3 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -655,8 +655,8 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): return biquad(waveform, b0, b1, b2, a0, a1, a2) -def compute_deltas(specgram, win_length=5): - # type: (Tensor, int) -> Tensor +def compute_deltas(specgram, win_length=5, mode="replicate"): + # type: (Tensor, int, string) -> Tensor r"""Compute delta coefficients of a spectrogram: .. math:: @@ -666,11 +666,10 @@ def compute_deltas(specgram, win_length=5): :math:`c_t` is the spectrogram coeffcients at time :math:`t`, :math:`N` is (`win_length`-1)//2. - The behavior at the edges is to replicate the boundaries. - Args: specgram (torch.Tensor): Tensor of audio of dimension (channel, n_mfcc, time) - win_length (int): The window length used for computing delta. + win_length (int): The window length used for computing delta + mode (string): Mode parameter passed to padding Returns: deltas (torch.Tensor): Tensor of audio of dimension (channel, n_mfcc, time) @@ -690,7 +689,7 @@ def compute_deltas(specgram, win_length=5): # twice sum of integer squared denom = n * (n + 1) * (2 * n + 1) / 3 - specgram = torch.nn.functional.pad(specgram, (n, n), mode='replicate') + specgram = torch.nn.functional.pad(specgram, (n, n), mode=mode) kernel = ( torch diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index d503959002..2af3fbaee6 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -375,11 +375,12 @@ class ComputeDeltas(torch.jit.ScriptModule): Args: win_length (int): The window length used for computing delta. """ - __constants__ = ['win_length'] + __constants__ = ['win_length', 'mode'] - def __init__(self, win_length=5): + def __init__(self, win_length=5, mode="replicate"): super(ComputeDeltas, self).__init__() self.win_length = win_length + self.mode = mode @torch.jit.script_method def forward(self, specgram): @@ -390,4 +391,4 @@ def forward(self, specgram): Returns: deltas (torch.Tensor): Tensor of audio of dimension (channel, n_mfcc, time) """ - return F.compute_deltas(specgram, win_length=self.win_length) + return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode) From d3f69583455b19b039d4a7a73ba92801d3ddccce Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Wed, 18 Sep 2019 17:11:01 -0400 Subject: [PATCH 20/22] typo. --- test/test_functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_functional.py b/test/test_functional.py index fd732274d0..5a5a169f86 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -21,7 +21,7 @@ class TestFunctional(unittest.TestCase): specgram = torch.tensor([1., 2., 3., 4.]) def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8): - computed = F.compute_deltas(specgram, win_length=3) + computed = F.compute_deltas(specgram, win_length=win_length) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) torch.testing.assert_allclose(computed, expected, atol=atol, rtol=rtol) From 5aee851108f257ee0f8e6495f77c60327b16043f Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Wed, 18 Sep 2019 17:38:57 -0400 Subject: [PATCH 21/22] string typo. --- torchaudio/functional.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index a05513eba3..f35786a504 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -656,7 +656,7 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): def compute_deltas(specgram, win_length=5, mode="replicate"): - # type: (Tensor, int, string) -> Tensor + # type: (Tensor, int, str) -> Tensor r"""Compute delta coefficients of a spectrogram: .. math:: @@ -669,7 +669,7 @@ def compute_deltas(specgram, win_length=5, mode="replicate"): Args: specgram (torch.Tensor): Tensor of audio of dimension (channel, n_mfcc, time) win_length (int): The window length used for computing delta - mode (string): Mode parameter passed to padding + mode (str): Mode parameter passed to padding Returns: deltas (torch.Tensor): Tensor of audio of dimension (channel, n_mfcc, time) From a403ae929561aca2fa9ac42d3eee7bf9aa218ed2 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 19 Sep 2019 10:51:06 -0400 Subject: [PATCH 22/22] fix compilation. specify the method is more general than spectrogram. --- torchaudio/functional.py | 2 +- torchaudio/transforms.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index f35786a504..547738cf9b 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -657,7 +657,7 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): def compute_deltas(specgram, win_length=5, mode="replicate"): # type: (Tensor, int, str) -> Tensor - r"""Compute delta coefficients of a spectrogram: + r"""Compute delta coefficients of a tensor, usually a spectrogram: .. math:: d_t = \frac{\sum_{n=1}^{\text{N}} n (c_{t+n} - c_{t-n})}{2 \sum_{n=1}^{\text{N} n^2} diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 2af3fbaee6..7362e7268d 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -368,19 +368,19 @@ def forward(self, waveform): class ComputeDeltas(torch.jit.ScriptModule): - r"""Compute delta coefficients of a spectrogram. + r"""Compute delta coefficients of a tensor, usually a spectrogram. See `torchaudio.functional.compute_deltas` for more details. Args: win_length (int): The window length used for computing delta. """ - __constants__ = ['win_length', 'mode'] + __constants__ = ['win_length'] def __init__(self, win_length=5, mode="replicate"): super(ComputeDeltas, self).__init__() self.win_length = win_length - self.mode = mode + self.mode = torch.jit.Attribute(mode, str) @torch.jit.script_method def forward(self, specgram):