From 15c730308b1246e49a3430e013a7007df381e2bc Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Fri, 10 Jan 2020 10:42:36 -0500 Subject: [PATCH 1/3] standardizing freq/time axis. --- torchaudio/functional.py | 14 +++++++------- torchaudio/transforms.py | 6 +++--- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 9538d70ed5..2690d04907 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -820,16 +820,16 @@ def mask_along_axis_iid(specgrams, mask_param, mask_value, axis): All examples will have the same mask interval. Args: - specgrams (Tensor): Real spectrograms (batch, channel, freq, time) + specgrams (Tensor): Real spectrograms (..., freq, time) mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param] mask_value (float): Value to assign to the masked columns - axis (int): Axis to apply masking on (2 -> frequency, 3 -> time) + axis (int): Axis to apply masking on (-2 -> frequency, -1 -> time) Returns: - torch.Tensor: Masked spectrograms of dimensions (batch, channel, freq, time) + torch.Tensor: Masked spectrograms of dimensions (..., freq, time) """ - if axis != 2 and axis != 3: + if axis != -2 and axis != -1: raise ValueError('Only Frequency and Time masking are supported') value = torch.rand(specgrams.shape[:2]) * mask_param @@ -859,7 +859,7 @@ def mask_along_axis(specgram, mask_param, mask_value, axis): specgram (Tensor): Real spectrogram (channel, freq, time) mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param] mask_value (float): Value to assign to the masked columns - axis (int): Axis to apply masking on (1 -> frequency, 2 -> time) + axis (int): Axis to apply masking on (-2 -> frequency, -1 -> time) Returns: torch.Tensor: Masked spectrogram of dimensions (channel, freq, time) @@ -876,9 +876,9 @@ def mask_along_axis(specgram, mask_param, mask_value, axis): mask_end = (min_value.long() + value.long()).squeeze() assert mask_end - mask_start < mask_param - if axis == 1: + if axis == -2: specgram[:, mask_start:mask_end] = mask_value - elif axis == 2: + elif axis == -1: specgram[:, :, mask_start:mask_end] = mask_value else: raise ValueError('Only Frequency and Time masking are supported') diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 92b336cf92..57453dcce6 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -556,7 +556,7 @@ def forward(self, specgram, mask_value=0.): # if iid_masks flag marked and specgram has a batch dimension if self.iid_masks and specgram.dim() == 4: - return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1) + return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis) else: return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis) @@ -572,7 +572,7 @@ class FrequencyMasking(_AxisMasking): """ def __init__(self, freq_mask_param, iid_masks=False): - super(FrequencyMasking, self).__init__(freq_mask_param, 1, iid_masks) + super(FrequencyMasking, self).__init__(freq_mask_param, -2, iid_masks) class TimeMasking(_AxisMasking): @@ -586,4 +586,4 @@ class TimeMasking(_AxisMasking): """ def __init__(self, time_mask_param, iid_masks=False): - super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks) + super(TimeMasking, self).__init__(time_mask_param, -1, iid_masks) From 6f166576272e5c19401cf60414fa375504fa531a Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Fri, 10 Jan 2020 11:00:55 -0500 Subject: [PATCH 2/3] update test. --- test/test_functional.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_functional.py b/test/test_functional.py index 43aa885506..1125427c0c 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -560,7 +560,7 @@ def test_mask_along_axis(self): specgram = torch.randn(2, 1025, 400) mask_param = 100 mask_value = 30. - axis = 2 + axis = -1 _test_torchscript_functional(F.mask_along_axis, specgram, mask_param, mask_value, axis) @@ -569,7 +569,7 @@ def test_mask_along_axis_iid(self): specgrams = torch.randn(4, 2, 1025, 400) mask_param = 100 mask_value = 30. - axis = 2 + axis = -2 _test_torchscript_functional(F.mask_along_axis_iid, specgrams, mask_param, mask_value, axis) @@ -651,7 +651,7 @@ def test_complex_norm(complex_tensor, power): ]) @pytest.mark.parametrize('mask_param', [100]) @pytest.mark.parametrize('mask_value', [0., 30.]) -@pytest.mark.parametrize('axis', [1, 2]) +@pytest.mark.parametrize('axis', [-2, -1]) def test_mask_along_axis(specgram, mask_param, mask_value, axis): mask_specgram = F.mask_along_axis(specgram, mask_param, mask_value, axis) @@ -671,7 +671,7 @@ def test_mask_along_axis(specgram, mask_param, mask_value, axis): ]) @pytest.mark.parametrize('mask_param', [100]) @pytest.mark.parametrize('mask_value', [0., 30.]) -@pytest.mark.parametrize('axis', [2, 3]) +@pytest.mark.parametrize('axis', [-2, -1]) def test_mask_along_axis_iid(specgrams, mask_param, mask_value, axis): mask_specgrams = F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis) From a1ea7dfcea8ab8602209a39e099b6fe22537b017 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Fri, 10 Jan 2020 17:14:11 -0500 Subject: [PATCH 3/3] keep batch, channel, in docstring. --- torchaudio/functional.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 2690d04907..68ff6662a0 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -820,13 +820,13 @@ def mask_along_axis_iid(specgrams, mask_param, mask_value, axis): All examples will have the same mask interval. Args: - specgrams (Tensor): Real spectrograms (..., freq, time) + specgrams (Tensor): Real spectrograms (batch, channel, freq, time) mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param] mask_value (float): Value to assign to the masked columns axis (int): Axis to apply masking on (-2 -> frequency, -1 -> time) Returns: - torch.Tensor: Masked spectrograms of dimensions (..., freq, time) + torch.Tensor: Masked spectrograms of dimensions (batch, channel, freq, time) """ if axis != -2 and axis != -1: