Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions test/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def test_mask_along_axis(self):
specgram = torch.randn(2, 1025, 400)
mask_param = 100
mask_value = 30.
axis = 2
axis = -1

_test_torchscript_functional(F.mask_along_axis, specgram, mask_param, mask_value, axis)

Expand All @@ -569,7 +569,7 @@ def test_mask_along_axis_iid(self):
specgrams = torch.randn(4, 2, 1025, 400)
mask_param = 100
mask_value = 30.
axis = 2
axis = -2

_test_torchscript_functional(F.mask_along_axis_iid, specgrams, mask_param, mask_value, axis)

Expand Down Expand Up @@ -651,7 +651,7 @@ def test_complex_norm(complex_tensor, power):
])
@pytest.mark.parametrize('mask_param', [100])
@pytest.mark.parametrize('mask_value', [0., 30.])
@pytest.mark.parametrize('axis', [1, 2])
@pytest.mark.parametrize('axis', [-2, -1])
def test_mask_along_axis(specgram, mask_param, mask_value, axis):

mask_specgram = F.mask_along_axis(specgram, mask_param, mask_value, axis)
Expand All @@ -671,7 +671,7 @@ def test_mask_along_axis(specgram, mask_param, mask_value, axis):
])
@pytest.mark.parametrize('mask_param', [100])
@pytest.mark.parametrize('mask_value', [0., 30.])
@pytest.mark.parametrize('axis', [2, 3])
@pytest.mark.parametrize('axis', [-2, -1])
def test_mask_along_axis_iid(specgrams, mask_param, mask_value, axis):

mask_specgrams = F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis)
Expand Down
10 changes: 5 additions & 5 deletions torchaudio/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,13 +823,13 @@ def mask_along_axis_iid(specgrams, mask_param, mask_value, axis):
specgrams (Tensor): Real spectrograms (batch, channel, freq, time)
mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
mask_value (float): Value to assign to the masked columns
axis (int): Axis to apply masking on (2 -> frequency, 3 -> time)
axis (int): Axis to apply masking on (-2 -> frequency, -1 -> time)

Returns:
torch.Tensor: Masked spectrograms of dimensions (batch, channel, freq, time)
"""

if axis != 2 and axis != 3:
if axis != -2 and axis != -1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
if axis != -2 and axis != -1:
if axis not in [-1, -2]:

raise ValueError('Only Frequency and Time masking are supported')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel it a little bit too strict to disallow positive indices, but as long as it is explained in doctoring, I think it's fair.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh wait, isn't this BC breaking?


value = torch.rand(specgrams.shape[:2]) * mask_param
Expand Down Expand Up @@ -859,7 +859,7 @@ def mask_along_axis(specgram, mask_param, mask_value, axis):
specgram (Tensor): Real spectrogram (channel, freq, time)
mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
mask_value (float): Value to assign to the masked columns
axis (int): Axis to apply masking on (1 -> frequency, 2 -> time)
axis (int): Axis to apply masking on (-2 -> frequency, -1 -> time)

Returns:
torch.Tensor: Masked spectrogram of dimensions (channel, freq, time)
Expand All @@ -876,9 +876,9 @@ def mask_along_axis(specgram, mask_param, mask_value, axis):
mask_end = (min_value.long() + value.long()).squeeze()

assert mask_end - mask_start < mask_param
if axis == 1:
if axis == -2:
specgram[:, mask_start:mask_end] = mask_value
elif axis == 2:
elif axis == -1:
specgram[:, :, mask_start:mask_end] = mask_value
else:
raise ValueError('Only Frequency and Time masking are supported')
Expand Down
6 changes: 3 additions & 3 deletions torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ def forward(self, specgram, mask_value=0.):

# if iid_masks flag marked and specgram has a batch dimension
if self.iid_masks and specgram.dim() == 4:
return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1)
return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis)
else:
return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis)

Expand All @@ -572,7 +572,7 @@ class FrequencyMasking(_AxisMasking):
"""

def __init__(self, freq_mask_param, iid_masks=False):
super(FrequencyMasking, self).__init__(freq_mask_param, 1, iid_masks)
super(FrequencyMasking, self).__init__(freq_mask_param, -2, iid_masks)


class TimeMasking(_AxisMasking):
Expand All @@ -586,4 +586,4 @@ class TimeMasking(_AxisMasking):
"""

def __init__(self, time_mask_param, iid_masks=False):
super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks)
super(TimeMasking, self).__init__(time_mask_param, -1, iid_masks)