@@ -380,9 +380,9 @@ def __init__(self, power=1.0):
380380 def forward (self , complex_tensor ):
381381 r"""
382382 Args:
383- complex_tensor (Tensor): Tensor shape of `(* , complex=2)`
383+ complex_tensor (Tensor): Tensor shape of `(... , complex=2)`
384384 Returns:
385- Tensor: norm of the input tensor, shape of `(* , )`
385+ Tensor: norm of the input tensor, shape of `(... , )`
386386 """
387387 return F .complex_norm (complex_tensor , self .power )
388388
@@ -438,14 +438,14 @@ def forward(self, complex_specgrams, overriding_rate=None):
438438 # type: (Tensor, Optional[float]) -> Tensor
439439 r"""
440440 Args:
441- complex_specgrams (Tensor): complex spectrogram (*, channel , freq, time, complex=2)
441+ complex_specgrams (Tensor): complex spectrogram (... , freq, time, complex=2)
442442 overriding_rate (float or None): speed up to apply to this batch.
443443 If no rate is passed, use ``self.fixed_rate``
444444
445445 Returns:
446- (Tensor): Stretched complex spectrogram of dimension (*, channel , freq, ceil(time/rate), complex=2)
446+ (Tensor): Stretched complex spectrogram of dimension (... , freq, ceil(time/rate), complex=2)
447447 """
448- assert complex_specgrams .size (- 1 ) == 2 , "complex_specgrams should be a complex tensor, shape (* , complex=2)"
448+ assert complex_specgrams .size (- 1 ) == 2 , "complex_specgrams should be a complex tensor, shape (... , complex=2)"
449449
450450 if overriding_rate is None :
451451 rate = self .fixed_rate
@@ -458,16 +458,12 @@ def forward(self, complex_specgrams, overriding_rate=None):
458458 if rate == 1.0 :
459459 return complex_specgrams
460460
461- shape = complex_specgrams .size ()
462- complex_specgrams = complex_specgrams .reshape ([- 1 ] + list (shape [- 3 :]))
463- complex_specgrams = F .phase_vocoder (complex_specgrams , rate , self .phase_advance )
464-
465- return complex_specgrams .reshape (shape [:- 3 ] + complex_specgrams .shape [- 3 :])
461+ return F .phase_vocoder (complex_specgrams , rate , self .phase_advance )
466462
467463
468464class _AxisMasking (torch .nn .Module ):
469- r"""
470- Apply masking to a spectrogram.
465+ r"""Apply masking to a spectrogram.
466+
471467 Args:
472468 mask_param (int): Maximum possible length of the mask
473469 axis: What dimension the mask is applied on
@@ -486,26 +482,22 @@ def forward(self, specgram, mask_value=0.):
486482 # type: (Tensor, float) -> Tensor
487483 r"""
488484 Args:
489- specgram (torch.Tensor): Tensor of dimension (*, channel , freq, time)
485+ specgram (torch.Tensor): Tensor of dimension (... , freq, time)
490486
491487 Returns:
492- torch.Tensor: Masked spectrogram of dimensions (*, channel , freq, time)
488+ torch.Tensor: Masked spectrogram of dimensions (... , freq, time)
493489 """
494490
495491 # if iid_masks flag marked and specgram has a batch dimension
496492 if self .iid_masks and specgram .dim () == 4 :
497493 return F .mask_along_axis_iid (specgram , self .mask_param , mask_value , self .axis + 1 )
498494 else :
499- shape = specgram .size ()
500- specgram = specgram .reshape ([- 1 ] + list (shape [- 2 :]))
501- specgram = F .mask_along_axis (specgram , self .mask_param , mask_value , self .axis )
502-
503- return specgram .reshape (shape [:- 2 ] + specgram .shape [- 2 :])
495+ return F .mask_along_axis (specgram , self .mask_param , mask_value , self .axis )
504496
505497
506498class FrequencyMasking (_AxisMasking ):
507- r"""
508- Apply masking to a spectrogram in the frequency domain.
499+ r"""Apply masking to a spectrogram in the frequency domain.
500+
509501 Args:
510502 freq_mask_param (int): maximum possible length of the mask.
511503 Indices uniformly sampled from [0, freq_mask_param).
@@ -518,8 +510,8 @@ def __init__(self, freq_mask_param, iid_masks=False):
518510
519511
520512class TimeMasking (_AxisMasking ):
521- r"""
522- Apply masking to a spectrogram in the time domain.
513+ r"""Apply masking to a spectrogram in the time domain.
514+
523515 Args:
524516 time_mask_param (int): maximum possible length of the mask.
525517 Indices uniformly sampled from [0, time_mask_param).
0 commit comments