diff --git a/README.md b/README.md index e018c87403..40045c35cc 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ to use and feel like a natural extension. - [Kaldi (ark/scp)](http://pytorch.org/audio/kaldi_io.html) - [Dataloaders for common audio datasets (VCTK, YesNo)](http://pytorch.org/audio/datasets.html) - Common audio transforms - - [Spectrogram, AmplitudeToDB, MelScale, MelSpectrogram, MFCC, MuLawEncoding, MuLawDecoding, Resample](http://pytorch.org/audio/transforms.html) + - [Spectrogram, AmplitudeToDB, MelScale, MelSpectrogram, MFCC, MuLawEncoding, MuLawDecoding, Resample, RandomCrop](http://pytorch.org/audio/transforms.html) - Compliance interfaces: Run code using PyTorch that align with other libraries - [Kaldi: spectrogram, fbank, mfcc, resample_waveform](https://pytorch.org/audio/compliance.kaldi.html) @@ -143,6 +143,7 @@ Transforms expect and return the following dimensions. * `MuLawEncode`: (channel, time) -> (channel, time) * `MuLawDecode`: (channel, time) -> (channel, time) * `Resample`: (channel, time) -> (channel, time) +* `RandomCrop`: (channel, time) -> (channel, time) Complex numbers are supported via tensors of dimension (..., 2), and torchaudio provides `complex_norm` and `angle` to convert such a tensor into its magnitude and phase. Here, and in the documentation, we use an ellipsis "..." as a placeholder for the rest of the dimensions of a tensor, e.g. optional batching and channel dimensions. diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index b460115a88..d57761205f 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -72,6 +72,13 @@ Transforms are common audio transforms. They can be chained together using :clas .. automethod:: forward +:hidden:`RandomCrop` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: RandomCrop + + .. automethod:: forward + :hidden:`ComplexNorm` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/test/test_transforms.py b/test/test_transforms.py index f2b45ec625..79c52771ae 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -22,7 +22,6 @@ def _test_script_module(f, tensor, *args, **kwargs): - py_method = f(*args, **kwargs) jit_method = torch.jit.script(py_method) @@ -32,7 +31,6 @@ def _test_script_module(f, tensor, *args, **kwargs): assert torch.allclose(jit_out, py_out) if RUN_CUDA: - tensor = tensor.to("cuda") py_method = py_method.cuda() @@ -45,20 +43,19 @@ def _test_script_module(f, tensor, *args, **kwargs): class Tester(unittest.TestCase): - # create a sinewave signal for testing sample_rate = 16000 freq = 440 volume = .3 waveform = (torch.cos(2 * math.pi * torch.arange(0, 4 * sample_rate).float() * freq / sample_rate)) waveform.unsqueeze_(0) # (1, 64000) - waveform = (waveform * volume * 2**31).long() + waveform = (waveform * volume * 2 ** 31).long() # file for stereo stft test test_dirpath, test_dir = common_utils.create_temp_assets_dir() test_filepath = os.path.join(test_dirpath, 'assets', 'steam-train-whistle-daniel_simon.mp3') - def scale(self, waveform, factor=float(2**31)): + def scale(self, waveform, factor=float(2 ** 31)): # scales a waveform by a factor if not waveform.is_floating_point(): waveform = waveform.to(torch.get_default_dtype()) @@ -73,7 +70,6 @@ def test_scriptmodule_GriffinLim(self): _test_script_module(transforms.GriffinLim, tensor, length=1000, rand_init=False) def test_mu_law_companding(self): - quantization_channels = 256 waveform = self.waveform.clone() @@ -266,14 +262,14 @@ def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, s # function body in https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram # to mirror this function call with correct args: - # librosa_mfcc = librosa.feature.mfcc(y=sound_librosa, - # sr=sample_rate, - # n_mfcc = n_mfcc, - # hop_length=hop_length, - # n_fft=n_fft, - # htk=True, - # norm=None, - # n_mels=n_mels) + # librosa_mfcc = librosa.feature.mfcc(y=sound_librosa, + # sr=sample_rate, + # n_mfcc = n_mfcc, + # hop_length=hop_length, + # n_fft=n_fft, + # htk=True, + # norm=None, + # n_mels=n_mels) librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc] librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc) @@ -448,6 +444,26 @@ def test_scriptmodule_TimeMasking(self): tensor = torch.rand((10, 2, 50, 10, 2)) _test_script_module(transforms.TimeMasking, tensor, time_mask_param=30, iid_masks=False) + def test_random_crop_size(self): + input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav') + waveform, sample_rate = torchaudio.load(input_path) + length = 2 + channels, total_samples = waveform.shape[0], waveform.shape[1] + random_cropper = torchaudio.transforms.RandomCrop(length=length, freq=sample_rate) + # Test waveform size is correct + self.assertTrue(random_cropper(waveform).shape == torch.Size([channels, length * sample_rate])) + + def test_random_crop_content(self): + import numpy as np + input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav') + waveform, sample_rate = torchaudio.load(input_path) + length = 2 + channels, total_samples = waveform.shape[0], waveform.shape[1] + random_cropper = torchaudio.transforms.RandomCrop(length=length, freq=sample_rate) + # Test waveform is subset of the original waveform + self.assertTrue( + np.sum(np.isin(waveform, random_cropper(waveform), assume_unique=True)) >= length * sample_rate * channels) + if __name__ == '__main__': unittest.main() diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 92b336cf92..6e9302eeb2 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -8,7 +8,6 @@ from . import functional as F from .compliance import kaldi - __all__ = [ 'Spectrogram', 'GriffinLim', @@ -23,6 +22,7 @@ 'TimeStretch', 'FrequencyMasking', 'TimeMasking', + 'RandomCropping', ] @@ -412,6 +412,7 @@ class Resample(torch.nn.Module): new_freq (float): The desired frequency. (Default: ``16000``) resampling_method (str): The resampling method (Default: ``'sinc_interpolation'``) """ + def __init__(self, orig_freq=16000, new_freq=16000, resampling_method='sinc_interpolation'): super(Resample, self).__init__() self.orig_freq = orig_freq @@ -587,3 +588,53 @@ class TimeMasking(_AxisMasking): def __init__(self, time_mask_param, iid_masks=False): super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks) + + +class RandomCrop(torch.nn.Module): + r"""Crop the given Audio at a random location with fixed length + + Args: + length (float, optional): Length of the Audio to crop in seconds + freq (float): The frequency of the signal. (Default: ``16000``) + padding_mode (string): Type of padding. Should be: ``'constant'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + Default: ``'constant'`` + - constant: pads with a constant value, this value is specified with fill + - edge: pads with the last value on the edge of the audio + - repeat: pads with repetition of audio + fill (float/int): fill value for constant fill. Default is 0. + """ + __constants__ = ['constant', 'reflect', 'replicate', 'circular'] + + def __init__(self, length=1, freq=16000, padding_mode='constant', fill=0): + super(RandomCrop, self).__init__() + # Check Correctness + if length <= 0: + raise ValueError(f"Audio length cannot be less than equal to 0") + if freq <= 0: + raise ValueError(f"Audio frequency cannot be less than equal to 0") + if padding_mode not in self.__constants__: + raise ValueError(f"Audio padding mode does not exist, pick on of these: {str(self.__constants__)}") + # Assign Values + self.length = length + self.freq = freq + self.padding_mode = padding_mode + self.fill = fill + + def forward(self, waveform): + r""" + Args: + waveform (torch.Tensor): The input signal of dimension (channel, total_audio_length x freq) + + Returns: + torch.Tensor: Output signal of dimension (channel, length x freq) + """ + n_samples = int(self.length * self.freq) + total_samples = waveform.shape[1] + if total_samples == n_samples: + return waveform + elif total_samples < n_samples: + n_pad = n_samples - total_samples + return torch.nn.functional.pad(waveform, [0, n_pad], mode=self.padding_mode, value=self.fill) + else: + s_index = torch.randint(low=0, high=total_samples - n_samples - 1, size=[1]).item() + return waveform[:, s_index:s_index + n_samples]