From 2e6bc4ae1558504b2ea131518a870dd1ef24c7e9 Mon Sep 17 00:00:00 2001 From: Jason Lian Date: Thu, 18 Jul 2019 06:57:19 -0700 Subject: [PATCH 01/28] more --- test/test_transforms.py | 16 +- torchaudio/functional.py | 161 +++++++++---------- torchaudio/transforms.py | 326 ++++++++++++++------------------------- 3 files changed, 195 insertions(+), 308 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 4c3a55565c..88d91e6a8e 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -47,25 +47,21 @@ def test_scale(self): def test_pad_trim(self): - audio_orig = self.sig.clone() - length_orig = audio_orig.size(0) + audio_orig = self.sig.clone().t() + length_orig = audio_orig.size(1) length_new = int(length_orig * 1.2) - result = transforms.PadTrim(max_len=length_new, channels_first=False)(audio_orig) - self.assertEqual(result.size(0), length_new) - - result = transforms.PadTrim(max_len=length_new, channels_first=True)(audio_orig.transpose(0, 1)) + result = transforms.PadTrim(max_len=length_new)(audio_orig) self.assertEqual(result.size(1), length_new) audio_orig = self.sig.clone() - length_orig = audio_orig.size(0) length_new = int(length_orig * 0.8) - result = transforms.PadTrim(max_len=length_new, channels_first=False)(audio_orig) + result = transforms.PadTrim(max_len=length_new)(audio_orig) - self.assertEqual(result.size(0), length_new) + self.assertEqual(result.size(1), length_new) - repr_test = transforms.PadTrim(max_len=length_new, channels_first=False) + repr_test = transforms.PadTrim(max_len=length_new) self.assertTrue(repr_test.__repr__()) def test_downmix_mono(self): diff --git a/torchaudio/functional.py b/torchaudio/functional.py index c770090ccb..6fd040b7f8 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -9,10 +9,9 @@ 'LC2CL', 'istft', 'spectrogram', - 'create_fb_matrix', 'spectrogram_to_DB', + 'create_fb_matrix', 'create_dct', - 'BLC2CBL', 'mu_law_encoding', 'mu_law_expanding' ] @@ -20,15 +19,15 @@ @torch.jit.script def scale(tensor, factor): - # type: (Tensor, int) -> Tensor + # type: (Tensor, float) -> Tensor r"""Scale audio tensor from a 16-bit integer (represented as a :class:`torch.FloatTensor`) to a floating point number between -1.0 and 1.0. Note the 16-bit number is called the "bit depth" or "precision", not to be confused with "bit rate". Args: - tensor (torch.Tensor): Tensor of audio of size (n, c) or (c, n) - factor (int): Maximum value of input tensor + tensor (torch.Tensor): Tensor of audio of size (c, n) + factor (float): Maximum value of input tensor Returns: torch.Tensor: Scaled by the scale factor @@ -40,43 +39,34 @@ def scale(tensor, factor): @torch.jit.script -def pad_trim(tensor, ch_dim, max_len, len_dim, fill_value): - # type: (Tensor, int, int, int, float) -> Tensor +def pad_trim(tensor, max_len, fill_value): + # type: (Tensor, int, float) -> Tensor r"""Pad/trim a 2D tensor (signal or labels). Args: - tensor (torch.Tensor): Tensor of audio of size (n, c) or (c, n) - ch_dim (int): Dimension of channel (not size) + tensor (torch.Tensor): Tensor of audio of size (c, n) max_len (int): Length to which the tensor will be padded - len_dim (int): Dimension of length (not size) fill_value (float): Value to fill in Returns: torch.Tensor: Padded/trimmed tensor """ - if max_len > tensor.size(len_dim): - # array of [padding_left, padding_right, padding_top, padding_bottom] - # so pad similar to append (aka only right/bottom) and do not pad - # the length dimension. assumes equal sizes of padding. - padding = [max_len - tensor.size(len_dim) - if (i % 2 == 1) and (i // 2 != len_dim) - else 0 - for i in [0, 1, 2, 3]] + n = tensor.size(1) + if max_len > n: # TODO add "with torch.no_grad():" back when JIT supports it - tensor = torch.nn.functional.pad(tensor, padding, "constant", fill_value) - elif max_len < tensor.size(len_dim): - tensor = tensor.narrow(len_dim, 0, max_len) + tensor = torch.nn.functional.pad(tensor, (0, max_len - n), 'constant', fill_value) + else: + tensor = tensor[:, :max_len] return tensor @torch.jit.script -def downmix_mono(tensor, ch_dim): - # type: (Tensor, int) -> Tensor +def downmix_mono(tensor): + # type: (Tensor) -> Tensor r"""Downmix any stereo signals to mono. Args: - tensor (torch.Tensor): Tensor of audio of size (c, n) or (n, c) - ch_dim (int): Dimension of channel (not size) + tensor (torch.Tensor): Tensor of audio of size (c, n) Returns: torch.Tensor: Mono signal @@ -84,7 +74,7 @@ def downmix_mono(tensor, ch_dim): if not tensor.is_floating_point(): tensor = tensor.to(torch.float32) - tensor = torch.mean(tensor, ch_dim, True) + tensor = torch.mean(tensor, 0, True) return tensor @@ -269,10 +259,9 @@ def spectrogram(sig, pad, window, n_fft, hop, ws, power, normalize): normalize (bool) : Whether to normalize by magnitude after stft Returns: - torch.Tensor: Channels x hops x n_fft (c, l, f), where channels - is unchanged, hops is the number of hops, and n_fft is the - number of fourier bins, which should be the window size divided - by 2 plus 1. + torch.Tensor: Channels x frequency x time (c, f, t), where channels + is unchanged, frequency is `n_fft // 2 + 1` where `n_fft` is the number of + fourier bins, time is the number of window hops """ assert sig.dim() == 2 @@ -282,48 +271,14 @@ def spectrogram(sig, pad, window, n_fft, hop, ws, power, normalize): # default values are consistent with librosa.core.spectrum._spectrogram spec_f = _stft(sig, n_fft, hop, ws, window, - True, 'reflect', False, True).transpose(1, 2) + True, 'reflect', False, True) if normalize: spec_f /= window.pow(2).sum().sqrt() - spec_f = spec_f.pow(power).sum(-1) # get power of "complex" tensor (c, l, n_fft) + spec_f = spec_f.pow(power).sum(-1) # get power of "complex" tensor return spec_f -@torch.jit.script -def create_fb_matrix(n_stft, f_min, f_max, n_mels): - # type: (int, float, float, int) -> Tensor - r""" Create a frequency bin conversion matrix. - - Args: - n_stft (int): Number of filter banks from spectrogram - f_min (float): Minimum frequency - f_max (float): Maximum frequency - n_mels (int): Number of mel bins - - Returns: - torch.Tensor: Triangular filter banks (fb matrix) - """ - # get stft freq bins - stft_freqs = torch.linspace(f_min, f_max, n_stft) - # calculate mel freq bins - # hertz to mel(f) is 2595. * math.log10(1. + (f / 700.)) - m_min = 0. if f_min == 0 else 2595. * math.log10(1. + (f_min / 700.)) - m_max = 2595. * math.log10(1. + (f_max / 700.)) - m_pts = torch.linspace(m_min, m_max, n_mels + 2) - # mel to hertz(mel) is 700. * (10**(mel / 2595.) - 1.) - f_pts = 700. * (10**(m_pts / 2595.) - 1.) - # calculate the difference between each mel point and each stft freq point in hertz - f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1) - slopes = f_pts.unsqueeze(0) - stft_freqs.unsqueeze(1) # (n_stft, n_mels + 2) - # create overlapping triangles - z = torch.zeros(1) - down_slopes = (-1. * slopes[:, :-2]) / f_diff[:-1] # (n_stft, n_mels) - up_slopes = slopes[:, 2:] / f_diff[1:] # (n_stft, n_mels) - fb = torch.max(z, torch.min(down_slopes, up_slopes)) - return fb - - @torch.jit.script def spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db=None): # type: (Tensor, float, float, float, Optional[float]) -> Tensor @@ -334,7 +289,7 @@ def spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db=None): a full clip. Args: - spec (torch.Tensor): Normal STFT + spec (torch.Tensor): Normal STFT of size (c, f, t) multiplier (float): Use 10. for power and 20. for amplitude amin (float): Number to clamp spec db_multiplier (float): Log10(max(reference value and amin)) @@ -342,7 +297,7 @@ def spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db=None): is 80. Returns: - torch.Tensor: Spectrogram in DB + torch.Tensor: Spectrogram in DB of size (c, f, t) """ spec_db = multiplier * torch.log10(torch.clamp(spec, min=amin)) spec_db -= multiplier * db_multiplier @@ -354,50 +309,72 @@ def spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db=None): return spec_db +@torch.jit.script +def create_fb_matrix(n_freqs, f_min, f_max, n_mels): + # type: (int, float, float, int) -> Tensor + r""" Create a frequency bin conversion matrix. + + Args: + n_freqs (int): Number of frequencies to highlight/apply + f_min (float): Minimum frequency + f_max (float): Maximum frequency + n_mels (int): Number of mel filterbanks + + Returns: + torch.Tensor: Triangular filter banks (fb matrix) of size (`n_freqs`, `n_mels`) + meaning number of frequencies to highlight/apply to x the number of filterbanks. + Each column is a filterbank so that assuming there is a matrix A of + size (..., `n_freqs`), the applied result would be + `A * create_fb_matrix(A.size(-1), ...)`. + """ + # get stft freq bins + stft_freqs = torch.linspace(f_min, f_max, n_freqs) + # calculate mel freq bins + # hertz to mel(f) is 2595. * math.log10(1. + (f / 700.)) + m_min = 0. if f_min == 0 else 2595. * math.log10(1. + (f_min / 700.)) + m_max = 2595. * math.log10(1. + (f_max / 700.)) + m_pts = torch.linspace(m_min, m_max, n_mels + 2) + # mel to hertz(mel) is 700. * (10**(mel / 2595.) - 1.) + f_pts = 700. * (10**(m_pts / 2595.) - 1.) + # calculate the difference between each mel point and each stft freq point in hertz + f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1) + slopes = f_pts.unsqueeze(0) - stft_freqs.unsqueeze(1) # (n_freqs, n_mels + 2) + # create overlapping triangles + z = torch.zeros(1) + down_slopes = (-1. * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_mels) + up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_mels) + fb = torch.max(z, torch.min(down_slopes, up_slopes)) + return fb + + @torch.jit.script def create_dct(n_mfcc, n_mels, norm): # type: (int, int, Optional[str]) -> Tensor - r"""Creates a DCT transformation matrix with shape (num_mels, num_mfcc), + r"""Creates a DCT transformation matrix with shape (`n_mels`, `n_mfcc`), normalized depending on norm. Args: n_mfcc (int) : Number of mfc coefficients to retain - n_mels (int): Number of MEL bins + n_mels (int): Number of mel filterbanks norm (Optional[str]) : Norm to use (either 'ortho' or None) Returns: - torch.Tensor: The transformation matrix, to be right-multiplied to row-wise data. + torch.Tensor: The transformation matrix, to be right-multiplied to + row-wise data of size (`n_mels`, `n_mfcc`). """ - outdim = n_mfcc - dim = n_mels # http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II - n = torch.arange(dim) - k = torch.arange(outdim)[:, None] - dct = torch.cos(math.pi / float(dim) * (n + 0.5) * k) + n = torch.arange(n_mels, dtype=torch.get_default_dtype()) + k = torch.arange(n_mfcc, dtype=torch.get_default_dtype()).unsqueeze(1) + dct = torch.cos(math.pi / float(n_mels) * (n + 0.5) * k) # size (n_mfcc, n_mels) if norm is None: dct *= 2.0 else: assert norm == 'ortho' dct[0] *= 1.0 / math.sqrt(2.0) - dct *= math.sqrt(2.0 / float(dim)) + dct *= math.sqrt(2.0 / float(n_mels)) return dct.t() -@torch.jit.script -def BLC2CBL(tensor): - # type: (Tensor) -> Tensor - r"""Permute a 3D tensor from Bands x Sample length x Channels to Channels x - Bands x Samples length. - - Args: - tensor (torch.Tensor): Tensor of spectrogram with shape (b, l, c) - - Returns: - torch.Tensor: Tensor of spectrogram with shape (c, b, l) - """ - return tensor.permute(2, 0, 1).contiguous() - - @torch.jit.script def mu_law_encoding(x, qc): # type: (Tensor, int) -> Tensor diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 9176540910..e4ccd6b9f1 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -7,97 +7,56 @@ from .compliance import kaldi -# TODO remove this class -class Compose(object): - """Composes several transforms together. - - Args: - transforms (list of ``Transform`` objects): list of transforms to compose. - - Example: - >>> transforms.Compose([ - >>> transforms.Scale(), - >>> transforms.PadTrim(max_len=16000), - >>> ]) - """ - def __init__(self, transforms): - self.transforms = transforms - - def __call__(self, audio): - for t in self.transforms: - audio = t(audio) - return audio - - def __repr__(self): - format_string = self.__class__.__name__ + '(' - for t in self.transforms: - format_string += '\n' - format_string += ' {0}'.format(t) - format_string += '\n)' - return format_string - - class Scale(torch.jit.ScriptModule): """Scale audio tensor from a 16-bit integer (represented as a FloatTensor) to a floating point number between -1.0 and 1.0. Note the 16-bit number is called the "bit depth" or "precision", not to be confused with "bit rate". Args: - factor (int): maximum value of input tensor. default: 16-bit depth - + factor (float): maximum value of input tensor. default: 16-bit depth """ __constants__ = ['factor'] - def __init__(self, factor=2**31): + def __init__(self, factor=float(2**31)): super(Scale, self).__init__() self.factor = factor @torch.jit.script_method def forward(self, tensor): """ - Args: - tensor (Tensor): Tensor of audio of size (Samples x Channels) + tensor (Tensor): Tensor of audio of size (c, n) Returns: Tensor: Scaled by the scale factor. (default between -1.0 and 1.0) - """ return F.scale(tensor, self.factor) - def __repr__(self): - return self.__class__.__name__ + '()' - class PadTrim(torch.jit.ScriptModule): - """Pad/Trim a 2d-Tensor (Signal or Labels) + """Pad/Trim a 2D-Tensor (Signal or Labels) Args: - tensor (Tensor): Tensor of audio of size (n x c) or (c x n) + tensor (Tensor): Tensor of audio of size (c, n) max_len (int): Length to which the tensor will be padded - channels_first (bool): Pad for channels first tensors. Default: `True` - """ - __constants__ = ['max_len', 'fill_value', 'len_dim', 'ch_dim'] + __constants__ = ['max_len', 'fill_value'] - def __init__(self, max_len, fill_value=0., channels_first=True): + def __init__(self, max_len, fill_value=0.): super(PadTrim, self).__init__() self.max_len = max_len self.fill_value = fill_value - self.len_dim, self.ch_dim = int(channels_first), int(not channels_first) @torch.jit.script_method def forward(self, tensor): """ + Args: + tensor (Tensor): Tensor of audio of size (c, n) Returns: - Tensor: (c x n) or (n x c) - + Tensor: (c, `max_len`) """ - return F.pad_trim(tensor, self.ch_dim, self.max_len, self.len_dim, self.fill_value) - - def __repr__(self): - return self.__class__.__name__ + '(max_len={0})'.format(self.max_len) + return F.pad_trim(tensor, self.max_len, self.fill_value) class DownmixMono(torch.jit.ScriptModule): @@ -105,54 +64,36 @@ class DownmixMono(torch.jit.ScriptModule): the `channels` effect instead of this transformation. Inputs: - tensor (Tensor): Tensor of audio of size (c x n) or (n x c) - channels_first (bool): Downmix across channels dimension. Default: `True` + tensor (Tensor): Tensor of audio of size (c, n) Returns: - tensor (Tensor) (Samples x 1): - + tensor (Tensor) (1, n): """ - __constants__ = ['ch_dim'] - def __init__(self, channels_first=None): + def __init__(self): super(DownmixMono, self).__init__() - self.ch_dim = int(not channels_first) @torch.jit.script_method def forward(self, tensor): return F.downmix_mono(tensor, self.ch_dim) - def __repr__(self): - return self.__class__.__name__ + '()' - class LC2CL(torch.jit.ScriptModule): - """Permute a 2d tensor from samples (n x c) to (c x n) + """Converts a 2D tensor from (n, c) to (c, n) """ - def __init__(self): super(LC2CL, self).__init__() @torch.jit.script_method def forward(self, tensor): """ - Args: - tensor (Tensor): Tensor of audio signal with shape (LxC) - + tensor (Tensor): Tensor of audio signal with shape (n, c) Returns: - tensor (Tensor): Tensor of audio signal with shape (CxL) + tensor (Tensor): Tensor of audio signal with shape (c, n) """ return F.LC2CL(tensor) - def __repr__(self): - return self.__class__.__name__ + '()' - - -def SPECTROGRAM(*args, **kwargs): - warn("SPECTROGRAM has been renamed to Spectrogram") - return Spectrogram(*args, **kwargs) - class Spectrogram(torch.jit.ScriptModule): """Create a spectrogram from a raw audio signal @@ -192,58 +133,15 @@ def forward(self, sig): sig (Tensor): Tensor of audio of size (c, n) Returns: - spec_f (Tensor): channels x hops x n_fft (c, l, f), where channels - is unchanged, hops is the number of hops, and n_fft is the - number of fourier bins, which should be the window size divided - by 2 plus 1. + spec_f (Tensor): Channels x frequency x time (c, f, t), where channels + is unchanged, frequency is `n_fft // 2 + 1` where `n_fft` is the number of + fourier bins, time is the number of window hops """ return F.spectrogram(sig, self.pad, self.window, self.n_fft, self.hop, self.ws, self.power, self.normalize) -def F2M(*args, **kwargs): - warn("F2M has been renamed to MelScale") - return MelScale(*args, **kwargs) - - -class MelScale(torch.jit.ScriptModule): - """This turns a normal STFT into a mel frequency STFT, using a conversion - matrix. This uses triangular filter banks. - - User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)). - - Args: - n_mels (int): number of mel bins - sr (int): sample rate of audio signal - f_max (float, optional): maximum frequency. default: `sr` // 2 - f_min (float): minimum frequency. default: 0 - n_stft (int, optional): number of filter banks from stft. Calculated from first input - if `None` is given. See `n_fft` in `Spectrogram`. - """ - __constants__ = ['n_mels', 'sr', 'f_min', 'f_max'] - - def __init__(self, n_mels=128, sr=16000, f_max=None, f_min=0., n_stft=None): - super(MelScale, self).__init__() - self.n_mels = n_mels - self.sr = sr - self.f_max = f_max if f_max is not None else float(sr // 2) - self.f_min = f_min - fb = torch.empty(0) if n_stft is None else F.create_fb_matrix( - n_stft, self.f_min, self.f_max, self.n_mels) - self.fb = torch.jit.Attribute(fb, torch.Tensor) - - @torch.jit.script_method - def forward(self, spec_f): - if self.fb.numel() == 0: - tmp_fb = F.create_fb_matrix(spec_f.size(2), self.f_min, self.f_max, self.n_mels) - # Attributes cannot be reassigned outside __init__ so workaround - self.fb.resize_(tmp_fb.size()) - self.fb.copy_(tmp_fb) - spec_m = torch.matmul(spec_f, self.fb) # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels) - return spec_m - - class SpectrogramToDB(torch.jit.ScriptModule): """Turns a spectrogram from the power/amplitude scale to the decibel scale. @@ -272,80 +170,67 @@ def __init__(self, stype="power", top_db=None): @torch.jit.script_method def forward(self, spec): - # numerically stable implementation from librosa - # https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html - return F.spectrogram_to_DB(spec, self.multiplier, self.amin, self.db_multiplier, self.top_db) + r"""Numerically stable implementation from Librosa + https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html + Args: + spec (torch.Tensor): STFT of size (c, f, t) -class MFCC(torch.jit.ScriptModule): - """Create the Mel-frequency cepstrum coefficients from an audio signal + Returns: + torch.Tensor: STFT after changing scale of size (c, f, t) + """ + return F.spectrogram_to_DB(spec, self.multiplier, self.amin, self.db_multiplier, self.top_db) - By default, this calculates the MFCC on the DB-scaled Mel spectrogram. - This is not the textbook implementation, but is implemented here to - give consistency with librosa. - This output depends on the maximum value in the input spectrogram, and so - may return different values for an audio clip split into snippets vs. a - a full clip. +class MelScale(torch.jit.ScriptModule): + """This turns a normal STFT into a mel frequency STFT, using a conversion + matrix. This uses triangular filter banks. - Args: - sr (int) : sample rate of audio signal - n_mfcc (int) : number of mfc coefficients to retain - dct_type (int) : type of DCT (discrete cosine transform) to use - norm (string, optional) : norm to use - log_mels (bool) : whether to use log-mel spectrograms instead of db-scaled - melkwargs (dict, optional): arguments for MelSpectrogram + User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)). + + Args: + n_mels (int): Number of mel filterbanks + sr (int): sample rate of audio signal + f_max (float, optional): maximum frequency. default: `sr` // 2 + f_min (float): minimum frequency. default: 0 + n_stft (int, optional): number of filter banks from stft. Calculated from first input + if `None` is given. See `n_fft` in `Spectrogram`. """ - __constants__ = ['sr', 'n_mfcc', 'dct_type', 'top_db', 'log_mels'] + __constants__ = ['n_mels', 'sr', 'f_min', 'f_max'] - def __init__(self, sr=16000, n_mfcc=40, dct_type=2, norm='ortho', log_mels=False, - melkwargs=None): - super(MFCC, self).__init__() - supported_dct_types = [2] - if dct_type not in supported_dct_types: - raise ValueError('DCT type not supported'.format(dct_type)) + def __init__(self, n_mels=128, sr=16000, f_max=None, f_min=0., n_stft=None): + super(MelScale, self).__init__() + self.n_mels = n_mels self.sr = sr - self.n_mfcc = n_mfcc - self.dct_type = dct_type - self.norm = torch.jit.Attribute(norm, Optional[str]) - self.top_db = 80. - self.s2db = SpectrogramToDB("power", self.top_db) - - if melkwargs is not None: - self.MelSpectrogram = MelSpectrogram(sr=self.sr, **melkwargs) - else: - self.MelSpectrogram = MelSpectrogram(sr=self.sr) - - if self.n_mfcc > self.MelSpectrogram.n_mels: - raise ValueError('Cannot select more MFCC coefficients than # mel bins') - dct_mat = F.create_dct(self.n_mfcc, self.MelSpectrogram.n_mels, self.norm) - self.dct_mat = torch.jit.Attribute(dct_mat, torch.Tensor) - self.log_mels = log_mels + self.f_max = f_max if f_max is not None else float(sr // 2) + self.f_min = f_min + fb = torch.empty(0) if n_stft is None else F.create_fb_matrix( + n_stft, self.f_min, self.f_max, self.n_mels) + self.fb = torch.jit.Attribute(fb, torch.Tensor) @torch.jit.script_method - def forward(self, sig): - """ + def forward(self, spec_f): + r""" Args: - sig (Tensor): Tensor of audio of size (channels [c], samples [n]) + spec_f (torch.Tensor): a spectrogram STFT of size (c, f, t) Returns: - spec_mel_db (Tensor): channels x hops x n_mels (c, l, m), where channels - is unchanged, hops is the number of hops, and n_mels is the - number of mel bins. + torch.Tensor: mel frequency spectrogram of size (c, `n_mels`, t) """ - mel_spect = self.MelSpectrogram(sig) - if self.log_mels: - log_offset = 1e-6 - mel_spect = torch.log(mel_spect + log_offset) - else: - mel_spect = self.s2db(mel_spect) - mfcc = torch.matmul(mel_spect, self.dct_mat) - return mfcc + if self.fb.numel() == 0: + tmp_fb = F.create_fb_matrix(spec_f.size(1), self.f_min, self.f_max, self.n_mels) + # Attributes cannot be reassigned outside __init__ so workaround + self.fb.resize_(tmp_fb.size()) + self.fb.copy_(tmp_fb) + + # (c, f, t).transpose(...) dot (f, n_mels) -> (c, t, n_mels).transpose(...) + spec_m = torch.matmul(spec_f.transpose(1, 2), self.fb).transpose(1, 2) + return spec_m class MelSpectrogram(torch.jit.ScriptModule): - """Create MEL Spectrograms from a raw audio signal using the stft - function in PyTorch. + """Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram + and MelScale. Sources: * https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe @@ -360,13 +245,13 @@ class MelSpectrogram(torch.jit.ScriptModule): f_max (float, optional): maximum frequency. default: `sr` // 2 f_min (float): minimum frequency. default: 0 pad (int): two sided padding of signal - n_mels (int): number of MEL bins + n_mels (int): Number of mel filterbanks window (torch windowing function): default: `torch.hann_window` wkwargs (dict, optional): arguments for window function Example: >>> sig, sr = torchaudio.load("test.wav", normalization=True) - >>> spec_mel = transforms.MelSpectrogram(sr)(sig) # (c, l, m) + >>> spec_mel = transforms.MelSpectrogram(sr)(sig) # (c, n_mels, t) """ __constants__ = ['sr', 'n_fft', 'ws', 'hop', 'pad', 'n_mels', 'f_min'] @@ -390,46 +275,79 @@ def __init__(self, sr=16000, n_fft=400, ws=None, hop=None, f_min=0., f_max=None, def forward(self, sig): """ Args: - sig (Tensor): Tensor of audio of size (channels [c], samples [n]) + sig (torch.Tensor): Tensor of audio of size (c, n) Returns: - spec_mel (Tensor): channels x hops x n_mels (c, l, m), where channels - is unchanged, hops is the number of hops, and n_mels is the - number of mel bins. - + torch.Tensor: mel frequency spectrogram of size (c, `n_mels`, t) """ spec = self.spec(sig) spec_mel = self.fm(spec) return spec_mel -def MEL(*args, **kwargs): - raise DeprecationWarning("MEL has been removed from the library please use MelSpectrogram or librosa") +class MFCC(torch.jit.ScriptModule): + """Create the Mel-frequency cepstrum coefficients from an audio signal + + By default, this calculates the MFCC on the DB-scaled Mel spectrogram. + This is not the textbook implementation, but is implemented here to + give consistency with librosa. + This output depends on the maximum value in the input spectrogram, and so + may return different values for an audio clip split into snippets vs. a + a full clip. -class BLC2CBL(torch.jit.ScriptModule): - """Permute a 3d tensor from Bands x Sample length x Channels to Channels x - Bands x Samples length + Args: + sr (int) : sample rate of audio signal + n_mfcc (int) : number of mfc coefficients to retain + dct_type (int) : type of DCT (discrete cosine transform) to use + norm (string, optional) : norm to use + log_mels (bool) : whether to use log-mel spectrograms instead of db-scaled + melkwargs (dict, optional): arguments for MelSpectrogram """ + __constants__ = ['sr', 'n_mfcc', 'dct_type', 'top_db', 'log_mels'] - def __init__(self): - super(BLC2CBL, self).__init__() + def __init__(self, sr=16000, n_mfcc=40, dct_type=2, norm='ortho', log_mels=False, + melkwargs=None): + super(MFCC, self).__init__() + supported_dct_types = [2] + if dct_type not in supported_dct_types: + raise ValueError('DCT type not supported'.format(dct_type)) + self.sr = sr + self.n_mfcc = n_mfcc + self.dct_type = dct_type + self.norm = torch.jit.Attribute(norm, Optional[str]) + self.top_db = 80. + self.s2db = SpectrogramToDB("power", self.top_db) + + if melkwargs is not None: + self.MelSpectrogram = MelSpectrogram(sr=self.sr, **melkwargs) + else: + self.MelSpectrogram = MelSpectrogram(sr=self.sr) + + if self.n_mfcc > self.MelSpectrogram.n_mels: + raise ValueError('Cannot select more MFCC coefficients than # mel bins') + dct_mat = F.create_dct(self.n_mfcc, self.MelSpectrogram.n_mels, self.norm) + self.dct_mat = torch.jit.Attribute(dct_mat, torch.Tensor) + self.log_mels = log_mels @torch.jit.script_method - def forward(self, tensor): + def forward(self, sig): """ - Args: - tensor (Tensor): Tensor of spectrogram with shape (BxLxC) + sig (torch.Tensor): Tensor of audio of size (c, n) Returns: - tensor (Tensor): Tensor of spectrogram with shape (CxBxL) - + torch.Tensor: spec_mel_db of size (c, `n_mfcc`, t) """ - return F.BLC2CBL(tensor) - - def __repr__(self): - return self.__class__.__name__ + '()' + mel_spect = self.MelSpectrogram(sig) + if self.log_mels: + log_offset = 1e-6 + mel_spect = torch.log(mel_spect + log_offset) + else: + mel_spect = self.s2db(mel_spect) + # (c, `n_mels`, t).tranpose(...) dot (`n_mels`, `n_mfcc`) -> (c, t, `n_mfcc`).tranpose(...) + mfcc = torch.matmul(mel_spect.transpose(1, 2), self.dct_mat).transpose(1, 2) + return mfcc class MuLawEncoding(torch.jit.ScriptModule): @@ -452,13 +370,11 @@ def __init__(self, quantization_channels=256): @torch.jit.script_method def forward(self, x): """ - Args: x (FloatTensor/LongTensor) Returns: x_mu (LongTensor) - """ return F.mu_law_encoding(x, self.qc) @@ -486,13 +402,11 @@ def __init__(self, quantization_channels=256): @torch.jit.script_method def forward(self, x_mu): """ - Args: x_mu (Tensor) Returns: x (Tensor) - """ return F.mu_law_expanding(x_mu, self.qc) From 8040752920a4967da43ef6c1f68a3eee68a7934e Mon Sep 17 00:00:00 2001 From: Jason Lian Date: Thu, 18 Jul 2019 07:53:15 -0700 Subject: [PATCH 02/28] more --- torchaudio/functional.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 6fd040b7f8..b2ff9ef921 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -92,6 +92,7 @@ def LC2CL(tensor): return tensor.transpose(0, 1).contiguous() +@torch.jit.ignore def _stft(input, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided): # type: (Tensor, int, Optional[int], Optional[int], Optional[Tensor], bool, str, bool, bool) -> Tensor return torch.stft(input, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided) From 5c0b69346ef62b3c618531ca72327ff91d749f65 Mon Sep 17 00:00:00 2001 From: Jason Lian Date: Thu, 18 Jul 2019 08:27:44 -0700 Subject: [PATCH 03/28] more --- test/test_transforms.py | 322 ++++++++++++++++++--------------------- torchaudio/functional.py | 4 +- torchaudio/transforms.py | 3 +- 3 files changed, 149 insertions(+), 180 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 88d91e6a8e..b14ee1d09c 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -23,7 +23,7 @@ class Tester(unittest.TestCase): freq = 440 volume = .3 sig = (torch.cos(2 * math.pi * torch.arange(0, 4 * sr).float() * freq / sr)) - sig.unsqueeze_(1) # (64000, 1) + sig.unsqueeze_(0) # (1, 64000) sig = (sig * volume * 2**31).long() # file for stereo stft test test_dirpath, test_dir = test.common_utils.create_temp_assets_dir() @@ -37,24 +37,22 @@ def test_scale(self): self.assertTrue(result.min() >= -1. and result.max() <= 1.) maxminmax = max(abs(audio_orig.min()), abs(audio_orig.max())).item() - result = transforms.Scale(factor=maxminmax)(audio_orig) + result = transforms.Scale(factor=float(maxminmax))(audio_orig) self.assertTrue((result.min() == -1. or result.max() == 1.) and result.min() >= -1. and result.max() <= 1.) repr_test = transforms.Scale() - self.assertTrue(repr_test.__repr__()) def test_pad_trim(self): - audio_orig = self.sig.clone().t() + audio_orig = self.sig.clone() length_orig = audio_orig.size(1) length_new = int(length_orig * 1.2) result = transforms.PadTrim(max_len=length_new)(audio_orig) self.assertEqual(result.size(1), length_new) - audio_orig = self.sig.clone() length_new = int(length_orig * 0.8) result = transforms.PadTrim(max_len=length_new)(audio_orig) @@ -62,7 +60,6 @@ def test_pad_trim(self): self.assertEqual(result.size(1), length_new) repr_test = transforms.PadTrim(max_len=length_new) - self.assertTrue(repr_test.__repr__()) def test_downmix_mono(self): @@ -71,43 +68,21 @@ def test_downmix_mono(self): R_idx = int(audio_R.size(0) * 0.1) audio_R = torch.cat((audio_R[R_idx:], audio_R[:R_idx])) - audio_Stereo = torch.cat((audio_L, audio_R), dim=1) + audio_Stereo = torch.cat((audio_L, audio_R), dim=0) - self.assertTrue(audio_Stereo.size(1) == 2) + self.assertTrue(audio_Stereo.size(0) == 2) - result = transforms.DownmixMono(channels_first=False)(audio_Stereo) + result = transforms.DownmixMono()(audio_Stereo) - self.assertTrue(result.size(1) == 1) - - repr_test = transforms.DownmixMono(channels_first=False) - self.assertTrue(repr_test.__repr__()) + self.assertTrue(result.size(0) == 1) def test_lc2cl(self): - audio = self.sig.clone() + audio = self.sig.clone().t() result = transforms.LC2CL()(audio) self.assertTrue(result.size()[::-1] == audio.size()) repr_test = transforms.LC2CL() - self.assertTrue(repr_test.__repr__()) - - def test_compose(self): - - audio_orig = self.sig.clone() - length_orig = audio_orig.size(0) - length_new = int(length_orig * 1.2) - maxminmax = max(abs(audio_orig.min()), abs(audio_orig.max())).item() - - tset = (transforms.Scale(factor=maxminmax), - transforms.PadTrim(max_len=length_new, channels_first=False)) - result = transforms.Compose(tset)(audio_orig) - - self.assertTrue(max(abs(result.min()), abs(result.max())) == 1.) - - self.assertTrue(result.size(0) == length_new) - - repr_test = transforms.Compose(tset) - self.assertTrue(repr_test.__repr__()) def test_mu_law_companding(self): @@ -123,21 +98,16 @@ def test_mu_law_companding(self): sig_exp = transforms.MuLawExpanding(quantization_channels)(sig_mu) self.assertTrue(sig_exp.min() >= -1. and sig_exp.max() <= 1.) - repr_test = transforms.MuLawEncoding(quantization_channels) - self.assertTrue(repr_test.__repr__()) - repr_test = transforms.MuLawExpanding(quantization_channels) - self.assertTrue(repr_test.__repr__()) - def test_mel2(self): top_db = 80. s2db = transforms.SpectrogramToDB("power", top_db) - audio_orig = self.sig.clone() # (16000, 1) - audio_scaled = transforms.Scale()(audio_orig) # (16000, 1) - audio_scaled = transforms.LC2CL()(audio_scaled) # (1, 16000) + audio_orig = self.sig.clone() # (1, 16000) + audio_scaled = transforms.Scale()(audio_orig) # (1, 16000) mel_transform = transforms.MelSpectrogram() # check defaults spectrogram_torch = s2db(mel_transform(audio_scaled)) # (1, 319, 40) + print(spectrogram_torch.shape) self.assertTrue(spectrogram_torch.dim() == 3) self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all()) self.assertEqual(spectrogram_torch.size(-1), mel_transform.n_mels) @@ -166,141 +136,141 @@ def test_mel2(self): self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all()) self.assertEqual(fb_matrix_transform.fb.size(), (400, 100)) - def test_mfcc(self): - audio_orig = self.sig.clone() - audio_scaled = transforms.Scale()(audio_orig) # (16000, 1) - audio_scaled = transforms.LC2CL()(audio_scaled) # (1, 16000) - - sample_rate = 16000 - n_mfcc = 40 - n_mels = 128 - mfcc_transform = torchaudio.transforms.MFCC(sr=sample_rate, - n_mfcc=n_mfcc, - norm='ortho') - # check defaults - torch_mfcc = mfcc_transform(audio_scaled) - self.assertTrue(torch_mfcc.dim() == 3) - self.assertTrue(torch_mfcc.shape[2] == n_mfcc) - self.assertTrue(torch_mfcc.shape[1] == 321) - # check melkwargs are passed through - melkwargs = {'ws': 200} - mfcc_transform2 = torchaudio.transforms.MFCC(sr=sample_rate, - n_mfcc=n_mfcc, - norm='ortho', - melkwargs=melkwargs) - torch_mfcc2 = mfcc_transform2(audio_scaled) - self.assertTrue(torch_mfcc2.shape[1] == 641) - - # check norms work correctly - mfcc_transform_norm_none = torchaudio.transforms.MFCC(sr=sample_rate, - n_mfcc=n_mfcc, - norm=None) - torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled) - - norm_check = torch_mfcc.clone() - norm_check[:, :, 0] *= math.sqrt(n_mels) * 2 - norm_check[:, :, 1:] *= math.sqrt(n_mels / 2) * 2 - - self.assertTrue(torch_mfcc_norm_none.allclose(norm_check)) - - @unittest.skipIf(not IMPORT_LIBROSA or not IMPORT_SCIPY, 'Librosa and scipy are not available') - def test_librosa_consistency(self): - def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate): - input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav') - sound, sample_rate = torchaudio.load(input_path) - sound_librosa = sound.cpu().numpy().squeeze().T # squeeze batch and channel first - - # test core spectrogram - spect_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop=hop_length, power=2) - out_librosa, _ = librosa.core.spectrum._spectrogram(y=sound_librosa, - n_fft=n_fft, - hop_length=hop_length, - power=2) - - out_torch = spect_transform(sound).squeeze().cpu().t() - self.assertTrue(torch.allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5)) - - # test mel spectrogram - melspect_transform = torchaudio.transforms.MelSpectrogram(sr=sample_rate, window=torch.hann_window, - hop=hop_length, n_mels=n_mels, n_fft=n_fft) - librosa_mel = librosa.feature.melspectrogram(y=sound_librosa, sr=sample_rate, - n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, - htk=True, norm=None) - librosa_mel_tensor = torch.from_numpy(librosa_mel) - torch_mel = melspect_transform(sound).squeeze().cpu().t() - - self.assertTrue(torch.allclose(torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3)) - - # test s2db - - db_transform = torchaudio.transforms.SpectrogramToDB("power", 80.) - db_torch = db_transform(spect_transform(sound)).squeeze().cpu().t() - db_librosa = librosa.core.spectrum.power_to_db(out_librosa) - self.assertTrue(torch.allclose(db_torch, torch.from_numpy(db_librosa), atol=5e-3)) - - db_torch = db_transform(melspect_transform(sound)).squeeze().cpu().t() - db_librosa = librosa.core.spectrum.power_to_db(librosa_mel) - db_librosa_tensor = torch.from_numpy(db_librosa) - - self.assertTrue(torch.allclose(db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3)) - - # test MFCC - melkwargs = {'hop': hop_length, 'n_fft': n_fft} - mfcc_transform = torchaudio.transforms.MFCC(sr=sample_rate, - n_mfcc=n_mfcc, - norm='ortho', - melkwargs=melkwargs) - - # librosa.feature.mfcc doesn't pass kwargs properly since some of the - # kwargs for melspectrogram and mfcc are the same. We just follow the - # function body in https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram - # to mirror this function call with correct args: - - # librosa_mfcc = librosa.feature.mfcc(y=sound_librosa, - # sr=sample_rate, - # n_mfcc = n_mfcc, - # hop_length=hop_length, - # n_fft=n_fft, - # htk=True, - # norm=None, - # n_mels=n_mels) - - librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc] - librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc) - torch_mfcc = mfcc_transform(sound).squeeze().cpu().t() - - self.assertTrue(torch.allclose(torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3)) - - kwargs1 = { - 'n_fft': 400, - 'hop_length': 200, - 'power': 2.0, - 'n_mels': 128, - 'n_mfcc': 40, - 'sample_rate': 16000 - } - - kwargs2 = { - 'n_fft': 600, - 'hop_length': 100, - 'power': 2.0, - 'n_mels': 128, - 'n_mfcc': 20, - 'sample_rate': 16000 - } - - kwargs3 = { - 'n_fft': 200, - 'hop_length': 50, - 'power': 2.0, - 'n_mels': 128, - 'n_mfcc': 50, - 'sample_rate': 24000 - } - - _test_librosa_consistency_helper(**kwargs1) - _test_librosa_consistency_helper(**kwargs2) - _test_librosa_consistency_helper(**kwargs3) + # def test_mfcc(self): + # audio_orig = self.sig.clone() + # audio_scaled = transforms.Scale()(audio_orig) # (16000, 1) + # audio_scaled = transforms.LC2CL()(audio_scaled) # (1, 16000) + # + # sample_rate = 16000 + # n_mfcc = 40 + # n_mels = 128 + # mfcc_transform = torchaudio.transforms.MFCC(sr=sample_rate, + # n_mfcc=n_mfcc, + # norm='ortho') + # # check defaults + # torch_mfcc = mfcc_transform(audio_scaled) + # self.assertTrue(torch_mfcc.dim() == 3) + # self.assertTrue(torch_mfcc.shape[2] == n_mfcc) + # self.assertTrue(torch_mfcc.shape[1] == 321) + # # check melkwargs are passed through + # melkwargs = {'ws': 200} + # mfcc_transform2 = torchaudio.transforms.MFCC(sr=sample_rate, + # n_mfcc=n_mfcc, + # norm='ortho', + # melkwargs=melkwargs) + # torch_mfcc2 = mfcc_transform2(audio_scaled) + # self.assertTrue(torch_mfcc2.shape[1] == 641) + # + # # check norms work correctly + # mfcc_transform_norm_none = torchaudio.transforms.MFCC(sr=sample_rate, + # n_mfcc=n_mfcc, + # norm=None) + # torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled) + # + # norm_check = torch_mfcc.clone() + # norm_check[:, :, 0] *= math.sqrt(n_mels) * 2 + # norm_check[:, :, 1:] *= math.sqrt(n_mels / 2) * 2 + # + # self.assertTrue(torch_mfcc_norm_none.allclose(norm_check)) + # + # @unittest.skipIf(not IMPORT_LIBROSA or not IMPORT_SCIPY, 'Librosa and scipy are not available') + # def test_librosa_consistency(self): + # def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate): + # input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav') + # sound, sample_rate = torchaudio.load(input_path) + # sound_librosa = sound.cpu().numpy().squeeze().T # squeeze batch and channel first + # + # # test core spectrogram + # spect_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop=hop_length, power=2) + # out_librosa, _ = librosa.core.spectrum._spectrogram(y=sound_librosa, + # n_fft=n_fft, + # hop_length=hop_length, + # power=2) + # + # out_torch = spect_transform(sound).squeeze().cpu().t() + # self.assertTrue(torch.allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5)) + # + # # test mel spectrogram + # melspect_transform = torchaudio.transforms.MelSpectrogram(sr=sample_rate, window=torch.hann_window, + # hop=hop_length, n_mels=n_mels, n_fft=n_fft) + # librosa_mel = librosa.feature.melspectrogram(y=sound_librosa, sr=sample_rate, + # n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, + # htk=True, norm=None) + # librosa_mel_tensor = torch.from_numpy(librosa_mel) + # torch_mel = melspect_transform(sound).squeeze().cpu().t() + # + # self.assertTrue(torch.allclose(torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3)) + # + # # test s2db + # + # db_transform = torchaudio.transforms.SpectrogramToDB("power", 80.) + # db_torch = db_transform(spect_transform(sound)).squeeze().cpu().t() + # db_librosa = librosa.core.spectrum.power_to_db(out_librosa) + # self.assertTrue(torch.allclose(db_torch, torch.from_numpy(db_librosa), atol=5e-3)) + # + # db_torch = db_transform(melspect_transform(sound)).squeeze().cpu().t() + # db_librosa = librosa.core.spectrum.power_to_db(librosa_mel) + # db_librosa_tensor = torch.from_numpy(db_librosa) + # + # self.assertTrue(torch.allclose(db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3)) + # + # # test MFCC + # melkwargs = {'hop': hop_length, 'n_fft': n_fft} + # mfcc_transform = torchaudio.transforms.MFCC(sr=sample_rate, + # n_mfcc=n_mfcc, + # norm='ortho', + # melkwargs=melkwargs) + # + # # librosa.feature.mfcc doesn't pass kwargs properly since some of the + # # kwargs for melspectrogram and mfcc are the same. We just follow the + # # function body in https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram + # # to mirror this function call with correct args: + # + # # librosa_mfcc = librosa.feature.mfcc(y=sound_librosa, + # # sr=sample_rate, + # # n_mfcc = n_mfcc, + # # hop_length=hop_length, + # # n_fft=n_fft, + # # htk=True, + # # norm=None, + # # n_mels=n_mels) + # + # librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc] + # librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc) + # torch_mfcc = mfcc_transform(sound).squeeze().cpu().t() + # + # self.assertTrue(torch.allclose(torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3)) + # + # kwargs1 = { + # 'n_fft': 400, + # 'hop_length': 200, + # 'power': 2.0, + # 'n_mels': 128, + # 'n_mfcc': 40, + # 'sample_rate': 16000 + # } + # + # kwargs2 = { + # 'n_fft': 600, + # 'hop_length': 100, + # 'power': 2.0, + # 'n_mels': 128, + # 'n_mfcc': 20, + # 'sample_rate': 16000 + # } + # + # kwargs3 = { + # 'n_fft': 200, + # 'hop_length': 50, + # 'power': 2.0, + # 'n_mels': 128, + # 'n_mfcc': 50, + # 'sample_rate': 24000 + # } + # + # _test_librosa_consistency_helper(**kwargs1) + # _test_librosa_consistency_helper(**kwargs2) + # _test_librosa_consistency_helper(**kwargs3) def test_resample_size(self): input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav') diff --git a/torchaudio/functional.py b/torchaudio/functional.py index b2ff9ef921..0fa76e1431 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -364,8 +364,8 @@ def create_dct(n_mfcc, n_mels, norm): row-wise data of size (`n_mels`, `n_mfcc`). """ # http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II - n = torch.arange(n_mels, dtype=torch.get_default_dtype()) - k = torch.arange(n_mfcc, dtype=torch.get_default_dtype()).unsqueeze(1) + n = torch.arange(float(n_mels)) + k = torch.arange(float(n_mfcc)).unsqueeze(1) dct = torch.cos(math.pi / float(n_mels) * (n + 0.5) * k) # size (n_mfcc, n_mels) if norm is None: dct *= 2.0 diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index e4ccd6b9f1..4638376406 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -69,13 +69,12 @@ class DownmixMono(torch.jit.ScriptModule): Returns: tensor (Tensor) (1, n): """ - def __init__(self): super(DownmixMono, self).__init__() @torch.jit.script_method def forward(self, tensor): - return F.downmix_mono(tensor, self.ch_dim) + return F.downmix_mono(tensor) class LC2CL(torch.jit.ScriptModule): From 23d29355e9792fe88200ccd982f7c24537893bfd Mon Sep 17 00:00:00 2001 From: Jason Lian Date: Thu, 18 Jul 2019 08:50:56 -0700 Subject: [PATCH 04/28] more --- test/test_transforms.py | 283 ++++++++++++++++++++-------------------- 1 file changed, 140 insertions(+), 143 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index b14ee1d09c..7d7e222302 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -106,171 +106,168 @@ def test_mel2(self): audio_scaled = transforms.Scale()(audio_orig) # (1, 16000) mel_transform = transforms.MelSpectrogram() # check defaults - spectrogram_torch = s2db(mel_transform(audio_scaled)) # (1, 319, 40) - print(spectrogram_torch.shape) + spectrogram_torch = s2db(mel_transform(audio_scaled)) # (1, 128, 321) self.assertTrue(spectrogram_torch.dim() == 3) self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all()) - self.assertEqual(spectrogram_torch.size(-1), mel_transform.n_mels) + self.assertEqual(spectrogram_torch.size(1), mel_transform.n_mels) # check correctness of filterbank conversion matrix self.assertTrue(mel_transform.fm.fb.sum(1).le(1.).all()) self.assertTrue(mel_transform.fm.fb.sum(1).ge(0.).all()) # check options kwargs = {"window": torch.hamming_window, "pad": 10, "ws": 500, "hop": 125, "n_fft": 800, "n_mels": 50} mel_transform2 = transforms.MelSpectrogram(**kwargs) - spectrogram2_torch = s2db(mel_transform2(audio_scaled)) # (1, 506, 50) + spectrogram2_torch = s2db(mel_transform2(audio_scaled)) # (1, 50, 513) self.assertTrue(spectrogram2_torch.dim() == 3) self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all()) - self.assertEqual(spectrogram2_torch.size(-1), mel_transform2.n_mels) + self.assertEqual(spectrogram2_torch.size(1), mel_transform2.n_mels) self.assertTrue(mel_transform2.fm.fb.sum(1).le(1.).all()) self.assertTrue(mel_transform2.fm.fb.sum(1).ge(0.).all()) # check on multi-channel audio - x_stereo, sr_stereo = torchaudio.load(self.test_filepath) - spectrogram_stereo = s2db(mel_transform(x_stereo)) + x_stereo, sr_stereo = torchaudio.load(self.test_filepath) # (2, 278756), 44100 + spectrogram_stereo = s2db(mel_transform(x_stereo)) # (2, 128, 1394) self.assertTrue(spectrogram_stereo.dim() == 3) self.assertTrue(spectrogram_stereo.size(0) == 2) self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all()) - self.assertEqual(spectrogram_stereo.size(-1), mel_transform.n_mels) + self.assertEqual(spectrogram_stereo.size(1), mel_transform.n_mels) # check filterbank matrix creation fb_matrix_transform = transforms.MelScale(n_mels=100, sr=16000, f_max=None, f_min=0., n_stft=400) self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all()) self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all()) self.assertEqual(fb_matrix_transform.fb.size(), (400, 100)) - # def test_mfcc(self): - # audio_orig = self.sig.clone() - # audio_scaled = transforms.Scale()(audio_orig) # (16000, 1) - # audio_scaled = transforms.LC2CL()(audio_scaled) # (1, 16000) - # - # sample_rate = 16000 - # n_mfcc = 40 - # n_mels = 128 - # mfcc_transform = torchaudio.transforms.MFCC(sr=sample_rate, - # n_mfcc=n_mfcc, - # norm='ortho') - # # check defaults - # torch_mfcc = mfcc_transform(audio_scaled) - # self.assertTrue(torch_mfcc.dim() == 3) - # self.assertTrue(torch_mfcc.shape[2] == n_mfcc) - # self.assertTrue(torch_mfcc.shape[1] == 321) - # # check melkwargs are passed through - # melkwargs = {'ws': 200} - # mfcc_transform2 = torchaudio.transforms.MFCC(sr=sample_rate, - # n_mfcc=n_mfcc, - # norm='ortho', - # melkwargs=melkwargs) - # torch_mfcc2 = mfcc_transform2(audio_scaled) - # self.assertTrue(torch_mfcc2.shape[1] == 641) - # - # # check norms work correctly - # mfcc_transform_norm_none = torchaudio.transforms.MFCC(sr=sample_rate, - # n_mfcc=n_mfcc, - # norm=None) - # torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled) - # - # norm_check = torch_mfcc.clone() - # norm_check[:, :, 0] *= math.sqrt(n_mels) * 2 - # norm_check[:, :, 1:] *= math.sqrt(n_mels / 2) * 2 - # - # self.assertTrue(torch_mfcc_norm_none.allclose(norm_check)) - # - # @unittest.skipIf(not IMPORT_LIBROSA or not IMPORT_SCIPY, 'Librosa and scipy are not available') - # def test_librosa_consistency(self): - # def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate): - # input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav') - # sound, sample_rate = torchaudio.load(input_path) - # sound_librosa = sound.cpu().numpy().squeeze().T # squeeze batch and channel first - # - # # test core spectrogram - # spect_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop=hop_length, power=2) - # out_librosa, _ = librosa.core.spectrum._spectrogram(y=sound_librosa, - # n_fft=n_fft, - # hop_length=hop_length, - # power=2) - # - # out_torch = spect_transform(sound).squeeze().cpu().t() - # self.assertTrue(torch.allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5)) - # - # # test mel spectrogram - # melspect_transform = torchaudio.transforms.MelSpectrogram(sr=sample_rate, window=torch.hann_window, - # hop=hop_length, n_mels=n_mels, n_fft=n_fft) - # librosa_mel = librosa.feature.melspectrogram(y=sound_librosa, sr=sample_rate, - # n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, - # htk=True, norm=None) - # librosa_mel_tensor = torch.from_numpy(librosa_mel) - # torch_mel = melspect_transform(sound).squeeze().cpu().t() - # - # self.assertTrue(torch.allclose(torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3)) - # - # # test s2db - # - # db_transform = torchaudio.transforms.SpectrogramToDB("power", 80.) - # db_torch = db_transform(spect_transform(sound)).squeeze().cpu().t() - # db_librosa = librosa.core.spectrum.power_to_db(out_librosa) - # self.assertTrue(torch.allclose(db_torch, torch.from_numpy(db_librosa), atol=5e-3)) - # - # db_torch = db_transform(melspect_transform(sound)).squeeze().cpu().t() - # db_librosa = librosa.core.spectrum.power_to_db(librosa_mel) - # db_librosa_tensor = torch.from_numpy(db_librosa) - # - # self.assertTrue(torch.allclose(db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3)) - # - # # test MFCC - # melkwargs = {'hop': hop_length, 'n_fft': n_fft} - # mfcc_transform = torchaudio.transforms.MFCC(sr=sample_rate, - # n_mfcc=n_mfcc, - # norm='ortho', - # melkwargs=melkwargs) - # - # # librosa.feature.mfcc doesn't pass kwargs properly since some of the - # # kwargs for melspectrogram and mfcc are the same. We just follow the - # # function body in https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram - # # to mirror this function call with correct args: - # - # # librosa_mfcc = librosa.feature.mfcc(y=sound_librosa, - # # sr=sample_rate, - # # n_mfcc = n_mfcc, - # # hop_length=hop_length, - # # n_fft=n_fft, - # # htk=True, - # # norm=None, - # # n_mels=n_mels) - # - # librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc] - # librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc) - # torch_mfcc = mfcc_transform(sound).squeeze().cpu().t() - # - # self.assertTrue(torch.allclose(torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3)) - # - # kwargs1 = { - # 'n_fft': 400, - # 'hop_length': 200, - # 'power': 2.0, - # 'n_mels': 128, - # 'n_mfcc': 40, - # 'sample_rate': 16000 - # } - # - # kwargs2 = { - # 'n_fft': 600, - # 'hop_length': 100, - # 'power': 2.0, - # 'n_mels': 128, - # 'n_mfcc': 20, - # 'sample_rate': 16000 - # } - # - # kwargs3 = { - # 'n_fft': 200, - # 'hop_length': 50, - # 'power': 2.0, - # 'n_mels': 128, - # 'n_mfcc': 50, - # 'sample_rate': 24000 - # } - # - # _test_librosa_consistency_helper(**kwargs1) - # _test_librosa_consistency_helper(**kwargs2) - # _test_librosa_consistency_helper(**kwargs3) + def test_mfcc(self): + audio_orig = self.sig.clone() + audio_scaled = transforms.Scale()(audio_orig) # (1, 16000) + + sample_rate = 16000 + n_mfcc = 40 + n_mels = 128 + mfcc_transform = torchaudio.transforms.MFCC(sr=sample_rate, + n_mfcc=n_mfcc, + norm='ortho') + # check defaults + torch_mfcc = mfcc_transform(audio_scaled) # (1, 40, 321) + self.assertTrue(torch_mfcc.dim() == 3) + self.assertTrue(torch_mfcc.shape[1] == n_mfcc) + self.assertTrue(torch_mfcc.shape[2] == 321) + # check melkwargs are passed through + melkwargs = {'ws': 200} + mfcc_transform2 = torchaudio.transforms.MFCC(sr=sample_rate, + n_mfcc=n_mfcc, + norm='ortho', + melkwargs=melkwargs) + torch_mfcc2 = mfcc_transform2(audio_scaled) # (1, 40, 641) + self.assertTrue(torch_mfcc2.shape[2] == 641) + + # check norms work correctly + mfcc_transform_norm_none = torchaudio.transforms.MFCC(sr=sample_rate, + n_mfcc=n_mfcc, + norm=None) + torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled) # (1, 40, 321) + + norm_check = torch_mfcc.clone() + norm_check[:, 0, :] *= math.sqrt(n_mels) * 2 + norm_check[:, 1:, :] *= math.sqrt(n_mels / 2) * 2 + + self.assertTrue(torch_mfcc_norm_none.allclose(norm_check)) + + @unittest.skipIf(not IMPORT_LIBROSA or not IMPORT_SCIPY, 'Librosa and scipy are not available') + def test_librosa_consistency(self): + def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate): + input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav') + sound, sample_rate = torchaudio.load(input_path) + sound_librosa = sound.cpu().numpy().squeeze() # (64000) + + # test core spectrogram + spect_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop=hop_length, power=2) + out_librosa, _ = librosa.core.spectrum._spectrogram(y=sound_librosa, + n_fft=n_fft, + hop_length=hop_length, + power=2) + + out_torch = spect_transform(sound).squeeze().cpu() + self.assertTrue(torch.allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5)) + + # test mel spectrogram + melspect_transform = torchaudio.transforms.MelSpectrogram(sr=sample_rate, window=torch.hann_window, + hop=hop_length, n_mels=n_mels, n_fft=n_fft) + librosa_mel = librosa.feature.melspectrogram(y=sound_librosa, sr=sample_rate, + n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, + htk=True, norm=None) + librosa_mel_tensor = torch.from_numpy(librosa_mel) + torch_mel = melspect_transform(sound).squeeze().cpu() + + self.assertTrue(torch.allclose(torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3)) + + # test s2db + db_transform = torchaudio.transforms.SpectrogramToDB("power", 80.) + db_torch = db_transform(spect_transform(sound)).squeeze().cpu() + db_librosa = librosa.core.spectrum.power_to_db(out_librosa) + self.assertTrue(torch.allclose(db_torch, torch.from_numpy(db_librosa), atol=5e-3)) + + db_torch = db_transform(melspect_transform(sound)).squeeze().cpu() + db_librosa = librosa.core.spectrum.power_to_db(librosa_mel) + db_librosa_tensor = torch.from_numpy(db_librosa) + + self.assertTrue(torch.allclose(db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3)) + + # test MFCC + melkwargs = {'hop': hop_length, 'n_fft': n_fft} + mfcc_transform = torchaudio.transforms.MFCC(sr=sample_rate, + n_mfcc=n_mfcc, + norm='ortho', + melkwargs=melkwargs) + + # librosa.feature.mfcc doesn't pass kwargs properly since some of the + # kwargs for melspectrogram and mfcc are the same. We just follow the + # function body in https://librosa.github.io/librosa/_modules/librosa/feature/spectral.html#melspectrogram + # to mirror this function call with correct args: + + # librosa_mfcc = librosa.feature.mfcc(y=sound_librosa, + # sr=sample_rate, + # n_mfcc = n_mfcc, + # hop_length=hop_length, + # n_fft=n_fft, + # htk=True, + # norm=None, + # n_mels=n_mels) + + librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc] + librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc) + torch_mfcc = mfcc_transform(sound).squeeze().cpu() + + self.assertTrue(torch.allclose(torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3)) + + kwargs1 = { + 'n_fft': 400, + 'hop_length': 200, + 'power': 2.0, + 'n_mels': 128, + 'n_mfcc': 40, + 'sample_rate': 16000 + } + + kwargs2 = { + 'n_fft': 600, + 'hop_length': 100, + 'power': 2.0, + 'n_mels': 128, + 'n_mfcc': 20, + 'sample_rate': 16000 + } + + kwargs3 = { + 'n_fft': 200, + 'hop_length': 50, + 'power': 2.0, + 'n_mels': 128, + 'n_mfcc': 50, + 'sample_rate': 24000 + } + + _test_librosa_consistency_helper(**kwargs1) + _test_librosa_consistency_helper(**kwargs2) + _test_librosa_consistency_helper(**kwargs3) def test_resample_size(self): input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav') From fce66374ae6df61337583b7eddbfd91a3c32ca44 Mon Sep 17 00:00:00 2001 From: Jason Lian Date: Thu, 18 Jul 2019 10:02:54 -0700 Subject: [PATCH 05/28] more --- test/test_jit.py | 64 +++++++++++++--------------------------- torchaudio/functional.py | 1 + 2 files changed, 22 insertions(+), 43 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index d2652a9dc4..3cccca9d05 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -33,11 +33,11 @@ def _test_script_module(self, tensor, f, *args): def test_torchscript_scale(self): @torch.jit.script def jit_method(tensor, factor): - # type: (Tensor, int) -> Tensor + # type: (Tensor, float) -> Tensor return F.scale(tensor, factor) - tensor = torch.rand((10, 1)) - factor = 2 + tensor = torch.rand((1, 10)) + factor = 2.0 jit_out = jit_method(tensor, factor) py_out = F.scale(tensor, factor) @@ -46,24 +46,22 @@ def jit_method(tensor, factor): @unittest.skipIf(not RUN_CUDA, "no CUDA") def test_scriptmodule_scale(self): - tensor = torch.rand((10, 1), device="cuda") + tensor = torch.rand((1, 10), device="cuda") self._test_script_module(tensor, transforms.Scale) def test_torchscript_pad_trim(self): @torch.jit.script - def jit_method(tensor, ch_dim, max_len, len_dim, fill_value): - # type: (Tensor, int, int, int, float) -> Tensor - return F.pad_trim(tensor, ch_dim, max_len, len_dim, fill_value) + def jit_method(tensor, max_len, fill_value): + # type: (Tensor, int, float) -> Tensor + return F.pad_trim(tensor, max_len, fill_value) - tensor = torch.rand((10, 1)) - ch_dim = 1 + tensor = torch.rand((1, 10)) max_len = 5 - len_dim = 0 fill_value = 3. - jit_out = jit_method(tensor, ch_dim, max_len, len_dim, fill_value) - py_out = F.pad_trim(tensor, ch_dim, max_len, len_dim, fill_value) + jit_out = jit_method(tensor, max_len, fill_value) + py_out = F.pad_trim(tensor, max_len, fill_value) self.assertTrue(torch.allclose(jit_out, py_out)) @@ -76,21 +74,20 @@ def test_scriptmodule_pad_trim(self): def test_torchscript_downmix_mono(self): @torch.jit.script - def jit_method(tensor, ch_dim): - # type: (Tensor, int) -> Tensor - return F.downmix_mono(tensor, ch_dim) + def jit_method(tensor): + # type: (Tensor) -> Tensor + return F.downmix_mono(tensor) - tensor = torch.rand((10, 1)) - ch_dim = 1 + tensor = torch.rand((2, 10)) - jit_out = jit_method(tensor, ch_dim) - py_out = F.downmix_mono(tensor, ch_dim) + jit_out = jit_method(tensor) + py_out = F.downmix_mono(tensor) self.assertTrue(torch.allclose(jit_out, py_out)) @unittest.skipIf(not RUN_CUDA, "no CUDA") def test_scriptmodule_downmix_mono(self): - tensor = torch.rand((1, 10), device="cuda") + tensor = torch.rand((2, 10), device="cuda") self._test_script_module(tensor, transforms.DownmixMono) @@ -211,32 +208,13 @@ def test_scriptmodule_MelSpectrogram(self): self._test_script_module(tensor, transforms.MelSpectrogram) - def test_torchscript_BLC2CBL(self): - @torch.jit.script - def jit_method(tensor): - # type: (Tensor) -> Tensor - return F.BLC2CBL(tensor) - - tensor = torch.rand((10, 1000, 1)) - - jit_out = jit_method(tensor) - py_out = F.BLC2CBL(tensor) - - self.assertTrue(torch.allclose(jit_out, py_out)) - - @unittest.skipIf(not RUN_CUDA, "no CUDA") - def test_scriptmodule_BLC2CBL(self): - tensor = torch.rand((10, 1000, 1), device="cuda") - - self._test_script_module(tensor, transforms.BLC2CBL) - def test_torchscript_mu_law_encoding(self): @torch.jit.script def jit_method(tensor, qc): # type: (Tensor, int) -> Tensor return F.mu_law_encoding(tensor, qc) - tensor = torch.rand((10, 1)) + tensor = torch.rand((1, 10)) qc = 256 jit_out = jit_method(tensor, qc) @@ -246,7 +224,7 @@ def jit_method(tensor, qc): @unittest.skipIf(not RUN_CUDA, "no CUDA") def test_scriptmodule_MuLawEncoding(self): - tensor = torch.rand((10, 1), device="cuda") + tensor = torch.rand((1, 10), device="cuda") self._test_script_module(tensor, transforms.MuLawEncoding) @@ -256,7 +234,7 @@ def jit_method(tensor, qc): # type: (Tensor, int) -> Tensor return F.mu_law_expanding(tensor, qc) - tensor = torch.rand((10, 1)) + tensor = torch.rand((1, 10)) qc = 256 jit_out = jit_method(tensor, qc) @@ -266,7 +244,7 @@ def jit_method(tensor, qc): @unittest.skipIf(not RUN_CUDA, "no CUDA") def test_scriptmodule_MuLawExpanding(self): - tensor = torch.rand((10, 1), device="cuda") + tensor = torch.rand((1, 10), device="cuda") self._test_script_module(tensor, transforms.MuLawExpanding) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 0fa76e1431..68de1ba788 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -92,6 +92,7 @@ def LC2CL(tensor): return tensor.transpose(0, 1).contiguous() +# TODO: remove this once https://github.com/pytorch/pytorch/issues/21478 gets solved @torch.jit.ignore def _stft(input, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided): # type: (Tensor, int, Optional[int], Optional[int], Optional[Tensor], bool, str, bool, bool) -> Tensor From 99f449bcbd391382deb25bb72fa2c800c0861adc Mon Sep 17 00:00:00 2001 From: Jason Lian Date: Thu, 18 Jul 2019 10:04:08 -0700 Subject: [PATCH 06/28] more --- test/test_jit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 3cccca9d05..9e3dcd107e 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -164,7 +164,7 @@ def jit_method(spec, multiplier, amin, db_multiplier, top_db): # type: (Tensor, float, float, float, Optional[float]) -> Tensor return F.spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db) - spec = torch.rand((10, 1)) + spec = torch.rand((6, 201)) multiplier = 10. amin = 1e-10 db_multiplier = 0. @@ -177,7 +177,7 @@ def jit_method(spec, multiplier, amin, db_multiplier, top_db): @unittest.skipIf(not RUN_CUDA, "no CUDA") def test_scriptmodule_SpectrogramToDB(self): - spec = torch.rand((10, 1), device="cuda") + spec = torch.rand((6, 201), device="cuda") self._test_script_module(spec, transforms.SpectrogramToDB) From f00c46c6b462958dacc5ba47506ec43c61a63657 Mon Sep 17 00:00:00 2001 From: Jason Lian Date: Mon, 22 Jul 2019 11:43:42 -0700 Subject: [PATCH 07/28] small push to save progress --- torchaudio/functional.py | 67 +++++++------ torchaudio/transforms.py | 207 ++++++++++++++++++++------------------- 2 files changed, 139 insertions(+), 135 deletions(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 68de1ba788..7b03288df8 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -20,14 +20,12 @@ @torch.jit.script def scale(tensor, factor): # type: (Tensor, float) -> Tensor - r"""Scale audio tensor from a 16-bit integer (represented as a - :class:`torch.FloatTensor`) to a floating point number between -1.0 and 1.0. - Note the 16-bit number is called the "bit depth" or "precision", not to be - confused with "bit rate". + r"""Scales tensor by a factor. By default, assuming the input is int32, it + will scale the tensor to have values between -1.0 and 1.0. Args: - tensor (torch.Tensor): Tensor of audio of size (c, n) - factor (float): Maximum value of input tensor + tensor (torch.Tensor): Tensor input to scale + factor (float): Factor to scale by Returns: torch.Tensor: Scaled by the scale factor @@ -39,43 +37,44 @@ def scale(tensor, factor): @torch.jit.script -def pad_trim(tensor, max_len, fill_value): +def pad_trim(waveform, max_len, fill_value): # type: (Tensor, int, float) -> Tensor - r"""Pad/trim a 2D tensor (signal or labels). + r"""Pad/trim a 2D tensor Args: - tensor (torch.Tensor): Tensor of audio of size (c, n) - max_len (int): Length to which the tensor will be padded + waveform (torch.Tensor): Tensor of audio of size (c, n) + max_len (int): Length to which the waveform will be padded fill_value (float): Value to fill in Returns: torch.Tensor: Padded/trimmed tensor """ - n = tensor.size(1) + n = waveform.size(1) if max_len > n: # TODO add "with torch.no_grad():" back when JIT supports it - tensor = torch.nn.functional.pad(tensor, (0, max_len - n), 'constant', fill_value) + waveform = torch.nn.functional.pad(waveform, (0, max_len - n), 'constant', fill_value) else: - tensor = tensor[:, :max_len] - return tensor + waveform = waveform[:, :max_len] + return waveform @torch.jit.script -def downmix_mono(tensor): +def downmix_mono(waveform): # type: (Tensor) -> Tensor - r"""Downmix any stereo signals to mono. + r"""Downmix stereo waveform to mono. Consider using a `SoxEffectsChain` with + the `channels` effect instead of this transformation. Args: - tensor (torch.Tensor): Tensor of audio of size (c, n) + waveform (torch.Tensor): Tensor of audio of size (c, n) Returns: - torch.Tensor: Mono signal + torch.Tensor: Tensor that has been downmixed of size (1, n) """ - if not tensor.is_floating_point(): - tensor = tensor.to(torch.float32) + if not waveform.is_floating_point(): + waveform = waveform.to(torch.float32) - tensor = torch.mean(tensor, 0, True) - return tensor + waveform = torch.mean(waveform, 0, True) + return waveform @torch.jit.script @@ -245,25 +244,25 @@ def istft(stft_matrix, # type: Tensor @torch.jit.script -def spectrogram(sig, pad, window, n_fft, hop, ws, power, normalize): +def spectrogram(sig, pad, window, n_fft, hop_length, win_length, power, normalized): # type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor r"""Create a spectrogram from a raw audio signal. Args: sig (torch.Tensor): Tensor of audio of size (c, n) pad (int): Two sided padding of signal - window (torch.Tensor): Window_tensor + window (torch.Tensor): Window tensor that is applied/multiplied to each frame/window n_fft (int): Size of fft - hop (int): Length of hop between STFT windows - ws (int): Window size + hop_length (int): Length of hop between STFT windows + win_length (int): Window size power (int) : Exponent for the magnitude spectrogram, (must be > 0) e.g., 1 for energy, 2 for power, etc. - normalize (bool) : Whether to normalize by magnitude after stft + normalized (bool) : Whether to normalize by magnitude after stft Returns: torch.Tensor: Channels x frequency x time (c, f, t), where channels is unchanged, frequency is `n_fft // 2 + 1` where `n_fft` is the number of - fourier bins, time is the number of window hops + fourier bins, and time is the number of window hops (n_frames). """ assert sig.dim() == 2 @@ -272,17 +271,17 @@ def spectrogram(sig, pad, window, n_fft, hop, ws, power, normalize): sig = torch.nn.functional.pad(sig, (pad, pad), "constant") # default values are consistent with librosa.core.spectrum._spectrogram - spec_f = _stft(sig, n_fft, hop, ws, window, + spec_f = _stft(sig, n_fft, hop_length, win_length, window, True, 'reflect', False, True) - if normalize: + if normalized: spec_f /= window.pow(2).sum().sqrt() spec_f = spec_f.pow(power).sum(-1) # get power of "complex" tensor return spec_f @torch.jit.script -def spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db=None): +def spectrogram_to_DB(specgram, multiplier, amin, db_multiplier, top_db=None): # type: (Tensor, float, float, float, Optional[float]) -> Tensor r"""Turns a spectrogram from the power/amplitude scale to the decibel scale. @@ -291,9 +290,9 @@ def spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db=None): a full clip. Args: - spec (torch.Tensor): Normal STFT of size (c, f, t) + specgram (torch.Tensor): Normal STFT of size (c, f, t) multiplier (float): Use 10. for power and 20. for amplitude - amin (float): Number to clamp spec + amin (float): Number to clamp specgram db_multiplier (float): Log10(max(reference value and amin)) top_db (Optional[float]): Minimum negative cut-off in decibels. A reasonable number is 80. @@ -301,7 +300,7 @@ def spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db=None): Returns: torch.Tensor: Spectrogram in DB of size (c, f, t) """ - spec_db = multiplier * torch.log10(torch.clamp(spec, min=amin)) + spec_db = multiplier * torch.log10(torch.clamp(specgram, min=amin)) spec_db -= multiplier * db_multiplier if top_db is not None: diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 4638376406..3a09247fff 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -8,12 +8,11 @@ class Scale(torch.jit.ScriptModule): - """Scale audio tensor from a 16-bit integer (represented as a FloatTensor) - to a floating point number between -1.0 and 1.0. Note the 16-bit number is - called the "bit depth" or "precision", not to be confused with "bit rate". + r"""Scales tensor by a factor. By default, assuming the input is int32, it + will scale the tensor to have values between -1.0 and 1.0. Args: - factor (float): maximum value of input tensor. default: 16-bit depth + factor (float): Factor to scale by. (Default: `float(2**31)`) """ __constants__ = ['factor'] @@ -23,22 +22,22 @@ def __init__(self, factor=float(2**31)): @torch.jit.script_method def forward(self, tensor): - """ + r""" Args: - tensor (Tensor): Tensor of audio of size (c, n) + tensor (torch.Tensor): Tensor input to scale Returns: - Tensor: Scaled by the scale factor. (default between -1.0 and 1.0) + torch.Tensor: Scaled by the scale factor """ return F.scale(tensor, self.factor) class PadTrim(torch.jit.ScriptModule): - """Pad/Trim a 2D-Tensor (Signal or Labels) + r"""Pad/Trim a 2D tensor Args: - tensor (Tensor): Tensor of audio of size (c, n) - max_len (int): Length to which the tensor will be padded + max_len (int): Length to which the waveform will be padded + fill_value (float): Value to fill in """ __constants__ = ['max_len', 'fill_value'] @@ -48,109 +47,112 @@ def __init__(self, max_len, fill_value=0.): self.fill_value = fill_value @torch.jit.script_method - def forward(self, tensor): - """ + def forward(self, waveform): + r""" Args: - tensor (Tensor): Tensor of audio of size (c, n) + waveform (torch.Tensor): Tensor of audio of size (c, n) Returns: - Tensor: (c, `max_len`) + Tensor: Tensor of size (c, `max_len`) """ - return F.pad_trim(tensor, self.max_len, self.fill_value) + return F.pad_trim(waveform, self.max_len, self.fill_value) class DownmixMono(torch.jit.ScriptModule): - """Downmix any stereo signals to mono. Consider using a `SoxEffectsChain` with + r"""Downmix stereo waveform to mono. Consider using a `SoxEffectsChain` with the `channels` effect instead of this transformation. - - Inputs: - tensor (Tensor): Tensor of audio of size (c, n) - - Returns: - tensor (Tensor) (1, n): """ def __init__(self): super(DownmixMono, self).__init__() @torch.jit.script_method - def forward(self, tensor): - return F.downmix_mono(tensor) + def forward(self, waveform): + r""" + Args: + waveform (torch.Tensor): Tensor of audio of size (c, n) + + Returns: + torch.Tensor: Tensor that has been downmixed of size (1, n) + """ + return F.downmix_mono(waveform) class LC2CL(torch.jit.ScriptModule): - """Converts a 2D tensor from (n, c) to (c, n) + r"""Converts a 2D tensor from (n, c) to (c, n) """ def __init__(self): super(LC2CL, self).__init__() @torch.jit.script_method def forward(self, tensor): - """ + r""" Args: - tensor (Tensor): Tensor of audio signal with shape (n, c) + tensor (torch.Tensor): Tensor of audio signal with shape (n, c) Returns: - tensor (Tensor): Tensor of audio signal with shape (c, n) + torch.Tensor: Tensor of audio signal with shape (c, n) """ return F.LC2CL(tensor) class Spectrogram(torch.jit.ScriptModule): - """Create a spectrogram from a raw audio signal + r"""Create a spectrogram from a audio signal Args: - n_fft (int, optional): size of fft, creates n_fft // 2 + 1 bins - ws (int): window size. default: n_fft - hop (int, optional): length of hop between STFT windows. default: ws // 2 - pad (int): two sided padding of signal - window (torch windowing function): default: torch.hann_window - power (int > 0 ) : Exponent for the magnitude spectrogram, - e.g., 1 for energy, 2 for power, etc. - normalize (bool) : whether to normalize by magnitude after stft - wkwargs (dict, optional): arguments for window function + n_fft (int, optional): Size of fft, creates `n_fft // 2 + 1` bins + win_length (int): Window size. (Default: `n_fft`) + hop_length (int, optional): Length of hop between STFT windows. ( + Default: `win_length // 2`) + pad (int): Two sided padding of signal. (Default: 0) + window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor + that is applied/multiplied to each frame/window. (Default: `torch.hann_window`) + power (int) : Exponent for the magnitude spectrogram, + (must be > 0) e.g., 1 for energy, 2 for power, etc. + normalized (bool) : Whether to normalized by magnitude after stft. (Default: `False`) + wkwargs (Dict[..., ...]): Arguments for window function. (Default: `None`) """ - __constants__ = ['n_fft', 'ws', 'hop', 'pad', 'power', 'normalize'] + __constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized'] - def __init__(self, n_fft=400, ws=None, hop=None, - pad=0, window=torch.hann_window, - power=2, normalize=False, wkwargs=None): + def __init__(self, n_fft=400, win_length=None, hop_length=None, + pad=0, window_fn=torch.hann_window, + power=2, normalized=False, wkwargs=None): super(Spectrogram, self).__init__() self.n_fft = n_fft # number of fft bins. the returned STFT result will have n_fft // 2 + 1 # number of frequecies due to onesided=True in torch.stft - self.ws = ws if ws is not None else n_fft - self.hop = hop if hop is not None else self.ws // 2 - window = window(self.ws) if wkwargs is None else window(self.ws, **wkwargs) + self.win_length = win_length if win_length is not None else n_fft + self.hop_length = hop_length if hop_length is not None else self.win_length // 2 + window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs) self.window = torch.jit.Attribute(window, torch.Tensor) self.pad = pad self.power = power - self.normalize = normalize + self.normalized = normalized @torch.jit.script_method - def forward(self, sig): - """ + def forward(self, waveform): + r""" Args: - sig (Tensor): Tensor of audio of size (c, n) + waveform (torch.Tensor): Tensor of audio of size (c, n) Returns: - spec_f (Tensor): Channels x frequency x time (c, f, t), where channels + torch.Tensor: Channels x frequency x time (c, f, t), where channels is unchanged, frequency is `n_fft // 2 + 1` where `n_fft` is the number of - fourier bins, time is the number of window hops + fourier bins, and time is the number of window hops (n_frames). """ - return F.spectrogram(sig, self.pad, self.window, self.n_fft, self.hop, - self.ws, self.power, self.normalize) + return F.spectrogram(waveform, self.pad, self.window, self.n_fft, self.hop_length, + self.win_length, self.power, self.normalized) class SpectrogramToDB(torch.jit.ScriptModule): - """Turns a spectrogram from the power/amplitude scale to the decibel scale. + r"""Turns a spectrogram from the power/amplitude scale to the decibel scale. This output depends on the maximum value in the input spectrogram, and so may return different values for an audio clip split into snippets vs. a a full clip. Args: - stype (str): scale of input spectrogram ("power" or "magnitude"). The - power being the elementwise square of the magnitude. default: "power" + stype (str): scale of input spectrogram ("power" or "magnitude"). The + power being the elementwise square of the magnitude. (Default: 'power') top_db (float, optional): minimum negative cut-off in decibels. A reasonable number is 80. """ @@ -168,67 +170,68 @@ def __init__(self, stype="power", top_db=None): self.db_multiplier = math.log10(max(self.amin, self.ref_value)) @torch.jit.script_method - def forward(self, spec): + def forward(self, specgram): r"""Numerically stable implementation from Librosa https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html Args: - spec (torch.Tensor): STFT of size (c, f, t) + specgram (torch.Tensor): STFT of size (c, f, t) Returns: torch.Tensor: STFT after changing scale of size (c, f, t) """ - return F.spectrogram_to_DB(spec, self.multiplier, self.amin, self.db_multiplier, self.top_db) + return F.spectrogram_to_DB(specgram, self.multiplier, self.amin, self.db_multiplier, self.top_db) class MelScale(torch.jit.ScriptModule): - """This turns a normal STFT into a mel frequency STFT, using a conversion + r"""This turns a normal STFT into a mel frequency STFT, using a conversion matrix. This uses triangular filter banks. User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)). Args: - n_mels (int): Number of mel filterbanks - sr (int): sample rate of audio signal - f_max (float, optional): maximum frequency. default: `sr` // 2 - f_min (float): minimum frequency. default: 0 - n_stft (int, optional): number of filter banks from stft. Calculated from first input + n_mels (int): Number of mel filterbanks. (Default: 128) + sample_rate (int): Sample rate of audio signal. (Default: 16000). + f_min (float): Minimum frequency. (Default: 0.) + f_max (float, optional): Maximum frequency. (Default: `sample_rate // 2`) + n_stft (int, optional): Number of bins in STFT. Calculated from first input if `None` is given. See `n_fft` in `Spectrogram`. """ - __constants__ = ['n_mels', 'sr', 'f_min', 'f_max'] + __constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max'] - def __init__(self, n_mels=128, sr=16000, f_max=None, f_min=0., n_stft=None): + def __init__(self, n_mels=128, sample_rate=16000, f_min=0., f_max=None, n_stft=None): super(MelScale, self).__init__() self.n_mels = n_mels - self.sr = sr - self.f_max = f_max if f_max is not None else float(sr // 2) + self.sample_rate = sample_rate + self.f_max = f_max if f_max is not None else float(sample_rate // 2) + assert f_min <= f_max, 'Require f_min: %f < f_max: %f' % (f_min, f_max) self.f_min = f_min fb = torch.empty(0) if n_stft is None else F.create_fb_matrix( n_stft, self.f_min, self.f_max, self.n_mels) self.fb = torch.jit.Attribute(fb, torch.Tensor) @torch.jit.script_method - def forward(self, spec_f): + def forward(self, specgram): r""" Args: - spec_f (torch.Tensor): a spectrogram STFT of size (c, f, t) + specgram (torch.Tensor): a spectrogram STFT of size (c, f, t) Returns: torch.Tensor: mel frequency spectrogram of size (c, `n_mels`, t) """ if self.fb.numel() == 0: - tmp_fb = F.create_fb_matrix(spec_f.size(1), self.f_min, self.f_max, self.n_mels) + tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max, self.n_mels) # Attributes cannot be reassigned outside __init__ so workaround self.fb.resize_(tmp_fb.size()) self.fb.copy_(tmp_fb) # (c, f, t).transpose(...) dot (f, n_mels) -> (c, t, n_mels).transpose(...) - spec_m = torch.matmul(spec_f.transpose(1, 2), self.fb).transpose(1, 2) + spec_m = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2) return spec_m class MelSpectrogram(torch.jit.ScriptModule): - """Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram + r"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram and MelScale. Sources: @@ -237,51 +240,53 @@ class MelSpectrogram(torch.jit.ScriptModule): * http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html Args: - sr (int): sample rate of audio signal - ws (int): window size - hop (int, optional): length of hop between STFT windows. default: `ws` // 2 - n_fft (int, optional): number of fft bins. default: `ws` // 2 + 1 - f_max (float, optional): maximum frequency. default: `sr` // 2 - f_min (float): minimum frequency. default: 0 - pad (int): two sided padding of signal - n_mels (int): Number of mel filterbanks - window (torch windowing function): default: `torch.hann_window` - wkwargs (dict, optional): arguments for window function + sample_rate (int): Sample rate of audio signal. (Default: 16000). + win_length (int): Window size. (Default: `n_fft`) + hop_length (int, optional): Length of hop between STFT windows. ( + Default: `win_length // 2`) + n_fft (int, optional): Size of fft, creates `n_fft // 2 + 1` bins + f_min (float): Minimum frequency. (Default: 0.) + f_max (float, optional): Maximum frequency. (Default: `sample_rate // 2`) + pad (int): Two sided padding of signal. (Default: 0) + n_mels (int): Number of mel filterbanks. (Default: 128) + window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor + that is applied/multiplied to each frame/window. (Default: `torch.hann_window`) + wkwargs (Dict[..., ...]): Arguments for window function. (Default: `None`) Example: - >>> sig, sr = torchaudio.load("test.wav", normalization=True) - >>> spec_mel = transforms.MelSpectrogram(sr)(sig) # (c, n_mels, t) + >>> waveform, sample_rate = torchaudio.load("test.wav", normalization=True) + >>> specgram_mel = transforms.MelSpectrogram(sample_rate)(waveform) # (c, n_mels, t) """ - __constants__ = ['sr', 'n_fft', 'ws', 'hop', 'pad', 'n_mels', 'f_min'] + __constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad', 'n_mels', 'f_min'] - def __init__(self, sr=16000, n_fft=400, ws=None, hop=None, f_min=0., f_max=None, - pad=0, n_mels=128, window=torch.hann_window, wkwargs=None): + def __init__(self, sample_rate=16000, n_fft=400, win_length=None, hop_length=None, f_min=0., f_max=None, + pad=0, n_mels=128, window_fn=torch.hann_window, wkwargs=None): super(MelSpectrogram, self).__init__() - self.sr = sr + self.sample_rate = sample_rate self.n_fft = n_fft - self.ws = ws if ws is not None else n_fft - self.hop = hop if hop is not None else self.ws // 2 + self.win_length = win_length if win_length is not None else n_fft + self.hop_length = hop_length if hop_length is not None else self.win_length // 2 self.pad = pad self.n_mels = n_mels # number of mel frequency bins self.f_max = torch.jit.Attribute(f_max, Optional[float]) self.f_min = f_min - self.spec = Spectrogram(n_fft=self.n_fft, ws=self.ws, hop=self.hop, - pad=self.pad, window=window, power=2, - normalize=False, wkwargs=wkwargs) - self.fm = MelScale(self.n_mels, self.sr, self.f_max, self.f_min) + self.spec = Spectrogram(n_fft=self.n_fft, win_length=self.win_length, hop_length=self.hop_length, + pad=self.pad, window_fn=window_fn, power=2, + normalized=False, wkwargs=wkwargs) + self.fm = MelScale(self.n_mels, self.sample_rate, self.f_max, self.f_min) @torch.jit.script_method - def forward(self, sig): + def forward(self, waveform): """ Args: - sig (torch.Tensor): Tensor of audio of size (c, n) + waveform (torch.Tensor): Tensor of audio of size (c, n) Returns: torch.Tensor: mel frequency spectrogram of size (c, `n_mels`, t) """ - spec = self.spec(sig) - spec_mel = self.fm(spec) - return spec_mel + specgram = self.spec(waveform) + specgram_mel = self.fm(specgram) + return specgram_mel class MFCC(torch.jit.ScriptModule): From e3085d389b87418473e3363f7bba6866358f2fc5 Mon Sep 17 00:00:00 2001 From: Jason Lian Date: Mon, 22 Jul 2019 12:26:14 -0700 Subject: [PATCH 08/28] small push to save progress --- torchaudio/functional.py | 37 +++++++------- torchaudio/transforms.py | 107 +++++++++++++++++++-------------------- 2 files changed, 71 insertions(+), 73 deletions(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 7b03288df8..ddbd58362d 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -61,7 +61,7 @@ def pad_trim(waveform, max_len, fill_value): @torch.jit.script def downmix_mono(waveform): # type: (Tensor) -> Tensor - r"""Downmix stereo waveform to mono. Consider using a `SoxEffectsChain` with + r"""Downmix stereo waveform to mono. Consider using a `SoxEffectsChain` with the `channels` effect instead of this transformation. Args: @@ -244,12 +244,12 @@ def istft(stft_matrix, # type: Tensor @torch.jit.script -def spectrogram(sig, pad, window, n_fft, hop_length, win_length, power, normalized): +def spectrogram(waveform, pad, window, n_fft, hop_length, win_length, power, normalized): # type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor r"""Create a spectrogram from a raw audio signal. Args: - sig (torch.Tensor): Tensor of audio of size (c, n) + waveform (torch.Tensor): Tensor of audio of size (c, n) pad (int): Two sided padding of signal window (torch.Tensor): Window tensor that is applied/multiplied to each frame/window n_fft (int): Size of fft @@ -264,14 +264,14 @@ def spectrogram(sig, pad, window, n_fft, hop_length, win_length, power, normaliz is unchanged, frequency is `n_fft // 2 + 1` where `n_fft` is the number of fourier bins, and time is the number of window hops (n_frames). """ - assert sig.dim() == 2 + assert waveform.dim() == 2 if pad > 0: # TODO add "with torch.no_grad():" back when JIT supports it - sig = torch.nn.functional.pad(sig, (pad, pad), "constant") + waveform = torch.nn.functional.pad(waveform, (pad, pad), "constant") # default values are consistent with librosa.core.spectrum._spectrogram - spec_f = _stft(sig, n_fft, hop_length, win_length, window, + spec_f = _stft(waveform, n_fft, hop_length, win_length, window, True, 'reflect', False, True) if normalized: @@ -294,20 +294,21 @@ def spectrogram_to_DB(specgram, multiplier, amin, db_multiplier, top_db=None): multiplier (float): Use 10. for power and 20. for amplitude amin (float): Number to clamp specgram db_multiplier (float): Log10(max(reference value and amin)) - top_db (Optional[float]): Minimum negative cut-off in decibels. A reasonable number + top_db (Optional[float]): Minimum negative cut-off in decibels. A reasonable number is 80. Returns: torch.Tensor: Spectrogram in DB of size (c, f, t) """ - spec_db = multiplier * torch.log10(torch.clamp(specgram, min=amin)) - spec_db -= multiplier * db_multiplier + specgram_db = multiplier * torch.log10(torch.clamp(specgram, min=amin)) + specgram_db -= multiplier * db_multiplier if top_db is not None: - new_spec_db_max = torch.tensor(float(spec_db.max()) - top_db, dtype=spec_db.dtype, device=spec_db.device) - spec_db = torch.max(spec_db, new_spec_db_max) + new_spec_db_max = torch.tensor(float(specgram_db.max()) - top_db, + dtype=specgram_db.dtype, device=specgram_db.device) + specgram_db = torch.max(specgram_db, new_spec_db_max) - return spec_db + return specgram_db @torch.jit.script @@ -328,8 +329,8 @@ def create_fb_matrix(n_freqs, f_min, f_max, n_mels): size (..., `n_freqs`), the applied result would be `A * create_fb_matrix(A.size(-1), ...)`. """ - # get stft freq bins - stft_freqs = torch.linspace(f_min, f_max, n_freqs) + # freq bins + freqs = torch.linspace(f_min, f_max, n_freqs) # calculate mel freq bins # hertz to mel(f) is 2595. * math.log10(1. + (f / 700.)) m_min = 0. if f_min == 0 else 2595. * math.log10(1. + (f_min / 700.)) @@ -339,12 +340,12 @@ def create_fb_matrix(n_freqs, f_min, f_max, n_mels): f_pts = 700. * (10**(m_pts / 2595.) - 1.) # calculate the difference between each mel point and each stft freq point in hertz f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1) - slopes = f_pts.unsqueeze(0) - stft_freqs.unsqueeze(1) # (n_freqs, n_mels + 2) + slopes = f_pts.unsqueeze(0) - freqs.unsqueeze(1) # (n_freqs, n_mels + 2) # create overlapping triangles - z = torch.zeros(1) + zero = torch.zeros(1) down_slopes = (-1. * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_mels) up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_mels) - fb = torch.max(z, torch.min(down_slopes, up_slopes)) + fb = torch.max(zero, torch.min(down_slopes, up_slopes)) return fb @@ -392,7 +393,6 @@ def mu_law_encoding(x, qc): Returns: torch.Tensor: Input after mu-law companding """ - assert isinstance(x, torch.Tensor), 'mu_law_encoding expects a Tensor' mu = qc - 1. if not x.is_floating_point(): x = x.to(torch.float) @@ -419,7 +419,6 @@ def mu_law_expanding(x_mu, qc): Returns: torch.Tensor: Input after decoding """ - assert isinstance(x_mu, torch.Tensor), 'mu_law_expanding expects a Tensor' mu = qc - 1. if not x_mu.is_floating_point(): x_mu = x_mu.to(torch.float) diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 3a09247fff..658f6d8c6f 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -191,7 +191,7 @@ class MelScale(torch.jit.ScriptModule): Args: n_mels (int): Number of mel filterbanks. (Default: 128) - sample_rate (int): Sample rate of audio signal. (Default: 16000). + sample_rate (int): Sample rate of audio signal. (Default: 16000) f_min (float): Minimum frequency. (Default: 0.) f_max (float, optional): Maximum frequency. (Default: `sample_rate // 2`) n_stft (int, optional): Number of bins in STFT. Calculated from first input @@ -226,8 +226,8 @@ def forward(self, specgram): self.fb.copy_(tmp_fb) # (c, f, t).transpose(...) dot (f, n_mels) -> (c, t, n_mels).transpose(...) - spec_m = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2) - return spec_m + mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2) + return mel_specgram class MelSpectrogram(torch.jit.ScriptModule): @@ -240,7 +240,7 @@ class MelSpectrogram(torch.jit.ScriptModule): * http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html Args: - sample_rate (int): Sample rate of audio signal. (Default: 16000). + sample_rate (int): Sample rate of audio signal. (Default: 16000) win_length (int): Window size. (Default: `n_fft`) hop_length (int, optional): Length of hop between STFT windows. ( Default: `win_length // 2`) @@ -255,7 +255,7 @@ class MelSpectrogram(torch.jit.ScriptModule): Example: >>> waveform, sample_rate = torchaudio.load("test.wav", normalization=True) - >>> specgram_mel = transforms.MelSpectrogram(sample_rate)(waveform) # (c, n_mels, t) + >>> mel_specgram = transforms.MelSpectrogram(sample_rate)(waveform) # (c, n_mels, t) """ __constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad', 'n_mels', 'f_min'] @@ -270,63 +270,64 @@ def __init__(self, sample_rate=16000, n_fft=400, win_length=None, hop_length=Non self.n_mels = n_mels # number of mel frequency bins self.f_max = torch.jit.Attribute(f_max, Optional[float]) self.f_min = f_min - self.spec = Spectrogram(n_fft=self.n_fft, win_length=self.win_length, hop_length=self.hop_length, - pad=self.pad, window_fn=window_fn, power=2, - normalized=False, wkwargs=wkwargs) - self.fm = MelScale(self.n_mels, self.sample_rate, self.f_max, self.f_min) + self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length, + hop_length=self.hop_length, + pad=self.pad, window_fn=window_fn, power=2, + normalized=False, wkwargs=wkwargs) + self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_max, self.f_min) @torch.jit.script_method def forward(self, waveform): - """ + r""" Args: waveform (torch.Tensor): Tensor of audio of size (c, n) Returns: torch.Tensor: mel frequency spectrogram of size (c, `n_mels`, t) """ - specgram = self.spec(waveform) - specgram_mel = self.fm(specgram) - return specgram_mel + specgram = self.spectrogram(waveform) + mel_specgram = self.mel_scale(specgram) + return mel_specgram class MFCC(torch.jit.ScriptModule): - """Create the Mel-frequency cepstrum coefficients from an audio signal + r"""Create the Mel-frequency cepstrum coefficients from an audio signal - By default, this calculates the MFCC on the DB-scaled Mel spectrogram. - This is not the textbook implementation, but is implemented here to - give consistency with librosa. + By default, this calculates the MFCC on the DB-scaled Mel spectrogram. + This is not the textbook implementation, but is implemented here to + give consistency with librosa. - This output depends on the maximum value in the input spectrogram, and so - may return different values for an audio clip split into snippets vs. a - a full clip. + This output depends on the maximum value in the input spectrogram, and so + may return different values for an audio clip split into snippets vs. a + a full clip. - Args: - sr (int) : sample rate of audio signal - n_mfcc (int) : number of mfc coefficients to retain + Args: + sample_rate (int) : Sample rate of audio signal. (Default: 16000) + n_mfcc (int) : Number of mfc coefficients to retain dct_type (int) : type of DCT (discrete cosine transform) to use norm (string, optional) : norm to use log_mels (bool) : whether to use log-mel spectrograms instead of db-scaled melkwargs (dict, optional): arguments for MelSpectrogram """ - __constants__ = ['sr', 'n_mfcc', 'dct_type', 'top_db', 'log_mels'] + __constants__ = ['sample_rate', 'n_mfcc', 'dct_type', 'top_db', 'log_mels'] - def __init__(self, sr=16000, n_mfcc=40, dct_type=2, norm='ortho', log_mels=False, + def __init__(self, sample_rate=16000, n_mfcc=40, dct_type=2, norm='ortho', log_mels=False, melkwargs=None): super(MFCC, self).__init__() supported_dct_types = [2] if dct_type not in supported_dct_types: raise ValueError('DCT type not supported'.format(dct_type)) - self.sr = sr + self.sample_rate = sample_rate self.n_mfcc = n_mfcc self.dct_type = dct_type self.norm = torch.jit.Attribute(norm, Optional[str]) - self.top_db = 80. - self.s2db = SpectrogramToDB("power", self.top_db) + self.top_db = 80.0 + self.spectrogram_to_DB = SpectrogramToDB('power', self.top_db) if melkwargs is not None: - self.MelSpectrogram = MelSpectrogram(sr=self.sr, **melkwargs) + self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate, **melkwargs) else: - self.MelSpectrogram = MelSpectrogram(sr=self.sr) + self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate) if self.n_mfcc > self.MelSpectrogram.n_mels: raise ValueError('Cannot select more MFCC coefficients than # mel bins') @@ -335,27 +336,27 @@ def __init__(self, sr=16000, n_mfcc=40, dct_type=2, norm='ortho', log_mels=False self.log_mels = log_mels @torch.jit.script_method - def forward(self, sig): - """ + def forward(self, waveform): + r""" Args: - sig (torch.Tensor): Tensor of audio of size (c, n) + waveform (torch.Tensor): Tensor of audio of size (c, n) Returns: - torch.Tensor: spec_mel_db of size (c, `n_mfcc`, t) + torch.Tensor: specgram_mel_db of size (c, `n_mfcc`, t) """ - mel_spect = self.MelSpectrogram(sig) + mel_specgram = self.MelSpectrogram(waveform) if self.log_mels: log_offset = 1e-6 - mel_spect = torch.log(mel_spect + log_offset) + mel_specgram = torch.log(mel_specgram + log_offset) else: - mel_spect = self.s2db(mel_spect) + mel_specgram = self.spectrogram_to_DB(mel_specgram) # (c, `n_mels`, t).tranpose(...) dot (`n_mels`, `n_mfcc`) -> (c, t, `n_mfcc`).tranpose(...) - mfcc = torch.matmul(mel_spect.transpose(1, 2), self.dct_mat).transpose(1, 2) + mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2) return mfcc class MuLawEncoding(torch.jit.ScriptModule): - """Encode signal based on mu-law companding. For more info see the + r"""Encode signal based on mu-law companding. For more info see the `Wikipedia Entry `_ This algorithm assumes the signal has been scaled to between -1 and 1 and @@ -363,7 +364,6 @@ class MuLawEncoding(torch.jit.ScriptModule): Args: quantization_channels (int): Number of channels. default: 256 - """ __constants__ = ['qc'] @@ -373,12 +373,12 @@ def __init__(self, quantization_channels=256): @torch.jit.script_method def forward(self, x): - """ + r""" Args: - x (FloatTensor/LongTensor) + x (torch.Tensor): A signal to be encoded Returns: - x_mu (LongTensor) + x_mu (torch.Tensor): An encoded signal """ return F.mu_law_encoding(x, self.qc) @@ -387,7 +387,7 @@ def __repr__(self): class MuLawExpanding(torch.jit.ScriptModule): - """Decode mu-law encoded signal. For more info see the + r"""Decode mu-law encoded signal. For more info see the `Wikipedia Entry `_ This expects an input with values between 0 and quantization_channels - 1 @@ -395,7 +395,6 @@ class MuLawExpanding(torch.jit.ScriptModule): Args: quantization_channels (int): Number of channels. default: 256 - """ __constants__ = ['qc'] @@ -405,12 +404,12 @@ def __init__(self, quantization_channels=256): @torch.jit.script_method def forward(self, x_mu): - """ + r""" Args: - x_mu (Tensor) + x_mu (torch.Tensor): A mu-law encoded signal which needs to be decoded Returns: - x (Tensor) + torch.Tensor: The signal decoded """ return F.mu_law_expanding(x_mu, self.qc) @@ -419,7 +418,7 @@ def __repr__(self): class Resample(torch.nn.Module): - """Resamples a signal from one frequency to another. A resampling method can + r"""Resamples a signal from one frequency to another. A resampling method can be given. Args: @@ -434,15 +433,15 @@ def __init__(self, orig_freq, new_freq, resampling_method='sinc_interpolation'): self.new_freq = new_freq self.resampling_method = resampling_method - def forward(self, sig): - """ + def forward(self, waveform): + r""" Args: - sig (Tensor): the input signal of size (c, n) + waveform (torch.Tensor): The input signal of size (c, n) Returns: - Tensor: output signal of size (c, m) + torch.Tensor: Output signal of size (c, m) """ if self.resampling_method == 'sinc_interpolation': - return kaldi.resample_waveform(sig, self.orig_freq, self.new_freq) + return kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq) raise ValueError('Invalid resampling method: %s' % (self.resampling_method)) From e9c805f6f0daca06ac9012d33bfbf9d2d2d0f411 Mon Sep 17 00:00:00 2001 From: Jason Lian Date: Mon, 22 Jul 2019 14:29:21 -0700 Subject: [PATCH 09/28] fix test --- test/test_transforms.py | 125 +++++++++++++++++++-------------------- torchaudio/transforms.py | 10 ++-- 2 files changed, 65 insertions(+), 70 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 7d7e222302..419e9b6e5a 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -19,56 +19,50 @@ class Tester(unittest.TestCase): # create a sinewave signal for testing - sr = 16000 + sample_rate = 16000 freq = 440 volume = .3 - sig = (torch.cos(2 * math.pi * torch.arange(0, 4 * sr).float() * freq / sr)) - sig.unsqueeze_(0) # (1, 64000) - sig = (sig * volume * 2**31).long() + waveform = (torch.cos(2 * math.pi * torch.arange(0, 4 * sample_rate).float() * freq / sample_rate)) + waveform.unsqueeze_(0) # (1, 64000) + waveform = (waveform * volume * 2**31).long() # file for stereo stft test test_dirpath, test_dir = test.common_utils.create_temp_assets_dir() - test_filepath = os.path.join(test_dirpath, "assets", - "steam-train-whistle-daniel_simon.mp3") + test_filepath = os.path.join(test_dirpath, 'assets', + 'steam-train-whistle-daniel_simon.mp3') def test_scale(self): - audio_orig = self.sig.clone() - result = transforms.Scale()(audio_orig) + waveform = self.waveform.clone() + result = transforms.Scale()(waveform) self.assertTrue(result.min() >= -1. and result.max() <= 1.) - maxminmax = max(abs(audio_orig.min()), abs(audio_orig.max())).item() - result = transforms.Scale(factor=float(maxminmax))(audio_orig) - + maxminmax = max(abs(waveform.min()), abs(waveform.max())).item() + result = transforms.Scale(factor=float(maxminmax))(waveform) self.assertTrue((result.min() == -1. or result.max() == 1.) and result.min() >= -1. and result.max() <= 1.) - repr_test = transforms.Scale() - def test_pad_trim(self): - audio_orig = self.sig.clone() - length_orig = audio_orig.size(1) + waveform = self.waveform.clone() + length_orig = waveform.size(1) length_new = int(length_orig * 1.2) - result = transforms.PadTrim(max_len=length_new)(audio_orig) + result = transforms.PadTrim(max_len=length_new)(waveform) self.assertEqual(result.size(1), length_new) length_new = int(length_orig * 0.8) - result = transforms.PadTrim(max_len=length_new)(audio_orig) - + result = transforms.PadTrim(max_len=length_new)(waveform) self.assertEqual(result.size(1), length_new) - repr_test = transforms.PadTrim(max_len=length_new) - def test_downmix_mono(self): - audio_L = self.sig.clone() - audio_R = self.sig.clone() - R_idx = int(audio_R.size(0) * 0.1) - audio_R = torch.cat((audio_R[R_idx:], audio_R[:R_idx])) + waveform_L = self.waveform.clone() + waveform_R = self.waveform.clone() + R_idx = int(waveform_R.size(0) * 0.1) + waveform_R = torch.cat((waveform_R[R_idx:], waveform_R[:R_idx])) - audio_Stereo = torch.cat((audio_L, audio_R), dim=0) + audio_Stereo = torch.cat((waveform_L, waveform_R), dim=0) self.assertTrue(audio_Stereo.size(0) == 2) @@ -78,50 +72,49 @@ def test_downmix_mono(self): def test_lc2cl(self): - audio = self.sig.clone().t() - result = transforms.LC2CL()(audio) - self.assertTrue(result.size()[::-1] == audio.size()) - - repr_test = transforms.LC2CL() + waveform = self.waveform.clone().t() + result = transforms.LC2CL()(waveform) + self.assertTrue(result.size()[::-1] == waveform.size()) def test_mu_law_companding(self): quantization_channels = 256 - sig = self.sig.clone() - sig = sig / torch.abs(sig).max() - self.assertTrue(sig.min() >= -1. and sig.max() <= 1.) + waveform = self.waveform.clone() + waveform /= torch.abs(waveform).max() + self.assertTrue(waveform.min() >= -1. and waveform.max() <= 1.) - sig_mu = transforms.MuLawEncoding(quantization_channels)(sig) - self.assertTrue(sig_mu.min() >= 0. and sig.max() <= quantization_channels) + waveform_mu = transforms.MuLawEncoding(quantization_channels)(waveform) + self.assertTrue(waveform_mu.min() >= 0. and waveform_mu.max() <= quantization_channels) - sig_exp = transforms.MuLawExpanding(quantization_channels)(sig_mu) - self.assertTrue(sig_exp.min() >= -1. and sig_exp.max() <= 1.) + waveform_exp = transforms.MuLawExpanding(quantization_channels)(waveform_mu) + self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.) def test_mel2(self): top_db = 80. - s2db = transforms.SpectrogramToDB("power", top_db) + s2db = transforms.SpectrogramToDB('power', top_db) - audio_orig = self.sig.clone() # (1, 16000) - audio_scaled = transforms.Scale()(audio_orig) # (1, 16000) + waveform = self.waveform.clone() # (1, 16000) + waveform_scaled = transforms.Scale()(waveform) # (1, 16000) mel_transform = transforms.MelSpectrogram() # check defaults - spectrogram_torch = s2db(mel_transform(audio_scaled)) # (1, 128, 321) + spectrogram_torch = s2db(mel_transform(waveform_scaled)) # (1, 128, 321) self.assertTrue(spectrogram_torch.dim() == 3) self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all()) self.assertEqual(spectrogram_torch.size(1), mel_transform.n_mels) # check correctness of filterbank conversion matrix - self.assertTrue(mel_transform.fm.fb.sum(1).le(1.).all()) - self.assertTrue(mel_transform.fm.fb.sum(1).ge(0.).all()) + self.assertTrue(mel_transform.mel_scale.fb.sum(1).le(1.).all()) + self.assertTrue(mel_transform.mel_scale.fb.sum(1).ge(0.).all()) # check options - kwargs = {"window": torch.hamming_window, "pad": 10, "ws": 500, "hop": 125, "n_fft": 800, "n_mels": 50} + kwargs = {'window_fn': torch.hamming_window, 'pad': 10, 'win_length': 500, + 'hop_length': 125, 'n_fft': 800, 'n_mels': 50} mel_transform2 = transforms.MelSpectrogram(**kwargs) - spectrogram2_torch = s2db(mel_transform2(audio_scaled)) # (1, 50, 513) + spectrogram2_torch = s2db(mel_transform2(waveform_scaled)) # (1, 50, 513) self.assertTrue(spectrogram2_torch.dim() == 3) self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all()) self.assertEqual(spectrogram2_torch.size(1), mel_transform2.n_mels) - self.assertTrue(mel_transform2.fm.fb.sum(1).le(1.).all()) - self.assertTrue(mel_transform2.fm.fb.sum(1).ge(0.).all()) + self.assertTrue(mel_transform2.mel_scale.fb.sum(1).le(1.).all()) + self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all()) # check on multi-channel audio x_stereo, sr_stereo = torchaudio.load(self.test_filepath) # (2, 278756), 44100 spectrogram_stereo = s2db(mel_transform(x_stereo)) # (2, 128, 1394) @@ -130,19 +123,20 @@ def test_mel2(self): self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all()) self.assertEqual(spectrogram_stereo.size(1), mel_transform.n_mels) # check filterbank matrix creation - fb_matrix_transform = transforms.MelScale(n_mels=100, sr=16000, f_max=None, f_min=0., n_stft=400) + fb_matrix_transform = transforms.MelScale( + n_mels=100, sample_rate=16000, f_min=0., f_max=None, n_stft=400) self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all()) self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all()) self.assertEqual(fb_matrix_transform.fb.size(), (400, 100)) def test_mfcc(self): - audio_orig = self.sig.clone() + audio_orig = self.waveform.clone() audio_scaled = transforms.Scale()(audio_orig) # (1, 16000) sample_rate = 16000 n_mfcc = 40 n_mels = 128 - mfcc_transform = torchaudio.transforms.MFCC(sr=sample_rate, + mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate, n_mfcc=n_mfcc, norm='ortho') # check defaults @@ -151,8 +145,8 @@ def test_mfcc(self): self.assertTrue(torch_mfcc.shape[1] == n_mfcc) self.assertTrue(torch_mfcc.shape[2] == 321) # check melkwargs are passed through - melkwargs = {'ws': 200} - mfcc_transform2 = torchaudio.transforms.MFCC(sr=sample_rate, + melkwargs = {'win_length': 200} + mfcc_transform2 = torchaudio.transforms.MFCC(sample_rate=sample_rate, n_mfcc=n_mfcc, norm='ortho', melkwargs=melkwargs) @@ -160,7 +154,7 @@ def test_mfcc(self): self.assertTrue(torch_mfcc2.shape[2] == 641) # check norms work correctly - mfcc_transform_norm_none = torchaudio.transforms.MFCC(sr=sample_rate, + mfcc_transform_norm_none = torchaudio.transforms.MFCC(sample_rate=sample_rate, n_mfcc=n_mfcc, norm=None) torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled) # (1, 40, 321) @@ -179,7 +173,7 @@ def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, s sound_librosa = sound.cpu().numpy().squeeze() # (64000) # test core spectrogram - spect_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop=hop_length, power=2) + spect_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop_length=hop_length, power=2) out_librosa, _ = librosa.core.spectrum._spectrogram(y=sound_librosa, n_fft=n_fft, hop_length=hop_length, @@ -189,8 +183,9 @@ def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, s self.assertTrue(torch.allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5)) # test mel spectrogram - melspect_transform = torchaudio.transforms.MelSpectrogram(sr=sample_rate, window=torch.hann_window, - hop=hop_length, n_mels=n_mels, n_fft=n_fft) + melspect_transform = torchaudio.transforms.MelSpectrogram( + sample_rate=sample_rate, window_fn=torch.hann_window, + hop_length=hop_length, n_mels=n_mels, n_fft=n_fft) librosa_mel = librosa.feature.melspectrogram(y=sound_librosa, sr=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, htk=True, norm=None) @@ -200,7 +195,7 @@ def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, s self.assertTrue(torch.allclose(torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3)) # test s2db - db_transform = torchaudio.transforms.SpectrogramToDB("power", 80.) + db_transform = torchaudio.transforms.SpectrogramToDB('power', 80.) db_torch = db_transform(spect_transform(sound)).squeeze().cpu() db_librosa = librosa.core.spectrum.power_to_db(out_librosa) self.assertTrue(torch.allclose(db_torch, torch.from_numpy(db_librosa), atol=5e-3)) @@ -212,8 +207,8 @@ def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, s self.assertTrue(torch.allclose(db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3)) # test MFCC - melkwargs = {'hop': hop_length, 'n_fft': n_fft} - mfcc_transform = torchaudio.transforms.MFCC(sr=sample_rate, + melkwargs = {'hop_length': hop_length, 'n_fft': n_fft} + mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate, n_mfcc=n_mfcc, norm='ortho', melkwargs=melkwargs) @@ -271,27 +266,27 @@ def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, s def test_resample_size(self): input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav') - sound, sample_rate = torchaudio.load(input_path) + waveform, sample_rate = torchaudio.load(input_path) upsample_rate = sample_rate * 2 downsample_rate = sample_rate // 2 invalid_resample = torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method='foo') - self.assertRaises(ValueError, invalid_resample, sound) + self.assertRaises(ValueError, invalid_resample, waveform) upsample_resample = torchaudio.transforms.Resample( sample_rate, upsample_rate, resampling_method='sinc_interpolation') - up_sampled = upsample_resample(sound) + up_sampled = upsample_resample(waveform) # we expect the upsampled signal to have twice as many samples - self.assertTrue(up_sampled.size(-1) == sound.size(-1) * 2) + self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2) downsample_resample = torchaudio.transforms.Resample( sample_rate, downsample_rate, resampling_method='sinc_interpolation') - down_sampled = downsample_resample(sound) + down_sampled = downsample_resample(waveform) # we expect the downsampled signal to have half as many samples - self.assertTrue(down_sampled.size(-1) == sound.size(-1) // 2) + self.assertTrue(down_sampled.size(-1) == waveform.size(-1) // 2) if __name__ == '__main__': unittest.main() diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 658f6d8c6f..0741ca9ab6 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -164,9 +164,9 @@ def __init__(self, stype="power", top_db=None): if top_db is not None and top_db < 0: raise ValueError('top_db must be positive value') self.top_db = torch.jit.Attribute(top_db, Optional[float]) - self.multiplier = 10. if stype == "power" else 20. + self.multiplier = 10.0 if stype == "power" else 20.0 self.amin = 1e-10 - self.ref_value = 1. + self.ref_value = 1.0 self.db_multiplier = math.log10(max(self.amin, self.ref_value)) @torch.jit.script_method @@ -204,7 +204,7 @@ def __init__(self, n_mels=128, sample_rate=16000, f_min=0., f_max=None, n_stft=N self.n_mels = n_mels self.sample_rate = sample_rate self.f_max = f_max if f_max is not None else float(sample_rate // 2) - assert f_min <= f_max, 'Require f_min: %f < f_max: %f' % (f_min, f_max) + assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max) self.f_min = f_min fb = torch.empty(0) if n_stft is None else F.create_fb_matrix( n_stft, self.f_min, self.f_max, self.n_mels) @@ -246,7 +246,7 @@ class MelSpectrogram(torch.jit.ScriptModule): Default: `win_length // 2`) n_fft (int, optional): Size of fft, creates `n_fft // 2 + 1` bins f_min (float): Minimum frequency. (Default: 0.) - f_max (float, optional): Maximum frequency. (Default: `sample_rate // 2`) + f_max (float, optional): Maximum frequency. (Default: `None`) pad (int): Two sided padding of signal. (Default: 0) n_mels (int): Number of mel filterbanks. (Default: 128) window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor @@ -274,7 +274,7 @@ def __init__(self, sample_rate=16000, n_fft=400, win_length=None, hop_length=Non hop_length=self.hop_length, pad=self.pad, window_fn=window_fn, power=2, normalized=False, wkwargs=wkwargs) - self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_max, self.f_min) + self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max) @torch.jit.script_method def forward(self, waveform): From d090ff60da2c00ff4787d2976cf5760f2ca628ab Mon Sep 17 00:00:00 2001 From: Jason Lian Date: Tue, 23 Jul 2019 06:27:48 -0700 Subject: [PATCH 10/28] more --- torchaudio/transforms.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 0741ca9ab6..4ccf09f942 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -60,7 +60,7 @@ def forward(self, waveform): class DownmixMono(torch.jit.ScriptModule): r"""Downmix stereo waveform to mono. Consider using a `SoxEffectsChain` with - the `channels` effect instead of this transformation. + the `channels` effect instead of this transformation. """ def __init__(self): super(DownmixMono, self).__init__() @@ -107,7 +107,7 @@ class Spectrogram(torch.jit.ScriptModule): that is applied/multiplied to each frame/window. (Default: `torch.hann_window`) power (int) : Exponent for the magnitude spectrogram, (must be > 0) e.g., 1 for energy, 2 for power, etc. - normalized (bool) : Whether to normalized by magnitude after stft. (Default: `False`) + normalized (bool) : Whether to normalize by magnitude after stft. (Default: `False`) wkwargs (Dict[..., ...]): Arguments for window function. (Default: `None`) """ __constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized'] @@ -137,7 +137,6 @@ def forward(self, waveform): torch.Tensor: Channels x frequency x time (c, f, t), where channels is unchanged, frequency is `n_fft // 2 + 1` where `n_fft` is the number of fourier bins, and time is the number of window hops (n_frames). - """ return F.spectrogram(waveform, self.pad, self.window, self.n_fft, self.hop_length, self.win_length, self.power, self.normalized) @@ -151,20 +150,20 @@ class SpectrogramToDB(torch.jit.ScriptModule): a full clip. Args: - stype (str): scale of input spectrogram ("power" or "magnitude"). The + stype (str): scale of input spectrogram ('power' or 'magnitude'). The power being the elementwise square of the magnitude. (Default: 'power') top_db (float, optional): minimum negative cut-off in decibels. A reasonable number is 80. """ __constants__ = ['multiplier', 'amin', 'ref_value', 'db_multiplier'] - def __init__(self, stype="power", top_db=None): + def __init__(self, stype='power', top_db=None): super(SpectrogramToDB, self).__init__() self.stype = torch.jit.Attribute(stype, str) if top_db is not None and top_db < 0: raise ValueError('top_db must be positive value') self.top_db = torch.jit.Attribute(top_db, Optional[float]) - self.multiplier = 10.0 if stype == "power" else 20.0 + self.multiplier = 10.0 if stype == 'power' else 20.0 self.amin = 1e-10 self.ref_value = 1.0 self.db_multiplier = math.log10(max(self.amin, self.ref_value)) @@ -254,7 +253,7 @@ class MelSpectrogram(torch.jit.ScriptModule): wkwargs (Dict[..., ...]): Arguments for window function. (Default: `None`) Example: - >>> waveform, sample_rate = torchaudio.load("test.wav", normalization=True) + >>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True) >>> mel_specgram = transforms.MelSpectrogram(sample_rate)(waveform) # (c, n_mels, t) """ __constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad', 'n_mels', 'f_min'] From fca025a45a429b6066da77657bae7066bee5193a Mon Sep 17 00:00:00 2001 From: Jason Lian Date: Tue, 23 Jul 2019 11:54:03 -0700 Subject: [PATCH 11/28] remove trailing zero --- torchaudio/functional.py | 8 ++++---- torchaudio/transforms.py | 14 +++++++------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index ddbd58362d..1606979092 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -255,9 +255,9 @@ def spectrogram(waveform, pad, window, n_fft, hop_length, win_length, power, nor n_fft (int): Size of fft hop_length (int): Length of hop between STFT windows win_length (int): Window size - power (int) : Exponent for the magnitude spectrogram, + power (int): Exponent for the magnitude spectrogram, (must be > 0) e.g., 1 for energy, 2 for power, etc. - normalized (bool) : Whether to normalize by magnitude after stft + normalized (bool): Whether to normalize by magnitude after stft Returns: torch.Tensor: Channels x frequency x time (c, f, t), where channels @@ -356,9 +356,9 @@ def create_dct(n_mfcc, n_mels, norm): normalized depending on norm. Args: - n_mfcc (int) : Number of mfc coefficients to retain + n_mfcc (int): Number of mfc coefficients to retain n_mels (int): Number of mel filterbanks - norm (Optional[str]) : Norm to use (either 'ortho' or None) + norm (Optional[str]): Norm to use (either 'ortho' or None) Returns: torch.Tensor: The transformation matrix, to be right-multiplied to diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 4ccf09f942..a056915250 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -105,9 +105,9 @@ class Spectrogram(torch.jit.ScriptModule): pad (int): Two sided padding of signal. (Default: 0) window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor that is applied/multiplied to each frame/window. (Default: `torch.hann_window`) - power (int) : Exponent for the magnitude spectrogram, + power (int): Exponent for the magnitude spectrogram, (must be > 0) e.g., 1 for energy, 2 for power, etc. - normalized (bool) : Whether to normalize by magnitude after stft. (Default: `False`) + normalized (bool): Whether to normalize by magnitude after stft. (Default: `False`) wkwargs (Dict[..., ...]): Arguments for window function. (Default: `None`) """ __constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized'] @@ -301,11 +301,11 @@ class MFCC(torch.jit.ScriptModule): a full clip. Args: - sample_rate (int) : Sample rate of audio signal. (Default: 16000) - n_mfcc (int) : Number of mfc coefficients to retain - dct_type (int) : type of DCT (discrete cosine transform) to use - norm (string, optional) : norm to use - log_mels (bool) : whether to use log-mel spectrograms instead of db-scaled + sample_rate (int): Sample rate of audio signal. (Default: 16000) + n_mfcc (int): Number of mfc coefficients to retain + dct_type (int): type of DCT (discrete cosine transform) to use + norm (string, optional): norm to use + log_mels (bool): whether to use log-mel spectrograms instead of db-scaled melkwargs (dict, optional): arguments for MelSpectrogram """ __constants__ = ['sample_rate', 'n_mfcc', 'dct_type', 'top_db', 'log_mels'] From b435f1bc166f9b4912c9aec87a8789fa007204e4 Mon Sep 17 00:00:00 2001 From: Jason Lian Date: Wed, 24 Jul 2019 06:39:06 -0700 Subject: [PATCH 12/28] apply feedback: remove scale and lc2cl --- test/test_jit.py | 39 ------------------------------------- test/test_transforms.py | 25 +++++++----------------- torchaudio/functional.py | 35 --------------------------------- torchaudio/transforms.py | 42 ---------------------------------------- 4 files changed, 7 insertions(+), 134 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 9e3dcd107e..1aec8427c8 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -30,26 +30,6 @@ def _test_script_module(self, tensor, f, *args): self.assertTrue(torch.allclose(jit_out, py_out)) - def test_torchscript_scale(self): - @torch.jit.script - def jit_method(tensor, factor): - # type: (Tensor, float) -> Tensor - return F.scale(tensor, factor) - - tensor = torch.rand((1, 10)) - factor = 2.0 - - jit_out = jit_method(tensor, factor) - py_out = F.scale(tensor, factor) - - self.assertTrue(torch.allclose(jit_out, py_out)) - - @unittest.skipIf(not RUN_CUDA, "no CUDA") - def test_scriptmodule_scale(self): - tensor = torch.rand((1, 10), device="cuda") - - self._test_script_module(tensor, transforms.Scale) - def test_torchscript_pad_trim(self): @torch.jit.script def jit_method(tensor, max_len, fill_value): @@ -91,25 +71,6 @@ def test_scriptmodule_downmix_mono(self): self._test_script_module(tensor, transforms.DownmixMono) - def test_torchscript_LC2CL(self): - @torch.jit.script - def jit_method(tensor): - # type: (Tensor) -> Tensor - return F.LC2CL(tensor) - - tensor = torch.rand((10, 1)) - - jit_out = jit_method(tensor) - py_out = F.LC2CL(tensor) - - self.assertTrue(torch.allclose(jit_out, py_out)) - - @unittest.skipIf(not RUN_CUDA, "no CUDA") - def test_scriptmodule_LC2CL(self): - tensor = torch.rand((10, 1), device="cuda") - - self._test_script_module(tensor, transforms.LC2CL) - def test_torchscript_spectrogram(self): @torch.jit.script def jit_method(sig, pad, window, n_fft, hop, ws, power, normalize): diff --git a/test/test_transforms.py b/test/test_transforms.py index 419e9b6e5a..cfa8ce3477 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -30,16 +30,11 @@ class Tester(unittest.TestCase): test_filepath = os.path.join(test_dirpath, 'assets', 'steam-train-whistle-daniel_simon.mp3') - def test_scale(self): - - waveform = self.waveform.clone() - result = transforms.Scale()(waveform) - self.assertTrue(result.min() >= -1. and result.max() <= 1.) - - maxminmax = max(abs(waveform.min()), abs(waveform.max())).item() - result = transforms.Scale(factor=float(maxminmax))(waveform) - self.assertTrue((result.min() == -1. or result.max() == 1.) and - result.min() >= -1. and result.max() <= 1.) + def scale(self, waveform, factor=float(2**31)): + # scales a waveform by a factor + if not waveform.is_floating_point(): + waveform = waveform.to(torch.get_default_dtype()) + return waveform / factor def test_pad_trim(self): @@ -70,12 +65,6 @@ def test_downmix_mono(self): self.assertTrue(result.size(0) == 1) - def test_lc2cl(self): - - waveform = self.waveform.clone().t() - result = transforms.LC2CL()(waveform) - self.assertTrue(result.size()[::-1] == waveform.size()) - def test_mu_law_companding(self): quantization_channels = 256 @@ -95,7 +84,7 @@ def test_mel2(self): s2db = transforms.SpectrogramToDB('power', top_db) waveform = self.waveform.clone() # (1, 16000) - waveform_scaled = transforms.Scale()(waveform) # (1, 16000) + waveform_scaled = self.scale(waveform) # (1, 16000) mel_transform = transforms.MelSpectrogram() # check defaults spectrogram_torch = s2db(mel_transform(waveform_scaled)) # (1, 128, 321) @@ -131,7 +120,7 @@ def test_mel2(self): def test_mfcc(self): audio_orig = self.waveform.clone() - audio_scaled = transforms.Scale()(audio_orig) # (1, 16000) + audio_scaled = self.scale(audio_orig) # (1, 16000) sample_rate = 16000 n_mfcc = 40 diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 1606979092..93b662338e 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -3,10 +3,8 @@ __all__ = [ - 'scale', 'pad_trim', 'downmix_mono', - 'LC2CL', 'istft', 'spectrogram', 'spectrogram_to_DB', @@ -17,25 +15,6 @@ ] -@torch.jit.script -def scale(tensor, factor): - # type: (Tensor, float) -> Tensor - r"""Scales tensor by a factor. By default, assuming the input is int32, it - will scale the tensor to have values between -1.0 and 1.0. - - Args: - tensor (torch.Tensor): Tensor input to scale - factor (float): Factor to scale by - - Returns: - torch.Tensor: Scaled by the scale factor - """ - if not tensor.is_floating_point(): - tensor = tensor.to(torch.float32) - - return tensor / factor - - @torch.jit.script def pad_trim(waveform, max_len, fill_value): # type: (Tensor, int, float) -> Tensor @@ -77,20 +56,6 @@ def downmix_mono(waveform): return waveform -@torch.jit.script -def LC2CL(tensor): - # type: (Tensor) -> Tensor - r"""Permute a 2D tensor from samples (n, c) to (c, n). - - Args: - tensor (torch.Tensor): Tensor of audio signal with shape (n, c) - - Returns: - torch.Tensor: Tensor of audio signal with shape (c, n) - """ - return tensor.transpose(0, 1).contiguous() - - # TODO: remove this once https://github.com/pytorch/pytorch/issues/21478 gets solved @torch.jit.ignore def _stft(input, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided): diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index a056915250..d3d3ef662b 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -7,31 +7,6 @@ from .compliance import kaldi -class Scale(torch.jit.ScriptModule): - r"""Scales tensor by a factor. By default, assuming the input is int32, it - will scale the tensor to have values between -1.0 and 1.0. - - Args: - factor (float): Factor to scale by. (Default: `float(2**31)`) - """ - __constants__ = ['factor'] - - def __init__(self, factor=float(2**31)): - super(Scale, self).__init__() - self.factor = factor - - @torch.jit.script_method - def forward(self, tensor): - r""" - Args: - tensor (torch.Tensor): Tensor input to scale - - Returns: - torch.Tensor: Scaled by the scale factor - """ - return F.scale(tensor, self.factor) - - class PadTrim(torch.jit.ScriptModule): r"""Pad/Trim a 2D tensor @@ -77,23 +52,6 @@ def forward(self, waveform): return F.downmix_mono(waveform) -class LC2CL(torch.jit.ScriptModule): - r"""Converts a 2D tensor from (n, c) to (c, n) - """ - def __init__(self): - super(LC2CL, self).__init__() - - @torch.jit.script_method - def forward(self, tensor): - r""" - Args: - tensor (torch.Tensor): Tensor of audio signal with shape (n, c) - Returns: - torch.Tensor: Tensor of audio signal with shape (c, n) - """ - return F.LC2CL(tensor) - - class Spectrogram(torch.jit.ScriptModule): r"""Create a spectrogram from a audio signal From 710a236aaa87813ab5050eb2523945621864b9d0 Mon Sep 17 00:00:00 2001 From: Jason Lian Date: Wed, 24 Jul 2019 07:40:18 -0700 Subject: [PATCH 13/28] apply feedback: remove downmix --- test/test_jit.py | 19 ------------------- test/test_transforms.py | 14 -------------- torchaudio/transforms.py | 19 ------------------- 3 files changed, 52 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 1aec8427c8..22113a295e 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -52,25 +52,6 @@ def test_scriptmodule_pad_trim(self): self._test_script_module(tensor, transforms.PadTrim, max_len) - def test_torchscript_downmix_mono(self): - @torch.jit.script - def jit_method(tensor): - # type: (Tensor) -> Tensor - return F.downmix_mono(tensor) - - tensor = torch.rand((2, 10)) - - jit_out = jit_method(tensor) - py_out = F.downmix_mono(tensor) - - self.assertTrue(torch.allclose(jit_out, py_out)) - - @unittest.skipIf(not RUN_CUDA, "no CUDA") - def test_scriptmodule_downmix_mono(self): - tensor = torch.rand((2, 10), device="cuda") - - self._test_script_module(tensor, transforms.DownmixMono) - def test_torchscript_spectrogram(self): @torch.jit.script def jit_method(sig, pad, window, n_fft, hop, ws, power, normalize): diff --git a/test/test_transforms.py b/test/test_transforms.py index cfa8ce3477..954b3db24f 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -50,20 +50,6 @@ def test_pad_trim(self): result = transforms.PadTrim(max_len=length_new)(waveform) self.assertEqual(result.size(1), length_new) - def test_downmix_mono(self): - - waveform_L = self.waveform.clone() - waveform_R = self.waveform.clone() - R_idx = int(waveform_R.size(0) * 0.1) - waveform_R = torch.cat((waveform_R[R_idx:], waveform_R[:R_idx])) - - audio_Stereo = torch.cat((waveform_L, waveform_R), dim=0) - - self.assertTrue(audio_Stereo.size(0) == 2) - - result = transforms.DownmixMono()(audio_Stereo) - - self.assertTrue(result.size(0) == 1) def test_mu_law_companding(self): diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index d3d3ef662b..3937deacd9 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -33,25 +33,6 @@ def forward(self, waveform): return F.pad_trim(waveform, self.max_len, self.fill_value) -class DownmixMono(torch.jit.ScriptModule): - r"""Downmix stereo waveform to mono. Consider using a `SoxEffectsChain` with - the `channels` effect instead of this transformation. - """ - def __init__(self): - super(DownmixMono, self).__init__() - - @torch.jit.script_method - def forward(self, waveform): - r""" - Args: - waveform (torch.Tensor): Tensor of audio of size (c, n) - - Returns: - torch.Tensor: Tensor that has been downmixed of size (1, n) - """ - return F.downmix_mono(waveform) - - class Spectrogram(torch.jit.ScriptModule): r"""Create a spectrogram from a audio signal From 015dd0e7cd4873ebf2f67a5607fec41793691245 Mon Sep 17 00:00:00 2001 From: Jason Lian Date: Wed, 24 Jul 2019 07:47:46 -0700 Subject: [PATCH 14/28] apply feedback: remove downmix --- torchaudio/functional.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 93b662338e..7a3e34ba6d 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -4,7 +4,6 @@ __all__ = [ 'pad_trim', - 'downmix_mono', 'istft', 'spectrogram', 'spectrogram_to_DB', @@ -37,25 +36,6 @@ def pad_trim(waveform, max_len, fill_value): return waveform -@torch.jit.script -def downmix_mono(waveform): - # type: (Tensor) -> Tensor - r"""Downmix stereo waveform to mono. Consider using a `SoxEffectsChain` with - the `channels` effect instead of this transformation. - - Args: - waveform (torch.Tensor): Tensor of audio of size (c, n) - - Returns: - torch.Tensor: Tensor that has been downmixed of size (1, n) - """ - if not waveform.is_floating_point(): - waveform = waveform.to(torch.float32) - - waveform = torch.mean(waveform, 0, True) - return waveform - - # TODO: remove this once https://github.com/pytorch/pytorch/issues/21478 gets solved @torch.jit.ignore def _stft(input, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided): From 840707d1a09f2564e2e701d615e24a292fab732b Mon Sep 17 00:00:00 2001 From: Jason Lian Date: Wed, 24 Jul 2019 08:47:46 -0700 Subject: [PATCH 15/28] apply feedback: rearrange functions --- torchaudio/functional.py | 64 ++++++++++++++++----------------- torchaudio/transforms.py | 77 +++++++++++++++++++++------------------- 2 files changed, 72 insertions(+), 69 deletions(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 7a3e34ba6d..af6dc9e003 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -6,8 +6,8 @@ 'pad_trim', 'istft', 'spectrogram', - 'spectrogram_to_DB', 'create_fb_matrix', + 'spectrogram_to_DB', 'create_dct', 'mu_law_encoding', 'mu_law_expanding' @@ -225,37 +225,6 @@ def spectrogram(waveform, pad, window, n_fft, hop_length, win_length, power, nor return spec_f -@torch.jit.script -def spectrogram_to_DB(specgram, multiplier, amin, db_multiplier, top_db=None): - # type: (Tensor, float, float, float, Optional[float]) -> Tensor - r"""Turns a spectrogram from the power/amplitude scale to the decibel scale. - - This output depends on the maximum value in the input spectrogram, and so - may return different values for an audio clip split into snippets vs. a - a full clip. - - Args: - specgram (torch.Tensor): Normal STFT of size (c, f, t) - multiplier (float): Use 10. for power and 20. for amplitude - amin (float): Number to clamp specgram - db_multiplier (float): Log10(max(reference value and amin)) - top_db (Optional[float]): Minimum negative cut-off in decibels. A reasonable number - is 80. - - Returns: - torch.Tensor: Spectrogram in DB of size (c, f, t) - """ - specgram_db = multiplier * torch.log10(torch.clamp(specgram, min=amin)) - specgram_db -= multiplier * db_multiplier - - if top_db is not None: - new_spec_db_max = torch.tensor(float(specgram_db.max()) - top_db, - dtype=specgram_db.dtype, device=specgram_db.device) - specgram_db = torch.max(specgram_db, new_spec_db_max) - - return specgram_db - - @torch.jit.script def create_fb_matrix(n_freqs, f_min, f_max, n_mels): # type: (int, float, float, int) -> Tensor @@ -294,6 +263,37 @@ def create_fb_matrix(n_freqs, f_min, f_max, n_mels): return fb +@torch.jit.script +def spectrogram_to_DB(specgram, multiplier, amin, db_multiplier, top_db=None): + # type: (Tensor, float, float, float, Optional[float]) -> Tensor + r"""Turns a spectrogram from the power/amplitude scale to the decibel scale. + + This output depends on the maximum value in the input spectrogram, and so + may return different values for an audio clip split into snippets vs. a + a full clip. + + Args: + specgram (torch.Tensor): Normal STFT of size (c, f, t) + multiplier (float): Use 10. for power and 20. for amplitude + amin (float): Number to clamp specgram + db_multiplier (float): Log10(max(reference value and amin)) + top_db (Optional[float]): Minimum negative cut-off in decibels. A reasonable number + is 80. + + Returns: + torch.Tensor: Spectrogram in DB of size (c, f, t) + """ + specgram_db = multiplier * torch.log10(torch.clamp(specgram, min=amin)) + specgram_db -= multiplier * db_multiplier + + if top_db is not None: + new_spec_db_max = torch.tensor(float(specgram_db.max()) - top_db, + dtype=specgram_db.dtype, device=specgram_db.device) + specgram_db = torch.max(specgram_db, new_spec_db_max) + + return specgram_db + + @torch.jit.script def create_dct(n_mfcc, n_mels, norm): # type: (int, int, Optional[str]) -> Tensor diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 3937deacd9..8cd4fe6f33 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -81,44 +81,7 @@ def forward(self, waveform): self.win_length, self.power, self.normalized) -class SpectrogramToDB(torch.jit.ScriptModule): - r"""Turns a spectrogram from the power/amplitude scale to the decibel scale. - This output depends on the maximum value in the input spectrogram, and so - may return different values for an audio clip split into snippets vs. a - a full clip. - - Args: - stype (str): scale of input spectrogram ('power' or 'magnitude'). The - power being the elementwise square of the magnitude. (Default: 'power') - top_db (float, optional): minimum negative cut-off in decibels. A reasonable number - is 80. - """ - __constants__ = ['multiplier', 'amin', 'ref_value', 'db_multiplier'] - - def __init__(self, stype='power', top_db=None): - super(SpectrogramToDB, self).__init__() - self.stype = torch.jit.Attribute(stype, str) - if top_db is not None and top_db < 0: - raise ValueError('top_db must be positive value') - self.top_db = torch.jit.Attribute(top_db, Optional[float]) - self.multiplier = 10.0 if stype == 'power' else 20.0 - self.amin = 1e-10 - self.ref_value = 1.0 - self.db_multiplier = math.log10(max(self.amin, self.ref_value)) - - @torch.jit.script_method - def forward(self, specgram): - r"""Numerically stable implementation from Librosa - https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html - - Args: - specgram (torch.Tensor): STFT of size (c, f, t) - - Returns: - torch.Tensor: STFT after changing scale of size (c, f, t) - """ - return F.spectrogram_to_DB(specgram, self.multiplier, self.amin, self.db_multiplier, self.top_db) class MelScale(torch.jit.ScriptModule): @@ -168,6 +131,46 @@ def forward(self, specgram): return mel_specgram +class SpectrogramToDB(torch.jit.ScriptModule): + r"""Turns a spectrogram from the power/amplitude scale to the decibel scale. + + This output depends on the maximum value in the input spectrogram, and so + may return different values for an audio clip split into snippets vs. a + a full clip. + + Args: + stype (str): scale of input spectrogram ('power' or 'magnitude'). The + power being the elementwise square of the magnitude. (Default: 'power') + top_db (float, optional): minimum negative cut-off in decibels. A reasonable number + is 80. + """ + __constants__ = ['multiplier', 'amin', 'ref_value', 'db_multiplier'] + + def __init__(self, stype='power', top_db=None): + super(SpectrogramToDB, self).__init__() + self.stype = torch.jit.Attribute(stype, str) + if top_db is not None and top_db < 0: + raise ValueError('top_db must be positive value') + self.top_db = torch.jit.Attribute(top_db, Optional[float]) + self.multiplier = 10.0 if stype == 'power' else 20.0 + self.amin = 1e-10 + self.ref_value = 1.0 + self.db_multiplier = math.log10(max(self.amin, self.ref_value)) + + @torch.jit.script_method + def forward(self, specgram): + r"""Numerically stable implementation from Librosa + https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html + + Args: + specgram (torch.Tensor): STFT of size (c, f, t) + + Returns: + torch.Tensor: STFT after changing scale of size (c, f, t) + """ + return F.spectrogram_to_DB(specgram, self.multiplier, self.amin, self.db_multiplier, self.top_db) + + class MelSpectrogram(torch.jit.ScriptModule): r"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram and MelScale. From 9da5089b05d48bcae870b20f35d4ad6218e30177 Mon Sep 17 00:00:00 2001 From: Jason Lian Date: Wed, 24 Jul 2019 08:49:49 -0700 Subject: [PATCH 16/28] apply feedback: rearrange functions --- torchaudio/transforms.py | 120 +++++++++++++++++++-------------------- 1 file changed, 60 insertions(+), 60 deletions(-) diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 8cd4fe6f33..9be2995c2d 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -171,66 +171,6 @@ def forward(self, specgram): return F.spectrogram_to_DB(specgram, self.multiplier, self.amin, self.db_multiplier, self.top_db) -class MelSpectrogram(torch.jit.ScriptModule): - r"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram - and MelScale. - - Sources: - * https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe - * https://timsainb.github.io/spectrograms-mfccs-and-inversion-in-python.html - * http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html - - Args: - sample_rate (int): Sample rate of audio signal. (Default: 16000) - win_length (int): Window size. (Default: `n_fft`) - hop_length (int, optional): Length of hop between STFT windows. ( - Default: `win_length // 2`) - n_fft (int, optional): Size of fft, creates `n_fft // 2 + 1` bins - f_min (float): Minimum frequency. (Default: 0.) - f_max (float, optional): Maximum frequency. (Default: `None`) - pad (int): Two sided padding of signal. (Default: 0) - n_mels (int): Number of mel filterbanks. (Default: 128) - window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor - that is applied/multiplied to each frame/window. (Default: `torch.hann_window`) - wkwargs (Dict[..., ...]): Arguments for window function. (Default: `None`) - - Example: - >>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True) - >>> mel_specgram = transforms.MelSpectrogram(sample_rate)(waveform) # (c, n_mels, t) - """ - __constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad', 'n_mels', 'f_min'] - - def __init__(self, sample_rate=16000, n_fft=400, win_length=None, hop_length=None, f_min=0., f_max=None, - pad=0, n_mels=128, window_fn=torch.hann_window, wkwargs=None): - super(MelSpectrogram, self).__init__() - self.sample_rate = sample_rate - self.n_fft = n_fft - self.win_length = win_length if win_length is not None else n_fft - self.hop_length = hop_length if hop_length is not None else self.win_length // 2 - self.pad = pad - self.n_mels = n_mels # number of mel frequency bins - self.f_max = torch.jit.Attribute(f_max, Optional[float]) - self.f_min = f_min - self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length, - hop_length=self.hop_length, - pad=self.pad, window_fn=window_fn, power=2, - normalized=False, wkwargs=wkwargs) - self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max) - - @torch.jit.script_method - def forward(self, waveform): - r""" - Args: - waveform (torch.Tensor): Tensor of audio of size (c, n) - - Returns: - torch.Tensor: mel frequency spectrogram of size (c, `n_mels`, t) - """ - specgram = self.spectrogram(waveform) - mel_specgram = self.mel_scale(specgram) - return mel_specgram - - class MFCC(torch.jit.ScriptModule): r"""Create the Mel-frequency cepstrum coefficients from an audio signal @@ -296,6 +236,66 @@ def forward(self, waveform): return mfcc +class MelSpectrogram(torch.jit.ScriptModule): + r"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram + and MelScale. + + Sources: + * https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe + * https://timsainb.github.io/spectrograms-mfccs-and-inversion-in-python.html + * http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html + + Args: + sample_rate (int): Sample rate of audio signal. (Default: 16000) + win_length (int): Window size. (Default: `n_fft`) + hop_length (int, optional): Length of hop between STFT windows. ( + Default: `win_length // 2`) + n_fft (int, optional): Size of fft, creates `n_fft // 2 + 1` bins + f_min (float): Minimum frequency. (Default: 0.) + f_max (float, optional): Maximum frequency. (Default: `None`) + pad (int): Two sided padding of signal. (Default: 0) + n_mels (int): Number of mel filterbanks. (Default: 128) + window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor + that is applied/multiplied to each frame/window. (Default: `torch.hann_window`) + wkwargs (Dict[..., ...]): Arguments for window function. (Default: `None`) + + Example: + >>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True) + >>> mel_specgram = transforms.MelSpectrogram(sample_rate)(waveform) # (c, n_mels, t) + """ + __constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad', 'n_mels', 'f_min'] + + def __init__(self, sample_rate=16000, n_fft=400, win_length=None, hop_length=None, f_min=0., f_max=None, + pad=0, n_mels=128, window_fn=torch.hann_window, wkwargs=None): + super(MelSpectrogram, self).__init__() + self.sample_rate = sample_rate + self.n_fft = n_fft + self.win_length = win_length if win_length is not None else n_fft + self.hop_length = hop_length if hop_length is not None else self.win_length // 2 + self.pad = pad + self.n_mels = n_mels # number of mel frequency bins + self.f_max = torch.jit.Attribute(f_max, Optional[float]) + self.f_min = f_min + self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length, + hop_length=self.hop_length, + pad=self.pad, window_fn=window_fn, power=2, + normalized=False, wkwargs=wkwargs) + self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max) + + @torch.jit.script_method + def forward(self, waveform): + r""" + Args: + waveform (torch.Tensor): Tensor of audio of size (c, n) + + Returns: + torch.Tensor: mel frequency spectrogram of size (c, `n_mels`, t) + """ + specgram = self.spectrogram(waveform) + mel_specgram = self.mel_scale(specgram) + return mel_specgram + + class MuLawEncoding(torch.jit.ScriptModule): r"""Encode signal based on mu-law companding. For more info see the `Wikipedia Entry `_ From afe528ab9038d09b71f83ad2208cdf852137eaec Mon Sep 17 00:00:00 2001 From: Jason Lian Date: Wed, 24 Jul 2019 09:05:58 -0700 Subject: [PATCH 17/28] merge: delete stft --- test/test_functional.py | 56 +++--------------------------------- test/test_transforms.py | 1 - torchaudio/functional.py | 62 ++++------------------------------------ torchaudio/transforms.py | 3 -- 4 files changed, 9 insertions(+), 113 deletions(-) diff --git a/test/test_functional.py b/test/test_functional.py index 04ef533f05..774ce728b2 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -2,6 +2,7 @@ import torch import torchaudio +import pytest import unittest import test.common_utils @@ -11,8 +12,6 @@ import numpy as np import librosa -import pytest -import torchaudio.functional as F xfail = pytest.mark.xfail @@ -197,54 +196,6 @@ def _num_stft_bins(signal_len, fft_len, hop_length, pad): return (signal_len + 2 * pad - fft_len + hop_length) // hop_length -@pytest.mark.parametrize('fft_length', [512]) -@pytest.mark.parametrize('hop_length', [256]) -@pytest.mark.parametrize('waveform', [ - (torch.randn(1, 100000)), - (torch.randn(1, 2, 100000)), - pytest.param(torch.randn(1, 100), marks=xfail(raises=RuntimeError)), -]) -@pytest.mark.parametrize('pad_mode', [ - # 'constant', - 'reflect', -]) -@unittest.skipIf(not IMPORT_LIBROSA, 'Librosa is not available') -def test_stft(waveform, fft_length, hop_length, pad_mode): - """ - Test STFT for multi-channel signals. - - Padding: Value in having padding outside of torch.stft? - """ - pad = fft_length // 2 - window = torch.hann_window(fft_length) - complex_spec = F.stft(waveform, - fft_length=fft_length, - hop_length=hop_length, - window=window, - pad_mode=pad_mode) - mag_spec, phase_spec = F.magphase(complex_spec) - - # == Test shape - expected_size = list(waveform.size()[:-1]) - expected_size += [fft_length // 2 + 1, _num_stft_bins( - waveform.size(-1), fft_length, hop_length, pad), 2] - assert complex_spec.dim() == waveform.dim() + 2 - assert complex_spec.size() == torch.Size(expected_size) - - # == Test values - fft_config = dict(n_fft=fft_length, hop_length=hop_length, pad_mode=pad_mode) - # note that librosa *automatically* pad with fft_length // 2. - expected_complex_spec = np.apply_along_axis(librosa.stft, -1, - waveform.numpy(), **fft_config) - expected_mag_spec, _ = librosa.magphase(expected_complex_spec) - # Convert torch to np.complex - complex_spec = complex_spec.numpy() - complex_spec = complex_spec[..., 0] + 1j * complex_spec[..., 1] - - assert np.allclose(complex_spec, expected_complex_spec, atol=1e-5) - assert np.allclose(mag_spec.numpy(), expected_mag_spec, atol=1e-5) - - @pytest.mark.parametrize('rate', [0.5, 1.01, 1.3]) @pytest.mark.parametrize('complex_specgrams', [ torch.randn(1, 2, 1025, 400, 2), @@ -261,7 +212,8 @@ def test_phase_vocoder(complex_specgrams, rate, hop_length): complex_specgrams = complex_specgrams.type(torch.float64) phase_advance = torch.linspace(0, np.pi * hop_length, complex_specgrams.shape[-3], dtype=torch.float64)[..., None] - complex_specgrams_stretch = F.phase_vocoder(complex_specgrams, rate=rate, phase_advance=phase_advance) + complex_specgrams_stretch = torchaudio.functional.phase_vocoder( + complex_specgrams, rate=rate, phase_advance=phase_advance) # == Test shape expected_size = list(complex_specgrams.size()) @@ -292,7 +244,7 @@ def test_phase_vocoder(complex_specgrams, rate, hop_length): @pytest.mark.parametrize('power', [1, 2, 0.7]) def test_complex_norm(complex_tensor, power): expected_norm_tensor = complex_tensor.pow(2).sum(-1).pow(power / 2) - norm_tensor = F.complex_norm(complex_tensor, power) + norm_tensor = torchaudio.functional.complex_norm(complex_tensor, power) assert torch.allclose(expected_norm_tensor, norm_tensor, atol=1e-5) diff --git a/test/test_transforms.py b/test/test_transforms.py index 954b3db24f..1d3a41a564 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -50,7 +50,6 @@ def test_pad_trim(self): result = transforms.PadTrim(max_len=length_new)(waveform) self.assertEqual(result.size(1), length_new) - def test_mu_law_companding(self): quantization_channels = 256 diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 318b4b39cc..d82c586e61 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -10,7 +10,11 @@ 'spectrogram_to_DB', 'create_dct', 'mu_law_encoding', - 'mu_law_expanding' + 'mu_law_expanding', + 'complex_norm', + 'angle', + 'magphase', + 'phase_vocoder', ] @@ -373,62 +377,6 @@ def mu_law_expanding(x_mu, qc): return x -def stft(waveforms, fft_length, hop_length=None, win_length=None, window=None, - center=True, pad_mode='reflect', normalized=False, onesided=True): - """Compute a short time Fourier transform of the input waveform(s). - It wraps `torch.stft` after reshaping the input audio to allow for `waveforms` that `.dim()` >= 3. - It follows most of the `torch.stft` default values, but for `window`, which defaults to hann window. - - Args: - waveforms (torch.Tensor): Audio signal of size `(*, channel, time)` - fft_length (int): FFT size [sample]. - hop_length (int): Hop size [sample] between STFT frames. - (Defaults to `fft_length // 4`, 75%-overlapping windows by `torch.stft`). - win_length (int): Size of STFT window. (Defaults to `fft_length` by `torch.stft`). - window (torch.Tensor): window function. (Defaults to Hann Window of size `win_length` *unlike* `torch.stft`). - center (bool): Whether to pad `waveforms` on both sides so that the `t`-th frame is centered - at time `t * hop_length`. (Defaults to `True` by `torch.stft`) - pad_mode (str): padding method (see `torch.nn.functional.pad`). (Defaults to `'reflect'` by `torch.stft`). - normalized (bool): Whether the results are normalized. (Defaults to `False` by `torch.stft`). - onesided (bool): Whether the half + 1 frequency bins are returned to removethe symmetric part of STFT - of real-valued signal. (Defaults to `True` by `torch.stft`). - - Returns: - torch.Tensor: `(*, channel, num_freqs, time, complex=2)` - - Example: - >>> waveforms = torch.randn(16, 2, 10000) # (batch, channel, time) - >>> x = stft(waveforms, 2048, 512) - >>> x.shape - torch.Size([16, 2, 1025, 20]) - """ - leading_dims = waveforms.shape[:-1] - - waveforms = waveforms.reshape(-1, waveforms.size(-1)) - - if window is None: - if win_length is None: - window = torch.hann_window(fft_length) - else: - window = torch.hann_window(win_length) - - complex_specgrams = torch.stft(waveforms, - n_fft=fft_length, - hop_length=hop_length, - win_length=win_length, - window=window, - center=center, - pad_mode=pad_mode, - normalized=normalized, - onesided=onesided) - - complex_specgrams = complex_specgrams.reshape( - leading_dims + - complex_specgrams.shape[1:]) - - return complex_specgrams - - def complex_norm(complex_tensor, power=1.0): """Compute the norm of complex tensor input diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 9be2995c2d..99dcca771b 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -81,9 +81,6 @@ def forward(self, waveform): self.win_length, self.power, self.normalized) - - - class MelScale(torch.jit.ScriptModule): r"""This turns a normal STFT into a mel frequency STFT, using a conversion matrix. This uses triangular filter banks. From be082f8581e5b61a7f47d02142b6fd879a4ca670 Mon Sep 17 00:00:00 2001 From: Jason Lian Date: Wed, 24 Jul 2019 09:09:45 -0700 Subject: [PATCH 18/28] merge --- test/test_functional.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_functional.py b/test/test_functional.py index 774ce728b2..0ad10ae916 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -2,6 +2,7 @@ import torch import torchaudio +import torchaudio.functional as F import pytest import unittest import test.common_utils @@ -212,8 +213,7 @@ def test_phase_vocoder(complex_specgrams, rate, hop_length): complex_specgrams = complex_specgrams.type(torch.float64) phase_advance = torch.linspace(0, np.pi * hop_length, complex_specgrams.shape[-3], dtype=torch.float64)[..., None] - complex_specgrams_stretch = torchaudio.functional.phase_vocoder( - complex_specgrams, rate=rate, phase_advance=phase_advance) + complex_specgrams_stretch = F.phase_vocoder(complex_specgrams, rate=rate, phase_advance=phase_advance) # == Test shape expected_size = list(complex_specgrams.size()) @@ -244,7 +244,7 @@ def test_phase_vocoder(complex_specgrams, rate, hop_length): @pytest.mark.parametrize('power', [1, 2, 0.7]) def test_complex_norm(complex_tensor, power): expected_norm_tensor = complex_tensor.pow(2).sum(-1).pow(power / 2) - norm_tensor = torchaudio.functional.complex_norm(complex_tensor, power) + norm_tensor = F.complex_norm(complex_tensor, power) assert torch.allclose(expected_norm_tensor, norm_tensor, atol=1e-5) From a7aa4409c93fda8b3d6f8d8b28a59a7815255ab1 Mon Sep 17 00:00:00 2001 From: Jason Lian Date: Wed, 24 Jul 2019 09:13:24 -0700 Subject: [PATCH 19/28] remove batch support for istft --- torchaudio/functional.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index d82c586e61..28b33a2619 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -89,8 +89,7 @@ def istft(stft_matrix, # type: Tensor Args: stft_matrix (torch.Tensor): Output of stft where each row of a batch is a frequency and each - column is a window. it has a shape of either (batch, fft_size, n_frames, 2) or ( - fft_size, n_frames, 2) + column is a window. it has a shape of (fft_size, n_frames, 2) n_fft (int): Size of Fourier transform hop_length (Optional[int]): The distance between neighboring sliding window frames. (Default: ``win_length // 4``) @@ -107,10 +106,13 @@ def istft(stft_matrix, # type: Tensor Returns: torch.Tensor: Least squares estimation of the original signal of size - (batch, signal_length) or (signal_length) + (signal_length) """ stft_matrix_dim = stft_matrix.dim() - assert 3 <= stft_matrix_dim <= 4, ('Incorrect stft dimension: %d' % (stft_matrix_dim)) + # Technically this function can accept either (batch, fft_size, n_frames, 2) or + # (fft_size, n_frames, 2). But going to temporarily remove batch support ( + # through adding an assert) to make torchaudio functions consistent. + assert stft_matrix_dim == 3, ('Incorrect stft dimension: %d' % (stft_matrix_dim)) if stft_matrix_dim == 3: # add a batch dimension From af1c8c80e013bc6667a40ec7b94362d00492d321 Mon Sep 17 00:00:00 2001 From: Jason Lian Date: Wed, 24 Jul 2019 09:42:01 -0700 Subject: [PATCH 20/28] docstring --- torchaudio/functional.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 28b33a2619..79aa21d791 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -121,8 +121,8 @@ def istft(stft_matrix, # type: Tensor device = stft_matrix.device fft_size = stft_matrix.size(1) assert (onesided and n_fft // 2 + 1 == fft_size) or (not onesided and n_fft == fft_size), ( - 'one_sided implies that n_fft // 2 + 1 == fft_size and not one_sided implies n_fft == fft_size. ' - + 'Given values were onesided: %s, n_fft: %d, fft_size: %d' % ('True' if onesided else False, n_fft, fft_size)) + 'one_sided implies that n_fft // 2 + 1 == fft_size and not one_sided implies n_fft == fft_size. ' + + 'Given values were onesided: %s, n_fft: %d, fft_size: %d' % ('True' if onesided else False, n_fft, fft_size)) # use stft defaults for Optionals if win_length is None: @@ -383,11 +383,11 @@ def complex_norm(complex_tensor, power=1.0): """Compute the norm of complex tensor input Args: - complex_tensor (Tensor): Tensor shape of `(*, complex=2)` + complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)` power (float): Power of the norm. Defaults to `1.0`. Returns: - Tensor: power of the normed input tensor, shape of `(*, )` + torch.Tensor: power of the normed input tensor, shape of `(*, )` """ if power == 1.0: return torch.norm(complex_tensor, 2, -1) @@ -417,15 +417,14 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): without modifying pitch by a factor of `rate`. Args: - complex_specgrams (Tensor): - (*, channel, num_freqs, time, complex=2) + complex_specgrams (torch.Tensor): Size of (*, c, f, t, complex=2) rate (float): Speed-up factor. - phase_advance (Tensor): Expected phase advance in - each bin. (num_freqs, 1). + phase_advance (torch.Tensor): Expected phase advance in + each bin. Size of (f, 1). Returns: - complex_specgrams_stretch (Tensor): - (*, channel, num_freqs, ceil(time/rate), complex=2). + complex_specgrams_stretch (torch.Tensor): + (*, c, f, ceil(t/rate), complex=2). Example: >>> num_freqs, hop_length = 1025, 512 From 44e1f4deb0688cb1d0c1347122736d11261dee3a Mon Sep 17 00:00:00 2001 From: Jason Lian Date: Wed, 24 Jul 2019 09:47:00 -0700 Subject: [PATCH 21/28] docstring --- torchaudio/functional.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 79aa21d791..020f509f65 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -384,7 +384,7 @@ def complex_norm(complex_tensor, power=1.0): Args: complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)` - power (float): Power of the norm. Defaults to `1.0`. + power (float): Power of the norm. (Default: `1.0`). Returns: torch.Tensor: power of the normed input tensor, shape of `(*, )` @@ -396,7 +396,11 @@ def complex_norm(complex_tensor, power=1.0): def angle(complex_tensor): """ - Return angle of a complex tensor with shape (*, 2). + Args: + complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)` + + Return: + torch.Tensor: Angle of a complex tensor """ return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0]) @@ -405,6 +409,13 @@ def magphase(complex_tensor, power=1.): """ Separate a complex-valued spectrogram with shape (*,2) into its magnitude and phase. + + Args: + complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)` + power (float): Power of the norm. (Default: `1.0`). + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The magnitude and phase of the complex_tensor """ mag = complex_norm(complex_tensor, power) phase = angle(complex_tensor) From b383f0030e863bb97e69c3a970d605e7cbc17809 Mon Sep 17 00:00:00 2001 From: Jason Lian Date: Wed, 24 Jul 2019 09:52:04 -0700 Subject: [PATCH 22/28] docstring --- torchaudio/functional.py | 27 +++++++++++---------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 020f509f65..992e8e89ee 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -380,14 +380,14 @@ def mu_law_expanding(x_mu, qc): def complex_norm(complex_tensor, power=1.0): - """Compute the norm of complex tensor input + r"""Compute the norm of complex tensor input. Args: complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)` power (float): Power of the norm. (Default: `1.0`). Returns: - torch.Tensor: power of the normed input tensor, shape of `(*, )` + torch.Tensor: Power of the normed input tensor. Shape of `(*, )` """ if power == 1.0: return torch.norm(complex_tensor, 2, -1) @@ -395,24 +395,22 @@ def complex_norm(complex_tensor, power=1.0): def angle(complex_tensor): - """ + r""" Args: complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)` Return: - torch.Tensor: Angle of a complex tensor + torch.Tensor: Angle of a complex tensor. Shape of `(*, )` """ return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0]) def magphase(complex_tensor, power=1.): - """ - Separate a complex-valued spectrogram with shape (*,2) - into its magnitude and phase. + r"""Separate a complex-valued spectrogram with shape (*,2) into its magnitude and phase. Args: complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)` - power (float): Power of the norm. (Default: `1.0`). + power (float): Power of the norm. (Default: `1.0`) Returns: Tuple[torch.Tensor, torch.Tensor]: The magnitude and phase of the complex_tensor @@ -423,19 +421,16 @@ def magphase(complex_tensor, power=1.): def phase_vocoder(complex_specgrams, rate, phase_advance): - """ - Phase vocoder. Given a STFT tensor, speed up in time - without modifying pitch by a factor of `rate`. + r"""Given a STFT tensor, speed up in time without modifying pitch by a + factor of `rate`. Args: complex_specgrams (torch.Tensor): Size of (*, c, f, t, complex=2) - rate (float): Speed-up factor. - phase_advance (torch.Tensor): Expected phase advance in - each bin. Size of (f, 1). + rate (float): Speed-up factor + phase_advance (torch.Tensor): Expected phase advance in each bin. Size of (f, 1) Returns: - complex_specgrams_stretch (torch.Tensor): - (*, c, f, ceil(t/rate), complex=2). + complex_specgrams_stretch (torch.Tensor): Size of (*, c, f, ceil(t/rate), complex=2) Example: >>> num_freqs, hop_length = 1025, 512 From fea5c0625ea693581281bcc5e6a2c4614918269c Mon Sep 17 00:00:00 2001 From: Jason Lian Date: Wed, 24 Jul 2019 09:53:05 -0700 Subject: [PATCH 23/28] docstring --- torchaudio/functional.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 992e8e89ee..d47710c449 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -395,7 +395,8 @@ def complex_norm(complex_tensor, power=1.0): def angle(complex_tensor): - r""" + r"""Compute the angle of complex tensor input. + Args: complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)` From a4f7d0f5fbd950a50da3f6d3421866910b85a18b Mon Sep 17 00:00:00 2001 From: Jason Lian Date: Wed, 24 Jul 2019 09:55:25 -0700 Subject: [PATCH 24/28] more --- torchaudio/functional.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index d47710c449..af91df7f54 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -42,9 +42,9 @@ def pad_trim(waveform, max_len, fill_value): # TODO: remove this once https://github.com/pytorch/pytorch/issues/21478 gets solved @torch.jit.ignore -def _stft(input, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided): +def _stft(waveform, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided): # type: (Tensor, int, Optional[int], Optional[int], Optional[Tensor], bool, str, bool, bool) -> Tensor - return torch.stft(input, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided) + return torch.stft(waveform, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided) def istft(stft_matrix, # type: Tensor From 3997d12a261ba59eaedfb9758489196da7301780 Mon Sep 17 00:00:00 2001 From: Jason Lian Date: Wed, 24 Jul 2019 09:57:16 -0700 Subject: [PATCH 25/28] more --- torchaudio/functional.py | 12 ++++++------ torchaudio/transforms.py | 18 ++++++------------ 2 files changed, 12 insertions(+), 18 deletions(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index af91df7f54..10e7ae2b67 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -329,7 +329,7 @@ def create_dct(n_mfcc, n_mels, norm): @torch.jit.script -def mu_law_encoding(x, qc): +def mu_law_encoding(x, quantization_channels): # type: (Tensor, int) -> Tensor r"""Encode signal based on mu-law companding. For more info see the `Wikipedia Entry `_ @@ -339,12 +339,12 @@ def mu_law_encoding(x, qc): Args: x (torch.Tensor): Input tensor - qc (int): Number of channels (i.e. quantization channels) + quantization_channels (int): Number of channels Returns: torch.Tensor: Input after mu-law companding """ - mu = qc - 1. + mu = quantization_channels - 1. if not x.is_floating_point(): x = x.to(torch.float) mu = torch.tensor(mu, dtype=x.dtype) @@ -355,7 +355,7 @@ def mu_law_encoding(x, qc): @torch.jit.script -def mu_law_expanding(x_mu, qc): +def mu_law_expanding(x_mu, quantization_channels): # type: (Tensor, int) -> Tensor r"""Decode mu-law encoded signal. For more info see the `Wikipedia Entry `_ @@ -365,12 +365,12 @@ def mu_law_expanding(x_mu, qc): Args: x_mu (torch.Tensor): Input tensor - qc (int): Number of channels (i.e. quantization channels) + quantization_channels (int): Number of channels Returns: torch.Tensor: Input after decoding """ - mu = qc - 1. + mu = quantization_channels - 1. if not x_mu.is_floating_point(): x_mu = x_mu.to(torch.float) mu = torch.tensor(mu, dtype=x_mu.dtype) diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 99dcca771b..e1c821ea2f 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -303,11 +303,11 @@ class MuLawEncoding(torch.jit.ScriptModule): Args: quantization_channels (int): Number of channels. default: 256 """ - __constants__ = ['qc'] + __constants__ = ['quantization_channels'] def __init__(self, quantization_channels=256): super(MuLawEncoding, self).__init__() - self.qc = quantization_channels + self.quantization_channels = quantization_channels @torch.jit.script_method def forward(self, x): @@ -318,10 +318,7 @@ def forward(self, x): Returns: x_mu (torch.Tensor): An encoded signal """ - return F.mu_law_encoding(x, self.qc) - - def __repr__(self): - return self.__class__.__name__ + '()' + return F.mu_law_encoding(x, self.quantization_channels) class MuLawExpanding(torch.jit.ScriptModule): @@ -334,11 +331,11 @@ class MuLawExpanding(torch.jit.ScriptModule): Args: quantization_channels (int): Number of channels. default: 256 """ - __constants__ = ['qc'] + __constants__ = ['quantization_channels'] def __init__(self, quantization_channels=256): super(MuLawExpanding, self).__init__() - self.qc = quantization_channels + self.quantization_channels = quantization_channels @torch.jit.script_method def forward(self, x_mu): @@ -349,10 +346,7 @@ def forward(self, x_mu): Returns: torch.Tensor: The signal decoded """ - return F.mu_law_expanding(x_mu, self.qc) - - def __repr__(self): - return self.__class__.__name__ + '()' + return F.mu_law_expanding(x_mu, self.quantization_channels) class Resample(torch.nn.Module): From dc226c9c3993aff086241131523162db991f0790 Mon Sep 17 00:00:00 2001 From: Jason Lian Date: Wed, 24 Jul 2019 10:39:32 -0700 Subject: [PATCH 26/28] remove unused xfail --- test/test_functional.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/test_functional.py b/test/test_functional.py index 0ad10ae916..02b01620ed 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -13,8 +13,6 @@ import numpy as np import librosa -xfail = pytest.mark.xfail - class TestFunctional(unittest.TestCase): data_sizes = [(2, 20), (3, 15), (4, 10)] From ab4ecb6b142b66d7efb885b7bf8c8452a6e614b3 Mon Sep 17 00:00:00 2001 From: Jason Lian Date: Wed, 24 Jul 2019 10:43:03 -0700 Subject: [PATCH 27/28] Revert "remove batch support for istft" This reverts commit a7aa4409c93fda8b3d6f8d8b28a59a7815255ab1. --- torchaudio/functional.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 10e7ae2b67..5ad0982ee1 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -89,7 +89,8 @@ def istft(stft_matrix, # type: Tensor Args: stft_matrix (torch.Tensor): Output of stft where each row of a batch is a frequency and each - column is a window. it has a shape of (fft_size, n_frames, 2) + column is a window. it has a shape of either (batch, fft_size, n_frames, 2) or ( + fft_size, n_frames, 2) n_fft (int): Size of Fourier transform hop_length (Optional[int]): The distance between neighboring sliding window frames. (Default: ``win_length // 4``) @@ -106,13 +107,10 @@ def istft(stft_matrix, # type: Tensor Returns: torch.Tensor: Least squares estimation of the original signal of size - (signal_length) + (batch, signal_length) or (signal_length) """ stft_matrix_dim = stft_matrix.dim() - # Technically this function can accept either (batch, fft_size, n_frames, 2) or - # (fft_size, n_frames, 2). But going to temporarily remove batch support ( - # through adding an assert) to make torchaudio functions consistent. - assert stft_matrix_dim == 3, ('Incorrect stft dimension: %d' % (stft_matrix_dim)) + assert 3 <= stft_matrix_dim <= 4, ('Incorrect stft dimension: %d' % (stft_matrix_dim)) if stft_matrix_dim == 3: # add a batch dimension From 99675a4fe0bad35ef4724de3950059d7e9c3a649 Mon Sep 17 00:00:00 2001 From: Jason Lian Date: Wed, 24 Jul 2019 10:45:50 -0700 Subject: [PATCH 28/28] rename batch to channel --- torchaudio/functional.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 5ad0982ee1..065de43a5e 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -88,8 +88,8 @@ def istft(stft_matrix, # type: Tensor IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984. Args: - stft_matrix (torch.Tensor): Output of stft where each row of a batch is a frequency and each - column is a window. it has a shape of either (batch, fft_size, n_frames, 2) or ( + stft_matrix (torch.Tensor): Output of stft where each row of a channel is a frequency and each + column is a window. it has a shape of either (channel, fft_size, n_frames, 2) or ( fft_size, n_frames, 2) n_fft (int): Size of Fourier transform hop_length (Optional[int]): The distance between neighboring sliding window frames. @@ -107,13 +107,13 @@ def istft(stft_matrix, # type: Tensor Returns: torch.Tensor: Least squares estimation of the original signal of size - (batch, signal_length) or (signal_length) + (channel, signal_length) or (signal_length) """ stft_matrix_dim = stft_matrix.dim() assert 3 <= stft_matrix_dim <= 4, ('Incorrect stft dimension: %d' % (stft_matrix_dim)) if stft_matrix_dim == 3: - # add a batch dimension + # add a channel dimension stft_matrix = stft_matrix.unsqueeze(0) device = stft_matrix.device @@ -145,16 +145,16 @@ def istft(stft_matrix, # type: Tensor assert window.size(0) == n_fft # win_length and n_fft are synonymous from here on - stft_matrix = stft_matrix.transpose(1, 2) # size (batch, n_frames, fft_size, 2) + stft_matrix = stft_matrix.transpose(1, 2) # size (channel, n_frames, fft_size, 2) stft_matrix = torch.irfft(stft_matrix, 1, normalized, - onesided, signal_sizes=(n_fft,)) # size (batch, n_frames, n_fft) + onesided, signal_sizes=(n_fft,)) # size (channel, n_frames, n_fft) assert stft_matrix.size(2) == n_fft n_frames = stft_matrix.size(1) - ytmp = stft_matrix * window.view(1, 1, n_fft) # size (batch, n_frames, n_fft) - # each column of a batch is a frame which needs to be overlap added at the right place - ytmp = ytmp.transpose(1, 2) # size (batch, n_fft, n_frames) + ytmp = stft_matrix * window.view(1, 1, n_fft) # size (channel, n_frames, n_fft) + # each column of a channel is a frame which needs to be overlap added at the right place + ytmp = ytmp.transpose(1, 2) # size (channel, n_fft, n_frames) eye = torch.eye(n_fft, requires_grad=False, device=device).unsqueeze(1) # size (n_fft, 1, n_fft) @@ -162,7 +162,7 @@ def istft(stft_matrix, # type: Tensor # this does overlap add where the frames of ytmp are added such that the i'th frame of # ytmp is added starting at i*hop_length in the output y = torch.nn.functional.conv_transpose1d( - ytmp, eye, stride=hop_length, padding=0) # size (batch, 1, expected_signal_len) + ytmp, eye, stride=hop_length, padding=0) # size (channel, 1, expected_signal_len) # do the same for the window function window_sq = window.pow(2).view(n_fft, 1).repeat((1, n_frames)).unsqueeze(0) # size (1, n_fft, n_frames) @@ -185,9 +185,9 @@ def istft(stft_matrix, # type: Tensor window_envelop_lowest = window_envelop.abs().min() assert window_envelop_lowest > 1e-11, ('window overlap add min: %f' % (window_envelop_lowest)) - y = (y / window_envelop).squeeze(1) # size (batch, expected_signal_len) + y = (y / window_envelop).squeeze(1) # size (channel, expected_signal_len) - if stft_matrix_dim == 3: # remove the batch dimension + if stft_matrix_dim == 3: # remove the channel dimension y = y.squeeze(0) return y