diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index d807b02365..9767b629a7 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -610,7 +610,7 @@ def __init__(self, hop_length=None, n_freq=201, fixed_rate=None): n_fft = (n_freq - 1) * 2 hop_length = hop_length if hop_length is not None else n_fft // 2 - self.phase_advance = torch.linspace(0, math.pi * hop_length, n_freq)[..., None] + self.register_buffer('phase_advance', torch.linspace(0, math.pi * hop_length, n_freq)[..., None]) def forward(self, complex_specgrams, overriding_rate=None): # type: (Tensor, Optional[float]) -> Tensor