Skip to content

Commit e45a9cb

Browse files
committed
dither tests
1 parent e16773f commit e45a9cb

File tree

4 files changed

+42
-27
lines changed

4 files changed

+42
-27
lines changed

test/assets/sinewave_soxdither.wav

250 KB
Binary file not shown.
250 KB
Binary file not shown.

test/test_functional.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,21 @@ def test_scale_to_interval(self):
268268
self.assertTrue(torch.max(waveform_scaled) <= scaled)
269269
self.assertTrue(torch.min(waveform_scaled) >= -scaled)
270270

271+
def test_dither(self):
272+
waveform, sample_rate = torchaudio.load(self.test_filepath)
273+
waveform_dithered = F.dither(waveform)
274+
waveform_dithered_noiseshaped = F.dither(waveform, noise_shaping=True)
275+
276+
test_filepath_sox_dither = os.path.join(self.test_dirpath, "assets", "sinewave_soxdither.wav")
277+
sox_dither_waveform, sox_sr = torchaudio.load(test_filepath_sox_dither)
278+
279+
self.assertTrue(torch.allclose(waveform_dithered, sox_dither_waveform, rtol=1e-03, atol=1e-03))
280+
281+
test_filepath_sox_dither_ns = os.path.join(self.test_dirpath, "assets", "sinewave_soxdither_noiseshaping.wav")
282+
sox_dither_waveform_ns, sox_sr = torchaudio.load(test_filepath_sox_dither_ns)
283+
284+
self.assertTrue(torch.allclose(waveform_dithered_noiseshaped, sox_dither_waveform_ns, rtol=1e-03, atol=1e-03))
285+
271286

272287
def _num_stft_bins(signal_len, fft_len, hop_length, pad):
273288
return (signal_len + 2 * pad - fft_len + hop_length) // hop_length

torchaudio/functional.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -842,21 +842,20 @@ def scale_to_interval(waveform, interval=1.0):
842842

843843

844844
def _add_noise_shaping(dithered_waveform, waveform):
845-
r"""Noise shaping is calculated by error:
846-
error[n] = dithered[n] - original[n]
847-
noise_shaped_waveform[n] = dithered[n] + error[n-1]
848-
"""
849-
error = dithered_waveform - waveform
850-
851-
# add error[n-1] to dithered_waveform[n], so offset the error by 1 index
852-
error0_offset = torch.cat((torch.zeros(1), error[0]))
853-
error[0] = error0_offset[:waveform.size()[1]]
854-
error1_offset = torch.cat((torch.zeros(1), error[1]))
855-
error[1] = error1_offset[:waveform.size()[1]]
845+
r"""Noise shaping is calculated by error:
846+
error[n] = dithered[n] - original[n]
847+
noise_shaped_waveform[n] = dithered[n] + error[n-1]
848+
"""
849+
error = dithered_waveform - waveform
856850

857-
noise_shaped = dithered_waveform + error
851+
# add error[n-1] to dithered_waveform[n], so offset the error by 1 index
852+
for index in range(error.size()[0]):
853+
err = error[index]
854+
error_offset = torch.cat((torch.zeros(1), err))
855+
error[index] = error_offset[:waveform.size()[1]]
858856

859-
return noise_shaped
857+
noise_shaped = dithered_waveform + error
858+
return noise_shaped
860859

861860

862861
def dither(waveform, probability_density_function="TPDF", noise_shaping=False, filter=None):
@@ -890,42 +889,43 @@ def dither(waveform, probability_density_function="TPDF", noise_shaping=False, f
890889

891890
number_of_bits = 16
892891

893-
up_scaling = 2 ** (number_of_bits-1) - 2
892+
up_scaling = 2 ** (number_of_bits - 1) - 2
894893
signal_scaled = waveform * up_scaling
895894
down_scaling = 2 ** (number_of_bits - 1)
896895

896+
dithered = waveform
897+
897898
if (probability_density_function == "RPDF"):
898899
RPDF_dither = waveform[0][random.randint(1, wave_size)] - 0.5
899900

900901
signal_scaled_RPDF_dithered = signal_scaled + RPDF_dither
901902
quantised_signal_scaled_RPDF_dithered = torch.round(signal_scaled_RPDF_dithered)
902903
quantised_signal_RPDF_dithered = quantised_signal_scaled_RPDF_dithered / down_scaling
903904

904-
if noise_shaped: return _add_noise_shaping(quantised_signal_RPDF_dithered, waveform)
905-
906-
return quantised_signal_RPDF_dithered
905+
dithered = quantised_signal_RPDF_dithered
907906
elif (probability_density_function == "GPDF"):
908907
gaussian_dither = (waveform[0][random.randint(1, wave_size)]
909-
+ waveform[0][random.randint(1, wave_size)]
910-
+ waveform[0][random.randint(1, wave_size)]
911-
+ waveform[0][random.randint(1, wave_size)]
912-
+ waveform[0][random.randint(1, wave_size)]
913-
+ waveform[0][random.randint(1, wave_size)]) / 6
908+
+ waveform[0][random.randint(1, wave_size)]
909+
+ waveform[0][random.randint(1, wave_size)]
910+
+ waveform[0][random.randint(1, wave_size)]
911+
+ waveform[0][random.randint(1, wave_size)]
912+
+ waveform[0][random.randint(1, wave_size)]) / 6
914913

915914
signal_scaled_gaussian_dithered = signal_scaled + gaussian_dither
916915
quantised_signal_scaled_gaussian_dithered = torch.round(signal_scaled_gaussian_dithered)
917916
quantised_signal_gaussian_dithered = quantised_signal_scaled_gaussian_dithered / down_scaling
918917

919-
if noise_shaped: return _add_noise_shaping(quantised_signal_gaussian_dithered, waveform)
920-
921-
return quantised_signal_gaussian_dithered
918+
dithered = quantised_signal_gaussian_dithered
922919
else:
923920
TPDF_dither = waveform[0][random.randint(1, wave_size)] - waveform[0][random.randint(1, wave_size)]
924921

925922
signal_scaled_TPDF_dithered = signal_scaled + TPDF_dither
926923
quantised_signal_scaled_TPDF_dithered = torch.round(signal_scaled_TPDF_dithered)
927924
quantised_signal_TPDF_dithered = quantised_signal_scaled_TPDF_dithered / down_scaling
928925

929-
if noise_shaped: return _add_noise_shaping(quantised_signal_TPDF_dithered, waveform)
926+
dithered = quantised_signal_TPDF_dithered
930927

931-
return quantised_signal_TPDF_dithered
928+
if noise_shaping:
929+
return _add_noise_shaping(dithered, waveform)
930+
else:
931+
return dithered

0 commit comments

Comments
 (0)