Skip to content

Commit 9e3778d

Browse files
authored
Add SpecAugment figure/citation (#1887)
1 parent e885204 commit 9e3778d

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

docs/source/refs.bib

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
@article{specaugment,
2+
title={SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition},
3+
url={http://dx.doi.org/10.21437/Interspeech.2019-2680},
4+
DOI={10.21437/interspeech.2019-2680},
5+
journal={Interspeech 2019},
6+
publisher={ISCA},
7+
author={Park, Daniel S. and Chan, William and Zhang, Yu and Chiu, Chung-Cheng and Zoph, Barret and Cubuk, Ekin D. and Le, Quoc V.},
8+
year={2019},
9+
month={Sep}
10+
}
111
@misc{ljspeech17,
212
author = {Keith Ito and Linda Johnson},
313
title = {The LJ Speech Dataset},

torchaudio/transforms.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -947,11 +947,34 @@ def forward(self, specgram: Tensor) -> Tensor:
947947
class TimeStretch(torch.nn.Module):
948948
r"""Stretch stft in time without modifying pitch for a given rate.
949949
950+
Proposed in *SpecAugment* [:footcite:`specaugment`].
951+
950952
Args:
951953
hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
952954
n_freq (int, optional): number of filter banks from stft. (Default: ``201``)
953955
fixed_rate (float or None, optional): rate to speed up or slow down by.
954956
If None is provided, rate must be passed to the forward method. (Default: ``None``)
957+
958+
Example
959+
>>> spectrogram = torchaudio.transforms.Spectrogram()
960+
>>> stretch = torchaudio.transforms.TimeStretch()
961+
>>>
962+
>>> original = spectrogram(waveform)
963+
>>> streched_1_2 = stretch(original, 1.2)
964+
>>> streched_0_9 = stretch(original, 0.9)
965+
966+
.. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch_1.png
967+
:width: 600
968+
:alt: Spectrogram streched by 1.2
969+
970+
.. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch_2.png
971+
:width: 600
972+
:alt: The original spectrogram
973+
974+
.. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch_3.png
975+
:width: 600
976+
:alt: Spectrogram streched by 0.9
977+
955978
"""
956979
__constants__ = ['fixed_rate']
957980

@@ -1111,12 +1134,27 @@ def forward(self, specgram: Tensor, mask_value: float = 0.) -> Tensor:
11111134
class FrequencyMasking(_AxisMasking):
11121135
r"""Apply masking to a spectrogram in the frequency domain.
11131136
1137+
Proposed in *SpecAugment* [:footcite:`specaugment`].
1138+
11141139
Args:
11151140
freq_mask_param (int): maximum possible length of the mask.
11161141
Indices uniformly sampled from [0, freq_mask_param).
11171142
iid_masks (bool, optional): whether to apply different masks to each
11181143
example/channel in the batch. (Default: ``False``)
11191144
This option is applicable only when the input tensor is 4D.
1145+
1146+
Example
1147+
>>> spectrogram = torchaudio.transforms.Spectrogram()
1148+
>>> masking = torchaudio.transforms.FrequencyMasking(freq_mask_param=80)
1149+
>>>
1150+
>>> original = spectrogram(waveform)
1151+
>>> masked = masking(original)
1152+
1153+
.. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_freq_masking1.png
1154+
:alt: The original spectrogram
1155+
1156+
.. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_freq_masking2.png
1157+
:alt: The spectrogram masked along frequency axis
11201158
"""
11211159

11221160
def __init__(self, freq_mask_param: int, iid_masks: bool = False) -> None:
@@ -1126,12 +1164,27 @@ def __init__(self, freq_mask_param: int, iid_masks: bool = False) -> None:
11261164
class TimeMasking(_AxisMasking):
11271165
r"""Apply masking to a spectrogram in the time domain.
11281166
1167+
Proposed in *SpecAugment* [:footcite:`specaugment`].
1168+
11291169
Args:
11301170
time_mask_param (int): maximum possible length of the mask.
11311171
Indices uniformly sampled from [0, time_mask_param).
11321172
iid_masks (bool, optional): whether to apply different masks to each
11331173
example/channel in the batch. (Default: ``False``)
11341174
This option is applicable only when the input tensor is 4D.
1175+
1176+
Example
1177+
>>> spectrogram = torchaudio.transforms.Spectrogram()
1178+
>>> masking = torchaudio.transforms.TimeMasking(time_mask_param=80)
1179+
>>>
1180+
>>> original = spectrogram(waveform)
1181+
>>> masked = masking(original)
1182+
1183+
.. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_masking1.png
1184+
:alt: The original spectrogram
1185+
1186+
.. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_masking2.png
1187+
:alt: The spectrogram masked along time axis
11351188
"""
11361189

11371190
def __init__(self, time_mask_param: int, iid_masks: bool = False) -> None:

0 commit comments

Comments
 (0)