@@ -443,7 +443,6 @@ def main(rank, args):
443443
444444 training , validation = split_process_librispeech (
445445 [args .dataset_train , args .dataset_valid ],
446- # [transforms_train, transforms_valid],
447446 [transforms , transforms ],
448447 language_model ,
449448 root = args .dataset_root ,
@@ -473,7 +472,6 @@ def main(rank, args):
473472 devices = list (range (rank * n , (rank + 1 ) * n ))
474473 model = model .to (devices [0 ])
475474 model = torch .nn .parallel .DistributedDataParallel (model , device_ids = devices )
476- # model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)
477475 else :
478476 devices = ["cuda" if torch .cuda .is_available () else "cpu" ]
479477 model = model .to (devices [0 ], non_blocking = True )
@@ -526,8 +524,6 @@ def main(rank, args):
526524 criterion = torch .nn .CTCLoss (
527525 blank = language_model .mapping [char_blank ], zero_infinity = False
528526 )
529- # criterion = torch.nn.MSELoss()
530- # criterion = torch.nn.NLLLoss()
531527
532528 # Data Loader
533529
0 commit comments