|  | 
| 4 | 4 | import torch | 
| 5 | 5 | from torch import Tensor | 
| 6 | 6 | 
 | 
|  | 7 | +from .functional import complex_norm | 
|  | 8 | + | 
| 7 | 9 | 
 | 
| 8 | 10 | def _dB2Linear(x: float) -> float: | 
| 9 | 11 |     return math.exp(x * math.log(10) / 20.0) | 
| @@ -1116,3 +1118,299 @@ def treble_biquad( | 
| 1116 | 1118 |     a2 = (A + 1) - temp2 - temp1 | 
| 1117 | 1119 | 
 | 
| 1118 | 1120 |     return biquad(waveform, b0, b1, b2, a0, a1, a2) | 
|  | 1121 | + | 
|  | 1122 | + | 
|  | 1123 | +def _measure( | 
|  | 1124 | +    measure_len_ws: int, | 
|  | 1125 | +    samples: Tensor, | 
|  | 1126 | +    spectrum: Tensor, | 
|  | 1127 | +    noise_spectrum: Tensor, | 
|  | 1128 | +    spectrum_window: Tensor, | 
|  | 1129 | +    spectrum_start: int, | 
|  | 1130 | +    spectrum_end: int, | 
|  | 1131 | +    cepstrum_window: Tensor, | 
|  | 1132 | +    cepstrum_start: int, | 
|  | 1133 | +    cepstrum_end: int, | 
|  | 1134 | +    noise_reduction_amount: float, | 
|  | 1135 | +    measure_smooth_time_mult: float, | 
|  | 1136 | +    noise_up_time_mult: float, | 
|  | 1137 | +    noise_down_time_mult: float, | 
|  | 1138 | +    index_ns: int, | 
|  | 1139 | +    boot_count: int | 
|  | 1140 | +) -> float: | 
|  | 1141 | + | 
|  | 1142 | +    assert spectrum.size()[-1] == noise_spectrum.size()[-1] | 
|  | 1143 | + | 
|  | 1144 | +    samplesLen_ns = samples.size()[-1] | 
|  | 1145 | +    dft_len_ws = spectrum.size()[-1] | 
|  | 1146 | + | 
|  | 1147 | +    dftBuf = torch.zeros(dft_len_ws) | 
|  | 1148 | + | 
|  | 1149 | +    _index_ns = torch.tensor([index_ns] + [ | 
|  | 1150 | +        (index_ns + i) % samplesLen_ns | 
|  | 1151 | +        for i in range(1, measure_len_ws) | 
|  | 1152 | +    ]) | 
|  | 1153 | +    dftBuf[:measure_len_ws] = \ | 
|  | 1154 | +        samples[_index_ns] * spectrum_window[:measure_len_ws] | 
|  | 1155 | + | 
|  | 1156 | +    # memset(c->dftBuf + i, 0, (p->dft_len_ws - i) * sizeof(*c->dftBuf)); | 
|  | 1157 | +    dftBuf[measure_len_ws:dft_len_ws].zero_() | 
|  | 1158 | + | 
|  | 1159 | +    # lsx_safe_rdft((int)p->dft_len_ws, 1, c->dftBuf); | 
|  | 1160 | +    _dftBuf = torch.rfft(dftBuf, 1) | 
|  | 1161 | + | 
|  | 1162 | +    # memset(c->dftBuf, 0, p->spectrum_start * sizeof(*c->dftBuf)); | 
|  | 1163 | +    _dftBuf[:spectrum_start].zero_() | 
|  | 1164 | + | 
|  | 1165 | +    mult: float = boot_count / (1. + boot_count) \ | 
|  | 1166 | +        if boot_count >= 0 \ | 
|  | 1167 | +        else measure_smooth_time_mult | 
|  | 1168 | + | 
|  | 1169 | +    _d = complex_norm(_dftBuf[spectrum_start:spectrum_end]) | 
|  | 1170 | +    spectrum[spectrum_start:spectrum_end].mul_(mult).add_(_d * (1 - mult)) | 
|  | 1171 | +    _d = spectrum[spectrum_start:spectrum_end] ** 2 | 
|  | 1172 | + | 
|  | 1173 | +    _zeros = torch.zeros(spectrum_end - spectrum_start) | 
|  | 1174 | +    _mult = _zeros \ | 
|  | 1175 | +        if boot_count >= 0 \ | 
|  | 1176 | +        else torch.where( | 
|  | 1177 | +            _d > noise_spectrum[spectrum_start:spectrum_end], | 
|  | 1178 | +            torch.tensor(noise_up_time_mult),   # if | 
|  | 1179 | +            torch.tensor(noise_down_time_mult)  # else | 
|  | 1180 | +        ) | 
|  | 1181 | + | 
|  | 1182 | +    noise_spectrum[spectrum_start:spectrum_end].mul_(_mult).add_(_d * (1 - _mult)) | 
|  | 1183 | +    _d = torch.sqrt( | 
|  | 1184 | +        torch.max( | 
|  | 1185 | +            _zeros, | 
|  | 1186 | +            _d - noise_reduction_amount * noise_spectrum[spectrum_start:spectrum_end])) | 
|  | 1187 | + | 
|  | 1188 | +    _cepstrum_Buf: Tensor = torch.zeros(dft_len_ws >> 1) | 
|  | 1189 | +    _cepstrum_Buf[spectrum_start:spectrum_end] = _d * cepstrum_window | 
|  | 1190 | +    _cepstrum_Buf[spectrum_end:dft_len_ws >> 1].zero_() | 
|  | 1191 | + | 
|  | 1192 | +    # lsx_safe_rdft((int)p->dft_len_ws >> 1, 1, c->dftBuf); | 
|  | 1193 | +    _cepstrum_Buf = torch.rfft(_cepstrum_Buf, 1) | 
|  | 1194 | + | 
|  | 1195 | +    result: float = float(torch.sum( | 
|  | 1196 | +        complex_norm( | 
|  | 1197 | +            _cepstrum_Buf[cepstrum_start:cepstrum_end], | 
|  | 1198 | +            power=2.0))) | 
|  | 1199 | +    result = \ | 
|  | 1200 | +        math.log(result / (cepstrum_end - cepstrum_start)) \ | 
|  | 1201 | +        if result > 0 \ | 
|  | 1202 | +        else -math.inf | 
|  | 1203 | +    return max(0, 21 + result) | 
|  | 1204 | + | 
|  | 1205 | + | 
|  | 1206 | +def vad( | 
|  | 1207 | +    waveform: Tensor, | 
|  | 1208 | +    sample_rate: int, | 
|  | 1209 | +    trigger_level: float = 7.0, | 
|  | 1210 | +    trigger_time: float = 0.25, | 
|  | 1211 | +    search_time: float = 1.0, | 
|  | 1212 | +    allowed_gap: float = 0.25, | 
|  | 1213 | +    pre_trigger_time: float = 0.0, | 
|  | 1214 | +    # Fine-tuning parameters | 
|  | 1215 | +    boot_time: float = .35, | 
|  | 1216 | +    noise_up_time: float = .1, | 
|  | 1217 | +    noise_down_time: float = .01, | 
|  | 1218 | +    noise_reduction_amount: float = 1.35, | 
|  | 1219 | +    measure_freq: float = 20.0, | 
|  | 1220 | +    measure_duration: Optional[float] = None, | 
|  | 1221 | +    measure_smooth_time: float = .4, | 
|  | 1222 | +    hp_filter_freq: float = 50., | 
|  | 1223 | +    lp_filter_freq: float = 6000., | 
|  | 1224 | +    hp_lifter_freq: float = 150., | 
|  | 1225 | +    lp_lifter_freq: float = 2000., | 
|  | 1226 | +) -> Tensor: | 
|  | 1227 | +    r"""Voice Activity Detector. Similar to SoX implementation. | 
|  | 1228 | +    Attempts to trim silence and quiet background sounds from the ends of recordings of speech. | 
|  | 1229 | +    The algorithm currently uses a simple cepstral power measurement to detect voice, | 
|  | 1230 | +    so may be fooled by other things, especially music. | 
|  | 1231 | +
 | 
|  | 1232 | +    The effect can trim only from the front of the audio, | 
|  | 1233 | +    so in order to trim from the back, the reverse effect must also be used. | 
|  | 1234 | +
 | 
|  | 1235 | +    Args: | 
|  | 1236 | +        waveform (Tensor): Tensor of audio of dimension `(..., time)` | 
|  | 1237 | +        sample_rate (int): Sample rate of audio signal. | 
|  | 1238 | +        trigger_level (float, optional): The measurement level used to trigger activity detection. | 
|  | 1239 | +            This may need to be cahnged depending on the noise level, signal level, | 
|  | 1240 | +            and other characteristics of the input audio. (Default: 7.0) | 
|  | 1241 | +        trigger_time (float, optional): The time constant (in seconds) | 
|  | 1242 | +            used to help ignore short bursts of sound. (Default: 0.25) | 
|  | 1243 | +        search_time (float, optional): The amount of audio (in seconds) | 
|  | 1244 | +            to search for quieter/shorter bursts of audio to include prior | 
|  | 1245 | +            to the detected trigger point. (Default: 1.0) | 
|  | 1246 | +        allowed_gap (float, optional): The allowed gap (in seconds) between | 
|  | 1247 | +            quiteter/shorter bursts of audio to include prior | 
|  | 1248 | +            to the detected trigger point. (Default: 0.25) | 
|  | 1249 | +        pre_trigger_time (float, optional): The amount of audio (in seconds) to preserve | 
|  | 1250 | +            before the trigger point and any found quieter/shorter bursts. (Default: 0.0) | 
|  | 1251 | +        boot_time (float, optional) The algorithm (internally) uses adaptive noise | 
|  | 1252 | +            estimation/reduction in order to detect the start of the wanted audio. | 
|  | 1253 | +            This option sets the time for the initial noise estimate. (Default: 0.35) | 
|  | 1254 | +        noise_up_time (float, optional) Time constant used by the adaptive noise estimator | 
|  | 1255 | +            for when the noise level is increasing. (Default: 0.1) | 
|  | 1256 | +        noise_down_time (float, optional) Time constant used by the adaptive noise estimator | 
|  | 1257 | +            for when the noise level is decreasing. (Default: 0.01) | 
|  | 1258 | +        noise_reduction_amount (float, optional) Amount of noise reduction to use in | 
|  | 1259 | +            the detection algorithm (e.g. 0, 0.5, ...). (Default: 1.35) | 
|  | 1260 | +        measure_freq (float, optional) Frequency of the algorithm’s | 
|  | 1261 | +            processing/measurements. (Default: 20.0) | 
|  | 1262 | +        measure_duration: (float, optional) Measurement duration. | 
|  | 1263 | +            (Default: Twice the measurement period; i.e. with overlap.) | 
|  | 1264 | +        measure_smooth_time (float, optional) Time constant used to smooth | 
|  | 1265 | +            spectral measurements. (Default: 0.4) | 
|  | 1266 | +        hp_filter_freq (float, optional) "Brick-wall" frequency of high-pass filter applied | 
|  | 1267 | +            at the input to the detector algorithm. (Default: 50.0) | 
|  | 1268 | +        lp_filter_freq (float, optional) "Brick-wall" frequency of low-pass filter applied | 
|  | 1269 | +            at the input to the detector algorithm. (Default: 6000.0) | 
|  | 1270 | +        hp_lifter_freq (float, optional) "Brick-wall" frequency of high-pass lifter used | 
|  | 1271 | +            in the detector algorithm. (Default: 150.0) | 
|  | 1272 | +        lp_lifter_freq (float, optional) "Brick-wall" frequency of low-pass lifter used | 
|  | 1273 | +            in the detector algorithm. (Default: 2000.0) | 
|  | 1274 | +
 | 
|  | 1275 | +    Returns: | 
|  | 1276 | +        Tensor: Tensor of audio of dimension (..., time). | 
|  | 1277 | +
 | 
|  | 1278 | +    References: | 
|  | 1279 | +        http://sox.sourceforge.net/sox.html | 
|  | 1280 | +    """ | 
|  | 1281 | + | 
|  | 1282 | +    measure_duration: float = 2.0 / measure_freq \ | 
|  | 1283 | +        if measure_duration is None \ | 
|  | 1284 | +        else measure_duration | 
|  | 1285 | + | 
|  | 1286 | +    measure_len_ws = int(sample_rate * measure_duration + .5) | 
|  | 1287 | +    measure_len_ns = measure_len_ws | 
|  | 1288 | +    # for (dft_len_ws = 16; dft_len_ws < measure_len_ws; dft_len_ws <<= 1); | 
|  | 1289 | +    dft_len_ws = 16 | 
|  | 1290 | +    while (dft_len_ws < measure_len_ws): | 
|  | 1291 | +        dft_len_ws *= 2 | 
|  | 1292 | + | 
|  | 1293 | +    measure_period_ns = int(sample_rate / measure_freq + .5) | 
|  | 1294 | +    measures_len = math.ceil(search_time * measure_freq) | 
|  | 1295 | +    search_pre_trigger_len_ns = measures_len * measure_period_ns | 
|  | 1296 | +    gap_len = int(allowed_gap * measure_freq + .5) | 
|  | 1297 | + | 
|  | 1298 | +    fixed_pre_trigger_len_ns = int(pre_trigger_time * sample_rate + .5) | 
|  | 1299 | +    samplesLen_ns = fixed_pre_trigger_len_ns + search_pre_trigger_len_ns + measure_len_ns | 
|  | 1300 | + | 
|  | 1301 | +    spectrum_window = torch.zeros(measure_len_ws) | 
|  | 1302 | +    for i in range(measure_len_ws): | 
|  | 1303 | +        # sox.h:741 define SOX_SAMPLE_MIN (sox_sample_t)SOX_INT_MIN(32) | 
|  | 1304 | +        spectrum_window[i] = 2. / math.sqrt(float(measure_len_ws)) | 
|  | 1305 | +    # lsx_apply_hann(spectrum_window, (int)measure_len_ws); | 
|  | 1306 | +    spectrum_window *= torch.hann_window(measure_len_ws, dtype=torch.float) | 
|  | 1307 | + | 
|  | 1308 | +    spectrum_start: int = int(hp_filter_freq / sample_rate * dft_len_ws + .5) | 
|  | 1309 | +    spectrum_start: int = max(spectrum_start, 1) | 
|  | 1310 | +    spectrum_end: int = int(lp_filter_freq / sample_rate * dft_len_ws + .5) | 
|  | 1311 | +    spectrum_end: int = min(spectrum_end, dft_len_ws // 2) | 
|  | 1312 | + | 
|  | 1313 | +    cepstrum_window = torch.zeros(spectrum_end - spectrum_start) | 
|  | 1314 | +    for i in range(spectrum_end - spectrum_start): | 
|  | 1315 | +        cepstrum_window[i] = 2. / math.sqrt(float(spectrum_end) - spectrum_start) | 
|  | 1316 | +    # lsx_apply_hann(cepstrum_window,(int)(spectrum_end - spectrum_start)); | 
|  | 1317 | +    cepstrum_window *= torch.hann_window(spectrum_end - spectrum_start, dtype=torch.float) | 
|  | 1318 | + | 
|  | 1319 | +    cepstrum_start = math.ceil(sample_rate * .5 / lp_lifter_freq) | 
|  | 1320 | +    cepstrum_end = math.floor(sample_rate * .5 / hp_lifter_freq) | 
|  | 1321 | +    cepstrum_end = min(cepstrum_end, dft_len_ws // 4) | 
|  | 1322 | + | 
|  | 1323 | +    assert cepstrum_end > cepstrum_start | 
|  | 1324 | + | 
|  | 1325 | +    noise_up_time_mult = math.exp(-1. / (noise_up_time * measure_freq)) | 
|  | 1326 | +    noise_down_time_mult = math.exp(-1. / (noise_down_time * measure_freq)) | 
|  | 1327 | +    measure_smooth_time_mult = math.exp(-1. / (measure_smooth_time * measure_freq)) | 
|  | 1328 | +    trigger_meas_time_mult = math.exp(-1. / (trigger_time * measure_freq)) | 
|  | 1329 | + | 
|  | 1330 | +    boot_count_max = int(boot_time * measure_freq - .5) | 
|  | 1331 | +    measure_timer_ns = measure_len_ns | 
|  | 1332 | +    boot_count = measures_index = flushedLen_ns = samplesIndex_ns = 0 | 
|  | 1333 | + | 
|  | 1334 | +    # pack batch | 
|  | 1335 | +    shape = waveform.size() | 
|  | 1336 | +    waveform = waveform.view(-1, shape[-1]) | 
|  | 1337 | + | 
|  | 1338 | +    n_channels, ilen = waveform.size() | 
|  | 1339 | + | 
|  | 1340 | +    mean_meas = torch.zeros(n_channels) | 
|  | 1341 | +    samples = torch.zeros(n_channels, samplesLen_ns) | 
|  | 1342 | +    spectrum = torch.zeros(n_channels, dft_len_ws) | 
|  | 1343 | +    noise_spectrum = torch.zeros(n_channels, dft_len_ws) | 
|  | 1344 | +    measures = torch.zeros(n_channels, measures_len) | 
|  | 1345 | + | 
|  | 1346 | +    has_triggered: bool = False | 
|  | 1347 | +    num_measures_to_flush: int = 0 | 
|  | 1348 | +    pos: int = 0 | 
|  | 1349 | + | 
|  | 1350 | +    while (pos < ilen and not has_triggered): | 
|  | 1351 | +        measure_timer_ns -= 1 | 
|  | 1352 | +        for i in range(n_channels): | 
|  | 1353 | +            samples[i, samplesIndex_ns] = waveform[i, pos] | 
|  | 1354 | +            # if (!p->measure_timer_ns) { | 
|  | 1355 | +            if (measure_timer_ns == 0): | 
|  | 1356 | +                index_ns: int = \ | 
|  | 1357 | +                    (samplesIndex_ns + samplesLen_ns - measure_len_ns) % samplesLen_ns | 
|  | 1358 | +                meas: float = _measure( | 
|  | 1359 | +                    measure_len_ws=measure_len_ws, | 
|  | 1360 | +                    samples=samples[i], | 
|  | 1361 | +                    spectrum=spectrum[i], | 
|  | 1362 | +                    noise_spectrum=noise_spectrum[i], | 
|  | 1363 | +                    spectrum_window=spectrum_window, | 
|  | 1364 | +                    spectrum_start=spectrum_start, | 
|  | 1365 | +                    spectrum_end=spectrum_end, | 
|  | 1366 | +                    cepstrum_window=cepstrum_window, | 
|  | 1367 | +                    cepstrum_start=cepstrum_start, | 
|  | 1368 | +                    cepstrum_end=cepstrum_end, | 
|  | 1369 | +                    noise_reduction_amount=noise_reduction_amount, | 
|  | 1370 | +                    measure_smooth_time_mult=measure_smooth_time_mult, | 
|  | 1371 | +                    noise_up_time_mult=noise_up_time_mult, | 
|  | 1372 | +                    noise_down_time_mult=noise_down_time_mult, | 
|  | 1373 | +                    index_ns=index_ns, | 
|  | 1374 | +                    boot_count=boot_count) | 
|  | 1375 | +                measures[i, measures_index] = meas | 
|  | 1376 | +                mean_meas[i] = mean_meas[i] * trigger_meas_time_mult + meas * (1. - trigger_meas_time_mult) | 
|  | 1377 | + | 
|  | 1378 | +                has_triggered = has_triggered or (mean_meas[i] >= trigger_level) | 
|  | 1379 | +                if has_triggered: | 
|  | 1380 | +                    n: int = measures_len | 
|  | 1381 | +                    k: int = measures_index | 
|  | 1382 | +                    jTrigger: int = n | 
|  | 1383 | +                    jZero: int = n | 
|  | 1384 | +                    j: int = 0 | 
|  | 1385 | + | 
|  | 1386 | +                    for j in range(n): | 
|  | 1387 | +                        if (measures[i, k] >= trigger_level) and (j <= jTrigger + gap_len): | 
|  | 1388 | +                            jZero = jTrigger = j | 
|  | 1389 | +                        elif (measures[i, k] == 0) and (jTrigger >= jZero): | 
|  | 1390 | +                            jZero = j | 
|  | 1391 | +                        k = (k + n - 1) % n | 
|  | 1392 | +                    j = min(j, jZero) | 
|  | 1393 | +                    # num_measures_to_flush = range_limit(j, num_measures_to_flush, n); | 
|  | 1394 | +                    num_measures_to_flush = (min(max(num_measures_to_flush, j), n)) | 
|  | 1395 | +                # end if has_triggered | 
|  | 1396 | +            # end if (measure_timer_ns == 0): | 
|  | 1397 | +        # end for | 
|  | 1398 | +        samplesIndex_ns += 1 | 
|  | 1399 | +        pos += 1 | 
|  | 1400 | +    # end while | 
|  | 1401 | +        if samplesIndex_ns == samplesLen_ns: | 
|  | 1402 | +            samplesIndex_ns = 0 | 
|  | 1403 | +        if measure_timer_ns == 0: | 
|  | 1404 | +            measure_timer_ns = measure_period_ns | 
|  | 1405 | +            measures_index += 1 | 
|  | 1406 | +            measures_index = measures_index % measures_len | 
|  | 1407 | +            if boot_count >= 0: | 
|  | 1408 | +                boot_count = -1 if boot_count == boot_count_max else boot_count + 1 | 
|  | 1409 | + | 
|  | 1410 | +        if has_triggered: | 
|  | 1411 | +            flushedLen_ns = (measures_len - num_measures_to_flush) * measure_period_ns | 
|  | 1412 | +            samplesIndex_ns = (samplesIndex_ns + flushedLen_ns) % samplesLen_ns | 
|  | 1413 | + | 
|  | 1414 | +    res = waveform[:, pos - samplesLen_ns + flushedLen_ns:] | 
|  | 1415 | +    # unpack batch | 
|  | 1416 | +    return res.view(shape[:-1] + res.shape[-1:]) | 
0 commit comments