diff --git a/test/torchaudio_unittest/transforms/sox_compatibility_test.py b/test/torchaudio_unittest/transforms/sox_compatibility_test.py index 81582c8393..be6c9020ab 100644 --- a/test/torchaudio_unittest/transforms/sox_compatibility_test.py +++ b/test/torchaudio_unittest/transforms/sox_compatibility_test.py @@ -1,3 +1,6 @@ +import warnings + +import torch import torchaudio.transforms as T from parameterized import parameterized @@ -61,3 +64,25 @@ def test_vad(self, filename): data, sample_rate = load_wav(path) result = T.Vad(sample_rate)(data) self.assert_sox_effect(result, path, ['vad']) + + def test_vad_warning(self): + """vad should throw a warning if input dimension is greater than 2""" + sample_rate = 41100 + + data = torch.rand(5, 5, sample_rate) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + T.Vad(sample_rate)(data) + assert len(w) == 1 + + data = torch.rand(5, sample_rate) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + T.Vad(sample_rate)(data) + assert len(w) == 0 + + data = torch.rand(sample_rate) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + T.Vad(sample_rate)(data) + assert len(w) == 0 diff --git a/torchaudio/functional/filtering.py b/torchaudio/functional/filtering.py index 93da27493b..85abe81339 100644 --- a/torchaudio/functional/filtering.py +++ b/torchaudio/functional/filtering.py @@ -1,4 +1,5 @@ import math +import warnings from typing import Optional import torch @@ -1374,7 +1375,10 @@ def vad( so in order to trim from the back, the reverse effect must also be used. Args: - waveform (Tensor): Tensor of audio of dimension `(..., time)` + waveform (Tensor): Tensor of audio of dimension `(channels, time)` or `(time)` + Tensor of shape `(channels, time)` is treated as a multi-channel recording + of the same event and the resulting output will be trimmed to the earliest + voice activity in any channel. sample_rate (int): Sample rate of audio signal. trigger_level (float, optional): The measurement level used to trigger activity detection. This may need to be cahnged depending on the noise level, signal level, @@ -1420,6 +1424,15 @@ def vad( http://sox.sourceforge.net/sox.html """ + if waveform.ndim > 2: + warnings.warn( + "Expected input tensor dimension of 1 for single channel" + f" or 2 for multi-channel. Got {waveform.ndim} instead. " + "Batch semantics is not supported. " + "Please refer to https://github.com/pytorch/audio/issues/1348" + " and https://github.com/pytorch/audio/issues/1468." + ) + measure_duration: float = ( 2.0 / measure_freq if measure_duration is None else measure_duration ) diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 4704c4eb39..32676d40c9 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -1110,7 +1110,10 @@ def __init__(self, def forward(self, waveform: Tensor) -> Tensor: r""" Args: - waveform (Tensor): Tensor of audio of dimension `(..., time)` + waveform (Tensor): Tensor of audio of dimension `(channels, time)` or `(time)` + Tensor of shape `(channels, time)` is treated as a multi-channel recording + of the same event and the resulting output will be trimmed to the earliest + voice activity in any channel. """ return F.vad( waveform=waveform,