Skip to content

Commit 60fd113

Browse files
authored
replace reshape by view. (#409)
1 parent b32606d commit 60fd113

File tree

2 files changed

+25
-26
lines changed

2 files changed

+25
-26
lines changed

torchaudio/functional.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def istft(
129129

130130
# pack batch
131131
shape = stft_matrix.size()
132-
stft_matrix = stft_matrix.reshape(-1, shape[-3], shape[-2], shape[-1])
132+
stft_matrix = stft_matrix.view(-1, shape[-3], shape[-2], shape[-1])
133133

134134
dtype = stft_matrix.dtype
135135
device = stft_matrix.device
@@ -214,7 +214,7 @@ def istft(
214214
y = (y / window_envelop).squeeze(1) # size (channel, expected_signal_len)
215215

216216
# unpack batch
217-
y = y.reshape(shape[:-3] + y.shape[-1:])
217+
y = y.view(shape[:-3] + y.shape[-1:])
218218

219219
if stft_matrix_dim == 3: # remove the channel dimension
220220
y = y.squeeze(0)
@@ -253,15 +253,15 @@ def spectrogram(
253253

254254
# pack batch
255255
shape = waveform.size()
256-
waveform = waveform.reshape(-1, shape[-1])
256+
waveform = waveform.view(-1, shape[-1])
257257

258258
# default values are consistent with librosa.core.spectrum._spectrogram
259259
spec_f = _stft(
260260
waveform, n_fft, hop_length, win_length, window, True, "reflect", False, True
261261
)
262262

263263
# unpack batch
264-
spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-3:])
264+
spec_f = spec_f.view(shape[:-1] + spec_f.shape[-3:])
265265

266266
if normalized:
267267
spec_f /= window.pow(2.).sum().sqrt()
@@ -317,7 +317,7 @@ def griffinlim(
317317

318318
# pack batch
319319
shape = specgram.size()
320-
specgram = specgram.reshape([-1] + list(shape[-2:]))
320+
specgram = specgram.view([-1] + list(shape[-2:]))
321321

322322
specgram = specgram.pow(1 / power)
323323

@@ -363,7 +363,7 @@ def griffinlim(
363363
length=length)
364364

365365
# unpack batch
366-
waveform = waveform.reshape(shape[:-2] + waveform.shape[-1:])
366+
waveform = waveform.view(shape[:-2] + waveform.shape[-1:])
367367

368368
return waveform
369369

@@ -587,7 +587,7 @@ def phase_vocoder(complex_specgrams, rate, phase_advance):
587587

588588
# pack batch
589589
shape = complex_specgrams.size()
590-
complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:]))
590+
complex_specgrams = complex_specgrams.view([-1] + list(shape[-3:]))
591591

592592
time_steps = torch.arange(0,
593593
complex_specgrams.size(-2),
@@ -627,7 +627,7 @@ def phase_vocoder(complex_specgrams, rate, phase_advance):
627627
complex_specgrams_stretch = torch.stack([real_stretch, imag_stretch], dim=-1)
628628

629629
# unpack batch
630-
complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-3] + complex_specgrams_stretch.shape[1:])
630+
complex_specgrams_stretch = complex_specgrams_stretch.view(shape[:-3] + complex_specgrams_stretch.shape[1:])
631631

632632
return complex_specgrams_stretch
633633

@@ -654,7 +654,7 @@ def lfilter(waveform, a_coeffs, b_coeffs):
654654

655655
# pack batch
656656
shape = waveform.size()
657-
waveform = waveform.reshape(-1, shape[-1])
657+
waveform = waveform.view(-1, shape[-1])
658658

659659
assert(a_coeffs.size(0) == b_coeffs.size(0))
660660
assert(len(waveform.size()) == 2)
@@ -697,7 +697,7 @@ def lfilter(waveform, a_coeffs, b_coeffs):
697697
output = torch.clamp(padded_output_waveform[:, (n_order - 1):], min=-1., max=1.)
698698

699699
# unpack batch
700-
output = output.reshape(shape[:-1] + output.shape[-1:])
700+
output = output.view(shape[:-1] + output.shape[-1:])
701701

702702
return output
703703

@@ -876,7 +876,7 @@ def mask_along_axis(specgram, mask_param, mask_value, axis):
876876

877877
# pack batch
878878
shape = specgram.size()
879-
specgram = specgram.reshape([-1] + list(shape[-2:]))
879+
specgram = specgram.view([-1] + list(shape[-2:]))
880880

881881
value = torch.rand(1) * mask_param
882882
min_value = torch.rand(1) * (specgram.size(axis) - value)
@@ -893,7 +893,7 @@ def mask_along_axis(specgram, mask_param, mask_value, axis):
893893
raise ValueError('Only Frequency and Time masking are supported')
894894

895895
# unpack batch
896-
specgram = specgram.reshape(shape[:-2] + specgram.shape[-2:])
896+
specgram = specgram.view(shape[:-2] + specgram.shape[-2:])
897897

898898
return specgram
899899

@@ -925,7 +925,7 @@ def compute_deltas(specgram, win_length=5, mode="replicate"):
925925

