Skip to content

Commit 65aee40

Browse files
committed
move vad
1 parent 2293ea5 commit 65aee40

File tree

3 files changed

+299
-298
lines changed

3 files changed

+299
-298
lines changed

torchaudio/functional/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
phase_vocoder,
1717
sliding_window_cmn,
1818
spectrogram,
19-
vad,
2019
)
2120
from .filtering import (
2221
allpass_biquad,
@@ -39,4 +38,5 @@
3938
phaser,
4039
riaa_biquad,
4140
treble_biquad,
41+
vad,
4242
)

torchaudio/functional/filtering.py

Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import torch
55
from torch import Tensor
66

7+
from .functional import complex_norm
8+
79

810
def _dB2Linear(x: float) -> float:
911
return math.exp(x * math.log(10) / 20.0)
@@ -1116,3 +1118,299 @@ def treble_biquad(
11161118
a2 = (A + 1) - temp2 - temp1
11171119

11181120
return biquad(waveform, b0, b1, b2, a0, a1, a2)
1121+
1122+
1123+
def _measure(
1124+
measure_len_ws: int,
1125+
samples: Tensor,
1126+
spectrum: Tensor,
1127+
noise_spectrum: Tensor,
1128+
spectrum_window: Tensor,
1129+
spectrum_start: int,
1130+
spectrum_end: int,
1131+
cepstrum_window: Tensor,
1132+
cepstrum_start: int,
1133+
cepstrum_end: int,
1134+
noise_reduction_amount: float,
1135+
measure_smooth_time_mult: float,
1136+
noise_up_time_mult: float,
1137+
noise_down_time_mult: float,
1138+
index_ns: int,
1139+
boot_count: int
1140+
) -> float:
1141+
1142+
assert spectrum.size()[-1] == noise_spectrum.size()[-1]
1143+
1144+
samplesLen_ns = samples.size()[-1]
1145+
dft_len_ws = spectrum.size()[-1]
1146+
1147+
dftBuf = torch.zeros(dft_len_ws)
1148+
1149+
_index_ns = torch.tensor([index_ns] + [
1150+
(index_ns + i) % samplesLen_ns
1151+
for i in range(1, measure_len_ws)
1152+
])
1153+
dftBuf[:measure_len_ws] = \
1154+
samples[_index_ns] * spectrum_window[:measure_len_ws]
1155+
1156+
# memset(c->dftBuf + i, 0, (p->dft_len_ws - i) * sizeof(*c->dftBuf));
1157+
dftBuf[measure_len_ws:dft_len_ws].zero_()
1158+
1159+
# lsx_safe_rdft((int)p->dft_len_ws, 1, c->dftBuf);
1160+
_dftBuf = torch.rfft(dftBuf, 1)
1161+
1162+
# memset(c->dftBuf, 0, p->spectrum_start * sizeof(*c->dftBuf));
1163+
_dftBuf[:spectrum_start].zero_()
1164+
1165+
mult: float = boot_count / (1. + boot_count) \
1166+
if boot_count >= 0 \
1167+
else measure_smooth_time_mult
1168+
1169+
_d = complex_norm(_dftBuf[spectrum_start:spectrum_end])
1170+
spectrum[spectrum_start:spectrum_end].mul_(mult).add_(_d * (1 - mult))
1171+
_d = spectrum[spectrum_start:spectrum_end] ** 2
1172+
1173+
_zeros = torch.zeros(spectrum_end - spectrum_start)
1174+
_mult = _zeros \
1175+
if boot_count >= 0 \
1176+
else torch.where(
1177+
_d > noise_spectrum[spectrum_start:spectrum_end],
1178+
torch.tensor(noise_up_time_mult), # if
1179+
torch.tensor(noise_down_time_mult) # else
1180+
)
1181+
1182+
noise_spectrum[spectrum_start:spectrum_end].mul_(_mult).add_(_d * (1 - _mult))
1183+
_d = torch.sqrt(
1184+
torch.max(
1185+
_zeros,
1186+
_d - noise_reduction_amount * noise_spectrum[spectrum_start:spectrum_end]))
1187+
1188+
_cepstrum_Buf: Tensor = torch.zeros(dft_len_ws >> 1)
1189+
_cepstrum_Buf[spectrum_start:spectrum_end] = _d * cepstrum_window
1190+
_cepstrum_Buf[spectrum_end:dft_len_ws >> 1].zero_()
1191+
1192+
# lsx_safe_rdft((int)p->dft_len_ws >> 1, 1, c->dftBuf);
1193+
_cepstrum_Buf = torch.rfft(_cepstrum_Buf, 1)
1194+
1195+
result: float = float(torch.sum(
1196+
complex_norm(
1197+
_cepstrum_Buf[cepstrum_start:cepstrum_end],
1198+
power=2.0)))
1199+
result = \
1200+
math.log(result / (cepstrum_end - cepstrum_start)) \
1201+
if result > 0 \
1202+
else -math.inf
1203+
return max(0, 21 + result)
1204+
1205+
1206+
def vad(
1207+
waveform: Tensor,
1208+
sample_rate: int,
1209+
trigger_level: float = 7.0,
1210+
trigger_time: float = 0.25,
1211+
search_time: float = 1.0,
1212+
allowed_gap: float = 0.25,
1213+
pre_trigger_time: float = 0.0,
1214+
# Fine-tuning parameters
1215+
boot_time: float = .35,
1216+
noise_up_time: float = .1,
1217+
noise_down_time: float = .01,
1218+
noise_reduction_amount: float = 1.35,
1219+
measure_freq: float = 20.0,
1220+
measure_duration: Optional[float] = None,
1221+
measure_smooth_time: float = .4,
1222+
hp_filter_freq: float = 50.,
1223+
lp_filter_freq: float = 6000.,
1224+
hp_lifter_freq: float = 150.,
1225+
lp_lifter_freq: float = 2000.,
1226+
) -> Tensor:
1227+
r"""Voice Activity Detector. Similar to SoX implementation.
1228+
Attempts to trim silence and quiet background sounds from the ends of recordings of speech.
1229+
The algorithm currently uses a simple cepstral power measurement to detect voice,
1230+
so may be fooled by other things, especially music.
1231+
1232+
The effect can trim only from the front of the audio,
1233+
so in order to trim from the back, the reverse effect must also be used.
1234+
1235+
Args:
1236+
waveform (Tensor): Tensor of audio of dimension `(..., time)`
1237+
sample_rate (int): Sample rate of audio signal.
1238+
trigger_level (float, optional): The measurement level used to trigger activity detection.
1239+
This may need to be cahnged depending on the noise level, signal level,
1240+
and other characteristics of the input audio. (Default: 7.0)
1241+
trigger_time (float, optional): The time constant (in seconds)
1242+
used to help ignore short bursts of sound. (Default: 0.25)
1243+
search_time (float, optional): The amount of audio (in seconds)
1244+
to search for quieter/shorter bursts of audio to include prior
1245+
to the detected trigger point. (Default: 1.0)
1246+
allowed_gap (float, optional): The allowed gap (in seconds) between
1247+
quiteter/shorter bursts of audio to include prior
1248+
to the detected trigger point. (Default: 0.25)
1249+
pre_trigger_time (float, optional): The amount of audio (in seconds) to preserve
1250+
before the trigger point and any found quieter/shorter bursts. (Default: 0.0)
1251+
boot_time (float, optional) The algorithm (internally) uses adaptive noise
1252+
estimation/reduction in order to detect the start of the wanted audio.
1253+
This option sets the time for the initial noise estimate. (Default: 0.35)
1254+
noise_up_time (float, optional) Time constant used by the adaptive noise estimator
1255+
for when the noise level is increasing. (Default: 0.1)
1256+
noise_down_time (float, optional) Time constant used by the adaptive noise estimator
1257+
for when the noise level is decreasing. (Default: 0.01)
1258+
noise_reduction_amount (float, optional) Amount of noise reduction to use in
1259+
the detection algorithm (e.g. 0, 0.5, ...). (Default: 1.35)
1260+
measure_freq (float, optional) Frequency of the algorithm’s
1261+
processing/measurements. (Default: 20.0)
1262+
measure_duration: (float, optional) Measurement duration.
1263+
(Default: Twice the measurement period; i.e. with overlap.)
1264+
measure_smooth_time (float, optional) Time constant used to smooth
1265+
spectral measurements. (Default: 0.4)
1266+
hp_filter_freq (float, optional) "Brick-wall" frequency of high-pass filter applied
1267+
at the input to the detector algorithm. (Default: 50.0)
1268+
lp_filter_freq (float, optional) "Brick-wall" frequency of low-pass filter applied
1269+
at the input to the detector algorithm. (Default: 6000.0)
1270+
hp_lifter_freq (float, optional) "Brick-wall" frequency of high-pass lifter used
1271+
in the detector algorithm. (Default: 150.0)
1272+
lp_lifter_freq (float, optional) "Brick-wall" frequency of low-pass lifter used
1273+
in the detector algorithm. (Default: 2000.0)
1274+
1275+
Returns:
1276+
Tensor: Tensor of audio of dimension (..., time).
1277+
1278+
References:
1279+
http://sox.sourceforge.net/sox.html
1280+
"""
1281+
1282+
measure_duration: float = 2.0 / measure_freq \
1283+
if measure_duration is None \
1284+
else measure_duration
1285+
1286+
measure_len_ws = int(sample_rate * measure_duration + .5)
1287+
measure_len_ns = measure_len_ws
1288+
# for (dft_len_ws = 16; dft_len_ws < measure_len_ws; dft_len_ws <<= 1);
1289+
dft_len_ws = 16
1290+
while (dft_len_ws < measure_len_ws):
1291+
dft_len_ws *= 2
1292+
1293+
measure_period_ns = int(sample_rate / measure_freq + .5)
1294+
measures_len = math.ceil(search_time * measure_freq)
1295+
search_pre_trigger_len_ns = measures_len * measure_period_ns
1296+
gap_len = int(allowed_gap * measure_freq + .5)
1297+
1298+
fixed_pre_trigger_len_ns = int(pre_trigger_time * sample_rate + .5)
1299+
samplesLen_ns = fixed_pre_trigger_len_ns + search_pre_trigger_len_ns + measure_len_ns
1300+
1301+
spectrum_window = torch.zeros(measure_len_ws)
1302+
for i in range(measure_len_ws):
1303+
# sox.h:741 define SOX_SAMPLE_MIN (sox_sample_t)SOX_INT_MIN(32)
1304+
spectrum_window[i] = 2. / math.sqrt(float(measure_len_ws))
1305+
# lsx_apply_hann(spectrum_window, (int)measure_len_ws);
1306+
spectrum_window *= torch.hann_window(measure_len_ws, dtype=torch.float)
1307+
1308+
spectrum_start: int = int(hp_filter_freq / sample_rate * dft_len_ws + .5)
1309+
spectrum_start: int = max(spectrum_start, 1)
1310+
spectrum_end: int = int(lp_filter_freq / sample_rate * dft_len_ws + .5)
1311+
spectrum_end: int = min(spectrum_end, dft_len_ws // 2)
1312+
1313+
cepstrum_window = torch.zeros(spectrum_end - spectrum_start)
1314+
for i in range(spectrum_end - spectrum_start):
1315+
cepstrum_window[i] = 2. / math.sqrt(float(spectrum_end) - spectrum_start)
1316+
# lsx_apply_hann(cepstrum_window,(int)(spectrum_end - spectrum_start));
1317+
cepstrum_window *= torch.hann_window(spectrum_end - spectrum_start, dtype=torch.float)
1318+
1319+
cepstrum_start = math.ceil(sample_rate * .5 / lp_lifter_freq)
1320+
cepstrum_end = math.floor(sample_rate * .5 / hp_lifter_freq)
1321+
cepstrum_end = min(cepstrum_end, dft_len_ws // 4)
1322+
1323+
assert cepstrum_end > cepstrum_start
1324+
1325+
noise_up_time_mult = math.exp(-1. / (noise_up_time * measure_freq))
1326+
noise_down_time_mult = math.exp(-1. / (noise_down_time * measure_freq))
1327+
measure_smooth_time_mult = math.exp(-1. / (measure_smooth_time * measure_freq))
1328+
trigger_meas_time_mult = math.exp(-1. / (trigger_time * measure_freq))
1329+
1330+
boot_count_max = int(boot_time * measure_freq - .5)
1331+
measure_timer_ns = measure_len_ns
1332+
boot_count = measures_index = flushedLen_ns = samplesIndex_ns = 0
1333+
1334+
# pack batch
1335+
shape = waveform.size()
1336+
waveform = waveform.view(-1, shape[-1])
1337+
1338+
n_channels, ilen = waveform.size()
1339+
1340+
mean_meas = torch.zeros(n_channels)
1341+
samples = torch.zeros(n_channels, samplesLen_ns)
1342+
spectrum = torch.zeros(n_channels, dft_len_ws)
1343+
noise_spectrum = torch.zeros(n_channels, dft_len_ws)
1344+
measures = torch.zeros(n_channels, measures_len)
1345+
1346+
has_triggered: bool = False
1347+
num_measures_to_flush: int = 0
1348+
pos: int = 0
1349+
1350+
while (pos < ilen and not has_triggered):
1351+
measure_timer_ns -= 1
1352+
for i in range(n_channels):
1353+
samples[i, samplesIndex_ns] = waveform[i, pos]
1354+
# if (!p->measure_timer_ns) {
1355+
if (measure_timer_ns == 0):
1356+
index_ns: int = \
1357+
(samplesIndex_ns + samplesLen_ns - measure_len_ns) % samplesLen_ns
1358+
meas: float = _measure(
1359+
measure_len_ws=measure_len_ws,
1360+
samples=samples[i],
1361+
spectrum=spectrum[i],
1362+
noise_spectrum=noise_spectrum[i],
1363+
spectrum_window=spectrum_window,
1364+
spectrum_start=spectrum_start,
1365+
spectrum_end=spectrum_end,
1366+
cepstrum_window=cepstrum_window,
1367+
cepstrum_start=cepstrum_start,
1368+
cepstrum_end=cepstrum_end,
1369+
noise_reduction_amount=noise_reduction_amount,
1370+
measure_smooth_time_mult=measure_smooth_time_mult,
1371+
noise_up_time_mult=noise_up_time_mult,
1372+
noise_down_time_mult=noise_down_time_mult,
1373+
index_ns=index_ns,
1374+
boot_count=boot_count)
1375+
measures[i, measures_index] = meas
1376+
mean_meas[i] = mean_meas[i] * trigger_meas_time_mult + meas * (1. - trigger_meas_time_mult)
1377+
1378+
has_triggered = has_triggered or (mean_meas[i] >= trigger_level)
1379+
if has_triggered:
1380+
n: int = measures_len
1381+
k: int = measures_index
1382+
jTrigger: int = n
1383+
jZero: int = n
1384+
j: int = 0
1385+
1386+
for j in range(n):
1387+
if (measures[i, k] >= trigger_level) and (j <= jTrigger + gap_len):
1388+
jZero = jTrigger = j
1389+
elif (measures[i, k] == 0) and (jTrigger >= jZero):
1390+
jZero = j
1391+
k = (k + n - 1) % n
1392+
j = min(j, jZero)
1393+
# num_measures_to_flush = range_limit(j, num_measures_to_flush, n);
1394+
num_measures_to_flush = (min(max(num_measures_to_flush, j), n))
1395+
# end if has_triggered
1396+
# end if (measure_timer_ns == 0):
1397+
# end for
1398+
samplesIndex_ns += 1
1399+
pos += 1
1400+
# end while
1401+
if samplesIndex_ns == samplesLen_ns:
1402+
samplesIndex_ns = 0
1403+
if measure_timer_ns == 0:
1404+
measure_timer_ns = measure_period_ns
1405+
measures_index += 1
1406+
measures_index = measures_index % measures_len
1407+
if boot_count >= 0:
1408+
boot_count = -1 if boot_count == boot_count_max else boot_count + 1
1409+
1410+
if has_triggered:
1411+
flushedLen_ns = (measures_len - num_measures_to_flush) * measure_period_ns
1412+
samplesIndex_ns = (samplesIndex_ns + flushedLen_ns) % samplesLen_ns
1413+
1414+
res = waveform[:, pos - samplesLen_ns + flushedLen_ns:]
1415+
# unpack batch
1416+
return res.view(shape[:-1] + res.shape[-1:])

0 commit comments

Comments
 (0)