From 0bcd11ab19416f8d5c32c722b26e9bebb4f55945 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Wed, 15 Jan 2020 10:51:44 -0500 Subject: [PATCH] replace reshape by view. --- torchaudio/functional.py | 43 ++++++++++++++++++++-------------------- torchaudio/transforms.py | 8 ++++---- 2 files changed, 25 insertions(+), 26 deletions(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 90fb87758f..c93aaa5741 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -129,7 +129,7 @@ def istft( # pack batch shape = stft_matrix.size() - stft_matrix = stft_matrix.reshape(-1, shape[-3], shape[-2], shape[-1]) + stft_matrix = stft_matrix.view(-1, shape[-3], shape[-2], shape[-1]) dtype = stft_matrix.dtype device = stft_matrix.device @@ -214,7 +214,7 @@ def istft( y = (y / window_envelop).squeeze(1) # size (channel, expected_signal_len) # unpack batch - y = y.reshape(shape[:-3] + y.shape[-1:]) + y = y.view(shape[:-3] + y.shape[-1:]) if stft_matrix_dim == 3: # remove the channel dimension y = y.squeeze(0) @@ -253,7 +253,7 @@ def spectrogram( # pack batch shape = waveform.size() - waveform = waveform.reshape(-1, shape[-1]) + waveform = waveform.view(-1, shape[-1]) # default values are consistent with librosa.core.spectrum._spectrogram spec_f = _stft( @@ -261,7 +261,7 @@ def spectrogram( ) # unpack batch - spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-3:]) + spec_f = spec_f.view(shape[:-1] + spec_f.shape[-3:]) if normalized: spec_f /= window.pow(2.).sum().sqrt() @@ -317,7 +317,7 @@ def griffinlim( # pack batch shape = specgram.size() - specgram = specgram.reshape([-1] + list(shape[-2:])) + specgram = specgram.view([-1] + list(shape[-2:])) specgram = specgram.pow(1 / power) @@ -363,7 +363,7 @@ def griffinlim( length=length) # unpack batch - waveform = waveform.reshape(shape[:-2] + waveform.shape[-1:]) + waveform = waveform.view(shape[:-2] + waveform.shape[-1:]) return waveform @@ -587,7 +587,7 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): # pack batch shape = complex_specgrams.size() - complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:])) + complex_specgrams = complex_specgrams.view([-1] + list(shape[-3:])) time_steps = torch.arange(0, complex_specgrams.size(-2), @@ -627,7 +627,7 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): complex_specgrams_stretch = torch.stack([real_stretch, imag_stretch], dim=-1) # unpack batch - complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-3] + complex_specgrams_stretch.shape[1:]) + complex_specgrams_stretch = complex_specgrams_stretch.view(shape[:-3] + complex_specgrams_stretch.shape[1:]) return complex_specgrams_stretch @@ -654,7 +654,7 @@ def lfilter(waveform, a_coeffs, b_coeffs): # pack batch shape = waveform.size() - waveform = waveform.reshape(-1, shape[-1]) + waveform = waveform.view(-1, shape[-1]) assert(a_coeffs.size(0) == b_coeffs.size(0)) assert(len(waveform.size()) == 2) @@ -697,7 +697,7 @@ def lfilter(waveform, a_coeffs, b_coeffs): output = torch.clamp(padded_output_waveform[:, (n_order - 1):], min=-1., max=1.) # unpack batch - output = output.reshape(shape[:-1] + output.shape[-1:]) + output = output.view(shape[:-1] + output.shape[-1:]) return output @@ -876,7 +876,7 @@ def mask_along_axis(specgram, mask_param, mask_value, axis): # pack batch shape = specgram.size() - specgram = specgram.reshape([-1] + list(shape[-2:])) + specgram = specgram.view([-1] + list(shape[-2:])) value = torch.rand(1) * mask_param min_value = torch.rand(1) * (specgram.size(axis) - value) @@ -893,7 +893,7 @@ def mask_along_axis(specgram, mask_param, mask_value, axis): raise ValueError('Only Frequency and Time masking are supported') # unpack batch - specgram = specgram.reshape(shape[:-2] + specgram.shape[-2:]) + specgram = specgram.view(shape[:-2] + specgram.shape[-2:]) return specgram @@ -925,7 +925,7 @@ def compute_deltas(specgram, win_length=5, mode="replicate"): # pack batch shape = specgram.size() - specgram = specgram.reshape(1, -1, shape[-1]) + specgram = specgram.view(1, -1, shape[-1]) assert win_length >= 3 @@ -945,7 +945,7 @@ def compute_deltas(specgram, win_length=5, mode="replicate"): output = torch.nn.functional.conv1d(specgram, kernel, groups=specgram.shape[1]) / denom # unpack batch - output = output.reshape(shape) + output = output.view(shape) return output @@ -974,11 +974,10 @@ def _add_noise_shaping(dithered_waveform, waveform): error[n] = dithered[n] - original[n] noise_shaped_waveform[n] = dithered[n] + error[n-1] """ - wf_shape = waveform.size() - waveform = waveform.reshape(-1, wf_shape[-1]) + waveform = waveform.view(-1, waveform.size()[-1]) dithered_shape = dithered_waveform.size() - dithered_waveform = dithered_waveform.reshape(-1, dithered_shape[-1]) + dithered_waveform = dithered_waveform.view(-1, dithered_shape[-1]) error = dithered_waveform - waveform @@ -989,7 +988,7 @@ def _add_noise_shaping(dithered_waveform, waveform): error[index] = error_offset[:waveform.size()[1]] noise_shaped = dithered_waveform + error - return noise_shaped.reshape(dithered_shape[:-1] + noise_shaped.shape[-1:]) + return noise_shaped.view(dithered_shape[:-1] + noise_shaped.shape[-1:]) def _apply_probability_distribution(waveform, density_function="TPDF"): @@ -1020,7 +1019,7 @@ def _apply_probability_distribution(waveform, density_function="TPDF"): # pack batch shape = waveform.size() - waveform = waveform.reshape(-1, shape[-1]) + waveform = waveform.view(-1, shape[-1]) channel_size = waveform.size()[0] - 1 time_size = waveform.size()[-1] - 1 @@ -1060,7 +1059,7 @@ def _apply_probability_distribution(waveform, density_function="TPDF"): quantised_signal = quantised_signal_scaled / down_scaling # unpack batch - return quantised_signal.reshape(shape[:-1] + quantised_signal.shape[-1:]) + return quantised_signal.view(shape[:-1] + quantised_signal.shape[-1:]) def dither(waveform, density_function="TPDF", noise_shaping=False): @@ -1231,7 +1230,7 @@ def detect_pitch_frequency( # pack batch shape = list(waveform.size()) - waveform = waveform.reshape([-1] + shape[-1:]) + waveform = waveform.view([-1] + shape[-1:]) nccf = _compute_nccf(waveform, sample_rate, frame_time, freq_low) indices = _find_max_per_frame(nccf, sample_rate, freq_high) @@ -1242,6 +1241,6 @@ def detect_pitch_frequency( freq = sample_rate / (EPSILON + indices.to(torch.float)) # unpack batch - freq = freq.reshape(shape[:-1] + list(freq.shape[-1:])) + freq = freq.view(shape[:-1] + list(freq.shape[-1:])) return freq diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index e2e461b82e..10390e6651 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -215,7 +215,7 @@ def forward(self, specgram): # pack batch shape = specgram.size() - specgram = specgram.reshape(-1, shape[-2], shape[-1]) + specgram = specgram.view(-1, shape[-2], shape[-1]) if self.fb.numel() == 0: tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max, self.n_mels, self.sample_rate) @@ -228,7 +228,7 @@ def forward(self, specgram): mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2) # unpack batch - mel_specgram = mel_specgram.reshape(shape[:-2] + mel_specgram.shape[-2:]) + mel_specgram = mel_specgram.view(shape[:-2] + mel_specgram.shape[-2:]) return mel_specgram @@ -349,7 +349,7 @@ def forward(self, waveform): # pack batch shape = waveform.size() - waveform = waveform.reshape(-1, shape[-1]) + waveform = waveform.view(-1, shape[-1]) mel_specgram = self.MelSpectrogram(waveform) if self.log_mels: @@ -362,7 +362,7 @@ def forward(self, waveform): mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2) # unpack batch - mfcc = mfcc.reshape(shape[:-1] + mfcc.shape[-2:]) + mfcc = mfcc.view(shape[:-1] + mfcc.shape[-2:]) return mfcc