44if __name__ == '__main__' :
55 mp .set_start_method ('forkserver' )
66
7+
78import argparse
89import collections
910import cProfile
3031import torch .distributed as dist
3132import torchaudio
3233from matplotlib import pyplot as plt
33- from tabulate import tabulate
3434from torch import nn , topk
3535from torch .optim import SGD , Adadelta , Adam
3636from torch .optim .lr_scheduler import ExponentialLR , ReduceLROnPlateau
3737from torch .utils .data import DataLoader
3838from torchaudio .datasets import LIBRISPEECH , SPEECHCOMMANDS
3939from torchaudio .datasets .utils import bg_iterator , diskcache_iterator
40- from torchaudio .transforms import MFCC , Resample
4140from torchaudio .models .wav2letter import Wav2Letter
41+ from torchaudio .transforms import MFCC , Resample
4242from tqdm .notebook import tqdm as tqdm
4343
44+ from tabulate import tabulate
45+
46+
47+
4448print ("start time: {}" .format (str (datetime .now ())), flush = True )
4549
4650try :
@@ -228,16 +232,10 @@ def save_checkpoint(state, is_best, filename=CHECKPOINT_filename):
228232
229233device = "cuda" if torch .cuda .is_available () else "cpu"
230234num_devices = torch .cuda .device_count ()
231- # num_devices = 1
232235print (num_devices , "GPUs" , flush = True )
233236
234237# max number of sentences per batch
235238batch_size = args .batch_size
236- # batch_size = 2048
237- # batch_size = 512
238- # batch_size = 256
239- # batch_size = 64
240- # batch_size = 1
241239
242240training_percentage = 90.
243241validation_percentage = 5.
@@ -254,7 +252,7 @@ def save_checkpoint(state, is_best, filename=CHECKPOINT_filename):
254252non_blocking = True
255253
256254
257- # text preprocessing
255+ # Text preprocessing
258256
259257char_blank = "*"
260258char_space = " "
@@ -550,8 +548,9 @@ def filter_speechcommands(tag, training_percentage, data):
550548 testing_percentage = (
551549 100. - training_percentage - validation_percentage )
552550
553- def which_set_filter (x ): return which_set (
554- x , validation_percentage , testing_percentage ) == tag
551+ def which_set_filter (x ):
552+ return which_set (x , validation_percentage , testing_percentage ) == tag
553+
555554 data ._walker = list (filter (which_set_filter , data ._walker ))
556555 return data
557556
@@ -1180,4 +1179,3 @@ def forward_decode(inputs, targets, decoder):
11801179)
11811180print (s .getvalue (), flush = True )
11821181print ("stop time: {}" .format (str (datetime .now ())), flush = True )
1183-
0 commit comments