From c3e080c7873b4be1762d2a59706dfe5381347180 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 21 Nov 2019 14:54:46 -0500 Subject: [PATCH 1/4] fixing errors in docstring. --- torchaudio/transforms.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 3d0269fe1b..77459d5e1e 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -380,9 +380,9 @@ def __init__(self, power=1.0): def forward(self, complex_tensor): r""" Args: - complex_tensor (Tensor): Tensor shape of `(*, complex=2)` + complex_tensor (Tensor): Tensor shape of `(\*, complex=2)` Returns: - Tensor: norm of the input tensor, shape of `(*, )` + Tensor: norm of the input tensor, shape of `(\*, )` """ return F.complex_norm(complex_tensor, self.power) @@ -435,17 +435,16 @@ def __init__(self, hop_length=None, n_freq=201, fixed_rate=None): self.phase_advance = torch.jit.Attribute(phase_advance, torch.Tensor) def forward(self, complex_specgrams, overriding_rate=None): - # type: (Tensor, Optional[float]) -> Tensor r""" Args: - complex_specgrams (Tensor): complex spectrogram (*, channel, freq, time, complex=2) + complex_specgrams (Tensor): complex spectrogram (\*, channel, freq, time, complex=2) overriding_rate (float or None): speed up to apply to this batch. If no rate is passed, use ``self.fixed_rate`` Returns: - (Tensor): Stretched complex spectrogram of dimension (*, channel, freq, ceil(time/rate), complex=2) + (Tensor): Stretched complex spectrogram of dimension (\*, channel, freq, ceil(time/rate), complex=2) """ - assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (*, complex=2)" + assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (\*, complex=2)" if overriding_rate is None: rate = self.fixed_rate @@ -466,8 +465,8 @@ def forward(self, complex_specgrams, overriding_rate=None): class _AxisMasking(torch.nn.Module): - r""" - Apply masking to a spectrogram. + r"""Apply masking to a spectrogram. + Args: mask_param (int): Maximum possible length of the mask axis: What dimension the mask is applied on @@ -483,13 +482,12 @@ def __init__(self, mask_param, axis, iid_masks): self.iid_masks = iid_masks def forward(self, specgram, mask_value=0.): - # type: (Tensor, float) -> Tensor r""" Args: - specgram (torch.Tensor): Tensor of dimension (*, channel, freq, time) + specgram (torch.Tensor): Tensor of dimension (\*, channel, freq, time) Returns: - torch.Tensor: Masked spectrogram of dimensions (*, channel, freq, time) + torch.Tensor: Masked spectrogram of dimensions (\*, channel, freq, time) """ # if iid_masks flag marked and specgram has a batch dimension @@ -504,8 +502,8 @@ def forward(self, specgram, mask_value=0.): class FrequencyMasking(_AxisMasking): - r""" - Apply masking to a spectrogram in the frequency domain. + r"""Apply masking to a spectrogram in the frequency domain. + Args: freq_mask_param (int): maximum possible length of the mask. Indices uniformly sampled from [0, freq_mask_param). @@ -518,8 +516,8 @@ def __init__(self, freq_mask_param, iid_masks=False): class TimeMasking(_AxisMasking): - r""" - Apply masking to a spectrogram in the time domain. + r"""Apply masking to a spectrogram in the time domain. + Args: time_mask_param (int): maximum possible length of the mask. Indices uniformly sampled from [0, time_mask_param). From d112f2be2a12de82640068afc0276c79b379a15f Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 21 Nov 2019 15:13:15 -0500 Subject: [PATCH 2/4] using ellipses for batch. --- torchaudio/transforms.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 77459d5e1e..2e6cf2eb1a 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -380,9 +380,9 @@ def __init__(self, power=1.0): def forward(self, complex_tensor): r""" Args: - complex_tensor (Tensor): Tensor shape of `(\*, complex=2)` + complex_tensor (Tensor): Tensor shape of `(..., complex=2)` Returns: - Tensor: norm of the input tensor, shape of `(\*, )` + Tensor: norm of the input tensor, shape of `(..., )` """ return F.complex_norm(complex_tensor, self.power) @@ -435,16 +435,17 @@ def __init__(self, hop_length=None, n_freq=201, fixed_rate=None): self.phase_advance = torch.jit.Attribute(phase_advance, torch.Tensor) def forward(self, complex_specgrams, overriding_rate=None): + # type: (Tensor, Optional[float]) -> Tensor r""" Args: - complex_specgrams (Tensor): complex spectrogram (\*, channel, freq, time, complex=2) + complex_specgrams (Tensor): complex spectrogram (..., freq, time, complex=2) overriding_rate (float or None): speed up to apply to this batch. If no rate is passed, use ``self.fixed_rate`` Returns: - (Tensor): Stretched complex spectrogram of dimension (\*, channel, freq, ceil(time/rate), complex=2) + (Tensor): Stretched complex spectrogram of dimension (..., freq, ceil(time/rate), complex=2) """ - assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (\*, complex=2)" + assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (..., complex=2)" if overriding_rate is None: rate = self.fixed_rate @@ -482,12 +483,13 @@ def __init__(self, mask_param, axis, iid_masks): self.iid_masks = iid_masks def forward(self, specgram, mask_value=0.): + # type: (Tensor, float) -> Tensor r""" Args: - specgram (torch.Tensor): Tensor of dimension (\*, channel, freq, time) + specgram (torch.Tensor): Tensor of dimension (..., freq, time) Returns: - torch.Tensor: Masked spectrogram of dimensions (\*, channel, freq, time) + torch.Tensor: Masked spectrogram of dimensions (..., freq, time) """ # if iid_masks flag marked and specgram has a batch dimension From 1f7ff9ee631e007910b71cbc05ecf52dbdd29acf Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 21 Nov 2019 16:03:53 -0500 Subject: [PATCH 3/4] move batch to functional. --- torchaudio/functional.py | 20 +++++++++++++++++--- torchaudio/transforms.py | 12 ++---------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 78e7a5b759..98aa998fa4 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -469,13 +469,13 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): factor of ``rate``. Args: - complex_specgrams (torch.Tensor): Dimension of `(channel, freq, time, complex=2)` + complex_specgrams (torch.Tensor): Dimension of `(..., freq, time, complex=2)` rate (float): Speed-up factor phase_advance (torch.Tensor): Expected phase advance in each bin. Dimension of (freq, 1) Returns: - complex_specgrams_stretch (torch.Tensor): Dimension of `(channel, + complex_specgrams_stretch (torch.Tensor): Dimension of `(..., freq, ceil(time/rate), complex=2)` Example @@ -490,6 +490,10 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): torch.Size([2, 1025, 231, 2]) """ + # pack batch + shape = complex_specgrams.size() + complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:])) + time_steps = torch.arange(0, complex_specgrams.size(-2), rate, @@ -527,6 +531,9 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): complex_specgrams_stretch = torch.stack([real_stretch, imag_stretch], dim=-1) + # unpack batch + complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-3] + complex_specgrams.shape[-3:]) + return complex_specgrams_stretch @@ -775,6 +782,10 @@ def mask_along_axis(specgram, mask_param, mask_value, axis): torch.Tensor: Masked spectrogram of dimensions (channel, freq, time) """ + # pack batch + shape = specgram.size() + specgram = specgram.reshape([-1] + list(shape[-2:])) + value = torch.rand(1) * mask_param min_value = torch.rand(1) * (specgram.size(axis) - value) @@ -789,7 +800,10 @@ def mask_along_axis(specgram, mask_param, mask_value, axis): else: raise ValueError('Only Frequency and Time masking are supported') - return specgram + # unpack batch + specgram = specgram.reshape(shape[:-2] + specgram.shape[-2:]) + + return specgram.reshape(shape[:-2] + specgram.shape[-2:]) def compute_deltas(specgram, win_length=5, mode="replicate"): diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 2e6cf2eb1a..ed0fcece25 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -458,11 +458,7 @@ def forward(self, complex_specgrams, overriding_rate=None): if rate == 1.0: return complex_specgrams - shape = complex_specgrams.size() - complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:])) - complex_specgrams = F.phase_vocoder(complex_specgrams, rate, self.phase_advance) - - return complex_specgrams.reshape(shape[:-3] + complex_specgrams.shape[-3:]) + return F.phase_vocoder(complex_specgrams, rate, self.phase_advance) class _AxisMasking(torch.nn.Module): @@ -496,11 +492,7 @@ def forward(self, specgram, mask_value=0.): if self.iid_masks and specgram.dim() == 4: return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1) else: - shape = specgram.size() - specgram = specgram.reshape([-1] + list(shape[-2:])) - specgram = F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis) - - return specgram.reshape(shape[:-2] + specgram.shape[-2:]) + return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis) class FrequencyMasking(_AxisMasking): From 41524a22c20b86e78d0afb6ec7cd58483237397d Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 21 Nov 2019 18:54:43 -0500 Subject: [PATCH 4/4] update reshape. --- torchaudio/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 98aa998fa4..496de76357 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -532,7 +532,7 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): complex_specgrams_stretch = torch.stack([real_stretch, imag_stretch], dim=-1) # unpack batch - complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-3] + complex_specgrams.shape[-3:]) + complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-3] + complex_specgrams_stretch.shape[1:]) return complex_specgrams_stretch