From 9aa256c456dc460fb09cb1db2bc6f4259b6ff6e7 Mon Sep 17 00:00:00 2001 From: Artyom Astafurov Date: Tue, 18 May 2021 13:27:23 -0400 Subject: [PATCH 1/3] Update VAD docstring and check for input shape length --- .../transforms/sox_compatibility_test.py | 26 +++++++++++++++++++ torchaudio/functional/filtering.py | 16 +++++++++++- 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/test/torchaudio_unittest/transforms/sox_compatibility_test.py b/test/torchaudio_unittest/transforms/sox_compatibility_test.py index 81582c8393..b9e7348b05 100644 --- a/test/torchaudio_unittest/transforms/sox_compatibility_test.py +++ b/test/torchaudio_unittest/transforms/sox_compatibility_test.py @@ -1,3 +1,6 @@ +import warnings + +import torch import torchaudio.transforms as T from parameterized import parameterized @@ -61,3 +64,26 @@ def test_vad(self, filename): data, sample_rate = load_wav(path) result = T.Vad(sample_rate)(data) self.assert_sox_effect(result, path, ['vad']) + + def test_vad_warning(self): + sample_rate = 41100 + data = torch.rand(5, 5, sample_rate) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + T.Vad(sample_rate)(data) + assert len(w) == 1 + + def test_vad_no_warning(self): + sample_rate = 41100 + + data = torch.rand(5, sample_rate) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + T.Vad(sample_rate)(data) + assert len(w) == 0 + + data = torch.rand(sample_rate) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + T.Vad(sample_rate)(data) + assert len(w) == 0 diff --git a/torchaudio/functional/filtering.py b/torchaudio/functional/filtering.py index 93da27493b..78aac2472f 100644 --- a/torchaudio/functional/filtering.py +++ b/torchaudio/functional/filtering.py @@ -1,4 +1,5 @@ import math +import warnings from typing import Optional import torch @@ -1374,7 +1375,10 @@ def vad( so in order to trim from the back, the reverse effect must also be used. Args: - waveform (Tensor): Tensor of audio of dimension `(..., time)` + waveform (Tensor): Tensor of audio of dimension `(channels, time)` or `(time)` + Tensor of shape `(channels, time)` is treated as a multi-channel recording + of the same event and the resulting output will be trimmed to the earliest + voice activity in any channel. sample_rate (int): Sample rate of audio signal. trigger_level (float, optional): The measurement level used to trigger activity detection. This may need to be cahnged depending on the noise level, signal level, @@ -1478,6 +1482,16 @@ def vad( # pack batch shape = waveform.size() + + if len(shape) > 2: + warnings.warn( + "Expected input tensor shape length of 1 for single channel" + f" or 2 for multi-channel. Got {len(shape)} instead.\n" + "Batch semantics is not supported yet.\n" + "Please refer to https://github.com/pytorch/audio/issues/1348" + " and https://github.com/pytorch/audio/issues/1468." + ) + waveform = waveform.view(-1, shape[-1]) n_channels, ilen = waveform.size() From 643a78708abe52409a889b4b862c6d2470252a86 Mon Sep 17 00:00:00 2001 From: Artyom Astafurov Date: Tue, 18 May 2021 14:03:56 -0400 Subject: [PATCH 2/3] Update docstring in forward for transform --- torchaudio/transforms.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 4704c4eb39..32676d40c9 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -1110,7 +1110,10 @@ def __init__(self, def forward(self, waveform: Tensor) -> Tensor: r""" Args: - waveform (Tensor): Tensor of audio of dimension `(..., time)` + waveform (Tensor): Tensor of audio of dimension `(channels, time)` or `(time)` + Tensor of shape `(channels, time)` is treated as a multi-channel recording + of the same event and the resulting output will be trimmed to the earliest + voice activity in any channel. """ return F.vad( waveform=waveform, From 3eb2829e145c8c248bed68cf2727715ce3c4b883 Mon Sep 17 00:00:00 2001 From: Artyom Astafurov Date: Wed, 19 May 2021 16:04:22 -0400 Subject: [PATCH 3/3] Address review feedback: merge tests, update wording --- .../transforms/sox_compatibility_test.py | 5 ++--- torchaudio/functional/filtering.py | 19 +++++++++---------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/test/torchaudio_unittest/transforms/sox_compatibility_test.py b/test/torchaudio_unittest/transforms/sox_compatibility_test.py index b9e7348b05..be6c9020ab 100644 --- a/test/torchaudio_unittest/transforms/sox_compatibility_test.py +++ b/test/torchaudio_unittest/transforms/sox_compatibility_test.py @@ -66,16 +66,15 @@ def test_vad(self, filename): self.assert_sox_effect(result, path, ['vad']) def test_vad_warning(self): + """vad should throw a warning if input dimension is greater than 2""" sample_rate = 41100 + data = torch.rand(5, 5, sample_rate) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") T.Vad(sample_rate)(data) assert len(w) == 1 - def test_vad_no_warning(self): - sample_rate = 41100 - data = torch.rand(5, sample_rate) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") diff --git a/torchaudio/functional/filtering.py b/torchaudio/functional/filtering.py index 78aac2472f..85abe81339 100644 --- a/torchaudio/functional/filtering.py +++ b/torchaudio/functional/filtering.py @@ -1424,6 +1424,15 @@ def vad( http://sox.sourceforge.net/sox.html """ + if waveform.ndim > 2: + warnings.warn( + "Expected input tensor dimension of 1 for single channel" + f" or 2 for multi-channel. Got {waveform.ndim} instead. " + "Batch semantics is not supported. " + "Please refer to https://github.com/pytorch/audio/issues/1348" + " and https://github.com/pytorch/audio/issues/1468." + ) + measure_duration: float = ( 2.0 / measure_freq if measure_duration is None else measure_duration ) @@ -1482,16 +1491,6 @@ def vad( # pack batch shape = waveform.size() - - if len(shape) > 2: - warnings.warn( - "Expected input tensor shape length of 1 for single channel" - f" or 2 for multi-channel. Got {len(shape)} instead.\n" - "Batch semantics is not supported yet.\n" - "Please refer to https://github.com/pytorch/audio/issues/1348" - " and https://github.com/pytorch/audio/issues/1468." - ) - waveform = waveform.view(-1, shape[-1]) n_channels, ilen = waveform.size()