diff --git a/test/test_functional.py b/test/test_functional.py index 43aa885506..1125427c0c 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -560,7 +560,7 @@ def test_mask_along_axis(self): specgram = torch.randn(2, 1025, 400) mask_param = 100 mask_value = 30. - axis = 2 + axis = -1 _test_torchscript_functional(F.mask_along_axis, specgram, mask_param, mask_value, axis) @@ -569,7 +569,7 @@ def test_mask_along_axis_iid(self): specgrams = torch.randn(4, 2, 1025, 400) mask_param = 100 mask_value = 30. - axis = 2 + axis = -2 _test_torchscript_functional(F.mask_along_axis_iid, specgrams, mask_param, mask_value, axis) @@ -651,7 +651,7 @@ def test_complex_norm(complex_tensor, power): ]) @pytest.mark.parametrize('mask_param', [100]) @pytest.mark.parametrize('mask_value', [0., 30.]) -@pytest.mark.parametrize('axis', [1, 2]) +@pytest.mark.parametrize('axis', [-2, -1]) def test_mask_along_axis(specgram, mask_param, mask_value, axis): mask_specgram = F.mask_along_axis(specgram, mask_param, mask_value, axis) @@ -671,7 +671,7 @@ def test_mask_along_axis(specgram, mask_param, mask_value, axis): ]) @pytest.mark.parametrize('mask_param', [100]) @pytest.mark.parametrize('mask_value', [0., 30.]) -@pytest.mark.parametrize('axis', [2, 3]) +@pytest.mark.parametrize('axis', [-2, -1]) def test_mask_along_axis_iid(specgrams, mask_param, mask_value, axis): mask_specgrams = F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 9538d70ed5..68ff6662a0 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -823,13 +823,13 @@ def mask_along_axis_iid(specgrams, mask_param, mask_value, axis): specgrams (Tensor): Real spectrograms (batch, channel, freq, time) mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param] mask_value (float): Value to assign to the masked columns - axis (int): Axis to apply masking on (2 -> frequency, 3 -> time) + axis (int): Axis to apply masking on (-2 -> frequency, -1 -> time) Returns: torch.Tensor: Masked spectrograms of dimensions (batch, channel, freq, time) """ - if axis != 2 and axis != 3: + if axis != -2 and axis != -1: raise ValueError('Only Frequency and Time masking are supported') value = torch.rand(specgrams.shape[:2]) * mask_param @@ -859,7 +859,7 @@ def mask_along_axis(specgram, mask_param, mask_value, axis): specgram (Tensor): Real spectrogram (channel, freq, time) mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param] mask_value (float): Value to assign to the masked columns - axis (int): Axis to apply masking on (1 -> frequency, 2 -> time) + axis (int): Axis to apply masking on (-2 -> frequency, -1 -> time) Returns: torch.Tensor: Masked spectrogram of dimensions (channel, freq, time) @@ -876,9 +876,9 @@ def mask_along_axis(specgram, mask_param, mask_value, axis): mask_end = (min_value.long() + value.long()).squeeze() assert mask_end - mask_start < mask_param - if axis == 1: + if axis == -2: specgram[:, mask_start:mask_end] = mask_value - elif axis == 2: + elif axis == -1: specgram[:, :, mask_start:mask_end] = mask_value else: raise ValueError('Only Frequency and Time masking are supported') diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 92b336cf92..57453dcce6 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -556,7 +556,7 @@ def forward(self, specgram, mask_value=0.): # if iid_masks flag marked and specgram has a batch dimension if self.iid_masks and specgram.dim() == 4: - return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1) + return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis) else: return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis) @@ -572,7 +572,7 @@ class FrequencyMasking(_AxisMasking): """ def __init__(self, freq_mask_param, iid_masks=False): - super(FrequencyMasking, self).__init__(freq_mask_param, 1, iid_masks) + super(FrequencyMasking, self).__init__(freq_mask_param, -2, iid_masks) class TimeMasking(_AxisMasking): @@ -586,4 +586,4 @@ class TimeMasking(_AxisMasking): """ def __init__(self, time_mask_param, iid_masks=False): - super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks) + super(TimeMasking, self).__init__(time_mask_param, -1, iid_masks)