Skip to content

Commit 66f0023

Browse files
committed
attempt at batch for mask.
1 parent a45e619 commit 66f0023

File tree

2 files changed

+19
-6
lines changed

2 files changed

+19
-6
lines changed

test/test_functional.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,6 @@ def _test_batch_shape(self, functional, tensor, *args, **kwargs):
511511
atol = kwargs['atol']
512512
del kwargs['atol']
513513
kwargs_compare['atol'] = atol
514-
print(kwargs)
515514

516515
if 'rtol' in kwargs:
517516
rtol = kwargs['rtol']
@@ -520,13 +519,16 @@ def _test_batch_shape(self, functional, tensor, *args, **kwargs):
520519

521520
# Single then transform then batch
522521

523-
expected = functional(tensor, *args, **kwargs)
522+
torch.random.manual_seed(42)
523+
expected = functional(tensor.clone(), *args, **kwargs)
524524
expected = expected.unsqueeze(0).unsqueeze(0)
525525

526526
# 1-Batch then transform
527527

528528
tensors = tensor.unsqueeze(0).unsqueeze(0)
529-
computed = functional(tensors, *args, **kwargs)
529+
530+
torch.random.manual_seed(42)
531+
computed = functional(tensors.clone(), *args, **kwargs)
530532

531533
self._compare_estimate(computed, expected, **kwargs_compare)
532534

@@ -555,19 +557,30 @@ def _test_batch(self, functional, tensor, *args, **kwargs):
555557
ind = [3] + [1] * (int(expected.dim()) - 1)
556558
expected = expected.repeat(*ind)
557559

558-
computed = functional(tensors, *args, **kwargs)
560+
torch.random.manual_seed(42)
561+
computed = functional(tensors.clone(), *args, **kwargs)
559562

560563
self._compare_estimate(computed, expected, **kwargs_compare)
561564

562565
def test_batch_mask_along_axis_iid(self):
563566

567+
mask_param = 2
568+
mask_value = 30.
569+
axis = 2
570+
571+
tensor = torch.rand(2, 5, 5)
572+
573+
self._test_batch(F.mask_along_axis_iid, tensor, mask_param=mask_param, mask_value=mask_value, axis=axis, atol=1e-1, rtol=1e-1)
574+
575+
def test_batch_mask_along_axis(self):
576+
564577
tensor = torch.rand(2, 5, 5)
565578

566579
mask_param = 2
567580
mask_value = 30.
568581
axis = 2
569582

570-
self._test_batch_shape(F.mask_along_axis_iid, tensor, mask_param=mask_param, mask_value=mask_value, axis=axis)
583+
self._test_batch(F.mask_along_axis, tensor, mask_param=mask_param, mask_value=mask_value, axis=axis)
571584

572585
def test_torchscript_create_fb_matrix(self):
573586

torchaudio/functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -875,7 +875,7 @@ def mask_along_axis(specgram, mask_param, mask_value, axis):
875875
All examples will have the same mask interval.
876876
877877
Args:
878-
specgram (Tensor): Real spectrogram (..., freq, time)
878+
specgram (Tensor): Real spectrogram (..., channel, freq, time)
879879
mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
880880
mask_value (float): Value to assign to the masked columns
881881
axis (int): Axis to apply masking on (1 -> frequency, 2 -> time)

0 commit comments

Comments
 (0)