1111information from executed cells disappear).
1212
1313First, let’s import the common torch packages such as
14- `` torchaudio <https://github.com/pytorch/audio>``\ \_ that can be
15- installed by following the instructions on the website.
14+ `torchaudio <https://github.com/pytorch/audio>`__ that can be installed
15+ by following the instructions on the website.
1616
1717"""
1818
1919# Uncomment the following line to run in Google Colab
2020
21- # GPU:
22- # !pip install torch==1.7.0+cu101 torchvision==0.8.1+cu101 torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
23-
2421# CPU:
2522# !pip install torch==1.7.0+cpu torchvision==0.8.1+cpu torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
2623
24+ # GPU:
25+ # !pip install torch==1.7.0+cu101 torchvision==0.8.1+cu101 torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
26+
2727# For interactive demo at the end:
2828# !pip install pydub
2929
30- import os
31- from base64 import b64decode
32- from io import BytesIO
33-
34- import IPython .display as ipd
35- import matplotlib .pyplot as plt
36- from tqdm .notebook import tqdm
37-
3830import torch
3931import torch .nn as nn
4032import torch .nn .functional as F
4133import torch .optim as optim
4234import torchaudio
43- from google .colab import output as colab_output
44- from pydub import AudioSegment
45- from torchaudio .datasets import SPEECHCOMMANDS
35+
36+ import matplotlib .pyplot as plt
37+ import IPython .display as ipd
38+ from tqdm .notebook import tqdm
39+
4640
4741######################################################################
4842# Let’s check if a CUDA GPU is available and select our device. Running
5852# ---------------------
5953#
6054# We use torchaudio to download and represent the dataset. Here we use
61- # SpeechCommands, which is a datasets of 35 commands spoken by different
62- # people. The dataset ``SPEECHCOMMANDS`` is a ``torch.utils.data.Dataset``
63- # version of the dataset.
55+ # `SpeechCommands <https://arxiv.org/abs/1804.03209>`__, which is a
56+ # datasets of 35 commands spoken by different people. The dataset
57+ # ``SPEECHCOMMANDS`` is a ``torch.utils.data.Dataset`` version of the
58+ # dataset. In this dataset, all audio files are about 1 second long (and
59+ # so about 16000 time frames long).
6460#
65- # The actual loading and formatting steps happen in the access function
66- # ``__getitem__``. In ``__getitem__``, we use ``torchaudio.load()`` to
67- # convert the audio files to tensors. ``torchaudio.load()`` returns a
68- # tuple containing the newly created tensor along with the sampling
69- # frequency of the audio file (16kHz for SpeechCommands). In this dataset,
70- # all audio files are about 1 second long (and so about 16000 time frames
71- # long).
61+ # The actual loading and formatting steps happen when a data point is
62+ # being accessed, and torchaudio takes care of converting the audio files
63+ # to tensors. If one wants to load an audio file directly instead,
64+ # ``torchaudio.load()`` can be used. It returns a tuple containing the
65+ # newly created tensor along with the sampling frequency of the audio file
66+ # (16kHz for SpeechCommands).
7267#
73- # Here we wrap it to split it into standard training, validation, testing
74- # subsets.
68+ # Going back to the dataset, here we create a subclass that splits it into
69+ # standard training, validation, testing subsets.
7570#
7671
72+ from torchaudio .datasets import SPEECHCOMMANDS
73+ import os
74+
7775
7876class SubsetSC (SPEECHCOMMANDS ):
7977 def __init__ (self , subset : str = None ):
@@ -168,11 +166,22 @@ def load_list(filename):
168166#
169167
170168
171- def encode (word ):
169+ def label_to_index (word ):
170+ # Return the position of the word in labels
172171 return torch .tensor (labels .index (word ))
173172
174173
175- encode ("yes" )
174+ def index_to_label (index ):
175+ # Return the word corresponding to the index in labels
176+ # This is the inverse of label_to_index
177+ return labels [index ]
178+
179+
180+ word_start = "yes"
181+ index = label_to_index (word_start )
182+ word_recovered = index_to_label (index )
183+
184+ print (word_start , "-->" , index , "-->" , word_recovered )
176185
177186
178187######################################################################
@@ -202,10 +211,10 @@ def collate_fn(batch):
202211
203212 tensors , targets = [], []
204213
205- # Gather in lists, and encode labels
214+ # Gather in lists, and encode labels as indices
206215 for waveform , _ , label , * _ in batch :
207216 tensors += [waveform ]
208- targets += [encode (label )]
217+ targets += [label_to_index (label )]
209218
210219 # Group the list of tensors into a batched tensor
211220 tensors = pad_sequence (tensors )
@@ -250,8 +259,8 @@ def collate_fn(batch):
250259# the raw audio data. Usually more advanced transforms are applied to the
251260# audio data, however CNNs can be used to accurately process the raw data.
252261# The specific architecture is modeled after the M5 network architecture
253- # described in `` this paper <https://arxiv.org/pdf/1610.00087.pdf>``\ \_.
254- # An important aspect of models processing raw audio data is the receptive
262+ # described in `this paper <https://arxiv.org/pdf/1610.00087.pdf>`__. An
263+ # important aspect of models processing raw audio data is the receptive
255264# field of their first layer’s filters. Our model’s first filter is length
256265# 80 so when processing audio sampled at 8kHz the receptive field is
257266# around 10ms (and at 4kHz, around 20 ms). This size is similar to speech
@@ -352,10 +361,12 @@ def train(model, epoch, log_interval):
352361
353362 # print training stats
354363 if batch_idx % log_interval == 0 :
355- print (f"Train Epoch: { epoch } [{ batch_idx * len (data )} /{ len (train_loader .dataset )} ({ 100. * batch_idx / len (train_loader ):.0f} %)]\t Loss: { loss :.6f} " )
364+ print (f"Train Epoch: { epoch } [{ batch_idx * len (data )} /{ len (train_loader .dataset )} ({ 100. * batch_idx / len (train_loader ):.0f} %)]\t Loss: { loss . item () :.6f} " )
356365
357- if "pbar" in globals () and "pbar_update" in globals ():
358- pbar .update (pbar_update )
366+ # update progress bar
367+ pbar .update (pbar_update )
368+ # record loss
369+ losses .append (loss .item ())
359370
360371
361372######################################################################
@@ -368,16 +379,16 @@ def train(model, epoch, log_interval):
368379#
369380
370381
371- def argmax (tensor ):
372- # index of the max log-probability
373- return tensor .max (- 1 )[1 ]
374-
375-
376382def number_of_correct (pred , target ):
377- # compute number of correct predictions
383+ # count number of correct predictions
378384 return pred .squeeze ().eq (target ).sum ().item ()
379385
380386
387+ def get_likely_index (tensor ):
388+ # find most likely label index for each element in the batch
389+ return tensor .argmax (dim = - 1 )
390+
391+
381392def test (model , epoch ):
382393 model .eval ()
383394 correct = 0
@@ -390,11 +401,11 @@ def test(model, epoch):
390401 data = transform (data )
391402 output = model (data )
392403
393- pred = argmax (output )
404+ pred = get_likely_index (output )
394405 correct += number_of_correct (pred , target )
395406
396- if "pbar" in globals () and "pbar_update" in globals ():
397- pbar .update (pbar_update )
407+ # update progress bar
408+ pbar .update (pbar_update )
398409
399410 print (f"\n Test Epoch: { epoch } \t Accuracy: { correct } /{ len (test_loader .dataset )} ({ 100. * correct / len (test_loader .dataset ):.0f} %)\n " )
400411
@@ -408,17 +419,22 @@ def test(model, epoch):
408419
409420log_interval = 20
410421n_epoch = 2
422+
411423pbar_update = 1 / (len (train_loader ) + len (test_loader ))
424+ losses = []
412425
413426# The transform needs to live on the same device as the model and the data.
414427transform = transform .to (device )
415-
416428with tqdm (total = n_epoch ) as pbar :
417429 for epoch in range (1 , n_epoch + 1 ):
418430 train (model , epoch , log_interval )
419431 test (model , epoch )
420432 scheduler .step ()
421433
434+ # Let's plot the training loss versus the number of iteration.
435+ # plt.plot(losses);
436+ # plt.title("training loss");
437+
422438
423439######################################################################
424440# The network should be more than 65% accurate on the test set after 2
@@ -427,14 +443,14 @@ def test(model, epoch):
427443#
428444
429445
430- def predict (waveform ):
446+ def predict (tensor ):
431447 # Use the model to predict the label of the waveform
432- waveform = waveform .to (device )
433- waveform = transform (waveform )
434- output = model (waveform .unsqueeze (0 ))
435- output = argmax ( output ). squeeze ( )
436- output = labels [ output ]
437- return output
448+ tensor = tensor .to (device )
449+ tensor = transform (tensor )
450+ tensor = model (tensor .unsqueeze (0 ))
451+ tensor = get_likely_index ( tensor )
452+ tensor = index_to_label ( tensor . squeeze ())
453+ return tensor
438454
439455
440456waveform , sample_rate , utterance , * _ = train_set [- 1 ]
@@ -466,6 +482,11 @@ def predict(waveform):
466482# will record one second of audio and try to classify it.
467483#
468484
485+ from google .colab import output as colab_output
486+ from base64 import b64decode
487+ from io import BytesIO
488+ from pydub import AudioSegment
489+
469490
470491RECORD = """
471492const sleep = time => new Promise(resolve => setTimeout(resolve, time))
@@ -501,7 +522,6 @@ def record(seconds=1):
501522 fileformat = "wav"
502523 filename = f"_audio.{ fileformat } "
503524 AudioSegment .from_file (BytesIO (b )).export (filename , format = fileformat )
504-
505525 return torchaudio .load (filename )
506526
507527
0 commit comments