Skip to content

Commit 08a7127

Browse files
authored
Switch string formatting to str.format to be TorchScript friendly. (#850)
1 parent 3bab2b2 commit 08a7127

File tree

5 files changed

+18
-14
lines changed

5 files changed

+18
-14
lines changed

test/compliance/generate_fbank_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def decode(fn, sound_path, exe_path, scp_path, out_dir):
9292
'round_to_power_of_two', 'snip_edges', 'subtract_mean', 'use_energy', 'use_log_fbank',
9393
'use_power', 'vtln_high', 'vtln_low', 'vtln_warp', 'window_type']
9494
fn_split = fn.split('-')
95-
assert len(fn_split) == len(arr), ('Len mismatch: %d and %d' % (len(fn_split), len(arr)))
95+
assert len(fn_split) == len(arr), ('Len mismatch: {} and {}'.format(len(fn_split), len(arr)))
9696
inputs = {arr[i]: utils.parse(fn_split[i]) for i in range(len(arr))}
9797

9898
# print flags for C++

test/test_compliance_kaldi.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,9 @@ def _compliance_test_helper(self, sound_filepath, filepath_key, expected_num_fil
148148
sound, sr = torchaudio.load_wav(sound_filepath)
149149
files = self.test_filepaths[filepath_key]
150150

151-
assert len(files) == expected_num_files, ('number of kaldi %s file changed to %d' % (filepath_key, len(files)))
151+
assert len(files) == expected_num_files, \
152+
('number of kaldi {} file changed to {}'.format(
153+
filepath_key, len(files)))
152154

153155
for f in files:
154156
print(f)

torchaudio/compliance/kaldi.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,13 +135,15 @@ def _get_waveform_and_window_properties(waveform: Tensor,
135135
r"""Gets the waveform and window properties
136136
"""
137137
channel = max(channel, 0)
138-
assert channel < waveform.size(0), ('Invalid channel %d for size %d' % (channel, waveform.size(0)))
138+
assert channel < waveform.size(0), ('Invalid channel {} for size {}'.format(channel, waveform.size(0)))
139139
waveform = waveform[channel, :] # size (n)
140140
window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS)
141141
window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS)
142142
padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size
143143

144-
assert 2 <= window_size <= len(waveform), ('choose a window size %d that is [2, %d]' % (window_size, len(waveform)))
144+
assert 2 <= window_size <= len(
145+
waveform), ('choose a window size {} that is [2, {}]'
146+
.format(window_size, len(waveform)))
145147
assert 0 < window_shift, '`window_shift` must be greater than 0'
146148
assert padded_window_size % 2 == 0, 'the padded `window_size` must be divisible by two.' \
147149
' use `round_to_power_of_two` or change `frame_length`'
@@ -430,7 +432,7 @@ def get_mel_banks(num_bins: int,
430432
high_freq += nyquist
431433

432434
assert (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq), \
433-
('Bad values in options: low-freq %f and high-freq %f vs. nyquist %f' % (low_freq, high_freq, nyquist))
435+
('Bad values in options: low-freq {} and high-freq {} vs. nyquist {}'.format(low_freq, high_freq, nyquist))
434436

435437
# fft-bin width [think of it as Nyquist-freq / half-window-length]
436438
fft_bin_width = sample_freq / window_length_padded
@@ -446,8 +448,8 @@ def get_mel_banks(num_bins: int,
446448

447449
assert vtln_warp_factor == 1.0 or ((low_freq < vtln_low < high_freq) and
448450
(0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)), \
449-
('Bad values in options: vtln-low %f and vtln-high %f, versus low-freq %f and high-freq %f' %
450-
(vtln_low, vtln_high, low_freq, high_freq))
451+
('Bad values in options: vtln-low {} and vtln-high {}, versus '
452+
'low-freq {} and high-freq {}'.format(vtln_low, vtln_high, low_freq, high_freq))
451453

452454
bin = torch.arange(num_bins).unsqueeze(1)
453455
left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1)

torchaudio/functional.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,8 @@ def griffinlim(
149149
Returns:
150150
torch.Tensor: waveform of (..., time), where time equals the ``length`` parameter if given.
151151
"""
152-
assert momentum < 1, 'momentum=%s > 1 can be unstable' % momentum
153-
assert momentum >= 0, 'momentum=%s < 0' % momentum
152+
assert momentum < 1, 'momentum={} > 1 can be unstable'.format(momentum)
153+
assert momentum >= 0, 'momentum={} < 0'.format(momentum)
154154

155155
# pack batch
156156
shape = specgram.size()

torchaudio/transforms.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,8 @@ def __init__(self,
141141
rand_init: bool = True) -> None:
142142
super(GriffinLim, self).__init__()
143143

144-
assert momentum < 1, 'momentum=%s > 1 can be unstable' % momentum
145-
assert momentum > 0, 'momentum=%s < 0' % momentum
144+
assert momentum < 1, 'momentum={} > 1 can be unstable'.format(momentum)
145+
assert momentum > 0, 'momentum={} < 0'.format(momentum)
146146

147147
self.n_fft = n_fft
148148
self.n_iter = n_iter
@@ -237,7 +237,7 @@ def __init__(self,
237237
self.f_max = f_max if f_max is not None else float(sample_rate // 2)
238238
self.f_min = f_min
239239

240-
assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max)
240+
assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)
241241

242242
fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
243243
n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate)
@@ -313,7 +313,7 @@ def __init__(self,
313313
self.tolerance_change = tolerance_change
314314
self.sgdargs = sgdargs or {'lr': 0.1, 'momentum': 0.9}
315315

316-
assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max)
316+
assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)
317317

318318
fb = F.create_fb_matrix(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate)
319319
self.register_buffer('fb', fb)
@@ -607,7 +607,7 @@ def forward(self, waveform: Tensor) -> Tensor:
607607

608608
return waveform
609609

610-
raise ValueError('Invalid resampling method: %s' % (self.resampling_method))
610+
raise ValueError('Invalid resampling method: {}'.format(self.resampling_method))
611611

612612

613613
class ComplexNorm(torch.nn.Module):

0 commit comments

Comments
 (0)