Skip to content

Commit bcdefb7

Browse files
committed
remove batch for mask.
1 parent 66f0023 commit bcdefb7

File tree

2 files changed

+5
-44
lines changed

2 files changed

+5
-44
lines changed

test/test_functional.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -560,28 +560,6 @@ def _test_batch(self, functional, tensor, *args, **kwargs):
560560
torch.random.manual_seed(42)
561561
computed = functional(tensors.clone(), *args, **kwargs)
562562

563-
self._compare_estimate(computed, expected, **kwargs_compare)
564-
565-
def test_batch_mask_along_axis_iid(self):
566-
567-
mask_param = 2
568-
mask_value = 30.
569-
axis = 2
570-
571-
tensor = torch.rand(2, 5, 5)
572-
573-
self._test_batch(F.mask_along_axis_iid, tensor, mask_param=mask_param, mask_value=mask_value, axis=axis, atol=1e-1, rtol=1e-1)
574-
575-
def test_batch_mask_along_axis(self):
576-
577-
tensor = torch.rand(2, 5, 5)
578-
579-
mask_param = 2
580-
mask_value = 30.
581-
axis = 2
582-
583-
self._test_batch(F.mask_along_axis, tensor, mask_param=mask_param, mask_value=mask_value, axis=axis)
584-
585563
def test_torchscript_create_fb_matrix(self):
586564

587565
n_stft = 100

torchaudio/functional.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -835,21 +835,14 @@ def mask_along_axis_iid(specgrams, mask_param, mask_value, axis):
835835
axis (int): Axis to apply masking on (2 -> frequency, 3 -> time)
836836
837837
Returns:
838-
torch.Tensor: Masked spectrograms of dimensions (..., channel, freq, time)
838+
torch.Tensor: Masked spectrograms of dimensions (batch, channel, freq, time)
839839
"""
840840

841-
# pack batch
842-
shape = specgrams.size()
843-
specgrams = specgrams.reshape([-1] + list(shape[-3:]))
844-
845841
if axis != 2 and axis != 3:
846842
raise ValueError('Only Frequency and Time masking are supported')
847843

848-
# Shift so as to start from the end
849-
axis -= 4
850-
851-
value = torch.rand(specgrams.shape[:-2]) * mask_param
852-
min_value = torch.rand(specgrams.shape[:-2]) * (specgrams.size(axis) - value)
844+
value = torch.rand(specgrams.shape[:2]) * mask_param
845+
min_value = torch.rand(specgrams.shape[:2]) * (specgrams.size(axis) - value)
853846

854847
# Create broadcastable mask
855848
mask_start = (min_value.long())[..., None, None].float()
@@ -861,9 +854,6 @@ def mask_along_axis_iid(specgrams, mask_param, mask_value, axis):
861854
specgrams.masked_fill_((mask >= mask_start) & (mask < mask_end), mask_value)
862855
specgrams = specgrams.transpose(axis, -1)
863856

864-
# unpack batch
865-
specgrams = specgrams.reshape(shape[:-3] + specgrams.shape[-3:])
866-
867857
return specgrams
868858

869859

@@ -875,19 +865,15 @@ def mask_along_axis(specgram, mask_param, mask_value, axis):
875865
All examples will have the same mask interval.
876866
877867
Args:
878-
specgram (Tensor): Real spectrogram (..., channel, freq, time)
868+
specgram (Tensor): Real spectrogram (channel, freq, time)
879869
mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
880870
mask_value (float): Value to assign to the masked columns
881871
axis (int): Axis to apply masking on (1 -> frequency, 2 -> time)
882872
883873
Returns:
884-
torch.Tensor: Masked spectrogram of dimensions (..., channel, freq, time)
874+
torch.Tensor: Masked spectrogram of dimensions (channel, freq, time)
885875
"""
886876

887-
# pack batch
888-
shape = specgram.size()
889-
specgram = specgram.reshape([-1] + list(shape[-3:]))
890-
891877
value = torch.rand(1) * mask_param
892878
min_value = torch.rand(1) * (specgram.size(axis) - value)
893879

@@ -902,9 +888,6 @@ def mask_along_axis(specgram, mask_param, mask_value, axis):
902888
else:
903889
raise ValueError('Only Frequency and Time masking are supported')
904890

905-
# unpack batch
906-
specgram = specgram.reshape(shape[:-3] + specgram.shape[-3:])
907-
908891
return specgram
909892

910893

0 commit comments

Comments
 (0)