diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 553d2fa403..d807b02365 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -137,7 +137,7 @@ def forward(self, S): self.normalized, self.n_iter, self.momentum, self.length, self.rand_init) -class AmplitudeToDB(torch.jit.ScriptModule): +class AmplitudeToDB(torch.nn.Module): r"""Turn a tensor from the power/amplitude scale to the decibel scale. This output depends on the maximum value in the input tensor, and so @@ -157,7 +157,7 @@ def __init__(self, stype='power', top_db=None): self.stype = stype if top_db is not None and top_db < 0: raise ValueError('top_db must be positive value') - self.top_db = torch.jit.Attribute(top_db, Optional[float]) + self.top_db = top_db self.multiplier = 10.0 if stype == 'power' else 20.0 self.amin = 1e-10 self.ref_value = 1.0 @@ -592,7 +592,7 @@ def forward(self, specgram): return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode) -class TimeStretch(torch.jit.ScriptModule): +class TimeStretch(torch.nn.Module): r"""Stretch stft in time without modifying pitch for a given rate. Args: @@ -610,8 +610,7 @@ def __init__(self, hop_length=None, n_freq=201, fixed_rate=None): n_fft = (n_freq - 1) * 2 hop_length = hop_length if hop_length is not None else n_fft // 2 - phase_advance = torch.linspace(0, math.pi * hop_length, n_freq)[..., None] - self.phase_advance = torch.jit.Attribute(phase_advance, torch.Tensor) + self.phase_advance = torch.linspace(0, math.pi * hop_length, n_freq)[..., None] def forward(self, complex_specgrams, overriding_rate=None): # type: (Tensor, Optional[float]) -> Tensor