From af9c5ead3ee987e4eefe7946c08e14dbf165c1b6 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Fri, 15 Oct 2021 20:42:10 -0400 Subject: [PATCH 1/2] Add SpecAugment figure/citation --- docs/source/refs.bib | 10 ++++++++ torchaudio/transforms.py | 53 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+) diff --git a/docs/source/refs.bib b/docs/source/refs.bib index ab6cfb73b6..d1b0ef17c3 100644 --- a/docs/source/refs.bib +++ b/docs/source/refs.bib @@ -1,3 +1,13 @@ +@article{specaugment, + title={SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition}, + url={http://dx.doi.org/10.21437/Interspeech.2019-2680}, + DOI={10.21437/interspeech.2019-2680}, + journal={Interspeech 2019}, + publisher={ISCA}, + author={Park, Daniel S. and Chan, William and Zhang, Yu and Chiu, Chung-Cheng and Zoph, Barret and Cubuk, Ekin D. and Le, Quoc V.}, + year={2019}, + month={Sep} +} @misc{ljspeech17, author = {Keith Ito and Linda Johnson}, title = {The LJ Speech Dataset}, diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 97a1ae84b0..f5a639d370 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -947,11 +947,34 @@ def forward(self, specgram: Tensor) -> Tensor: class TimeStretch(torch.nn.Module): r"""Stretch stft in time without modifying pitch for a given rate. + Proposed in *SpecAugment* [:footcite:`specaugment`]. + Args: hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``) n_freq (int, optional): number of filter banks from stft. (Default: ``201``) fixed_rate (float or None, optional): rate to speed up or slow down by. If None is provided, rate must be passed to the forward method. (Default: ``None``) + + Example: + >>> spectrogram = torchaudio.transforms.Spectrogram() + >>> stretch = torchaudio.transforms.TimeStretch() + >>> + >>> original = spectrogram(waveform) + >>> streched_1_2 = stretch(original, 1.2) + >>> streched_0_9 = stretch(original, 0.9) + + .. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch_1.png + :width: 600 + :alt: Spectrogram streched by 1.2 + + .. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch_2.png + :width: 600 + :alt: The original spectrogram + + .. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch_3.png + :width: 600 + :alt: Spectrogram streched by 0.9 + """ __constants__ = ['fixed_rate'] @@ -1111,12 +1134,27 @@ def forward(self, specgram: Tensor, mask_value: float = 0.) -> Tensor: class FrequencyMasking(_AxisMasking): r"""Apply masking to a spectrogram in the frequency domain. + Proposed in *SpecAugment* [:footcite:`specaugment`]. + Args: freq_mask_param (int): maximum possible length of the mask. Indices uniformly sampled from [0, freq_mask_param). iid_masks (bool, optional): whether to apply different masks to each example/channel in the batch. (Default: ``False``) This option is applicable only when the input tensor is 4D. + + Example: + >>> spectrogram = torchaudio.transforms.Spectrogram() + >>> masking = torchaudio.transforms.FrequencyMasking(freq_mask_param=80) + >>> + >>> original = spectrogram(waveform) + >>> masked = masking(original) + + .. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_freq_masking1.png + :alt: The original spectrogram + + .. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_freq_masking2.png + :alt: The spectrogram masked along frequency axis """ def __init__(self, freq_mask_param: int, iid_masks: bool = False) -> None: @@ -1126,12 +1164,27 @@ def __init__(self, freq_mask_param: int, iid_masks: bool = False) -> None: class TimeMasking(_AxisMasking): r"""Apply masking to a spectrogram in the time domain. + Proposed in *SpecAugment* [:footcite:`specaugment`]. + Args: time_mask_param (int): maximum possible length of the mask. Indices uniformly sampled from [0, time_mask_param). iid_masks (bool, optional): whether to apply different masks to each example/channel in the batch. (Default: ``False``) This option is applicable only when the input tensor is 4D. + + Example: + >>> spectrogram = torchaudio.transforms.Spectrogram() + >>> masking = torchaudio.transforms.TimeMasking(time_mask_param=80) + >>> + >>> original = spectrogram(waveform) + >>> masked = masking(original) + + .. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_masking1.png + :alt: The original spectrogram + + .. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_masking2.png + :alt: The spectrogram masked along time axis """ def __init__(self, time_mask_param: int, iid_masks: bool = False) -> None: From ba43b7c545b34be416ea18c3e986732dc62c6580 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Sat, 16 Oct 2021 11:18:46 -0400 Subject: [PATCH 2/2] apply suggestions --- torchaudio/transforms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index f5a639d370..b8f3044a3e 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -955,7 +955,7 @@ class TimeStretch(torch.nn.Module): fixed_rate (float or None, optional): rate to speed up or slow down by. If None is provided, rate must be passed to the forward method. (Default: ``None``) - Example: + Example >>> spectrogram = torchaudio.transforms.Spectrogram() >>> stretch = torchaudio.transforms.TimeStretch() >>> @@ -1143,7 +1143,7 @@ class FrequencyMasking(_AxisMasking): example/channel in the batch. (Default: ``False``) This option is applicable only when the input tensor is 4D. - Example: + Example >>> spectrogram = torchaudio.transforms.Spectrogram() >>> masking = torchaudio.transforms.FrequencyMasking(freq_mask_param=80) >>> @@ -1173,7 +1173,7 @@ class TimeMasking(_AxisMasking): example/channel in the batch. (Default: ``False``) This option is applicable only when the input tensor is 4D. - Example: + Example >>> spectrogram = torchaudio.transforms.Spectrogram() >>> masking = torchaudio.transforms.TimeMasking(time_mask_param=80) >>>