Skip to content

Commit 5b5473e

Browse files
committed
Initial commit for gain logic
1 parent 5023bd2 commit 5b5473e

File tree

3 files changed

+32
-0
lines changed

3 files changed

+32
-0
lines changed

test/assets/sinewave_soxgain5.wav

250 KB
Binary file not shown.

test/test_functional.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytest
88
import unittest
99
import common_utils
10+
import os
1011

1112
from torchaudio.common_utils import IMPORT_LIBROSA
1213

@@ -247,6 +248,19 @@ def test_create_fb(self):
247248
self._test_create_fb(n_mels=56, fmin=1900.0, fmax=900.0)
248249
self._test_create_fb(n_mels=10, fmin=1900.0, fmax=900.0)
249250

251+
def test_gain(self):
252+
test_dirpath, test_dir = common_utils.create_temp_assets_dir()
253+
test_filepath = os.path.join(test_dirpath, "assets", "sinewave.wav")
254+
255+
waveform, sample_rate = torchaudio.load(test_filepath)
256+
waveform_gain = F.gain(waveform, 5)
257+
self.assertTrue(waveform_gain.abs().max().item(), 1.)
258+
259+
test_filepath_sox_gain = os.path.join(test_dirpath, "assets", "sinewave_soxgain5.wav")
260+
sox_gain_waveform, sox_sr = torchaudio.load(test_filepath_sox_gain)
261+
262+
self.assertTrue(torch.allclose(waveform_gain, sox_gain_waveform))
263+
250264

251265
def _num_stft_bins(signal_len, fft_len, hop_length, pad):
252266
return (signal_len + 2 * pad - fft_len + hop_length) // hop_length

torchaudio/functional.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -801,3 +801,21 @@ def compute_deltas(specgram, win_length=5, mode="replicate"):
801801
return torch.nn.functional.conv1d(
802802
specgram, kernel, groups=specgram.shape[1] // specgram.shape[0]
803803
) / denom
804+
805+
def gain(waveform, gain_db):
806+
r"""Apply amplification or attenuation to the audio signal.
807+
808+
Args:
809+
waveform (torch.Tensor): Tensor of audio of dimension (channel, time)
810+
gain_db(float, default=0.0) Gain adjustment in decibels (dB).
811+
812+
Returns:
813+
torch.Tensor: waveform amplified by gain_db
814+
"""
815+
assert waveform.dim() == 2
816+
817+
#gain_db = 20log(Tensor_output/tensor_input)
818+
ratio = 10 ** (gain_db/20)
819+
820+
return waveform * ratio
821+

0 commit comments

Comments
 (0)