Skip to content

Commit cb9f9e6

Browse files
authored
Testing audio_preprocessing_tutorial.py (DON
1 parent 276878f commit cb9f9e6

File tree

1 file changed

+233
-0
lines changed

1 file changed

+233
-0
lines changed
Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
"""
2+
Torchaudio Tutorial
3+
===================
4+
PyTorch is an open source deep learning platform that provides a
5+
seamless path from research prototyping to production deployment with
6+
GPU support.
7+
Significant effort in solving machine learning problems goes into data
8+
preparation. Torchaudio leverages PyTorch’s GPU support, and provides
9+
many tools to make data loading easy and more readable. In this
10+
tutorial, we will see how to load and preprocess data from a simple
11+
dataset.
12+
For this tutorial, please make sure the ``matplotlib`` package is
13+
installed for easier visualization.
14+
"""
15+
16+
import torch
17+
import torchaudio
18+
import matplotlib.pyplot as plt
19+
20+
21+
######################################################################
22+
# Opening a dataset
23+
# -----------------
24+
#
25+
26+
27+
######################################################################
28+
# Torchaudio supports loading sound files in the wav and mp3 format. We
29+
# call waveform the resulting raw audio signal.
30+
#
31+
32+
filename = "_static/img/steam-train-whistle-daniel_simon-converted-from-mp3.wav"
33+
waveform, sample_rate = torchaudio.load(filename)
34+
35+
print("Shape of waveform: {}".format(waveform.size()))
36+
print("Sample rate of waveform: {}".format(sample_rate))
37+
38+
plt.figure()
39+
plt.plot(waveform.transpose(0,1).numpy())
40+
41+
42+
######################################################################
43+
# Transformations
44+
# ---------------
45+
#
46+
# Torchaudio supports a growing list of
47+
# `transformations <https://pytorch.org/audio/transforms.html>`_.
48+
#
49+
# - **Resample**: Resample waveform to a different sample rate.
50+
# - **Spectrogram**: Create a spectrogram from a waveform.
51+
# - **MelScale**: This turns a normal STFT into a Mel-frequency STFT,
52+
# using a conversion matrix.
53+
# - **SpectrogramToDB**: This turns a spectrogram from the
54+
# power/amplitude scale to the decibel scale.
55+
# - **MFCC**: Create the Mel-frequency cepstrum coefficients from a
56+
# waveform.
57+
# - **MelSpectrogram**: Create MEL Spectrograms from a waveform using the
58+
# STFT function in PyTorch.
59+
# - **MuLawEncoding**: Encode waveform based on mu-law companding.
60+
# - **MuLawDeconding**: Decode mu-law encoded waveform.
61+
#
62+
# Since all transforms are nn.Modules or jit.ScriptModules, they can be
63+
# used as part of a neural network at any point.
64+
#
65+
66+
67+
######################################################################
68+
# To start, we can look at the log of the spectrogram on a log scale.
69+
#
70+
71+
specgram = torchaudio.transforms.Spectrogram()(waveform)
72+
73+
print("Shape of spectrogram: {}".format(specgram.size()))
74+
75+
plt.figure()
76+
plt.imshow(specgram.log2()[0,:,:].numpy(), cmap='gray')
77+
78+
79+
######################################################################
80+
# Or we can look at the Mel Spectrogram on a log scale.
81+
#
82+
83+
specgram = torchaudio.transforms.MelSpectrogram()(waveform)
84+
85+
print("Shape of spectrogram: {}".format(specgram.size()))
86+
87+
plt.figure()
88+
p = plt.imshow(specgram.log2()[0,:,:].detach().numpy(), cmap='gray')
89+
90+
91+
######################################################################
92+
# We can resample the waveform, one channel at a time.
93+
#
94+
95+
new_sample_rate = sample_rate/10
96+
97+
# Since Resample applies to a single channel, we resample first channel here
98+
channel = 0
99+
transformed = torchaudio.transforms.Resample(sample_rate, new_sample_rate)(waveform[channel,:].view(1,-1))
100+
101+
print("Shape of transformed waveform: {}".format(transformed.size()))
102+
103+
plt.figure()
104+
plt.plot(transformed[0,:].numpy())
105+
106+
107+
######################################################################
108+
# As another example of transformations, we can encode the signal based on
109+
# Mu-Law enconding. But to do so, we need the signal to be between -1 and
110+
# 1. Since the tensor is just a regular PyTorch tensor, we can apply
111+
# standard operators on it.
112+
#
113+
114+
# Let's check if the tensor is in the interval [-1,1]
115+
print("Min of waveform: {}\nMax of waveform: {}\nMean of waveform: {}".format(waveform.min(), waveform.max(), waveform.mean()))
116+
117+
118+
######################################################################
119+
# Since the waveform is already between -1 and 1, we do not need to
120+
# normalize it.
121+
#
122+
123+
def normalize(tensor):
124+
# Subtract the mean, and scale to the interval [-1,1]
125+
tensor_minusmean = tensor - tensor.mean()
126+
return tensor_minusmean/tensor_minusmean.abs().max()
127+
128+
# Let's normalize to the full interval [-1,1]
129+
# waveform = normalize(waveform)
130+
131+
132+
######################################################################
133+
# Let’s apply encode the waveform.
134+
#
135+
136+
transformed = torchaudio.transforms.MuLawEncoding()(waveform)
137+
138+
print("Shape of transformed waveform: {}".format(transformed.size()))
139+
140+
plt.figure()
141+
plt.plot(transformed[0,:].numpy())
142+
143+
144+
######################################################################
145+
# And now decode.
146+
#
147+
148+
reconstructed = torchaudio.transforms.MuLawDecoding()(transformed)
149+
150+
print("Shape of recovered waveform: {}".format(reconstructed.size()))
151+
152+
plt.figure()
153+
plt.plot(reconstructed[0,:].numpy())
154+
155+
156+
######################################################################
157+
# We can finally compare the original waveform with its reconstructed
158+
# version.
159+
#
160+
161+
# Compute median relative difference
162+
err = ((waveform-reconstructed).abs() / waveform.abs()).median()
163+
164+
print("Median relative difference between original and MuLaw reconstucted signals: {:.2%}".format(err))
165+
166+
167+
######################################################################
168+
# Migrating to Torchaudio from Kaldi
169+
# ----------------------------------
170+
#
171+
# Users may be familiar with
172+
# `Kaldi <http://github.com/kaldi-asr/kaldi>`_, a toolkit for speech
173+
# recognition. Torchaudio offers compatibility with it in
174+
# ``torchaudio.kaldi_io``. It can indeed read from kaldi scp, or ark file
175+
# or streams with:
176+
#
177+
# - read_vec_int_ark
178+
# - read_vec_flt_scp
179+
# - read_vec_flt_arkfile/stream
180+
# - read_mat_scp
181+
# - read_mat_ark
182+
#
183+
# Torchaudio provides Kaldi-compatible transforms for ``spectrogram`` and
184+
# ``fbank`` with the benefit of GPU support, see
185+
# `here <compliance.kaldi.html>`__ for more information.
186+
#
187+
188+
n_fft = 400.0
189+
frame_length = n_fft / sample_rate * 1000.0
190+
frame_shift = frame_length / 2.0
191+
192+
params = {
193+
"channel": 0,
194+
"dither": 0.0,
195+
"window_type": "hanning",
196+
"frame_length": frame_length,
197+
"frame_shift": frame_shift,
198+
"remove_dc_offset": False,
199+
"round_to_power_of_two": False,
200+
"sample_frequency": sample_rate,
201+
}
202+
203+
specgram = torchaudio.compliance.kaldi.spectrogram(waveform, **params)
204+
205+
print("Shape of spectrogram: {}".format(specgram.size()))
206+
207+
plt.figure()
208+
plt.imshow(specgram.transpose(0,1).numpy(), cmap='gray')
209+
210+
211+
######################################################################
212+
# We also support computing the filterbank features from waveforms,
213+
# matching Kaldi’s implementation.
214+
#
215+
216+
fbank = torchaudio.compliance.kaldi.fbank(waveform, **params)
217+
218+
print("Shape of fbank: {}".format(fbank.size()))
219+
220+
plt.figure()
221+
plt.imshow(fbank.transpose(0,1).numpy(), cmap='gray')
222+
223+
224+
######################################################################
225+
# Conclusion
226+
# ----------
227+
#
228+
# We used an example raw audio signal, or waveform, to illustrate how to
229+
# open an audio file using Torchaudio, and how to pre-process and
230+
# transform such waveform. Given that Torchaudio is built on PyTorch,
231+
# these techniques can be used as building blocks for more advanced audio
232+
# applications, such as speech recognition, while leveraging GPUs.
233+
#

0 commit comments

Comments
 (0)