Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,7 +942,7 @@ def detect_pitch_frequency(


def sliding_window_cmn(
waveform: Tensor,
specgram: Tensor,
cmn_window: int = 600,
min_cmn_window: int = 100,
center: bool = False,
Expand All @@ -952,7 +952,7 @@ def sliding_window_cmn(
Apply sliding-window cepstral mean (and optionally variance) normalization per utterance.

Args:
waveform (Tensor): Tensor of audio of dimension (..., freq, time)
specgram (Tensor): Tensor of audio of dimension (..., freq, time)
cmn_window (int, optional): Window in frames for running average CMN computation (int, default = 600)
min_cmn_window (int, optional): Minimum CMN window used at start of decoding (adds latency only at start).
Only applicable if center == false, ignored if center==true (int, default = 100)
Expand All @@ -961,19 +961,19 @@ def sliding_window_cmn(
norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false)

Returns:
Tensor: Tensor of freq of dimension (..., frame)
Tensor: Tensor matching input shape (..., freq, time)
"""
input_shape = waveform.shape
input_shape = specgram.shape
num_frames, num_feats = input_shape[-2:]
waveform = waveform.view(-1, num_frames, num_feats)
num_channels = waveform.shape[0]
specgram = specgram.view(-1, num_frames, num_feats)
num_channels = specgram.shape[0]

dtype = waveform.dtype
device = waveform.device
dtype = specgram.dtype
device = specgram.device
last_window_start = last_window_end = -1
cur_sum = torch.zeros(num_channels, num_feats, dtype=dtype, device=device)
cur_sumsq = torch.zeros(num_channels, num_feats, dtype=dtype, device=device)
cmn_waveform = torch.zeros(
cmn_specgram = torch.zeros(
num_channels, num_frames, num_feats, dtype=dtype, device=device)
for t in range(num_frames):
window_start = 0
Expand All @@ -996,40 +996,40 @@ def sliding_window_cmn(
if window_start < 0:
window_start = 0
if last_window_start == -1:
input_part = waveform[:, window_start: window_end - window_start, :]
input_part = specgram[:, window_start: window_end - window_start, :]
cur_sum += torch.sum(input_part, 1)
if norm_vars:
cur_sumsq += torch.cumsum(input_part ** 2, 1)[:, -1, :]
else:
if window_start > last_window_start:
frame_to_remove = waveform[:, last_window_start, :]
frame_to_remove = specgram[:, last_window_start, :]
cur_sum -= frame_to_remove
if norm_vars:
cur_sumsq -= (frame_to_remove ** 2)
if window_end > last_window_end:
frame_to_add = waveform[:, last_window_end, :]
frame_to_add = specgram[:, last_window_end, :]
cur_sum += frame_to_add
if norm_vars:
cur_sumsq += (frame_to_add ** 2)
window_frames = window_end - window_start
last_window_start = window_start
last_window_end = window_end
cmn_waveform[:, t, :] = waveform[:, t, :] - cur_sum / window_frames
cmn_specgram[:, t, :] = specgram[:, t, :] - cur_sum / window_frames
if norm_vars:
if window_frames == 1:
cmn_waveform[:, t, :] = torch.zeros(
cmn_specgram[:, t, :] = torch.zeros(
num_channels, num_feats, dtype=dtype, device=device)
else:
variance = cur_sumsq
variance = variance / window_frames
variance -= ((cur_sum ** 2) / (window_frames ** 2))
variance = torch.pow(variance, -0.5)
cmn_waveform[:, t, :] *= variance
cmn_specgram[:, t, :] *= variance

cmn_waveform = cmn_waveform.view(input_shape[:-2] + (num_frames, num_feats))
cmn_specgram = cmn_specgram.view(input_shape[:-2] + (num_frames, num_feats))
if len(input_shape) == 2:
cmn_waveform = cmn_waveform.squeeze(0)
return cmn_waveform
cmn_specgram = cmn_specgram.squeeze(0)
return cmn_specgram


def spectral_centroid(
Expand Down