Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions test/torchaudio_unittest/functional/functional_cpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
import itertools

from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import (
TorchaudioTestCase,
skipIfNoExtension,
)
from torchaudio_unittest.backend.sox_io.common import name_func

from .functional_impl import Lfilter, Spectrogram


Expand Down Expand Up @@ -53,6 +59,7 @@ def test_warning(self):

class TestComputeDeltas(common_utils.TorchaudioTestCase):
"""Test suite for correctness of compute_deltas"""

def test_one_channel(self):
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0]]])
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]])
Expand Down Expand Up @@ -211,3 +218,48 @@ def test_mask_along_axis_iid(self, mask_param, mask_value, axis):

assert mask_specgrams.size() == specgrams.size()
assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel()


@skipIfNoExtension
class TestApplyCodec(TorchaudioTestCase):
backend = "sox_io"

def _smoke_test(self, format, compression, check_num_frames):
"""
The purpose of this test suite is to verify that apply_codec functionalities do not exhibit
abnormal behaviors.
"""
torch.random.manual_seed(42)
sample_rate = 8000
num_frames = 3 * sample_rate
num_channels = 2
waveform = torch.rand(num_channels, num_frames)

augmented = F.apply_codec(waveform,
sample_rate,
format,
True,
compression
)
assert augmented.dtype == waveform.dtype
assert augmented.shape[0] == num_channels
if check_num_frames:
assert augmented.shape[1] == num_frames

def test_wave(self):
self._smoke_test("wav", compression=None, check_num_frames=True)

@parameterized.expand([(96,), (128,), (160,), (192,), (224,), (256,), (320,)],
name_func=name_func)
def test_mp3(self, compression):
self._smoke_test("mp3", compression, check_num_frames=False)

@parameterized.expand([(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,)],
name_func=name_func)
def test_flac(self, compression):
self._smoke_test("flac", compression, check_num_frames=False)

@parameterized.expand([(-1,), (0,), (1,), (2,), (3,), (3.6,), (5,), (10,)],
name_func=name_func)
def test_vorbis(self, compression):
self._smoke_test("vorbis", compression, check_num_frames=False)
2 changes: 2 additions & 0 deletions torchaudio/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
sliding_window_cmn,
spectrogram,
spectral_centroid,
apply_codec,
)
from .filtering import (
allpass_biquad,
Expand Down Expand Up @@ -84,4 +85,5 @@
'riaa_biquad',
'treble_biquad',
'vad',
'apply_codec'
]
54 changes: 53 additions & 1 deletion torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# -*- coding: utf-8 -*-

import io
import math
from typing import Optional, Tuple
import warnings
from typing import Optional, Tuple

import torch
from torch import Tensor
from torchaudio._internal import (
module_utils as _mod_utils,
)
import torchaudio

__all__ = [
"spectrogram",
Expand All @@ -29,6 +34,7 @@
'mask_along_axis_iid',
'sliding_window_cmn',
"spectral_centroid",
"apply_codec",
]


Expand Down Expand Up @@ -994,6 +1000,52 @@ def spectral_centroid(
return (freqs * specgram).sum(dim=freq_dim) / specgram.sum(dim=freq_dim)


@_mod_utils.requires_module('torchaudio._torchaudio')
def apply_codec(
waveform: Tensor,
sample_rate: int,
format: str,
channels_first: bool = True,
compression: Optional[float] = None,
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you reorder the parameters as waveform, sample_rate, channels_first, format, encoding, bits_per_sample?

Sorry I am reverting what I originally asked (moving channels_first after format) but grouping format related parameters together looks nicer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving the channels_first (default_parameter) before format (non default_parameter) throws a syntax-error SyntaxError: non-default argument follows default argument.
Also what about the compression parameter?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I made an invalid suggestion again.
let's do waveform, sample_rate, format, channels_first, compression, encoding, bits_per_sample

) -> Tensor:
r"""
Applies codecs as a form of augmentation
Args:
waveform (Tensor): Audio data. Must be 2 dimensional. See also ```channels_first```
sample_rate (int): Sample rate of the audio waveform
format (str): file format
channels_first (bool):
When True, both the input and output Tensor have dimension ``[channel, time]``.
Otherwise, they have dimension ``[time, channel]``.
compression (float): Used for formats other than WAV.
For mor details see :py:func:`torchaudio.backend.sox_io_backend.save`
encoding (str, optional): Changes the encoding for the supported formats.
For more details see :py:func:`torchaudio.backend.sox_io_backend.save`
bits_per_sample (int, optional): Changes the bit depth for the supported formats.
For more details see :py:func:`torchaudio.backend.sox_io_backend.save`

Returns:
torch.Tensor: Resulting Tensor.
If ``channels_first=True``, it has ``[channel, time]`` else ``[time, channel]``
"""
bytes = io.BytesIO()
torchaudio.backend.sox_io_backend.save(bytes,
waveform,
sample_rate,
channels_first,
compression,
format,
encoding,
bits_per_sample
)
bytes.seek(0)
augmented, _ = torchaudio.sox_effects.sox_effects.apply_effects_file(
bytes, effects=[["rate", f"{sample_rate}"]], channels_first=channels_first, format=format)
return augmented


def compute_kaldi_pitch(
waveform: torch.Tensor,
sample_rate: float,
Expand Down