Skip to content

Commit b9325fe

Browse files
committed
Revert "speech commands. typo in loss."
This reverts commit 5051f69.
1 parent 91344b0 commit b9325fe

File tree

3 files changed

+47
-72
lines changed

3 files changed

+47
-72
lines changed

examples/pipeline_wav2letter/datasets.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
from torch import Tensor
66
from torchaudio.datasets import LIBRISPEECH
77

8-
from speechcommands import SPEECHCOMMANDS
9-
108

119
def pad_sequence(sequences, padding_value=0.0):
1210
# type: (List[Tensor], float) -> Tensor
@@ -161,36 +159,6 @@ def create(tags, cache=True):
161159
return tuple(create(dataset) for dataset in datasets)
162160

163161

164-
def split_process_speechcommands(
165-
datasets, transforms, language_model, root,
166-
):
167-
def create(tags, cache=True):
168-
169-
if isinstance(tags, str):
170-
tags = [tags]
171-
if isinstance(transforms, list):
172-
transform_list = transforms
173-
else:
174-
transform_list = [transforms]
175-
176-
data = torch.utils.data.ConcatDataset(
177-
[
178-
Processed(
179-
SPEECHCOMMANDS(root, split=tag, download=False,),
180-
transform,
181-
language_model.encode,
182-
)
183-
for tag, transform in zip(tags, transform_list)
184-
]
185-
)
186-
187-
data = MapMemoryCache(data)
188-
return data
189-
190-
# For performance, we cache all datasets
191-
return tuple(create(dataset) for dataset in datasets)
192-
193-
194162
def collate_factory(model_length_function, transforms=None):
195163

196164
if transforms is None:

examples/pipeline_wav2letter/engine.py

Lines changed: 47 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from torchaudio.models.wav2letter import Wav2Letter
1313

1414
from ctc_decoders import GreedyDecoder
15-
from datasets import collate_factory, split_process_librispeech, split_process_speechcommands
15+
from datasets import collate_factory, split_process_librispeech
1616
from languagemodels import LanguageModel
1717
from metrics import levenshtein_distance
1818
from transforms import Normalize, ToMono, UnsqueezeFirst
@@ -289,10 +289,7 @@ def main(rank, args):
289289
)
290290
elif args.model_input_type == "waveform":
291291
transforms = torch.nn.Sequential(transforms, UnsqueezeFirst())
292-
# assert args.bins == 1, "waveform model input type only supports bins == 1"
293-
if args.bins != 1:
294-
logging.warn("waveform model input type only supports bins == 1")
295-
args.bins = 1
292+
assert args.bins == 1, "waveform model input type only supports bins == 1"
296293
else:
297294
raise NotImplementedError(
298295
f"Selected model input type {args.model_input_type} not supported"
@@ -323,23 +320,13 @@ def main(rank, args):
323320

324321
# Dataset
325322

326-
if args.speechcommands:
327-
training, validation = split_process_speechcommands(
328-
["training", "validation"],
329-
[transforms, transforms],
330-
language_model,
331-
root="/private/home/vincentqb/audio-pytorch/examples/pipeline_wav2letter/",
332-
# root=args.dataset_root,
333-
# folder_in_archive=args.dataset_folder_in_archive,
334-
)
335-
else:
336-
training, validation = split_process_librispeech(
337-
[args.dataset_train, args.dataset_valid],
338-
[transforms, transforms],
339-
language_model,
340-
root=args.dataset_root,
341-
folder_in_archive=args.dataset_folder_in_archive,
342-
)
323+
training, validation = split_process_librispeech(
324+
[args.dataset_train, args.dataset_valid],
325+
[transforms, transforms],
326+
language_model,
327+
root=args.dataset_root,
328+
folder_in_archive=args.dataset_folder_in_archive,
329+
)
343330

344331
# Decoder
345332

@@ -408,12 +395,12 @@ def main(rank, args):
408395

409396
best_loss = 1.0
410397

411-
checkpoint_exists = args.checkpoint and os.path.isfile(args.checkpoint)
398+
checkpoint_exists = os.path.isfile(args.checkpoint)
412399

413400
if args.distributed:
414401
torch.distributed.barrier()
415402

416-
if args.checkpoint and checkpoint_exists:
403+
if args.checkpoint and checkpoint_exists and args.resume:
417404
logging.info("Checkpoint loading %s", args.checkpoint)
418405
checkpoint = torch.load(args.checkpoint)
419406

@@ -427,7 +414,13 @@ def main(rank, args):
427414
logging.info(
428415
"Checkpoint loaded '%s' at epoch %s", args.checkpoint, checkpoint["epoch"]
429416
)
430-
elif args.checkpoint and main_rank:
417+
elif args.checkpoint and checkpoint_exists:
418+
raise RuntimeError(
419+
"Checkpoint already exists. Add --resume to resume, or manually delete existing file."
420+
)
421+
elif args.checkpoint and args.resume:
422+
raise RuntimeError("Checkpoint not found")
423+
elif args.checkpoint and main_rank and args.checkpoint:
431424
save_checkpoint(
432425
{
433426
"epoch": args.start_epoch,
@@ -439,6 +432,8 @@ def main(rank, args):
439432
False,
440433
args.checkpoint,
441434
)
435+
elif not args.checkpoint and args.resume:
436+
raise RuntimeError("Checkpoint not provided. Use --checkpoint to specify.")
442437

443438
if args.distributed:
444439
torch.distributed.barrier()
@@ -464,16 +459,33 @@ def main(rank, args):
464459
not args.reduce_lr_valid,
465460
)
466461

467-
loss = evaluate(
468-
model,
469-
criterion,
470-
loader_validation,
471-
decoder,
472-
language_model,
473-
devices[0],
474-
epoch,
475-
not main_rank,
476-
)
462+
if not (epoch + 1) % args.print_freq or epoch == args.epochs - 1:
463+
464+
loss = evaluate(
465+
model,
466+
criterion,
467+
loader_validation,
468+
decoder,
469+
language_model,
470+
devices[0],
471+
epoch,
472+
not main_rank,
473+
)
474+
475+
is_best = loss < best_loss
476+
best_loss = min(loss, best_loss)
477+
if main_rank and args.checkpoint:
478+
save_checkpoint(
479+
{
480+
"epoch": epoch + 1,
481+
"state_dict": model.state_dict(),
482+
"best_loss": best_loss,
483+
"optimizer": optimizer.state_dict(),
484+
"scheduler": scheduler.state_dict(),
485+
},
486+
is_best,
487+
args.checkpoint,
488+
)
477489

478490
if args.reduce_lr_valid and isinstance(scheduler, ReduceLROnPlateau):
479491
scheduler.step(loss)

examples/pipeline_wav2letter/main.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,6 @@ def parse_args():
163163
parser.add_argument(
164164
"--world-size", type=int, default=8, help="the world size to initiate DPP"
165165
)
166-
parser.add_argument(
167-
"--speechcommands",
168-
action="store_true",
169-
help="select speechcommands instead of librispeech",
170-
)
171166

172167
args = parser.parse_args()
173168
logging.info(args)

0 commit comments

Comments
 (0)