926926
# pack batch
927927
shape = specgram.size()
928-
specgram = specgram.reshape(1, -1, shape[-1])
928+
specgram = specgram.view(1, -1, shape[-1])
929929

930930
assert win_length >= 3
931931

@@ -945,7 +945,7 @@ def compute_deltas(specgram, win_length=5, mode="replicate"):
945945
output = torch.nn.functional.conv1d(specgram, kernel, groups=specgram.shape[1]) / denom
946946

947947
# unpack batch
948-
output = output.reshape(shape)
948+
output = output.view(shape)
949949

950950
return output
951951

@@ -974,11 +974,10 @@ def _add_noise_shaping(dithered_waveform, waveform):
974974
error[n] = dithered[n] - original[n]
975975
noise_shaped_waveform[n] = dithered[n] + error[n-1]
976976
"""
977-
wf_shape = waveform.size()
978-
waveform = waveform.reshape(-1, wf_shape[-1])
977+
waveform = waveform.view(-1, waveform.size()[-1])
979978

980979
dithered_shape = dithered_waveform.size()
981-
dithered_waveform = dithered_waveform.reshape(-1, dithered_shape[-1])
980+
dithered_waveform = dithered_waveform.view(-1, dithered_shape[-1])
982981

983982
error = dithered_waveform - waveform
984983

@@ -989,7 +988,7 @@ def _add_noise_shaping(dithered_waveform, waveform):
989988
error[index] = error_offset[:waveform.size()[1]]
990989

991990
noise_shaped = dithered_waveform + error
992-
return noise_shaped.reshape(dithered_shape[:-1] + noise_shaped.shape[-1:])
991+
return noise_shaped.view(dithered_shape[:-1] + noise_shaped.shape[-1:])
993992

994993

995994
def _apply_probability_distribution(waveform, density_function="TPDF"):
@@ -1020,7 +1019,7 @@ def _apply_probability_distribution(waveform, density_function="TPDF"):
10201019

10211020
# pack batch
10221021
shape = waveform.size()
1023-
waveform = waveform.reshape(-1, shape[-1])
1022+
waveform = waveform.view(-1, shape[-1])
10241023

10251024
channel_size = waveform.size()[0] - 1
10261025
time_size = waveform.size()[-1] - 1
@@ -1060,7 +1059,7 @@ def _apply_probability_distribution(waveform, density_function="TPDF"):
10601059
quantised_signal = quantised_signal_scaled / down_scaling
10611060

10621061
# unpack batch
1063-
return quantised_signal.reshape(shape[:-1] + quantised_signal.shape[-1:])
1062+
return quantised_signal.view(shape[:-1] + quantised_signal.shape[-1:])
10641063

10651064

10661065
def dither(waveform, density_function="TPDF", noise_shaping=False):
@@ -1231,7 +1230,7 @@ def detect_pitch_frequency(
12311230

12321231
# pack batch
12331232
shape = list(waveform.size())
1234-
waveform = waveform.reshape([-1] + shape[-1:])
1233+
waveform = waveform.view([-1] + shape[-1:])
12351234

12361235
nccf = _compute_nccf(waveform, sample_rate, frame_time, freq_low)
12371236
indices = _find_max_per_frame(nccf, sample_rate, freq_high)
@@ -1242,6 +1241,6 @@ def detect_pitch_frequency(
12421241
freq = sample_rate / (EPSILON + indices.to(torch.float))
12431242

12441243
# unpack batch
1245-
freq = freq.reshape(shape[:-1] + list(freq.shape[-1:]))
1244+
freq = freq.view(shape[:-1] + list(freq.shape[-1:]))
12461245

12471246
return freq

torchaudio/transforms.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def forward(self, specgram):
215215

216216
# pack batch
217217
shape = specgram.size()
218-
specgram = specgram.reshape(-1, shape[-2], shape[-1])
218+
specgram = specgram.view(-1, shape[-2], shape[-1])
219219

220220
if self.fb.numel() == 0:
221221
tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max, self.n_mels, self.sample_rate)
@@ -228,7 +228,7 @@ def forward(self, specgram):
228228
mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2)
229229

230230
# unpack batch
231-
mel_specgram = mel_specgram.reshape(shape[:-2] + mel_specgram.shape[-2:])
231+
mel_specgram = mel_specgram.view(shape[:-2] + mel_specgram.shape[-2:])
232232

233233
return mel_specgram
234234

@@ -349,7 +349,7 @@ def forward(self, waveform):
349349

350350
# pack batch
351351
shape = waveform.size()
352-
waveform = waveform.reshape(-1, shape[-1])
352+
waveform = waveform.view(-1, shape[-1])
353353

354354
mel_specgram = self.MelSpectrogram(waveform)
355355
if self.log_mels:
@@ -362,7 +362,7 @@ def forward(self, waveform):
362362
mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
363363

364364
# unpack batch
365-
mfcc = mfcc.reshape(shape[:-1] + mfcc.shape[-2:])
365+
mfcc = mfcc.view(shape[:-1] + mfcc.shape[-2:])
366366

367367
return mfcc
368368

0 commit comments

Comments
 (0)