diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index a1309fb21a..2114c99aaf 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -736,7 +736,7 @@ def mask_along_axis_iid( Tensor: Masked spectrograms of dimensions (batch, channel, freq, time) """ - if axis != 2 and axis != 3: + if axis not in [2, 3]: raise ValueError('Only Frequency and Time masking are supported') device = specgrams.device @@ -778,7 +778,7 @@ def mask_along_axis( Returns: Tensor: Masked spectrogram of dimensions (channel, freq, time) """ - if axis != 1 and axis != 2: + if axis not in [1, 2]: raise ValueError('Only Frequency and Time masking are supported') # pack batch