Skip to content

Commit c90c18d

Browse files
authored
Move batch from vocoder transform to functional (#350)
* fixing errors in docstring. * move batch to functional.
1 parent c74e580 commit c90c18d

File tree

2 files changed

+32
-26
lines changed

2 files changed

+32
-26
lines changed

torchaudio/functional.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -469,13 +469,13 @@ def phase_vocoder(complex_specgrams, rate, phase_advance):
469469
factor of ``rate``.
470470
471471
Args:
472-
complex_specgrams (torch.Tensor): Dimension of `(channel, freq, time, complex=2)`
472+
complex_specgrams (torch.Tensor): Dimension of `(..., freq, time, complex=2)`
473473
rate (float): Speed-up factor
474474
phase_advance (torch.Tensor): Expected phase advance in each bin. Dimension
475475
of (freq, 1)
476476
477477
Returns:
478-
complex_specgrams_stretch (torch.Tensor): Dimension of `(channel,
478+
complex_specgrams_stretch (torch.Tensor): Dimension of `(...,
479479
freq, ceil(time/rate), complex=2)`
480480
481481
Example
@@ -490,6 +490,10 @@ def phase_vocoder(complex_specgrams, rate, phase_advance):
490490
torch.Size([2, 1025, 231, 2])
491491
"""
492492

493+
# pack batch
494+
shape = complex_specgrams.size()
495+
complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:]))
496+
493497
time_steps = torch.arange(0,
494498
complex_specgrams.size(-2),
495499
rate,
@@ -527,6 +531,9 @@ def phase_vocoder(complex_specgrams, rate, phase_advance):
527531

528532
complex_specgrams_stretch = torch.stack([real_stretch, imag_stretch], dim=-1)
529533

534+
# unpack batch
535+
complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-3] + complex_specgrams_stretch.shape[1:])
536+
530537
return complex_specgrams_stretch
531538

532539

@@ -775,6 +782,10 @@ def mask_along_axis(specgram, mask_param, mask_value, axis):
775782
torch.Tensor: Masked spectrogram of dimensions (channel, freq, time)
776783
"""
777784

785+
# pack batch
786+
shape = specgram.size()
787+
specgram = specgram.reshape([-1] + list(shape[-2:]))
788+
778789
value = torch.rand(1) * mask_param
779790
min_value = torch.rand(1) * (specgram.size(axis) - value)
780791

@@ -789,7 +800,10 @@ def mask_along_axis(specgram, mask_param, mask_value, axis):
789800
else:
790801
raise ValueError('Only Frequency and Time masking are supported')
791802

792-
return specgram
803+
# unpack batch
804+
specgram = specgram.reshape(shape[:-2] + specgram.shape[-2:])
805+
806+
return specgram.reshape(shape[:-2] + specgram.shape[-2:])
793807

794808

795809
def compute_deltas(specgram, win_length=5, mode="replicate"):

torchaudio/transforms.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -380,9 +380,9 @@ def __init__(self, power=1.0):
380380
def forward(self, complex_tensor):
381381
r"""
382382
Args:
383-
complex_tensor (Tensor): Tensor shape of `(*, complex=2)`
383+
complex_tensor (Tensor): Tensor shape of `(..., complex=2)`
384384
Returns:
385-
Tensor: norm of the input tensor, shape of `(*, )`
385+
Tensor: norm of the input tensor, shape of `(..., )`
386386
"""
387387
return F.complex_norm(complex_tensor, self.power)
388388

@@ -438,14 +438,14 @@ def forward(self, complex_specgrams, overriding_rate=None):
438438
# type: (Tensor, Optional[float]) -> Tensor
439439
r"""
440440
Args:
441-
complex_specgrams (Tensor): complex spectrogram (*, channel, freq, time, complex=2)
441+
complex_specgrams (Tensor): complex spectrogram (..., freq, time, complex=2)
442442
overriding_rate (float or None): speed up to apply to this batch.
443443
If no rate is passed, use ``self.fixed_rate``
444444
445445
Returns:
446-
(Tensor): Stretched complex spectrogram of dimension (*, channel, freq, ceil(time/rate), complex=2)
446+
(Tensor): Stretched complex spectrogram of dimension (..., freq, ceil(time/rate), complex=2)
447447
"""
448-
assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (*, complex=2)"
448+
assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (..., complex=2)"
449449

450450
if overriding_rate is None:
451451
rate = self.fixed_rate
@@ -458,16 +458,12 @@ def forward(self, complex_specgrams, overriding_rate=None):
458458
if rate == 1.0:
459459
return complex_specgrams
460460

461-
shape = complex_specgrams.size()
462-
complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:]))
463-
complex_specgrams = F.phase_vocoder(complex_specgrams, rate, self.phase_advance)
464-
465-
return complex_specgrams.reshape(shape[:-3] + complex_specgrams.shape[-3:])
461+
return F.phase_vocoder(complex_specgrams, rate, self.phase_advance)
466462

467463

468464
class _AxisMasking(torch.nn.Module):
469-
r"""
470-
Apply masking to a spectrogram.
465+
r"""Apply masking to a spectrogram.
466+
471467
Args:
472468
mask_param (int): Maximum possible length of the mask
473469
axis: What dimension the mask is applied on
@@ -486,26 +482,22 @@ def forward(self, specgram, mask_value=0.):
486482
# type: (Tensor, float) -> Tensor
487483
r"""
488484
Args:
489-
specgram (torch.Tensor): Tensor of dimension (*, channel, freq, time)
485+
specgram (torch.Tensor): Tensor of dimension (..., freq, time)
490486
491487
Returns:
492-
torch.Tensor: Masked spectrogram of dimensions (*, channel, freq, time)
488+
torch.Tensor: Masked spectrogram of dimensions (..., freq, time)
493489
"""
494490

495491
# if iid_masks flag marked and specgram has a batch dimension
496492
if self.iid_masks and specgram.dim() == 4:
497493
return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1)
498494
else:
499-
shape = specgram.size()
500-
specgram = specgram.reshape([-1] + list(shape[-2:]))
501-
specgram = F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis)
502-
503-
return specgram.reshape(shape[:-2] + specgram.shape[-2:])
495+
return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis)
504496

505497

506498
class FrequencyMasking(_AxisMasking):
507-
r"""
508-
Apply masking to a spectrogram in the frequency domain.
499+
r"""Apply masking to a spectrogram in the frequency domain.
500+
509501
Args:
510502
freq_mask_param (int): maximum possible length of the mask.
511503
Indices uniformly sampled from [0, freq_mask_param).
@@ -518,8 +510,8 @@ def __init__(self, freq_mask_param, iid_masks=False):
518510

519511

520512
class TimeMasking(_AxisMasking):
521-
r"""
522-
Apply masking to a spectrogram in the time domain.
513+
r"""Apply masking to a spectrogram in the time domain.
514+
523515
Args:
524516
time_mask_param (int): maximum possible length of the mask.
525517
Indices uniformly sampled from [0, time_mask_param).

0 commit comments

Comments
 (0)