1212from torchaudio .models .wav2letter import Wav2Letter
1313
1414from ctc_decoders import GreedyDecoder
15- from datasets import collate_factory , split_process_librispeech , split_process_speechcommands
15+ from datasets import collate_factory , split_process_librispeech
1616from languagemodels import LanguageModel
1717from metrics import levenshtein_distance
1818from transforms import Normalize , ToMono , UnsqueezeFirst
@@ -289,10 +289,7 @@ def main(rank, args):
289289 )
290290 elif args .model_input_type == "waveform" :
291291 transforms = torch .nn .Sequential (transforms , UnsqueezeFirst ())
292- # assert args.bins == 1, "waveform model input type only supports bins == 1"
293- if args .bins != 1 :
294- logging .warn ("waveform model input type only supports bins == 1" )
295- args .bins = 1
292+ assert args .bins == 1 , "waveform model input type only supports bins == 1"
296293 else :
297294 raise NotImplementedError (
298295 f"Selected model input type { args .model_input_type } not supported"
@@ -323,23 +320,13 @@ def main(rank, args):
323320
324321 # Dataset
325322
326- if args .speechcommands :
327- training , validation = split_process_speechcommands (
328- ["training" , "validation" ],
329- [transforms , transforms ],
330- language_model ,
331- root = "/private/home/vincentqb/audio-pytorch/examples/pipeline_wav2letter/" ,
332- # root=args.dataset_root,
333- # folder_in_archive=args.dataset_folder_in_archive,
334- )
335- else :
336- training , validation = split_process_librispeech (
337- [args .dataset_train , args .dataset_valid ],
338- [transforms , transforms ],
339- language_model ,
340- root = args .dataset_root ,
341- folder_in_archive = args .dataset_folder_in_archive ,
342- )
323+ training , validation = split_process_librispeech (
324+ [args .dataset_train , args .dataset_valid ],
325+ [transforms , transforms ],
326+ language_model ,
327+ root = args .dataset_root ,
328+ folder_in_archive = args .dataset_folder_in_archive ,
329+ )
343330
344331 # Decoder
345332
@@ -408,12 +395,12 @@ def main(rank, args):
408395
409396 best_loss = 1.0
410397
411- checkpoint_exists = args . checkpoint and os .path .isfile (args .checkpoint )
398+ checkpoint_exists = os .path .isfile (args .checkpoint )
412399
413400 if args .distributed :
414401 torch .distributed .barrier ()
415402
416- if args .checkpoint and checkpoint_exists :
403+ if args .checkpoint and checkpoint_exists and args . resume :
417404 logging .info ("Checkpoint loading %s" , args .checkpoint )
418405 checkpoint = torch .load (args .checkpoint )
419406
@@ -427,7 +414,13 @@ def main(rank, args):
427414 logging .info (
428415 "Checkpoint loaded '%s' at epoch %s" , args .checkpoint , checkpoint ["epoch" ]
429416 )
430- elif args .checkpoint and main_rank :
417+ elif args .checkpoint and checkpoint_exists :
418+ raise RuntimeError (
419+ "Checkpoint already exists. Add --resume to resume, or manually delete existing file."
420+ )
421+ elif args .checkpoint and args .resume :
422+ raise RuntimeError ("Checkpoint not found" )
423+ elif args .checkpoint and main_rank and args .checkpoint :
431424 save_checkpoint (
432425 {
433426 "epoch" : args .start_epoch ,
@@ -439,6 +432,8 @@ def main(rank, args):
439432 False ,
440433 args .checkpoint ,
441434 )
435+ elif not args .checkpoint and args .resume :
436+ raise RuntimeError ("Checkpoint not provided. Use --checkpoint to specify." )
442437
443438 if args .distributed :
444439 torch .distributed .barrier ()
@@ -464,16 +459,33 @@ def main(rank, args):
464459 not args .reduce_lr_valid ,
465460 )
466461
467- loss = evaluate (
468- model ,
469- criterion ,
470- loader_validation ,
471- decoder ,
472- language_model ,
473- devices [0 ],
474- epoch ,
475- not main_rank ,
476- )
462+ if not (epoch + 1 ) % args .print_freq or epoch == args .epochs - 1 :
463+
464+ loss = evaluate (
465+ model ,
466+ criterion ,
467+ loader_validation ,
468+ decoder ,
469+ language_model ,
470+ devices [0 ],
471+ epoch ,
472+ not main_rank ,
473+ )
474+
475+ is_best = loss < best_loss
476+ best_loss = min (loss , best_loss )
477+ if main_rank and args .checkpoint :
478+ save_checkpoint (
479+ {
480+ "epoch" : epoch + 1 ,
481+ "state_dict" : model .state_dict (),
482+ "best_loss" : best_loss ,
483+ "optimizer" : optimizer .state_dict (),
484+ "scheduler" : scheduler .state_dict (),
485+ },
486+ is_best ,
487+ args .checkpoint ,
488+ )
477489
478490 if args .reduce_lr_valid and isinstance (scheduler , ReduceLROnPlateau ):
479491 scheduler .step (loss )
0 commit comments