Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions torchaudio/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,13 +469,13 @@ def phase_vocoder(complex_specgrams, rate, phase_advance):
factor of ``rate``.

Args:
complex_specgrams (torch.Tensor): Dimension of `(channel, freq, time, complex=2)`
complex_specgrams (torch.Tensor): Dimension of `(..., freq, time, complex=2)`
rate (float): Speed-up factor
phase_advance (torch.Tensor): Expected phase advance in each bin. Dimension
of (freq, 1)

Returns:
complex_specgrams_stretch (torch.Tensor): Dimension of `(channel,
complex_specgrams_stretch (torch.Tensor): Dimension of `(...,
freq, ceil(time/rate), complex=2)`

Example
Expand All @@ -490,6 +490,10 @@ def phase_vocoder(complex_specgrams, rate, phase_advance):
torch.Size([2, 1025, 231, 2])
"""

# pack batch
shape = complex_specgrams.size()
complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:]))

time_steps = torch.arange(0,
complex_specgrams.size(-2),
rate,
Expand Down Expand Up @@ -527,6 +531,9 @@ def phase_vocoder(complex_specgrams, rate, phase_advance):

complex_specgrams_stretch = torch.stack([real_stretch, imag_stretch], dim=-1)

# unpack batch
complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-3] + complex_specgrams_stretch.shape[1:])

return complex_specgrams_stretch


Expand Down Expand Up @@ -775,6 +782,10 @@ def mask_along_axis(specgram, mask_param, mask_value, axis):
torch.Tensor: Masked spectrogram of dimensions (channel, freq, time)
"""

# pack batch
shape = specgram.size()
specgram = specgram.reshape([-1] + list(shape[-2:]))

value = torch.rand(1) * mask_param
min_value = torch.rand(1) * (specgram.size(axis) - value)

Expand All @@ -789,7 +800,10 @@ def mask_along_axis(specgram, mask_param, mask_value, axis):
else:
raise ValueError('Only Frequency and Time masking are supported')

return specgram
# unpack batch
specgram = specgram.reshape(shape[:-2] + specgram.shape[-2:])

return specgram.reshape(shape[:-2] + specgram.shape[-2:])


def compute_deltas(specgram, win_length=5, mode="replicate"):
Expand Down
38 changes: 15 additions & 23 deletions torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,9 +380,9 @@ def __init__(self, power=1.0):
def forward(self, complex_tensor):
r"""
Args:
complex_tensor (Tensor): Tensor shape of `(*, complex=2)`
complex_tensor (Tensor): Tensor shape of `(..., complex=2)`
Returns:
Tensor: norm of the input tensor, shape of `(*, )`
Tensor: norm of the input tensor, shape of `(..., )`
"""
return F.complex_norm(complex_tensor, self.power)

Expand Down Expand Up @@ -438,14 +438,14 @@ def forward(self, complex_specgrams, overriding_rate=None):
# type: (Tensor, Optional[float]) -> Tensor
r"""
Args:
complex_specgrams (Tensor): complex spectrogram (*, channel, freq, time, complex=2)
complex_specgrams (Tensor): complex spectrogram (..., freq, time, complex=2)
overriding_rate (float or None): speed up to apply to this batch.
If no rate is passed, use ``self.fixed_rate``

Returns:
(Tensor): Stretched complex spectrogram of dimension (*, channel, freq, ceil(time/rate), complex=2)
(Tensor): Stretched complex spectrogram of dimension (..., freq, ceil(time/rate), complex=2)
"""
assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (*, complex=2)"
assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (..., complex=2)"

if overriding_rate is None:
rate = self.fixed_rate
Expand All @@ -458,16 +458,12 @@ def forward(self, complex_specgrams, overriding_rate=None):
if rate == 1.0:
return complex_specgrams

shape = complex_specgrams.size()
complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:]))
complex_specgrams = F.phase_vocoder(complex_specgrams, rate, self.phase_advance)

return complex_specgrams.reshape(shape[:-3] + complex_specgrams.shape[-3:])
return F.phase_vocoder(complex_specgrams, rate, self.phase_advance)


class _AxisMasking(torch.nn.Module):
r"""
Apply masking to a spectrogram.
r"""Apply masking to a spectrogram.

Args:
mask_param (int): Maximum possible length of the mask
axis: What dimension the mask is applied on
Expand All @@ -486,26 +482,22 @@ def forward(self, specgram, mask_value=0.):
# type: (Tensor, float) -> Tensor
r"""
Args:
specgram (torch.Tensor): Tensor of dimension (*, channel, freq, time)
specgram (torch.Tensor): Tensor of dimension (..., freq, time)

Returns:
torch.Tensor: Masked spectrogram of dimensions (*, channel, freq, time)
torch.Tensor: Masked spectrogram of dimensions (..., freq, time)
"""

# if iid_masks flag marked and specgram has a batch dimension
if self.iid_masks and specgram.dim() == 4:
return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1)
else:
shape = specgram.size()
specgram = specgram.reshape([-1] + list(shape[-2:]))
specgram = F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis)

return specgram.reshape(shape[:-2] + specgram.shape[-2:])
return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis)


class FrequencyMasking(_AxisMasking):
r"""
Apply masking to a spectrogram in the frequency domain.
r"""Apply masking to a spectrogram in the frequency domain.

Args:
freq_mask_param (int): maximum possible length of the mask.
Indices uniformly sampled from [0, freq_mask_param).
Expand All @@ -518,8 +510,8 @@ def __init__(self, freq_mask_param, iid_masks=False):


class TimeMasking(_AxisMasking):
r"""
Apply masking to a spectrogram in the time domain.
r"""Apply masking to a spectrogram in the time domain.

Args:
time_mask_param (int): maximum possible length of the mask.
Indices uniformly sampled from [0, time_mask_param).
Expand Down