Skip to content

Commit 38a0a28

Browse files
committed
relative cer/wer.
1 parent 3decf09 commit 38a0a28

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

examples/pipeline/wav2letter.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,17 +260,21 @@ def evaluate(model, criterion, data_loader, decoder, language_model, device):
260260
cers = [levenshtein_distance(a, b) for a, b in zip(target, output)]
261261
# cers_normalized = [d / len(a) for a, d in zip(target, cers)]
262262
cers = sum(cers)
263+
n = sum(len(t) for t in target)
263264
sums["cer"] += cers
264-
sums["total_chars"] += sum(len(t) for t in target)
265+
sums["cer_relative"] += cers / n
266+
sums["total_chars"] += n
265267

266268
output = [o.split(language_model.char_space) for o in output]
267269
target = [o.split(language_model.char_space) for o in target]
268270

269271
wers = [levenshtein_distance(a, b) for a, b in zip(target, output)]
270272
# wers_normalized = [d / len(a) for a, d in zip(target, wers)]
271273
wers = sum(wers)
274+
n = len(target)
272275
sums["wer"] += wers
273-
sums["total_words"] += len(target)
276+
sums["wer_relative"] += wers / n
277+
sums["total_words"] += n
274278

275279
avg_loss = sums["loss"] / len(data_loader)
276280
print(f"Validation loss: {avg_loss:.5f}", flush=True)

0 commit comments

Comments
 (0)