diff --git a/README.md b/README.md index cbbfa87853..c60b2662f7 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,9 @@ torchaudio: an audio library for PyTorch - [Kaldi (ark/scp)](http://pytorch.org/audio/kaldi_io.html) - [Dataloaders for common audio datasets (VCTK, YesNo)](http://pytorch.org/audio/datasets.html) - Common audio transforms - - [Scale, PadTrim, DownmixMono, LC2CL, BLC2CBL, MuLawEncoding, MuLawExpanding](http://pytorch.org/audio/transforms.html) + - [Spectrogram, SpectrogramToDB, MelScale, MelSpectrogram, MFCC, MuLawEncoding, MuLawDecoding, Resample](http://pytorch.org/audio/transforms.html) +- Compliance interfaces: Run code using PyTorch that align with other libraries + - [Kaldi: fbank, spectrogram, resample_waveform](https://pytorch.org/audio/compliance.kaldi.html) Dependencies ------------ diff --git a/docs/source/compliance.kaldi.rst b/docs/source/compliance.kaldi.rst index 50fdf5838a..1dfee29eb1 100644 --- a/docs/source/compliance.kaldi.rst +++ b/docs/source/compliance.kaldi.rst @@ -24,3 +24,8 @@ Functions ~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: spectrogram + +:hidden:`resample_waveform` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: resample_waveform diff --git a/docs/source/functional.rst b/docs/source/functional.rst index d07b7cc52e..a8d57bb36c 100644 --- a/docs/source/functional.rst +++ b/docs/source/functional.rst @@ -8,63 +8,57 @@ torchaudio.functional Functions to perform common audio operations. -:hidden:`scale` -~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: scale - - -:hidden:`pad_trim` +:hidden:`istft` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: pad_trim +.. autofunction:: istft -:hidden:`downmix_mono` +:hidden:`spectrogram` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: downmix_mono +.. autofunction:: spectrogram -:hidden:`LC2CL` +:hidden:`amplitude_to_DB` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: LC2CL +.. autofunction:: amplitude_to_DB -:hidden:`istft` +:hidden:`create_fb_matrix` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: istft +.. autofunction:: create_fb_matrix -:hidden:`spectrogram` +:hidden:`create_dct` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: spectrogram +.. autofunction:: create_dct -:hidden:`create_fb_matrix` +:hidden:`mu_law_encoding` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: create_fb_matrix +.. autofunction:: mu_law_encoding -:hidden:`spectrogram_to_DB` +:hidden:`mu_law_decoding` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: spectrogram_to_DB +.. autofunction:: mu_law_decoding -:hidden:`create_dct` +:hidden:`complex_norm` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: create_dct +.. autofunction:: complex_norm -:hidden:`BLC2CBL` +:hidden:`angle` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: BLC2CBL +.. autofunction:: angle -:hidden:`mu_law_encoding` +:hidden:`magphase` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: mu_law_encoding +.. autofunction:: magphase -:hidden:`mu_law_expanding` +:hidden:`phase_vocoder` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: mu_law_expanding +.. autofunction:: phase_vocoder diff --git a/docs/source/kaldi_io.rst b/docs/source/kaldi_io.rst index 74b26645f3..2744bcc897 100644 --- a/docs/source/kaldi_io.rst +++ b/docs/source/kaldi_io.rst @@ -7,7 +7,7 @@ torchaudio.kaldi_io .. currentmodule:: torchaudio.kaldi_io To use this module, the dependency kaldi_io_ needs to be installed. -This is a light wrapper around ``kaldi_io`` that returns :class:`torch.Tensors`. +This is a light wrapper around ``kaldi_io`` that returns :class:`torch.Tensor`. .. _kaldi_io: https://github.com/vesis84/kaldi-io-for-python diff --git a/docs/source/legacy.rst b/docs/source/legacy.rst index 86e5390be5..ac95a657f7 100644 --- a/docs/source/legacy.rst +++ b/docs/source/legacy.rst @@ -1,7 +1,19 @@ +.. role:: hidden + :class: hidden-section + torchaudio.legacy ====================== +.. currentmodule:: torchaudio.legacy + Legacy loading and save functions. -.. automodule:: torchaudio.legacy - :members: +:hidden:`load` +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: load + +:hidden:`save` +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: save diff --git a/docs/source/sox_effects.rst b/docs/source/sox_effects.rst index e02c220eac..56cd985d0a 100644 --- a/docs/source/sox_effects.rst +++ b/docs/source/sox_effects.rst @@ -1,3 +1,6 @@ +.. role:: hidden + :class: hidden-section + torchaudio.sox_effects ====================== @@ -5,8 +8,14 @@ Create SoX effects chain for preprocessing audio. .. currentmodule:: torchaudio.sox_effects +:hidden:`SoxEffect` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + .. autoclass:: SoxEffect :members: +:hidden:`SoxEffectsChain` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + .. autoclass:: SoxEffectsChain :members: append_effect_to_chain, sox_build_flow_effects, clear_chain, set_input_file diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 52f6ca1973..ac2c733ac6 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -1,24 +1,66 @@ +.. role:: hidden + :class: hidden-section + torchaudio.transforms ====================== .. currentmodule:: torchaudio.transforms -Transforms are common audio transforms. They can be chained together using :class:`Compose` +Transforms are common audio transforms. They can be chained together using :class:`torch.nn.Sequential` + + +:hidden:`Spectrogram` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: Spectrogram + + .. automethod:: torchaudio._docs.Spectrogram.forward + +:hidden:`AmplitudeToDB` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: AmplitudeToDB + + .. automethod:: torchaudio._docs.AmplitudeToDB.forward + +:hidden:`MelScale` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: Compose +.. autoclass:: MelScale -.. autoclass:: Scale + .. automethod:: torchaudio._docs.MelScale.forward -.. autoclass:: PadTrim +:hidden:`MelSpectrogram` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: DownmixMono +.. autoclass:: MelSpectrogram -.. autoclass:: LC2CL + .. automethod:: torchaudio._docs.MelSpectrogram.forward -.. autoclass:: MEL +:hidden:`MFCC` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: BLC2CBL +.. autoclass:: MFCC + + .. automethod:: torchaudio._docs.MFCC.forward + +:hidden:`MuLawEncoding` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: MuLawEncoding -.. autoclass:: MuLawExpanding + .. automethod:: torchaudio._docs.MuLawEncoding.forward + +:hidden:`MuLawDecoding` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: MuLawDecoding + + .. automethod:: torchaudio._docs.MuLawDecoding.forward + +:hidden:`Resample` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: Resample + + .. automethod:: torchaudio._docs.Resample.forward diff --git a/torchaudio/__init__.py b/torchaudio/__init__.py index 54fe40f04b..942ba3530c 100644 --- a/torchaudio/__init__.py +++ b/torchaudio/__init__.py @@ -5,7 +5,7 @@ import _torch_sox from .version import __version__, git_version -from torchaudio import transforms, datasets, kaldi_io, sox_effects, legacy, compliance +from torchaudio import transforms, datasets, kaldi_io, sox_effects, legacy, compliance, _docs def check_input(src): @@ -24,33 +24,35 @@ def load(filepath, signalinfo=None, encodinginfo=None, filetype=None): - """Loads an audio file from disk into a Tensor + r"""Loads an audio file from disk into a tensor Args: - filepath (string or pathlib.Path): path to audio file - out (Tensor, optional): an output Tensor to use instead of creating one + filepath (str or pathlib.Path): Path to audio file + out (torch.Tensor, optional): An output tensor to use instead of creating one. (Default: ``None``) normalization (bool, number, or callable, optional): If boolean `True`, then output is divided by `1 << 31` - (assumes signed 32-bit audio), and normalizes to `[0, 1]`. - If `number`, then output is divided by that number - If `callable`, then the output is passed as a parameter - to the given function, then the output is divided by - the result. - channels_first (bool): Set channels first or length first in result. Default: ``True`` - num_frames (int, optional): number of frames to load. 0 to load everything after the offset. - offset (int, optional): number of frames from the start of the file to begin data loading. - signalinfo (sox_signalinfo_t, optional): a sox_signalinfo_t type, which could be helpful if the - audio type cannot be automatically determined - encodinginfo (sox_encodinginfo_t, optional): a sox_encodinginfo_t type, which could be set if the - audio type cannot be automatically determined - filetype (str, optional): a filetype or extension to be set if sox cannot determine it automatically - - Returns: tuple(Tensor, int) - - Tensor: output Tensor of size `[C x L]` or `[L x C]` where L is the number of audio frames and - C is the number of channels - - int: the sample rate of the audio (as listed in the metadata of the file) - - Example:: + (assumes signed 32-bit audio), and normalizes to `[0, 1]`. + If `number`, then output is divided by that number + If `callable`, then the output is passed as a parameter + to the given function, then the output is divided by + the result. (Default: ``True``) + channels_first (bool): Set channels first or length first in result. (Default: ``True``) + num_frames (int, optional): Number of frames to load. 0 to load everything after the offset. + (Default: ``0``) + offset (int, optional): Number of frames from the start of the file to begin data loading. + (Default: ``0``) + signalinfo (sox_signalinfo_t, optional): A sox_signalinfo_t type, which could be helpful if the + audio type cannot be automatically determined. (Default: ``None``) + encodinginfo (sox_encodinginfo_t, optional): A sox_encodinginfo_t type, which could be set if the + audio type cannot be automatically determined. (Default: ``None``) + filetype (str, optional): A filetype or extension to be set if sox cannot determine it + automatically. (Default: ``None``) + Returns: + Tuple[torch.Tensor, int]: An output tensor of size `[C x L]` or `[L x C]` where L is the number + of audio frames and C is the number of channels. An integer which is the sample rate of the + audio (as listed in the metadata of the file) + + Example >>> data, sample_rate = torchaudio.load('foo.mp3') >>> print(data.size()) torch.Size([2, 278756]) @@ -94,16 +96,33 @@ def load(filepath, def load_wav(filepath, **kwargs): - """ Loads a wave file. It assumes that the wav file uses 16 bit per sample that needs normalization by shifting + r""" Loads a wave file. It assumes that the wav file uses 16 bit per sample that needs normalization by shifting the input right by 16 bits. + + Args: + filepath (str or pathlib.Path): Path to audio file + + Returns: + Tuple[torch.Tensor, int]: An output tensor of size `[C x L]` or `[L x C]` where L is the number + of audio frames and C is the number of channels. An integer which is the sample rate of the + audio (as listed in the metadata of the file) """ kwargs['normalization'] = 1 << 16 return load(filepath, **kwargs) def save(filepath, src, sample_rate, precision=16, channels_first=True): - """Convenience function for `save_encinfo`. + r"""Convenience function for `save_encinfo`. + Args: + filepath (str): Path to audio file + src (torch.Tensor): An input 2D tensor of shape `[C x L]` or `[L x C]` where L is + the number of audio frames, C is the number of channels + sample_rate (int): An integer which is the sample rate of the + audio (as listed in the metadata of the file) + precision (int): Bit precision (Default: ``16``) + channels_first (bool): Set channels first or length first in result. ( + Default: ``True``) """ si = sox_signalinfo_t() ch_idx = 0 if channels_first else 1 @@ -120,21 +139,21 @@ def save_encinfo(filepath, signalinfo=None, encodinginfo=None, filetype=None): - """Saves a Tensor of an audio signal to disk as a standard format like mp3, wav, etc. + r"""Saves a tensor of an audio signal to disk as a standard format like mp3, wav, etc. Args: - filepath (string): path to audio file - src (Tensor): an input 2D Tensor of shape `[C x L]` or `[L x C]` where L is - the number of audio frames, C is the number of channels - channels_first (bool): Set channels first or length first in result. Default: ``True`` - signalinfo (sox_signalinfo_t): a sox_signalinfo_t type, which could be helpful if the - audio type cannot be automatically determined - encodinginfo (sox_encodinginfo_t, optional): a sox_encodinginfo_t type, which could be set if the - audio type cannot be automatically determined - filetype (str, optional): a filetype or extension to be set if sox cannot determine it automatically - - Example:: - + filepath (str): Path to audio file + src (torch.Tensor): An input 2D tensor of shape `[C x L]` or `[L x C]` where L is + the number of audio frames, C is the number of channels + channels_first (bool): Set channels first or length first in result. (Default: ``True``) + signalinfo (sox_signalinfo_t): A sox_signalinfo_t type, which could be helpful if the + audio type cannot be automatically determined. (Default: ``None``) + encodinginfo (sox_encodinginfo_t, optional): A sox_encodinginfo_t type, which could be set if the + audio type cannot be automatically determined. (Default: ``None``) + filetype (str, optional): A filetype or extension to be set if sox cannot determine it + automatically. (Default: ``None``) + + Example >>> data, sample_rate = torchaudio.load('foo.mp3') >>> torchaudio.save('foo.wav', data, sample_rate) @@ -184,16 +203,16 @@ def save_encinfo(filepath, def info(filepath): - """Gets metadata from an audio file without loading the signal. + r"""Gets metadata from an audio file without loading the signal. Args: - filepath (string): path to audio file + filepath (str): Path to audio file - Returns: tuple(si, ei) - - si (sox_signalinfo_t): signal info as a python object - - ei (sox_encodinginfo_t): encoding info as a python object + Returns: + Tuple[sox_signalinfo_t, sox_encodinginfo_t]: A si (sox_signalinfo_t) signal + info as a python object. An ei (sox_encodinginfo_t) encoding info - Example:: + Example >>> si, ei = torchaudio.info('foo.wav') >>> rate, channels, encoding = si.rate, si.channels, ei.encoding """ @@ -206,13 +225,13 @@ def sox_signalinfo_t(): primarily for effects Returns: sox_signalinfo_t(object) - - rate (float), sample rate as a float, practically will likely be an integer float - - channel (int), number of audio channels - - precision (int), bit precision - - length (int), length of audio in samples * channels, 0 for unspecified and -1 for unknown - - mult (float, optional), headroom multiplier for effects and None for no multiplier + - rate (float), sample rate as a float, practically will likely be an integer float + - channel (int), number of audio channels + - precision (int), bit precision + - length (int), length of audio in samples * channels, 0 for unspecified and -1 for unknown + - mult (float, optional), headroom multiplier for effects and ``None`` for no multiplier - Example:: + Example >>> si = torchaudio.sox_signalinfo_t() >>> si.channels = 1 >>> si.rate = 16000. @@ -223,7 +242,7 @@ def sox_signalinfo_t(): def sox_encodinginfo_t(): - """Create a sox_encodinginfo_t object. This object can be used to set the encoding + r"""Create a sox_encodinginfo_t object. This object can be used to set the encoding type, bit precision, compression factor, reverse bytes, reverse nibbles, reverse bits and endianness. This can be used in an effects chain to encode the final output or to save a file with a specific encoding. For example, one could @@ -232,15 +251,15 @@ def sox_encodinginfo_t(): the bit precision. Returns: sox_encodinginfo_t(object) - - encoding (sox_encoding_t), output encoding - - bits_per_sample (int), bit precision, same as `precision` in sox_signalinfo_t - - compression (float), compression for lossy formats, 0.0 for default compression - - reverse_bytes (sox_option_t), reverse bytes, use sox_option_default - - reverse_nibbles (sox_option_t), reverse nibbles, use sox_option_default - - reverse_bits (sox_option_t), reverse bytes, use sox_option_default - - opposite_endian (sox_bool), change endianness, use sox_false - - Example:: + - encoding (sox_encoding_t), output encoding + - bits_per_sample (int), bit precision, same as `precision` in sox_signalinfo_t + - compression (float), compression for lossy formats, 0.0 for default compression + - reverse_bytes (sox_option_t), reverse bytes, use sox_option_default + - reverse_nibbles (sox_option_t), reverse nibbles, use sox_option_default + - reverse_bits (sox_option_t), reverse bytes, use sox_option_default + - opposite_endian (sox_bool), change endianness, use sox_false + + Example >>> ei = torchaudio.sox_encodinginfo_t() >>> ei.encoding = torchaudio.get_sox_encoding_t(1) >>> ei.bits_per_sample = 16 @@ -260,13 +279,14 @@ def sox_encodinginfo_t(): def get_sox_encoding_t(i=None): - """Get enum of sox_encoding_t for sox encodings. + r"""Get enum of sox_encoding_t for sox encodings. Args: - i (int, optional): choose type or get a dict with all possible options - use `__members__` to see all options when not specified + i (int, optional): Choose type or get a dict with all possible options + use ``__members__`` to see all options when not specified. (Default: ``None``) + Returns: - sox_encoding_t: a sox_encoding_t type for output encoding + sox_encoding_t: A sox_encoding_t type for output encoding """ if i is None: # one can see all possible values using the .__members__ attribute @@ -276,14 +296,14 @@ def get_sox_encoding_t(i=None): def get_sox_option_t(i=2): - """Get enum of sox_option_t for sox encodinginfo options. + r"""Get enum of sox_option_t for sox encodinginfo options. Args: - i (int, optional): choose type or get a dict with all possible options - use `__members__` to see all options when not specified. - Defaults to sox_option_default. + i (int, optional): Choose type or get a dict with all possible options + use ``__members__`` to see all options when not specified. + (Default: ``sox_option_default`` or ``2``) Returns: - sox_option_t: a sox_option_t type + sox_option_t: A sox_option_t type """ if i is None: return _torch_sox.sox_option_t @@ -292,14 +312,15 @@ def get_sox_option_t(i=2): def get_sox_bool(i=0): - """Get enum of sox_bool for sox encodinginfo options. + r"""Get enum of sox_bool for sox encodinginfo options. Args: - i (int, optional): choose type or get a dict with all possible options - use `__members__` to see all options when not specified. - Defaults to sox_false. + i (int, optional): Choose type or get a dict with all possible options + use ``__members__`` to see all options when not specified. (Default: + ``sox_false`` or ``0``) + Returns: - sox_bool: a sox_bool type + sox_bool: A sox_bool type """ if i is None: return _torch_sox.sox_bool diff --git a/torchaudio/_docs.py b/torchaudio/_docs.py new file mode 100644 index 0000000000..2b2c3000f6 --- /dev/null +++ b/torchaudio/_docs.py @@ -0,0 +1,35 @@ +import torchaudio + + +# TODO See https://github.com/pytorch/audio/issues/165 +class Spectrogram: + forward = torchaudio.transforms.Spectrogram().forward + + +class AmplitudeToDB: + forward = torchaudio.transforms.AmplitudeToDB().forward + + +class MelScale: + forward = torchaudio.transforms.MelScale().forward + + +class MelSpectrogram: + forward = torchaudio.transforms.MelSpectrogram().forward + + +class MFCC: + forward = torchaudio.transforms.MFCC().forward + + +class MuLawEncoding: + forward = torchaudio.transforms.MuLawEncoding().forward + + +class MuLawDecoding: + forward = torchaudio.transforms.MuLawDecoding().forward + + +class Resample: + # Resample isn't a script_method + forward = torchaudio.transforms.Resample.forward diff --git a/torchaudio/compliance/kaldi.py b/torchaudio/compliance/kaldi.py index 9291722fa3..d0591f4411 100644 --- a/torchaudio/compliance/kaldi.py +++ b/torchaudio/compliance/kaldi.py @@ -37,11 +37,11 @@ def _next_power_of_2(x): def _get_strided(waveform, window_size, window_shift, snip_edges): - r"""Given a waveform (1D tensor of size `num_samples`), it returns a 2D tensor (m, `window_size`) + r"""Given a waveform (1D tensor of size ``num_samples``), it returns a 2D tensor (m, ``window_size``) representing how the window is shifted along the waveform. Each row is a frame. Args: - waveform (torch.Tensor): Tensor of size `num_samples` + waveform (torch.Tensor): Tensor of size ``num_samples`` window_size (int): Frame length window_shift (int): Frame shift snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit @@ -49,7 +49,7 @@ def _get_strided(waveform, window_size, window_shift, snip_edges): depends only on the frame_shift, and we reflect the data at the ends. Returns: - torch.Tensor: 2D tensor of size (m, `window_size`) where each row is a frame + torch.Tensor: 2D tensor of size (m, ``window_size``) where each row is a frame """ assert waveform.dim() == 1 num_samples = waveform.size(0) @@ -134,7 +134,7 @@ def _get_window(waveform, padded_window_size, window_size, window_shift, window_ r"""Gets a window and its log energy Returns: - strided_input (torch.Tensor): size (m, `padded_window_size`) + strided_input (torch.Tensor): size (m, ``padded_window_size``) signal_log_energy (torch.Tensor): size (m) """ # size (m, window_size) @@ -191,33 +191,33 @@ def spectrogram( Args: waveform (torch.Tensor): Tensor of audio of size (c, n) where c is in the range [0,2) - blackman_coeff (float): Constant coefficient for generalized Blackman window. (Default: 0.42) - channel (int): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: -1) + blackman_coeff (float): Constant coefficient for generalized Blackman window. (Default: ``0.42``) + channel (int): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``) dither (float): Dithering constant (0.0 means no dither). If you turn this off, you should set - the energy_floor option, e.g. to 1.0 or 0.1 (Default: 1.0) + the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``1.0``) energy_floor (float): Floor on energy (absolute, not relative) in Spectrogram computation. Caution: this floor is applied to the zeroth component, representing the total signal energy. The floor on the - individual spectrogram elements is fixed at std::numeric_limits::epsilon(). (Default: 0.0) - frame_length (float): Frame length in milliseconds (Default: 25.0) - frame_shift (float): Frame shift in milliseconds (Default: 10.0) - min_duration (float): Minimum duration of segments to process (in seconds). (Default: 0.0) - preemphasis_coefficient (float): Coefficient for use in signal preemphasis (Default: 0.97) - raw_energy (bool): If True, compute energy before preemphasis and windowing (Default: True) - remove_dc_offset: Subtract mean from waveform on each frame (Default: True) + individual spectrogram elements is fixed at std::numeric_limits::epsilon(). (Default: ``0.0``) + frame_length (float): Frame length in milliseconds (Default: ``25.0``) + frame_shift (float): Frame shift in milliseconds (Default: ``10.0``) + min_duration (float): Minimum duration of segments to process (in seconds). (Default: ``0.0``) + preemphasis_coefficient (float): Coefficient for use in signal preemphasis (Default: ``0.97``) + raw_energy (bool): If True, compute energy before preemphasis and windowing (Default: ``True``) + remove_dc_offset: Subtract mean from waveform on each frame (Default: ``True``) round_to_power_of_two (bool): If True, round window size to power of two by zero-padding input - to FFT. (Default: True) + to FFT. (Default: ``True``) sample_frequency (float): Waveform data sample frequency (must match the waveform file, if - specified there) (Default: 16000.0) + specified there) (Default: ``16000.0``) snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit in the file, and the number of frames depends on the frame_length. If False, the number of frames - depends only on the frame_shift, and we reflect the data at the ends. (Default: True) + depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``) subtract_mean (bool): Subtract mean of each feature file [CMS]; not recommended to do - it this way. (Default: False) - window_type (str): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman') (Default: 'povey') + it this way. (Default: ``False``) + window_type (str): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman') (Default: ``'povey'``) Returns: torch.Tensor: A spectrogram identical to what Kaldi would output. The shape is - (m, `padded_window_size` // 2 + 1) where m is calculated in _get_strided + (m, ``padded_window_size // 2 + 1``) where m is calculated in _get_strided """ waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties( waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient) @@ -343,7 +343,7 @@ def vtln_warp_mel_freq(vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, mel_freq (torch.Tensor): Given frequency in Mel Returns: - torch.Tensor: `mel_freq` after vtln warp + torch.Tensor: ``mel_freq`` after vtln warp """ return mel_scale(vtln_warp_freq(vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, vtln_warp_factor, inverse_mel_scale(mel_freq))) @@ -354,9 +354,9 @@ def get_mel_banks(num_bins, window_length_padded, sample_freq, # type: (int, int, float, float, float, float, float) """ Returns: - Tuple[torch.Tensor, torch.Tensor]: The tuple consists of `bins` (which is - Melbank of size (`num_bins`, `num_fft_bins`)) and `center_freqs` (which is - Center frequencies of bins of size (`num_bins`)). + Tuple[torch.Tensor, torch.Tensor]: The tuple consists of ``bins`` (which is + melbank of size (``num_bins``, ``num_fft_bins``)) and ``center_freqs`` (which is + center frequencies of bins of size (``num_bins``)). """ assert num_bins > 3, 'Must have at least 3 mel bins' assert window_length_padded % 2 == 0 @@ -430,44 +430,44 @@ def fbank( Args: waveform (torch.Tensor): Tensor of audio of size (c, n) where c is in the range [0,2) - blackman_coeff (float): Constant coefficient for generalized Blackman window. (Default: 0.42) - channel (int): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: -1) + blackman_coeff (float): Constant coefficient for generalized Blackman window. (Default: ``0.42``) + channel (int): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``) dither (float): Dithering constant (0.0 means no dither). If you turn this off, you should set - the energy_floor option, e.g. to 1.0 or 0.1 (Default: 1.0) + the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``1.0``) energy_floor (float): Floor on energy (absolute, not relative) in Spectrogram computation. Caution: this floor is applied to the zeroth component, representing the total signal energy. The floor on the - individual spectrogram elements is fixed at std::numeric_limits::epsilon(). (Default: 0.0) - frame_length (float): Frame length in milliseconds (Default: 25.0) - frame_shift (float): Frame shift in milliseconds (Default: 10.0) - high_freq (float): High cutoff frequency for mel bins (if <= 0, offset from Nyquist) (Default: 0.0) + individual spectrogram elements is fixed at std::numeric_limits::epsilon(). (Default: ``0.0``) + frame_length (float): Frame length in milliseconds (Default: ``25.0``) + frame_shift (float): Frame shift in milliseconds (Default: ``10.0``) + high_freq (float): High cutoff frequency for mel bins (if <= 0, offset from Nyquist) (Default: ``0.0``) htk_compat (bool): If true, put energy last. Warning: not sufficient to get HTK compatible features (need - to change other parameters). (Default: False) - low_freq (float): Low cutoff frequency for mel bins (Default: 20.0) - min_duration (float): Minimum duration of segments to process (in seconds). (Default: 0.0) - num_mel_bins (int): Number of triangular mel-frequency bins (Default: 23) - preemphasis_coefficient (float): Coefficient for use in signal preemphasis (Default: 0.97) - raw_energy (bool): If True, compute energy before preemphasis and windowing (Default: True) - remove_dc_offset: Subtract mean from waveform on each frame (Default: True) + to change other parameters). (Default: ``False``) + low_freq (float): Low cutoff frequency for mel bins (Default: ``20.0``) + min_duration (float): Minimum duration of segments to process (in seconds). (Default: ``0.0``) + num_mel_bins (int): Number of triangular mel-frequency bins (Default: ``23``) + preemphasis_coefficient (float): Coefficient for use in signal preemphasis (Default: ``0.97``) + raw_energy (bool): If True, compute energy before preemphasis and windowing (Default: ``True``) + remove_dc_offset: Subtract mean from waveform on each frame (Default: ``True``) round_to_power_of_two (bool): If True, round window size to power of two by zero-padding input - to FFT. (Default: True) + to FFT. (Default: ``True``) sample_frequency (float): Waveform data sample frequency (must match the waveform file, if - specified there) (Default: 16000.0) + specified there) (Default: ``16000.0``) snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit in the file, and the number of frames depends on the frame_length. If False, the number of frames - depends only on the frame_shift, and we reflect the data at the ends. (Default: True) + depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``) subtract_mean (bool): Subtract mean of each feature file [CMS]; not recommended to do - it this way. (Default: False) - use_energy (bool): Add an extra dimension with energy to the FBANK output. (Default: False) - use_log_fbank (bool):If true, produce log-filterbank, else produce linear. (Default: True) - use_power (bool): If true, use power, else use magnitude. (Default: True) + it this way. (Default: ``False``) + use_energy (bool): Add an extra dimension with energy to the FBANK output. (Default: ``False``) + use_log_fbank (bool):If true, produce log-filterbank, else produce linear. (Default: ``True``) + use_power (bool): If true, use power, else use magnitude. (Default: ``True``) vtln_high (float): High inflection point in piecewise linear VTLN warping function (if - negative, offset from high-mel-freq (Default: -500.0) - vtln_low (float): Low inflection point in piecewise linear VTLN warping function (Default: 100.0) - vtln_warp (float): Vtln warp factor (only applicable if vtln_map not specified) (Default: 1.0) - window_type (str): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman') (Default: 'povey') + negative, offset from high-mel-freq (Default: ``-500.0``) + vtln_low (float): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``) + vtln_warp (float): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``) + window_type (str): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman') (Default: ``'povey'``) Returns: - torch.Tensor: A fbank identical to what Kaldi would output. The shape is (m, `num_mel_bins` + `use_energy`) + torch.Tensor: A fbank identical to what Kaldi would output. The shape is (m, ``num_mel_bins + use_energy``) where m is calculated in _get_strided """ waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties( @@ -523,7 +523,7 @@ def _get_LR_indices_and_weights(orig_freq, new_freq, output_samples_in_unit, win r"""Based on LinearResample::SetIndexesAndWeights where it retrieves the weights for resampling as well as the indices in which they are valid. LinearResample (LR) means that the output signal is at linearly spaced intervals (i.e the output signal has a - frequency of `new_freq`). It uses sinc/bandlimited interpolation to upsample/downsample + frequency of ``new_freq``). It uses sinc/bandlimited interpolation to upsample/downsample the signal. The reason why the same filter is not used for multiple convolutions is because the @@ -541,7 +541,7 @@ def _get_LR_indices_and_weights(orig_freq, new_freq, output_samples_in_unit, win assuming the center of the sinc function is at 0, 16, and 32 (the deltas [..., 6, 1, 4, ....] for 16 vs [...., 2, 3, ....] for 32) - Example, one case is when the orig_freq and new_freq are multiples of each other then + Example, one case is when the ``orig_freq`` and ``new_freq`` are multiples of each other then there needs to be one filter. A windowed filter function (i.e. Hanning * sinc) because the ideal case of sinc function @@ -562,9 +562,9 @@ def _get_LR_indices_and_weights(orig_freq, new_freq, output_samples_in_unit, win efficient. We suggest around 4 to 10 for normal use Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple of `min_input_index` (which is the minimum indices - where the window is valid, size (`output_samples_in_unit`)) and `weights` (which is the weights - which correspond with min_input_index, size (`output_samples_in_unit`, `max_weight_width`)). + Tuple[torch.Tensor, torch.Tensor]: A tuple of ``min_input_index`` (which is the minimum indices + where the window is valid, size (``output_samples_in_unit``)) and ``weights`` (which is the weights + which correspond with min_input_index, size (``output_samples_in_unit``, ``max_weight_width``)). """ assert lowpass_cutoff < min(orig_freq, new_freq) / 2 output_t = torch.arange(0, output_samples_in_unit, dtype=torch.get_default_dtype()) / new_freq @@ -606,7 +606,7 @@ def _lcm(a, b): def _get_num_LR_output_samples(input_num_samp, samp_rate_in, samp_rate_out): r"""Based on LinearResample::GetNumOutputSamples. LinearResample (LR) means that the output signal is at linearly spaced intervals (i.e the output signal has a - frequency of `new_freq`). It uses sinc/bandlimited interpolation to upsample/downsample + frequency of ``new_freq``). It uses sinc/bandlimited interpolation to upsample/downsample the signal. Args: @@ -651,7 +651,7 @@ def resample_waveform(waveform, orig_freq, new_freq, lowpass_filter_width=6): r"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample a signal). LinearResample (LR) means that the output signal is at linearly spaced intervals (i.e - the output signal has a frequency of `new_freq`). It uses sinc/bandlimited interpolation to + the output signal has a frequency of ``new_freq``). It uses sinc/bandlimited interpolation to upsample/downsample the signal. https://ccrma.stanford.edu/~jos/resample/Theory_Ideal_Bandlimited_Interpolation.html @@ -662,10 +662,10 @@ def resample_waveform(waveform, orig_freq, new_freq, lowpass_filter_width=6): orig_freq (float): The original frequency of the signal new_freq (float): The desired frequency lowpass_filter_width (int): Controls the sharpness of the filter, more == sharper - but less efficient. We suggest around 4 to 10 for normal use. (Default: 6) + but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``) Returns: - torch.Tensor: The signal at the new frequency + torch.Tensor: The waveform at the new frequency """ assert waveform.dim() == 2 assert orig_freq > 0.0 and new_freq > 0.0 diff --git a/torchaudio/datasets/vctk.py b/torchaudio/datasets/vctk.py index ed7f1f82f8..d66c988303 100644 --- a/torchaudio/datasets/vctk.py +++ b/torchaudio/datasets/vctk.py @@ -71,21 +71,22 @@ def load_txts(dir): class VCTK(data.Dataset): - """`VCTK `_ Dataset. - `alternate url ` + r"""`VCTK `_ Dataset. + `alternate url `_ Args: - root (string): Root directory of dataset where ``processed/training.pt`` + root (str): Root directory of dataset where ``processed/training.pt`` and ``processed/test.pt`` exist. + downsample (bool, optional): Whether to downsample the signal (Default: ``True``) + transform (Callable, optional): A function/transform that takes in an raw audio + and returns a transformed version. E.g, ``transforms.Spectrogram``. (Default: ``None``) + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. (Default: ``None``) download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not - downloaded again. - transform (callable, optional): A function/transform that takes in an raw audio - and returns a transformed version. E.g, ``transforms.Scale`` - target_transform (callable, optional): A function/transform that takes in the - target and transforms it. - dev_mode(bool, optional): if true, clean up is not performed on downloaded - files. Useful to keep raw audio and transcriptions. + downloaded again. (Default: ``True``) + dev_mode(bool, optional): If true, clean up is not performed on downloaded + files. Useful to keep raw audio and transcriptions. (Default: ``False``) """ raw_folder = 'vctk/raw' processed_folder = 'vctk/processed' @@ -121,7 +122,8 @@ def __getitem__(self, index): index (int): Index Returns: - tuple: (image, target) where target is index of the target class. + Tuple[torch.Tensor, int]: The output tuple (image, target) where target + is index of the target class. """ if self.cached_pt != index // self.chunk_size: self.cached_pt = int(index // self.chunk_size) diff --git a/torchaudio/datasets/yesno.py b/torchaudio/datasets/yesno.py index ee086c1d4f..8d80b9e14a 100644 --- a/torchaudio/datasets/yesno.py +++ b/torchaudio/datasets/yesno.py @@ -9,20 +9,21 @@ class YESNO(data.Dataset): - """`YesNo Hebrew `_ Dataset. + r"""`YesNo Hebrew `_ Dataset. Args: - root (string): Root directory of dataset where ``processed/training.pt`` + root (str): Root directory of dataset where ``processed/training.pt`` and ``processed/test.pt`` exist. + transform (Callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.Spectrogram``. ( + Default: ``None``) + target_transform (Callable, optional): A function/transform that takes in the + target and transforms it. (Default: ``None``) download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not - downloaded again. - transform (callable, optional): A function/transform that takes in an PIL image - and returns a transformed version. E.g, ``transforms.Scale`` - target_transform (callable, optional): A function/transform that takes in the - target and transforms it. - dev_mode(bool, optional): if true, clean up is not performed on downloaded - files. Useful to keep raw audio and transcriptions. + downloaded again. (Default: ``False``) + dev_mode(bool, optional): If true, clean up is not performed on downloaded + files. Useful to keep raw audio and transcriptions. (Default: ``False``) """ raw_folder = 'yesno/raw' processed_folder = 'yesno/processed' @@ -55,7 +56,8 @@ def __getitem__(self, index): index (int): Index Returns: - tuple: (image, target) where target is index of the target class. + Tuple[torch.Tensor, int]: The output tuple (image, target) where target + is index of the target class. """ audio, target = self.data[index], self.labels[index] diff --git a/torchaudio/functional.py b/torchaudio/functional.py index ca6aa66694..9155c827d7 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -36,7 +36,7 @@ def istft(stft_matrix, # type: Tensor length=None # type: Optional[int] ): # type: (...) -> Tensor - r""" Inverse short time Fourier Transform. This is expected to be the inverse of torch.stft. + r"""Inverse short time Fourier Transform. This is expected to be the inverse of torch.stft. It has the same parameters (+ additional optional parameter of ``length``) and it should return the least squares estimation of the original signal. The algorithm will check using the NOLA condition ( nonzero overlap). @@ -46,7 +46,7 @@ def istft(stft_matrix, # type: Tensor :math:`\sum_{t=-\infty}^{\infty} w^2[n-t\times hop\_length] \cancel{=} 0`. Since stft discards elements at the end of the signal if they do not fit in a frame, the - istft may return a shorter signal than the original signal (can occur if `center` is False + istft may return a shorter signal than the original signal (can occur if ``center`` is False since the signal isn't padded). If ``center`` is True, then there will be padding e.g. 'constant', 'reflect', etc. Left padding @@ -66,7 +66,7 @@ def istft(stft_matrix, # type: Tensor Args: stft_matrix (torch.Tensor): Output of stft where each row of a channel is a frequency and each - column is a window. it has a shape of either (channel, fft_size, n_frames, 2) or ( + column is a window. it has a size of either (channel, fft_size, n_frames, 2) or ( fft_size, n_frames, 2) n_fft (int): Size of Fourier transform hop_length (Optional[int]): The distance between neighboring sliding window frames. @@ -75,10 +75,12 @@ def istft(stft_matrix, # type: Tensor window (Optional[torch.Tensor]): The optional window function. (Default: ``torch.ones(win_length)``) center (bool): Whether ``input`` was padded on both sides so - that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}` - pad_mode (str): Controls the padding method used when ``center`` is ``True`` - normalized (bool): Whether the STFT was normalized - onesided (bool): Whether the STFT is onesided + that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`. + (Default: ``True``) + pad_mode (str): Controls the padding method used when ``center`` is True. (Default: + ``'reflect'``) + normalized (bool): Whether the STFT was normalized. (Default: ``False``) + onesided (bool): Whether the STFT is onesided. (Default: ``True``) length (Optional[int]): The amount to trim the signal by (i.e. the original signal length). (Default: whole signal) @@ -175,10 +177,10 @@ def spectrogram(waveform, pad, window, n_fft, hop_length, win_length, power, nor r"""Create a spectrogram from a raw audio signal. Args: - waveform (torch.Tensor): Tensor of audio of size (c, n) + waveform (torch.Tensor): Tensor of audio of dimension (channel, time) pad (int): Two sided padding of signal window (torch.Tensor): Window tensor that is applied/multiplied to each frame/window - n_fft (int): Size of fft + n_fft (int): Size of FFT hop_length (int): Length of hop between STFT windows win_length (int): Window size power (int): Exponent for the magnitude spectrogram, @@ -186,9 +188,9 @@ def spectrogram(waveform, pad, window, n_fft, hop_length, win_length, power, nor normalized (bool): Whether to normalize by magnitude after stft Returns: - torch.Tensor: Channels x frequency x time (c, f, t), where channels - is unchanged, frequency is `n_fft // 2 + 1` where `n_fft` is the number of - fourier bins, and time is the number of window hops (n_frames). + torch.Tensor: Dimension (channel, freq, time), where channel + is unchanged, freq is ``n_fft // 2 + 1`` where ``n_fft`` is the number of + Fourier bins, and time is the number of window hops (n_frames). """ assert waveform.dim() == 2 @@ -221,7 +223,7 @@ def amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None): amin (float): Number to clamp ``x`` db_multiplier (float): Log10(max(reference value and amin)) top_db (Optional[float]): Minimum negative cut-off in decibels. A reasonable number - is 80. + is 80. (Default: ``None``) Returns: torch.Tensor: Output tensor in decibel scale @@ -249,11 +251,11 @@ def create_fb_matrix(n_freqs, f_min, f_max, n_mels): n_mels (int): Number of mel filterbanks Returns: - torch.Tensor: Triangular filter banks (fb matrix) of size (`n_freqs`, `n_mels`) + torch.Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``) meaning number of frequencies to highlight/apply to x the number of filterbanks. Each column is a filterbank so that assuming there is a matrix A of - size (..., `n_freqs`), the applied result would be - `A * create_fb_matrix(A.size(-1), ...)`. + size (..., ``n_freqs``), the applied result would be + ``A * create_fb_matrix(A.size(-1), ...)``. """ # freq bins freqs = torch.linspace(f_min, f_max, n_freqs) @@ -278,7 +280,7 @@ def create_fb_matrix(n_freqs, f_min, f_max, n_mels): @torch.jit.script def create_dct(n_mfcc, n_mels, norm): # type: (int, int, Optional[str]) -> Tensor - r"""Creates a DCT transformation matrix with shape (`n_mels`, `n_mfcc`), + r"""Creates a DCT transformation matrix with shape (``n_mels``, ``n_mfcc``), normalized depending on norm. Args: @@ -288,7 +290,7 @@ def create_dct(n_mfcc, n_mels, norm): Returns: torch.Tensor: The transformation matrix, to be right-multiplied to - row-wise data of size (`n_mels`, `n_mfcc`). + row-wise data of size (``n_mels``, ``n_mfcc``). """ # http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II n = torch.arange(float(n_mels)) @@ -317,7 +319,7 @@ def mu_law_encoding(x, quantization_channels): quantization_channels (int): Number of channels Returns: - torch.Tensor: Input after mu-law companding + torch.Tensor: Input after mu-law encoding """ mu = quantization_channels - 1. if not x.is_floating_point(): @@ -343,7 +345,7 @@ def mu_law_decoding(x_mu, quantization_channels): quantization_channels (int): Number of channels Returns: - torch.Tensor: Input after decoding + torch.Tensor: Input after mu-law decoding """ mu = quantization_channels - 1. if not x_mu.is_floating_point(): @@ -382,14 +384,14 @@ def angle(complex_tensor): def magphase(complex_tensor, power=1.): - r"""Separate a complex-valued spectrogram with shape (*,2) into its magnitude and phase. + r"""Separate a complex-valued spectrogram with shape `(*, 2)` into its magnitude and phase. Args: complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)` power (float): Power of the norm. (Default: `1.0`) Returns: - Tuple[torch.Tensor, torch.Tensor]: The magnitude and phase of the complex_tensor + Tuple[torch.Tensor, torch.Tensor]: The magnitude and phase of the complex tensor """ mag = complex_norm(complex_tensor, power) phase = angle(complex_tensor) @@ -398,17 +400,19 @@ def magphase(complex_tensor, power=1.): def phase_vocoder(complex_specgrams, rate, phase_advance): r"""Given a STFT tensor, speed up in time without modifying pitch by a - factor of `rate`. + factor of ``rate``. Args: - complex_specgrams (torch.Tensor): Size of (*, c, f, t, complex=2) + complex_specgrams (torch.Tensor): Dimension of `(*, channel, freq, time, complex=2)` rate (float): Speed-up factor - phase_advance (torch.Tensor): Expected phase advance in each bin. Size of (f, 1) + phase_advance (torch.Tensor): Expected phase advance in each bin. Dimension + of (freq, 1) Returns: - complex_specgrams_stretch (torch.Tensor): Size of (*, c, f, ceil(t/rate), complex=2) + complex_specgrams_stretch (torch.Tensor): Dimension of `(*, channel, + freq, ceil(time/rate), complex=2)` - Example: + Example >>> num_freqs, hop_length = 1025, 512 >>> # (batch, channel, num_freqs, time, complex=2) >>> complex_specgrams = torch.randn(16, 1, num_freqs, 300, 2) diff --git a/torchaudio/kaldi_io.py b/torchaudio/kaldi_io.py index 6663b87f5c..50b101db7a 100644 --- a/torchaudio/kaldi_io.py +++ b/torchaudio/kaldi_io.py @@ -21,17 +21,18 @@ def _convert_method_output_to_tensor(file_or_fd, fn, convert_contiguous=False): - r""" Takes a method invokes it. The output is converted to a tensor. + r"""Takes a method invokes it. The output is converted to a tensor. - Arguments: - file_or_fd (string/File Descriptor): file name or file descriptor. - fn (Function): function that has the signature (file name/descriptor) -> generator(string, ndarray) - and converts it to (file name/descriptor) -> generator(string, Tensor). - convert_contiguous (bool): determines whether the array should be converted into a - contiguous layout. + Args: + file_or_fd (str/FileDescriptor): File name or file descriptor + fn (Callable[[...], Generator[str, numpy.ndarray]]): Function that has the signature ( + file name/descriptor) -> Generator(str, numpy.ndarray) and converts it to ( + file name/descriptor) -> Generator(str, torch.Tensor). + convert_contiguous (bool): Determines whether the array should be converted into a + contiguous layout. (Default: ``None``) Returns: - generator[key (string), vec/mat (Tensor)] + Generator[str, torch.Tensor]: The string is the key and the tensor is vec/mat """ if not IMPORT_KALDI_IO: raise ImportError('Could not import kaldi_io. Did you install it?') @@ -45,14 +46,13 @@ def _convert_method_output_to_tensor(file_or_fd, fn, convert_contiguous=False): def read_vec_int_ark(file_or_fd): r"""Create generator of (key,vector) tuples, which reads from the ark file/stream. - Arguments: - file_or_fd (string/File Descriptor): ark, gzipped ark, pipe or opened file descriptor. + Args: + file_or_fd (str/FileDescriptor): ark, gzipped ark, pipe or opened file descriptor Returns: - generator[key (string), vec (Tensor)] - - Example:: + Generator[str, torch.Tensor]: The string is the key and the tensor is the vector read from file + Example >>> # read ark to a 'dictionary' >>> d = { u:d for u,d in torchaudio.kaldi_io.read_vec_int_ark(file) } """ @@ -63,16 +63,15 @@ def read_vec_int_ark(file_or_fd): def read_vec_flt_scp(file_or_fd): - r"""Create generator of (key,vector) tuples, read according to kaldi scp. + r"""Create generator of (key,vector) tuples, read according to Kaldi scp. - Arguments: - file_or_fd (string/File Descriptor): scp, gzipped scp, pipe or opened file descriptor. + Args: + file_or_fd (str/FileDescriptor): scp, gzipped scp, pipe or opened file descriptor Returns: - generator[key (string), vec (Tensor)] - - Example:: + Generator[str, torch.Tensor]: The string is the key and the tensor is the vector read from file + Example >>> # read scp to a 'dictionary' >>> # d = { u:d for u,d in torchaudio.kaldi_io.read_vec_flt_scp(file) } """ @@ -82,14 +81,13 @@ def read_vec_flt_scp(file_or_fd): def read_vec_flt_ark(file_or_fd): r"""Create generator of (key,vector) tuples, which reads from the ark file/stream. - Arguments: - file_or_fd (string/File Descriptor): ark, gzipped ark, pipe or opened file descriptor. + Args: + file_or_fd (str/FileDescriptor): ark, gzipped ark, pipe or opened file descriptor Returns: - generator[key (string), vec (Tensor)] - - Example:: + Generator[str, torch.Tensor]: The string is the key and the tensor is the vector read from file + Example >>> # read ark to a 'dictionary' >>> d = { u:d for u,d in torchaudio.kaldi_io.read_vec_flt_ark(file) } """ @@ -97,16 +95,15 @@ def read_vec_flt_ark(file_or_fd): def read_mat_scp(file_or_fd): - r"""Create generator of (key,matrix) tuples, read according to kaldi scp. + r"""Create generator of (key,matrix) tuples, read according to Kaldi scp. - Arguments: - file_or_fd (string/File Descriptor): scp, gzipped scp, pipe or opened file descriptor. + Args: + file_or_fd (str/FileDescriptor): scp, gzipped scp, pipe or opened file descriptor Returns: - generator[key (string), mat (Tensor)] - - Example:: + Generator[str, torch.Tensor]: The string is the key and the tensor is the matrix read from file + Example >>> # read scp to a 'dictionary' >>> d = { u:d for u,d in torchaudio.kaldi_io.read_mat_scp(file) } """ @@ -116,14 +113,13 @@ def read_mat_scp(file_or_fd): def read_mat_ark(file_or_fd): r"""Create generator of (key,matrix) tuples, which reads from the ark file/stream. - Arguments: - file_or_fd (string/File Descriptor): ark, gzipped ark, pipe or opened file descriptor. + Args: + file_or_fd (str/FileDescriptor): ark, gzipped ark, pipe or opened file descriptor Returns: - generator[key (string), mat (Tensor)] - - Example:: + Generator[str, torch.Tensor]: The string is the key and the tensor is the matrix read from file + Example >>> # read ark to a 'dictionary' >>> d = { u:d for u,d in torchaudio.kaldi_io.read_mat_ark(file) } """ diff --git a/torchaudio/legacy.py b/torchaudio/legacy.py index ad81377b05..3d477fac3c 100644 --- a/torchaudio/legacy.py +++ b/torchaudio/legacy.py @@ -8,51 +8,50 @@ def load(filepath, out=None, normalization=None, num_frames=0, offset=0): - """Loads an audio file from disk into a Tensor. The default options have + r"""Loads an audio file from disk into a Tensor. The default options have changed as of torchaudio 0.2 and this function maintains option defaults from version 0.1. Args: - filepath (string): path to audio file - out (Tensor, optional): an output Tensor to use instead of creating one + filepath (str): Path to audio file + out (torch.Tensor, optional): An output Tensor to use instead of creating one. (Default: ``None``) normalization (bool or number, optional): If boolean `True`, then output is divided by `1 << 31` - (assumes 16-bit depth audio, and normalizes to `[0, 1]`. - If `number`, then output is divided by that number - num_frames (int, optional): number of frames to load. -1 to load everything after the offset. - offset (int, optional): number of frames from the start of the file to begin data loading. - - Returns: tuple(Tensor, int) - - Tensor: output Tensor of size `[L x C]` where L is the number of audio frames, C is the number of channels - - int: the sample-rate of the audio (as listed in the metadata of the file) - - Example:: - + (assumes 16-bit depth audio, and normalizes to `[0, 1]`. If `number`, then output is divided by that + number. (Default: ``None``) + num_frames (int, optional): Number of frames to load. -1 to load everything after the + offset. (Default: ``0``) + offset (int, optional): Number of frames from the start of the file to begin data + loading. (Default: ``0``) + + Returns: + Tuple[torch.Tensor, int]: The output tensor is of size `[L x C]` where L is the number of audio frames, + C is the number of channels. The integer is sample-rate of the audio (as listed in the metadata of + the file) + + Example >>> data, sample_rate = torchaudio.legacy.load('foo.mp3') >>> print(data.size()) torch.Size([278756, 2]) >>> print(sample_rate) 44100 - """ return torchaudio.load(filepath, out, normalization, False, num_frames, offset) def save(filepath, src, sample_rate, precision=32): - """Saves a Tensor with audio signal to disk as a standard format like mp3, wav, etc. + r"""Saves a Tensor with audio signal to disk as a standard format like mp3, wav, etc. The default options have changed as of torchaudio 0.2 and this function maintains option defaults from version 0.1. Args: - filepath (string): path to audio file - src (Tensor): an input 2D Tensor of shape `[L x C]` where L is - the number of audio frames, C is the number of channels - sample_rate (int): the sample-rate of the audio to be saved - precision (int, optional): the bit-precision of the audio to be saved - - Example:: + filepath (str): Path to audio file + src (torch.Tensor): An input 2D Tensor of shape `[L x C]` where L is + the number of audio frames, C is the number of channels + sample_rate (int): The sample-rate of the audio to be saved + precision (int, optional): The bit-precision of the audio to be saved. (Default: ``32``) + Example >>> data, sample_rate = torchaudio.legacy.load('foo.mp3') >>> torchaudio.legacy.save('foo.wav', data, sample_rate) - """ torchaudio.save(filepath, src, sample_rate, precision, False) diff --git a/torchaudio/sox_effects.py b/torchaudio/sox_effects.py index 564d10b1c2..2c709503eb 100644 --- a/torchaudio/sox_effects.py +++ b/torchaudio/sox_effects.py @@ -10,61 +10,59 @@ def effect_names(): Returns: list[str] - Example:: + Example >>> EFFECT_NAMES = torchaudio.sox_effects.effect_names() """ return _torch_sox.get_effect_names() def SoxEffect(): - """Create an object for passing sox effect information between python and c++ + r"""Create an object for passing sox effect information between python and c++ - Returns: SoxEffect(object) - - ename (str), name of effect - - eopts (list[str]), list of effect options + Returns: + SoxEffect: An object with the following attributes: ename (str) which is the + name of effect, and eopts (List[str]) which is a list of effect options. """ return _torch_sox.SoxEffect() class SoxEffectsChain(object): - """SoX effects chain class. + r"""SoX effects chain class. Args: normalization (bool, number, or callable, optional): If boolean `True`, then output is divided by `1 << 31` - (assumes signed 32-bit audio), and normalizes to `[0, 1]`. - If `number`, then output is divided by that number - If `callable`, then the output is passed as a parameter - to the given function, then the output is divided by - the result. - channels_first (bool, optional): Set channels first or length first in result. Default: ``True`` + (assumes signed 32-bit audio), and normalizes to `[0, 1]`. If `number`, then output is divided by that + number. If `callable`, then the output is passed as a parameter to the given function, then the + output is divided by the result. (Default: ``True``) + channels_first (bool, optional): Set channels first or length first in result. (Default: ``True``) out_siginfo (sox_signalinfo_t, optional): a sox_signalinfo_t type, which could be helpful if the - audio type cannot be automatically determined + audio type cannot be automatically determined. (Default: ``None``) out_encinfo (sox_encodinginfo_t, optional): a sox_encodinginfo_t type, which could be set if the - audio type cannot be automatically determined - filetype (str, optional): a filetype or extension to be set if sox cannot determine it automatically - - Returns: tuple(Tensor, int) - - Tensor: output Tensor of size `[C x L]` or `[L x C]` where L is the number of audio frames and - C is the number of channels - - int: the sample rate of the audio (as listed in the metadata of the file) - - Example:: - - class MyDataset(Dataset): - def __init__(self, audiodir_path): - self.data = [os.path.join(audiodir_path, fn) for fn in os.listdir(audiodir_path)] - self.E = torchaudio.sox_effects.SoxEffectsChain() - self.E.append_effect_to_chain("rate", [16000]) # resample to 16000hz - self.E.append_effect_to_chain("channels", ["1"]) # mono signal - def __getitem__(self, index): - fn = self.data[index] - self.E.set_input_file(fn) - x, sr = self.E.sox_build_flow_effects() - return x, sr - - def __len__(self): - return len(self.data) - + audio type cannot be automatically determined. (Default: ``None``) + filetype (str, optional): a filetype or extension to be set if sox cannot determine it + automatically. . (Default: ``'raw'``) + + Returns: + Tuple[torch.Tensor, int]: An output Tensor of size `[C x L]` or `[L x C]` where L is the number + of audio frames and C is the number of channels. An integer which is the sample rate of the + audio (as listed in the metadata of the file) + + Example + >>> class MyDataset(Dataset): + >>> def __init__(self, audiodir_path): + >>> self.data = [os.path.join(audiodir_path, fn) for fn in os.listdir(audiodir_path)] + >>> self.E = torchaudio.sox_effects.SoxEffectsChain() + >>> self.E.append_effect_to_chain("rate", [16000]) # resample to 16000hz + >>> self.E.append_effect_to_chain("channels", ["1"]) # mono signal + >>> def __getitem__(self, index): + >>> fn = self.data[index] + >>> self.E.set_input_file(fn) + >>> x, sr = self.E.sox_build_flow_effects() + >>> return x, sr + >>> + >>> def __len__(self): + >>> return len(self.data) + >>> >>> torchaudio.initialize_sox() >>> ds = MyDataset(path_to_audio_files) >>> for sig, sr in ds: @@ -87,7 +85,11 @@ def __init__(self, normalization=True, channels_first=True, out_siginfo=None, ou self.channels_first = channels_first def append_effect_to_chain(self, ename, eargs=None): - """Append effect to a sox effects chain. + r"""Append effect to a sox effects chain. + + Args: + ename (str): which is the name of effect + eargs (List[str]): which is a list of effect options. (Default: ``None``) """ e = SoxEffect() # check if we have a valid effect @@ -106,7 +108,15 @@ def append_effect_to_chain(self, ename, eargs=None): self.chain.append(e) def sox_build_flow_effects(self, out=None): - """Build effects chain and flow effects from input file to output tensor + r"""Build effects chain and flow effects from input file to output tensor + + Args: + out (torch.Tensor): Where the output will be written to. (Default: ``None``) + + Returns: + Tuple[torch.Tensor, int]: An output Tensor of size `[C x L]` or `[L x C]` where L is the number + of audio frames and C is the number of channels. An integer which is the sample rate of the + audio (as listed in the metadata of the file) """ # initialize output tensor if out is not None: @@ -134,12 +144,15 @@ def sox_build_flow_effects(self, out=None): return out, sr def clear_chain(self): - """Clear effects chain in python + r"""Clear effects chain in python """ self.chain = [] def set_input_file(self, input_file): - """Set input file for input of chain + r"""Set input file for input of chain + + Args: + input_file (str): The path to the input file. """ self.input_file = input_file diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 467bff96c3..cdd079dccf 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -23,17 +23,17 @@ class Spectrogram(torch.jit.ScriptModule): r"""Create a spectrogram from a audio signal Args: - n_fft (int, optional): Size of fft, creates `n_fft // 2 + 1` bins - win_length (int): Window size. (Default: `n_fft`) + n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins + win_length (int): Window size. (Default: ``n_fft``) hop_length (int, optional): Length of hop between STFT windows. ( - Default: `win_length // 2`) - pad (int): Two sided padding of signal. (Default: 0) + Default: ``win_length // 2``) + pad (int): Two sided padding of signal. (Default: ``0``) window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor - that is applied/multiplied to each frame/window. (Default: `torch.hann_window`) + that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) power (int): Exponent for the magnitude spectrogram, - (must be > 0) e.g., 1 for energy, 2 for power, etc. - normalized (bool): Whether to normalize by magnitude after stft. (Default: `False`) - wkwargs (Dict[..., ...]): Arguments for window function. (Default: `None`) + (must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``) + normalized (bool): Whether to normalize by magnitude after stft. (Default: ``False``) + wkwargs (Dict[..., ...]): Arguments for window function. (Default: ``None``) """ __constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized'] @@ -42,7 +42,7 @@ def __init__(self, n_fft=400, win_length=None, hop_length=None, power=2, normalized=False, wkwargs=None): super(Spectrogram, self).__init__() self.n_fft = n_fft - # number of fft bins. the returned STFT result will have n_fft // 2 + 1 + # number of FFT bins. the returned STFT result will have n_fft // 2 + 1 # number of frequecies due to onesided=True in torch.stft self.win_length = win_length if win_length is not None else n_fft self.hop_length = hop_length if hop_length is not None else self.win_length // 2 @@ -56,12 +56,12 @@ def __init__(self, n_fft=400, win_length=None, hop_length=None, def forward(self, waveform): r""" Args: - waveform (torch.Tensor): Tensor of audio of size (c, n) + waveform (torch.Tensor): Tensor of audio of dimension (channel, time) Returns: - torch.Tensor: Channels x frequency x time (c, f, t), where channels - is unchanged, frequency is `n_fft // 2 + 1` where `n_fft` is the number of - fourier bins, and time is the number of window hops (n_frames). + torch.Tensor: Dimension (channel, freq, time), where channel + is unchanged, freq is ``n_fft // 2 + 1`` where ``n_fft`` is the number of + Fourier bins, and time is the number of window hops (n_frames). """ return F.spectrogram(waveform, self.pad, self.window, self.n_fft, self.hop_length, self.win_length, self.power, self.normalized) @@ -76,9 +76,9 @@ class AmplitudeToDB(torch.jit.ScriptModule): Args: stype (str): scale of input tensor ('power' or 'magnitude'). The - power being the elementwise square of the magnitude. (Default: 'power') + power being the elementwise square of the magnitude. (Default: ``'power'``) top_db (float, optional): minimum negative cut-off in decibels. A reasonable number - is 80. + is 80. (Default: ``None``) """ __constants__ = ['multiplier', 'amin', 'ref_value', 'db_multiplier'] @@ -109,17 +109,17 @@ def forward(self, x): class MelScale(torch.jit.ScriptModule): r"""This turns a normal STFT into a mel frequency STFT, using a conversion - matrix. This uses triangular filter banks. + matrix. This uses triangular filter banks. - User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)). + User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)). Args: - n_mels (int): Number of mel filterbanks. (Default: 128) - sample_rate (int): Sample rate of audio signal. (Default: 16000) - f_min (float): Minimum frequency. (Default: 0.) - f_max (float, optional): Maximum frequency. (Default: `sample_rate // 2`) + n_mels (int): Number of mel filterbanks. (Default: ``128``) + sample_rate (int): Sample rate of audio signal. (Default: ``16000``) + f_min (float): Minimum frequency. (Default: ``0.``) + f_max (float, optional): Maximum frequency. (Default: ``sample_rate // 2``) n_stft (int, optional): Number of bins in STFT. Calculated from first input - if `None` is given. See `n_fft` in `Spectrogram`. + if None is given. See ``n_fft`` in :class:`Spectrogram`. """ __constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max'] @@ -138,10 +138,10 @@ def __init__(self, n_mels=128, sample_rate=16000, f_min=0., f_max=None, n_stft=N def forward(self, specgram): r""" Args: - specgram (torch.Tensor): a spectrogram STFT of size (c, f, t) + specgram (torch.Tensor): A spectrogram STFT of dimension (channel, freq, time) Returns: - torch.Tensor: mel frequency spectrogram of size (c, `n_mels`, t) + torch.Tensor: Mel frequency spectrogram of size (channel, ``n_mels``, time) """ if self.fb.numel() == 0: tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max, self.n_mels) @@ -149,7 +149,8 @@ def forward(self, specgram): self.fb.resize_(tmp_fb.size()) self.fb.copy_(tmp_fb) - # (c, f, t).transpose(...) dot (f, n_mels) -> (c, t, n_mels).transpose(...) + # (channel, frequency, time).transpose(...) dot (frequency, n_mels) + # -> (channel, time, n_mels).transpose(...) mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2) return mel_specgram @@ -158,28 +159,28 @@ class MelSpectrogram(torch.jit.ScriptModule): r"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram and MelScale. - Sources: + Sources * https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe * https://timsainb.github.io/spectrograms-mfccs-and-inversion-in-python.html * http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html Args: - sample_rate (int): Sample rate of audio signal. (Default: 16000) - win_length (int): Window size. (Default: `n_fft`) + sample_rate (int): Sample rate of audio signal. (Default: ``16000``) + win_length (int): Window size. (Default: ``n_fft``) hop_length (int, optional): Length of hop between STFT windows. ( - Default: `win_length // 2`) - n_fft (int, optional): Size of fft, creates `n_fft // 2 + 1` bins - f_min (float): Minimum frequency. (Default: 0.) - f_max (float, optional): Maximum frequency. (Default: `None`) - pad (int): Two sided padding of signal. (Default: 0) - n_mels (int): Number of mel filterbanks. (Default: 128) + Default: ``win_length // 2``) + n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins + f_min (float): Minimum frequency. (Default: ``0.``) + f_max (float, optional): Maximum frequency. (Default: ``None``) + pad (int): Two sided padding of signal. (Default: ``0``) + n_mels (int): Number of mel filterbanks. (Default: ``128``) window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor - that is applied/multiplied to each frame/window. (Default: `torch.hann_window`) - wkwargs (Dict[..., ...]): Arguments for window function. (Default: `None`) + that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) + wkwargs (Dict[..., ...]): Arguments for window function. (Default: ``None``) - Example: + Example >>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True) - >>> mel_specgram = transforms.MelSpectrogram(sample_rate)(waveform) # (c, n_mels, t) + >>> mel_specgram = transforms.MelSpectrogram(sample_rate)(waveform) # (channel, n_mels, time) """ __constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad', 'n_mels', 'f_min'] @@ -204,10 +205,10 @@ def __init__(self, sample_rate=16000, n_fft=400, win_length=None, hop_length=Non def forward(self, waveform): r""" Args: - waveform (torch.Tensor): Tensor of audio of size (c, n) + waveform (torch.Tensor): Tensor of audio of dimension (channel, time) Returns: - torch.Tensor: mel frequency spectrogram of size (c, `n_mels`, t) + torch.Tensor: Mel frequency spectrogram of size (channel, ``n_mels``, time) """ specgram = self.spectrogram(waveform) mel_specgram = self.mel_scale(specgram) @@ -226,12 +227,13 @@ class MFCC(torch.jit.ScriptModule): a full clip. Args: - sample_rate (int): Sample rate of audio signal. (Default: 16000) - n_mfcc (int): Number of mfc coefficients to retain - dct_type (int): type of DCT (discrete cosine transform) to use - norm (string, optional): norm to use - log_mels (bool): whether to use log-mel spectrograms instead of db-scaled - melkwargs (dict, optional): arguments for MelSpectrogram + sample_rate (int): Sample rate of audio signal. (Default: ``16000``) + n_mfcc (int): Number of mfc coefficients to retain. (Default: ``40``) + dct_type (int): type of DCT (discrete cosine transform) to use. (Default: ``2``) + norm (str, optional): norm to use. (Default: ``'ortho'``) + log_mels (bool): whether to use log-mel spectrograms instead of db-scaled. (Default: + ``False``) + melkwargs (dict, optional): arguments for MelSpectrogram. (Default: ``None``) """ __constants__ = ['sample_rate', 'n_mfcc', 'dct_type', 'top_db', 'log_mels'] @@ -263,10 +265,10 @@ def __init__(self, sample_rate=16000, n_mfcc=40, dct_type=2, norm='ortho', log_m def forward(self, waveform): r""" Args: - waveform (torch.Tensor): Tensor of audio of size (c, n) + waveform (torch.Tensor): Tensor of audio of dimension (channel, time) Returns: - torch.Tensor: specgram_mel_db of size (c, `n_mfcc`, t) + torch.Tensor: specgram_mel_db of size (channel, ``n_mfcc``, time) """ mel_specgram = self.MelSpectrogram(waveform) if self.log_mels: @@ -274,7 +276,8 @@ def forward(self, waveform): mel_specgram = torch.log(mel_specgram + log_offset) else: mel_specgram = self.amplitude_to_DB(mel_specgram) - # (c, `n_mels`, t).tranpose(...) dot (`n_mels`, `n_mfcc`) -> (c, t, `n_mfcc`).tranpose(...) + # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc) + # -> (channel, time, n_mfcc).tranpose(...) mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2) return mfcc @@ -287,7 +290,7 @@ class MuLawEncoding(torch.jit.ScriptModule): returns a signal encoded with values from 0 to quantization_channels - 1 Args: - quantization_channels (int): Number of channels (Default: 256) + quantization_channels (int): Number of channels (Default: ``256``) """ __constants__ = ['quantization_channels'] @@ -315,7 +318,7 @@ class MuLawDecoding(torch.jit.ScriptModule): and returns a signal scaled between -1 and 1. Args: - quantization_channels (int): Number of channels (Default: 256) + quantization_channels (int): Number of channels (Default: ``256``) """ __constants__ = ['quantization_channels'] @@ -340,11 +343,11 @@ class Resample(torch.nn.Module): be given. Args: - orig_freq (float): The original frequency of the signal - new_freq (float): The desired frequency - resampling_method (str): The resampling method (Default: 'sinc_interpolation') + orig_freq (float): The original frequency of the signal. (Default: ``16000``) + new_freq (float): The desired frequency. (Default: ``16000``) + resampling_method (str): The resampling method (Default: ``'sinc_interpolation'``) """ - def __init__(self, orig_freq, new_freq, resampling_method='sinc_interpolation'): + def __init__(self, orig_freq=16000, new_freq=16000, resampling_method='sinc_interpolation'): super(Resample, self).__init__() self.orig_freq = orig_freq self.new_freq = new_freq @@ -353,10 +356,10 @@ def __init__(self, orig_freq, new_freq, resampling_method='sinc_interpolation'): def forward(self, waveform): r""" Args: - waveform (torch.Tensor): The input signal of size (c, n) + waveform (torch.Tensor): The input signal of dimension (channel, time) Returns: - torch.Tensor: Output signal of size (c, m) + torch.Tensor: Output signal of dimension (channel, time) """ if self.resampling_method == 'sinc_interpolation': return kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq)