Skip to content

Commit db1e7da

Browse files
authored
Migrate TimeStretch and AmplitudeToDB to torch.nn.Module (#456)
* AmplitudeToDB to torch.nn.Module * TimeStretch use torch.nn.Module
1 parent babc24a commit db1e7da

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

torchaudio/transforms.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def forward(self, S):
137137
self.normalized, self.n_iter, self.momentum, self.length, self.rand_init)
138138

139139

140-
class AmplitudeToDB(torch.jit.ScriptModule):
140+
class AmplitudeToDB(torch.nn.Module):
141141
r"""Turn a tensor from the power/amplitude scale to the decibel scale.
142142
143143
This output depends on the maximum value in the input tensor, and so
@@ -157,7 +157,7 @@ def __init__(self, stype='power', top_db=None):
157157
self.stype = stype
158158
if top_db is not None and top_db < 0:
159159
raise ValueError('top_db must be positive value')
160-
self.top_db = torch.jit.Attribute(top_db, Optional[float])
160+
self.top_db = top_db
161161
self.multiplier = 10.0 if stype == 'power' else 20.0
162162
self.amin = 1e-10
163163
self.ref_value = 1.0
@@ -592,7 +592,7 @@ def forward(self, specgram):
592592
return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode)
593593

594594

595-
class TimeStretch(torch.jit.ScriptModule):
595+
class TimeStretch(torch.nn.Module):
596596
r"""Stretch stft in time without modifying pitch for a given rate.
597597
598598
Args:
@@ -610,8 +610,7 @@ def __init__(self, hop_length=None, n_freq=201, fixed_rate=None):
610610

611611
n_fft = (n_freq - 1) * 2
612612
hop_length = hop_length if hop_length is not None else n_fft // 2
613-
phase_advance = torch.linspace(0, math.pi * hop_length, n_freq)[..., None]
614-
self.phase_advance = torch.jit.Attribute(phase_advance, torch.Tensor)
613+
self.phase_advance = torch.linspace(0, math.pi * hop_length, n_freq)[..., None]
615614

616615
def forward(self, complex_specgrams, overriding_rate=None):
617616
# type: (Tensor, Optional[float]) -> Tensor

0 commit comments

Comments
 (0)