22
33import math
44from typing import Optional , Tuple
5+ import warnings
56
67import torch
78from torch import Tensor
@@ -49,7 +50,7 @@ def istft(
4950 win_length : Optional [int ] = None ,
5051 window : Optional [Tensor ] = None ,
5152 center : bool = True ,
52- pad_mode : str = "reflect" ,
53+ pad_mode : Optional [ str ] = None ,
5354 normalized : bool = False ,
5455 onesided : bool = True ,
5556 length : Optional [int ] = None ,
@@ -94,8 +95,7 @@ def istft(
9495 center (bool, optional): Whether ``input`` was padded on both sides so
9596 that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
9697 (Default: ``True``)
97- pad_mode (str, optional): Controls the padding method used when ``center`` is True. (Default:
98- ``"reflect"``)
98+ pad_mode: This argument was ignored and to be removed.
9999 normalized (bool, optional): Whether the STFT was normalized. (Default: ``False``)
100100 onesided (bool, optional): Whether the STFT is onesided. (Default: ``True``)
101101 length (int or None, optional): The amount to trim the signal by (i.e. the
@@ -104,105 +104,16 @@ def istft(
104104 Returns:
105105 Tensor: Least squares estimation of the original signal of size (..., signal_length)
106106 """
107- stft_matrix_dim = stft_matrix .dim ()
108- assert 3 <= stft_matrix_dim , "Incorrect stft dimension: %d" % (stft_matrix_dim )
109- assert stft_matrix .numel () > 0
110-
111- if stft_matrix_dim == 3 :
112- # add a channel dimension
113- stft_matrix = stft_matrix .unsqueeze (0 )
114-
115- # pack batch
116- shape = stft_matrix .size ()
117- stft_matrix = stft_matrix .reshape (- 1 , shape [- 3 ], shape [- 2 ], shape [- 1 ])
118-
119- dtype = stft_matrix .dtype
120- device = stft_matrix .device
121- fft_size = stft_matrix .size (1 )
122- assert (onesided and n_fft // 2 + 1 == fft_size ) or (
123- not onesided and n_fft == fft_size
124- ), (
125- "one_sided implies that n_fft // 2 + 1 == fft_size and not one_sided implies n_fft == fft_size. "
126- + "Given values were onesided: %s, n_fft: %d, fft_size: %d"
127- % ("True" if onesided else False , n_fft , fft_size )
128- )
129-
130- # use stft defaults for Optionals
131- if win_length is None :
132- win_length = n_fft
133-
134- if hop_length is None :
135- hop_length = int (win_length // 4 )
136-
137- # There must be overlap
138- assert 0 < hop_length <= win_length
139- assert 0 < win_length <= n_fft
140-
141- if window is None :
142- window = torch .ones (win_length , device = device , dtype = dtype )
143-
144- assert window .dim () == 1 and window .size (0 ) == win_length
145-
146- if win_length != n_fft :
147- # center window with pad left and right zeros
148- left = (n_fft - win_length ) // 2
149- window = torch .nn .functional .pad (window , (left , n_fft - win_length - left ))
150- assert window .size (0 ) == n_fft
151- # win_length and n_fft are synonymous from here on
152-
153- stft_matrix = stft_matrix .transpose (1 , 2 ) # size (channel, n_frame, fft_size, 2)
154- stft_matrix = torch .irfft (
155- stft_matrix , 1 , normalized , onesided , signal_sizes = (n_fft ,)
156- ) # size (channel, n_frame, n_fft)
157-
158- assert stft_matrix .size (2 ) == n_fft
159- n_frame = stft_matrix .size (1 )
160-
161- ytmp = stft_matrix * window .view (1 , 1 , n_fft ) # size (channel, n_frame, n_fft)
162- # each column of a channel is a frame which needs to be overlap added at the right place
163- ytmp = ytmp .transpose (1 , 2 ) # size (channel, n_fft, n_frame)
164-
165- # this does overlap add where the frames of ytmp are added such that the i'th frame of
166- # ytmp is added starting at i*hop_length in the output
167- y = torch .nn .functional .fold (
168- ytmp , (1 , (n_frame - 1 ) * hop_length + n_fft ), (1 , n_fft ), stride = (1 , hop_length )
169- ).squeeze (2 )
170-
171- # do the same for the window function
172- window_sq = (
173- window .pow (2 ).view (n_fft , 1 ).repeat ((1 , n_frame )).unsqueeze (0 )
174- ) # size (1, n_fft, n_frame)
175- window_envelop = torch .nn .functional .fold (
176- window_sq , (1 , (n_frame - 1 ) * hop_length + n_fft ), (1 , n_fft ), stride = (1 , hop_length )
177- ).squeeze (2 ) # size (1, 1, expected_signal_len)
178-
179- expected_signal_len = n_fft + hop_length * (n_frame - 1 )
180- assert y .size (2 ) == expected_signal_len
181- assert window_envelop .size (2 ) == expected_signal_len
182-
183- half_n_fft = n_fft // 2
184- # we need to trim the front padding away if center
185- start = half_n_fft if center else 0
186- end = - half_n_fft if length is None else start + length
187-
188- y = y [:, :, start :end ]
189- window_envelop = window_envelop [:, :, start :end ]
190-
191- # check NOLA non-zero overlap condition
192- window_envelop_lowest = window_envelop .abs ().min ()
193- assert window_envelop_lowest > 1e-11 , "window overlap add min: %f" % (
194- window_envelop_lowest
195- )
196-
197- y = (y / window_envelop ).squeeze (1 ) # size (channel, expected_signal_len)
198-
199- # unpack batch
200- y = y .reshape (shape [:- 3 ] + y .shape [- 1 :])
201-
202- if stft_matrix_dim == 3 : # remove the channel dimension
203- y = y .squeeze (0 )
204-
205- return y
107+ warnings .warn (
108+ 'istft has been moved to PyTorch and will be removed from torchaudio, '
109+ 'please use torch.istft instead.' )
110+ if pad_mode is not None :
111+ warnings .warn (
112+ 'The parameter `pad_mode` was ignored in isftft, and is thus being deprecated. '
113+ 'Please set `pad_mode` to None to suppress this warning.' )
114+ return torch .istft (
115+ input = stft_matrix , n_fft = n_fft , hop_length = hop_length , win_length = win_length , window = window ,
116+ center = center , normalized = normalized , onesided = onesided , length = length )
206117
207118
208119def spectrogram (
0 commit comments