Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ to use and feel like a natural extension.
- [Kaldi (ark/scp)](http://pytorch.org/audio/kaldi_io.html)
- [Dataloaders for common audio datasets (VCTK, YesNo)](http://pytorch.org/audio/datasets.html)
- Common audio transforms
- [Spectrogram, AmplitudeToDB, MelScale, MelSpectrogram, MFCC, MuLawEncoding, MuLawDecoding, Resample](http://pytorch.org/audio/transforms.html)
- [Spectrogram, AmplitudeToDB, MelScale, MelSpectrogram, MFCC, MuLawEncoding, MuLawDecoding, Resample, RandomCrop](http://pytorch.org/audio/transforms.html)
- Compliance interfaces: Run code using PyTorch that align with other libraries
- [Kaldi: spectrogram, fbank, mfcc, resample_waveform](https://pytorch.org/audio/compliance.kaldi.html)

Expand Down Expand Up @@ -143,6 +143,7 @@ Transforms expect and return the following dimensions.
* `MuLawEncode`: (channel, time) -> (channel, time)
* `MuLawDecode`: (channel, time) -> (channel, time)
* `Resample`: (channel, time) -> (channel, time)
* `RandomCrop`: (channel, time) -> (channel, time)

Complex numbers are supported via tensors of dimension (..., 2), and torchaudio provides `complex_norm` and `angle` to convert such a tensor into its magnitude and phase. Here, and in the documentation, we use an ellipsis "..." as a placeholder for the rest of the dimensions of a tensor, e.g. optional batching and channel dimensions.

Expand Down
7 changes: 7 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ Transforms are common audio transforms. They can be chained together using :clas

.. automethod:: forward

:hidden:`RandomCrop`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: RandomCrop

.. automethod:: forward

:hidden:`ComplexNorm`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
44 changes: 30 additions & 14 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@


def _test_script_module(f, tensor, *args, **kwargs):

py_method = f(*args, **kwargs)
jit_method = torch.jit.script(py_method)

Expand All @@ -32,7 +31,6 @@ def _test_script_module(f, tensor, *args, **kwargs):
assert torch.allclose(jit_out, py_out)

if RUN_CUDA:

tensor = tensor.to("cuda")

py_method = py_method.cuda()
Expand All @@ -45,20 +43,19 @@ def _test_script_module(f, tensor, *args, **kwargs):


class Tester(unittest.TestCase):

# create a sinewave signal for testing
sample_rate = 16000
freq = 440
volume = .3
waveform = (torch.cos(2 * math.pi * torch.arange(0, 4 * sample_rate).float() * freq / sample_rate))
waveform.unsqueeze_(0) # (1, 64000)
waveform = (waveform * volume * 2**31).long()
waveform = (waveform * volume * 2 ** 31).long()
# file for stereo stft test
test_dirpath, test_dir = common_utils.create_temp_assets_dir()
test_filepath = os.path.join(test_dirpath, 'assets',
'steam-train-whistle-daniel_simon.mp3')

def scale(self, waveform, factor=float(2**31)):
def scale(self, waveform, factor=float(2 ** 31)):
# scales a waveform by a factor
if not waveform.is_floating_point():
waveform = waveform.to(torch.get_default_dtype())
Expand All @@ -73,7 +70,6 @@ def test_scriptmodule_GriffinLim(self):
_test_script_module(transforms.GriffinLim, tensor, length=1000, rand_init=False)

def test_mu_law_companding(self):

quantization_channels = 256

waveform = self.waveform.clone()
Expand Down Expand Up @@ -266,14 +262,14 @@ def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, s
# function body in https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram
# to mirror this function call with correct args:

# librosa_mfcc = librosa.feature.mfcc(y=sound_librosa,
# sr=sample_rate,
# n_mfcc = n_mfcc,
# hop_length=hop_length,
# n_fft=n_fft,
# htk=True,
# norm=None,
# n_mels=n_mels)
# librosa_mfcc = librosa.feature.mfcc(y=sound_librosa,
# sr=sample_rate,
# n_mfcc = n_mfcc,
# hop_length=hop_length,
# n_fft=n_fft,
# htk=True,
# norm=None,
# n_mels=n_mels)

librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc]
librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc)
Expand Down Expand Up @@ -448,6 +444,26 @@ def test_scriptmodule_TimeMasking(self):
tensor = torch.rand((10, 2, 50, 10, 2))
_test_script_module(transforms.TimeMasking, tensor, time_mask_param=30, iid_masks=False)

def test_random_crop_size(self):
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
waveform, sample_rate = torchaudio.load(input_path)
length = 2
channels, total_samples = waveform.shape[0], waveform.shape[1]
random_cropper = torchaudio.transforms.RandomCrop(length=length, freq=sample_rate)
# Test waveform size is correct
self.assertTrue(random_cropper(waveform).shape == torch.Size([channels, length * sample_rate]))

def test_random_crop_content(self):
import numpy as np
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
waveform, sample_rate = torchaudio.load(input_path)
length = 2
channels, total_samples = waveform.shape[0], waveform.shape[1]
random_cropper = torchaudio.transforms.RandomCrop(length=length, freq=sample_rate)
# Test waveform is subset of the original waveform
self.assertTrue(
np.sum(np.isin(waveform, random_cropper(waveform), assume_unique=True)) >= length * sample_rate * channels)


if __name__ == '__main__':
unittest.main()
53 changes: 52 additions & 1 deletion torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from . import functional as F
from .compliance import kaldi


__all__ = [
'Spectrogram',
'GriffinLim',
Expand All @@ -23,6 +22,7 @@
'TimeStretch',
'FrequencyMasking',
'TimeMasking',
'RandomCropping',
]


Expand Down Expand Up @@ -412,6 +412,7 @@ class Resample(torch.nn.Module):
new_freq (float): The desired frequency. (Default: ``16000``)
resampling_method (str): The resampling method (Default: ``'sinc_interpolation'``)
"""

def __init__(self, orig_freq=16000, new_freq=16000, resampling_method='sinc_interpolation'):
super(Resample, self).__init__()
self.orig_freq = orig_freq
Expand Down Expand Up @@ -587,3 +588,53 @@ class TimeMasking(_AxisMasking):

def __init__(self, time_mask_param, iid_masks=False):
super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks)


class RandomCrop(torch.nn.Module):
r"""Crop the given Audio at a random location with fixed length

Args:
length (float, optional): Length of the Audio to crop in seconds
freq (float): The frequency of the signal. (Default: ``16000``)
padding_mode (string): Type of padding. Should be: ``'constant'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
Default: ``'constant'``
- constant: pads with a constant value, this value is specified with fill
- edge: pads with the last value on the edge of the audio
- repeat: pads with repetition of audio
fill (float/int): fill value for constant fill. Default is 0.
"""
__constants__ = ['constant', 'reflect', 'replicate', 'circular']

def __init__(self, length=1, freq=16000, padding_mode='constant', fill=0):
super(RandomCrop, self).__init__()
# Check Correctness
if length <= 0:
raise ValueError(f"Audio length cannot be less than equal to 0")
if freq <= 0:
raise ValueError(f"Audio frequency cannot be less than equal to 0")
if padding_mode not in self.__constants__:
raise ValueError(f"Audio padding mode does not exist, pick on of these: {str(self.__constants__)}")
# Assign Values
self.length = length
self.freq = freq
self.padding_mode = padding_mode
self.fill = fill

def forward(self, waveform):
r"""
Args:
waveform (torch.Tensor): The input signal of dimension (channel, total_audio_length x freq)

Returns:
torch.Tensor: Output signal of dimension (channel, length x freq)
"""
n_samples = int(self.length * self.freq)
total_samples = waveform.shape[1]
if total_samples == n_samples:
return waveform
elif total_samples < n_samples:
n_pad = n_samples - total_samples
return torch.nn.functional.pad(waveform, [0, n_pad], mode=self.padding_mode, value=self.fill)
else:
s_index = torch.randint(low=0, high=total_samples - n_samples - 1, size=[1]).item()
return waveform[:, s_index:s_index + n_samples]