@@ -137,7 +137,7 @@ def forward(self, S):
137137 self .normalized , self .n_iter , self .momentum , self .length , self .rand_init )
138138
139139
140- class AmplitudeToDB (torch .jit . ScriptModule ):
140+ class AmplitudeToDB (torch .nn . Module ):
141141 r"""Turn a tensor from the power/amplitude scale to the decibel scale.
142142
143143 This output depends on the maximum value in the input tensor, and so
@@ -157,7 +157,7 @@ def __init__(self, stype='power', top_db=None):
157157 self .stype = stype
158158 if top_db is not None and top_db < 0 :
159159 raise ValueError ('top_db must be positive value' )
160- self .top_db = torch . jit . Attribute ( top_db , Optional [ float ])
160+ self .top_db = top_db
161161 self .multiplier = 10.0 if stype == 'power' else 20.0
162162 self .amin = 1e-10
163163 self .ref_value = 1.0
@@ -592,7 +592,7 @@ def forward(self, specgram):
592592 return F .compute_deltas (specgram , win_length = self .win_length , mode = self .mode )
593593
594594
595- class TimeStretch (torch .jit . ScriptModule ):
595+ class TimeStretch (torch .nn . Module ):
596596 r"""Stretch stft in time without modifying pitch for a given rate.
597597
598598 Args:
@@ -610,8 +610,7 @@ def __init__(self, hop_length=None, n_freq=201, fixed_rate=None):
610610
611611 n_fft = (n_freq - 1 ) * 2
612612 hop_length = hop_length if hop_length is not None else n_fft // 2
613- phase_advance = torch .linspace (0 , math .pi * hop_length , n_freq )[..., None ]
614- self .phase_advance = torch .jit .Attribute (phase_advance , torch .Tensor )
613+ self .phase_advance = torch .linspace (0 , math .pi * hop_length , n_freq )[..., None ]
615614
616615 def forward (self , complex_specgrams , overriding_rate = None ):
617616 # type: (Tensor, Optional[float]) -> Tensor
0 commit comments