@@ -835,21 +835,14 @@ def mask_along_axis_iid(specgrams, mask_param, mask_value, axis):
835835 axis (int): Axis to apply masking on (2 -> frequency, 3 -> time)
836836
837837 Returns:
838- torch.Tensor: Masked spectrograms of dimensions (... , channel, freq, time)
838+ torch.Tensor: Masked spectrograms of dimensions (batch , channel, freq, time)
839839 """
840840
841- # pack batch
842- shape = specgrams .size ()
843- specgrams = specgrams .reshape ([- 1 ] + list (shape [- 3 :]))
844-
845841 if axis != 2 and axis != 3 :
846842 raise ValueError ('Only Frequency and Time masking are supported' )
847843
848- # Shift so as to start from the end
849- axis -= 4
850-
851- value = torch .rand (specgrams .shape [:- 2 ]) * mask_param
852- min_value = torch .rand (specgrams .shape [:- 2 ]) * (specgrams .size (axis ) - value )
844+ value = torch .rand (specgrams .shape [:2 ]) * mask_param
845+ min_value = torch .rand (specgrams .shape [:2 ]) * (specgrams .size (axis ) - value )
853846
854847 # Create broadcastable mask
855848 mask_start = (min_value .long ())[..., None , None ].float ()
@@ -861,9 +854,6 @@ def mask_along_axis_iid(specgrams, mask_param, mask_value, axis):
861854 specgrams .masked_fill_ ((mask >= mask_start ) & (mask < mask_end ), mask_value )
862855 specgrams = specgrams .transpose (axis , - 1 )
863856
864- # unpack batch
865- specgrams = specgrams .reshape (shape [:- 3 ] + specgrams .shape [- 3 :])
866-
867857 return specgrams
868858
869859
@@ -875,19 +865,15 @@ def mask_along_axis(specgram, mask_param, mask_value, axis):
875865 All examples will have the same mask interval.
876866
877867 Args:
878- specgram (Tensor): Real spectrogram (..., channel, freq, time)
868+ specgram (Tensor): Real spectrogram (channel, freq, time)
879869 mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
880870 mask_value (float): Value to assign to the masked columns
881871 axis (int): Axis to apply masking on (1 -> frequency, 2 -> time)
882872
883873 Returns:
884- torch.Tensor: Masked spectrogram of dimensions (..., channel, freq, time)
874+ torch.Tensor: Masked spectrogram of dimensions (channel, freq, time)
885875 """
886876
887- # pack batch
888- shape = specgram .size ()
889- specgram = specgram .reshape ([- 1 ] + list (shape [- 3 :]))
890-
891877 value = torch .rand (1 ) * mask_param
892878 min_value = torch .rand (1 ) * (specgram .size (axis ) - value )
893879
@@ -902,9 +888,6 @@ def mask_along_axis(specgram, mask_param, mask_value, axis):
902888 else :
903889 raise ValueError ('Only Frequency and Time masking are supported' )
904890
905- # unpack batch
906- specgram = specgram .reshape (shape [:- 3 ] + specgram .shape [- 3 :])
907-
908891 return specgram
909892
910893
0 commit comments