33==========================================
44
55This tutorial will show you how to correctly format an audio dataset and
6- then train/test an audio classifier network on the dataset. First, let’s
7- import the common torch packages such as
8- ``torchaudio <https://github.com/pytorch/audio>``\ \_ and can be
6+ then train/test an audio classifier network on the dataset.
7+
8+ Colab has GPU option available. In the menu tabs, select “Runtime” then
9+ “Change runtime type”. In the pop-up that follows, you can choose GPU.
10+ After the change, your runtime should automatically restart (which means
11+ information from executed cells disappear).
12+
13+ First, let’s import the common torch packages such as
14+ ``torchaudio <https://github.com/pytorch/audio>``\ \_ that can be
915installed by following the instructions on the website.
1016
1117"""
1218
1319# Uncomment the following line to run in Google Colab
14- # !pip install torch
15- # !pip install torchaudio
20+
21+ # GPU:
22+ # !pip install torch==1.7.0+cu101 torchvision==0.8.1+cu101 torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
23+
24+ # CPU:
25+ # !pip install torch==1.7.0+cpu torchvision==0.8.1+cpu torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
26+
27+ # For interactive demo at the end:
28+ # !pip install pydub
1629
1730import os
31+ from base64 import b64decode
32+ from io import BytesIO
1833
1934import IPython .display as ipd
2035import matplotlib .pyplot as plt
2540import torch .nn .functional as F
2641import torch .optim as optim
2742import torchaudio
43+ from google .colab import output as colab_output
44+ from pydub import AudioSegment
2845from torchaudio .datasets import SPEECHCOMMANDS
2946
3047######################################################################
@@ -73,11 +90,12 @@ def load_list(filename):
7390 self ._walker = load_list ("testing_list.txt" )
7491 elif subset == "training" :
7592 excludes = load_list ("validation_list.txt" ) + load_list ("testing_list.txt" )
93+ excludes = set (excludes )
7694 self ._walker = [w for w in self ._walker if w not in excludes ]
7795
7896
97+ # Create training and testing split of the data. We do not use validation in this tutorial.
7998train_set = SubsetSC ("training" )
80- # valid_set = SubsetSC("validation")
8199test_set = SubsetSC ("testing" )
82100
83101waveform , sample_rate , label , speaker_id , utterance_number = train_set [0 ]
@@ -92,15 +110,14 @@ def load_list(filename):
92110print ("Shape of waveform: {}" .format (waveform .size ()))
93111print ("Sample rate of waveform: {}" .format (sample_rate ))
94112
95- plt .figure ();
96113plt .plot (waveform .t ().numpy ());
97114
98115
99116######################################################################
100117# Let’s find the list of labels available in the dataset.
101118#
102119
103- labels = list (set (datapoint [2 ] for datapoint in train_set ))
120+ labels = sorted ( list (set (datapoint [2 ] for datapoint in train_set ) ))
104121labels
105122
106123
@@ -170,6 +187,7 @@ def encode(word):
170187# encoding.
171188#
172189
190+
173191def pad_sequence (batch ):
174192 # Make all tensor in a batch the same length by padding with zeros
175193 batch = [item .t () for item in batch ]
@@ -184,9 +202,9 @@ def collate_fn(batch):
184202
185203 tensors , targets = [], []
186204
187- # Apply transform and encode
205+ # Gather in lists, and encode labels
188206 for waveform , _ , label , * _ in batch :
189- tensors += [transform ( waveform ) ]
207+ tensors += [waveform ]
190208 targets += [encode (label )]
191209
192210 # Group the list of tensors into a batched tensor
@@ -196,20 +214,31 @@ def collate_fn(batch):
196214 return tensors , targets
197215
198216
199- batch_size = 128
217+ batch_size = 256
200218
201- if device == ' cuda' :
219+ if device == " cuda" :
202220 num_workers = 1
203221 pin_memory = True
204222else :
205223 num_workers = 0
206224 pin_memory = False
207225
208226train_loader = torch .utils .data .DataLoader (
209- train_set , batch_size = batch_size , shuffle = True , collate_fn = collate_fn , num_workers = num_workers , pin_memory = pin_memory ,
227+ train_set ,
228+ batch_size = batch_size ,
229+ shuffle = True ,
230+ collate_fn = collate_fn ,
231+ num_workers = num_workers ,
232+ pin_memory = pin_memory ,
210233)
211234test_loader = torch .utils .data .DataLoader (
212- test_set , batch_size = batch_size , shuffle = False , collate_fn = collate_fn , num_workers = num_workers , pin_memory = pin_memory ,
235+ test_set ,
236+ batch_size = batch_size ,
237+ shuffle = False ,
238+ drop_last = False ,
239+ collate_fn = collate_fn ,
240+ num_workers = num_workers ,
241+ pin_memory = pin_memory ,
213242)
214243
215244
@@ -232,21 +261,21 @@ def collate_fn(batch):
232261
233262
234263class M5 (nn .Module ):
235- def __init__ (self , stride = 16 , n_channel = 32 , n_output = 35 ):
264+ def __init__ (self , n_input = 1 , n_output = 35 , stride = 16 , n_channel = 32 ):
236265 super ().__init__ ()
237- self .conv1 = nn .Conv1d (1 , n_channel , 80 , stride = stride )
266+ self .conv1 = nn .Conv1d (n_input , n_channel , kernel_size = 80 , stride = stride )
238267 self .bn1 = nn .BatchNorm1d (n_channel )
239268 self .pool1 = nn .MaxPool1d (4 )
240- self .conv2 = nn .Conv1d (n_channel , n_channel , 3 )
269+ self .conv2 = nn .Conv1d (n_channel , n_channel , kernel_size = 3 )
241270 self .bn2 = nn .BatchNorm1d (n_channel )
242271 self .pool2 = nn .MaxPool1d (4 )
243- self .conv3 = nn .Conv1d (n_channel , 2 * n_channel , 3 )
244- self .bn3 = nn .BatchNorm1d (2 * n_channel )
272+ self .conv3 = nn .Conv1d (n_channel , 2 * n_channel , kernel_size = 3 )
273+ self .bn3 = nn .BatchNorm1d (2 * n_channel )
245274 self .pool3 = nn .MaxPool1d (4 )
246- self .conv4 = nn .Conv1d (2 * n_channel , 2 * n_channel , 3 )
247- self .bn4 = nn .BatchNorm1d (2 * n_channel )
275+ self .conv4 = nn .Conv1d (2 * n_channel , 2 * n_channel , kernel_size = 3 )
276+ self .bn4 = nn .BatchNorm1d (2 * n_channel )
248277 self .pool4 = nn .MaxPool1d (4 )
249- self .fc1 = nn .Linear (2 * n_channel , n_output )
278+ self .fc1 = nn .Linear (2 * n_channel , n_output )
250279
251280 def forward (self , x ):
252281 x = self .conv1 (x )
@@ -267,7 +296,7 @@ def forward(self, x):
267296 return F .log_softmax (x , dim = 2 )
268297
269298
270- model = M5 (n_output = len (labels ))
299+ model = M5 (n_input = transformed . shape [ 0 ], n_output = len (labels ))
271300model .to (device )
272301print (model )
273302
@@ -284,7 +313,7 @@ def count_parameters(model):
284313# We will use the same optimization technique used in the paper, an Adam
285314# optimizer with weight decay set to 0.0001. At first, we will train with
286315# a learning rate of 0.01, but we will use a ``scheduler`` to decrease it
287- # to 0.001 during training.
316+ # to 0.001 during training after 20 epochs .
288317#
289318
290319optimizer = optim .Adam (model .parameters (), lr = 0.01 , weight_decay = 0.0001 )
@@ -296,11 +325,9 @@ def count_parameters(model):
296325# --------------------------------
297326#
298327# Now let’s define a training function that will feed our training data
299- # into the model and perform the backward pass and optimization steps.
300- #
301- # Finally, we can train and test the network. We will train the network
302- # for ten epochs then reduce the learn rate and train for ten more epochs.
303- # The network will be tested after each epoch to see how the accuracy
328+ # into the model and perform the backward pass and optimization steps. For
329+ # training, the loss we will use is the negative log-likelihood. The
330+ # network will then be tested after each epoch to see how the accuracy
304331# varies during the training.
305332#
306333
@@ -312,6 +339,8 @@ def train(model, epoch, log_interval):
312339 data = data .to (device )
313340 target = target .to (device )
314341
342+ # apply transform and model on whole batch directly on device
343+ data = transform (data )
315344 output = model (data )
316345
317346 # negative log-likelihood for a tensor of size (batch x 1 x n_output)
@@ -323,10 +352,10 @@ def train(model, epoch, log_interval):
323352
324353 # print training stats
325354 if batch_idx % log_interval == 0 :
326- print (f' Train Epoch: { epoch } [{ batch_idx * len (data )} /{ len (train_loader .dataset )} ({ 100. * batch_idx / len (train_loader ):.0f} %)]\t Loss: { loss :.6f} ' )
355+ print (f" Train Epoch: { epoch } [{ batch_idx * len (data )} /{ len (train_loader .dataset )} ({ 100. * batch_idx / len (train_loader ):.0f} %)]\t Loss: { loss :.6f} " )
327356
328- if ' pbar' in globals ():
329- pbar .update ()
357+ if " pbar" in globals () and "pbar_update" in globals ():
358+ pbar .update (pbar_update )
330359
331360
332361######################################################################
@@ -346,24 +375,28 @@ def argmax(tensor):
346375
347376def number_of_correct (pred , target ):
348377 # compute number of correct predictions
349- return pred .squeeze ().eq (target ).cpu (). sum ().item ()
378+ return pred .squeeze ().eq (target ).sum ().item ()
350379
351380
352381def test (model , epoch ):
353382 model .eval ()
354383 correct = 0
355384 for data , target in test_loader :
385+
356386 data = data .to (device )
357387 target = target .to (device )
358388
389+ # apply transform and model on whole batch directly on device
390+ data = transform (data )
359391 output = model (data )
392+
360393 pred = argmax (output )
361394 correct += number_of_correct (pred , target )
362395
363- if ' pbar' in globals ():
364- pbar .update ()
396+ if " pbar" in globals () and "pbar_update" in globals ():
397+ pbar .update (pbar_update )
365398
366- print (f' \n Test Epoch: { epoch } \t Accuracy: { correct } /{ len (test_loader .dataset )} ({ 100. * correct / len (test_loader .dataset ):.0f} %)\n ' )
399+ print (f" \n Test Epoch: { epoch } \t Accuracy: { correct } /{ len (test_loader .dataset )} ({ 100. * correct / len (test_loader .dataset ):.0f} %)\n " )
367400
368401
369402######################################################################
@@ -375,21 +408,28 @@ def test(model, epoch):
375408
376409log_interval = 20
377410n_epoch = 2
411+ pbar_update = 1 / (len (train_loader ) + len (test_loader ))
412+
413+ # The transform needs to live on the same device as the model and the data.
414+ transform = transform .to (device )
378415
379- with tqdm (total = n_epoch * ( len ( train_loader ) + len ( test_loader )) ) as pbar :
380- for epoch in range (1 , n_epoch + 1 ):
416+ with tqdm (total = n_epoch ) as pbar :
417+ for epoch in range (1 , n_epoch + 1 ):
381418 train (model , epoch , log_interval )
382419 test (model , epoch )
383420 scheduler .step ()
384421
385422
386423######################################################################
387- # Let’s look at the last words in the train set, and see how the model did
388- # on it.
424+ # The network should be more than 65% accurate on the test set after 2
425+ # epochs, and 85% after 21 epochs. Let’s look at the last words in the
426+ # train set, and see how the model did on it.
389427#
390428
429+
391430def predict (waveform ):
392- # Take a waveform and use the model to predict
431+ # Use the model to predict the label of the waveform
432+ waveform = waveform .to (device )
393433 waveform = transform (waveform )
394434 output = model (waveform .unsqueeze (0 ))
395435 output = argmax (output ).squeeze ()
@@ -410,9 +450,9 @@ def predict(waveform):
410450for i , (waveform , sample_rate , utterance , * _ ) in enumerate (test_set ):
411451 output = predict (waveform )
412452 if output != utterance :
413- ipd .Audio (waveform .numpy (), rate = sample_rate )
414- print (f"Data point #{ i } . Expected: { utterance } . Predicted: { output } ." )
415- break
453+ ipd .Audio (waveform .numpy (), rate = sample_rate )
454+ print (f"Data point #{ i } . Expected: { utterance } . Predicted: { output } ." )
455+ break
416456else :
417457 print ("All examples in this dataset were correctly classified!" )
418458 print ("In this case, let's just look at the last data point" )
@@ -421,17 +461,59 @@ def predict(waveform):
421461
422462
423463######################################################################
424- # Feel free to try with one of your own recordings!
464+ # Feel free to try with one of your own recordings of one of the labels!
465+ # For example, using Colab, say “Go” while executing the cell below. This
466+ # will record one second of audio and try to classify it.
425467#
426468
427469
470+ RECORD = """
471+ const sleep = time => new Promise(resolve => setTimeout(resolve, time))
472+ const b2text = blob => new Promise(resolve => {
473+ const reader = new FileReader()
474+ reader.onloadend = e => resolve(e.srcElement.result)
475+ reader.readAsDataURL(blob)
476+ })
477+ var record = time => new Promise(async resolve => {
478+ stream = await navigator.mediaDevices.getUserMedia({ audio: true })
479+ recorder = new MediaRecorder(stream)
480+ chunks = []
481+ recorder.ondataavailable = e => chunks.push(e.data)
482+ recorder.start()
483+ await sleep(time)
484+ recorder.onstop = async ()=>{
485+ blob = new Blob(chunks)
486+ text = await b2text(blob)
487+ resolve(text)
488+ }
489+ recorder.stop()
490+ })
491+ """
492+
493+
494+ def record (seconds = 1 ):
495+ display (ipd .Javascript (RECORD ))
496+ print (f"Recording started for { seconds } seconds." )
497+ s = colab_output .eval_js ("record(%d)" % (seconds * 1000 ))
498+ print ("Recording ended." )
499+ b = b64decode (s .split ("," )[1 ])
500+
501+ fileformat = "wav"
502+ filename = f"_audio.{ fileformat } "
503+ AudioSegment .from_file (BytesIO (b )).export (filename , format = fileformat )
504+
505+ return torchaudio .load (filename )
506+
507+
508+ waveform , sample_rate = record ()
509+ print (f"Predicted: { predict (waveform )} ." )
510+ ipd .Audio (waveform .numpy (), rate = sample_rate )
511+
512+
428513######################################################################
429514# Conclusion
430515# ----------
431516#
432- # The network should be more than 70% accurate on the test set after 2
433- # epochs, 80% after 14 epochs, and 85% after 21 epochs.
434- #
435517# In this tutorial, we used torchaudio to load a dataset and resample the
436518# signal. We have then defined a neural network that we trained to
437519# recognize a given command. There are also other data preprocessing
0 commit comments