Skip to content

Commit c378c48

Browse files
committed
Revert "Revert "remove other decoders, and unused dataset class.""
This reverts commit 7b00bc8.
1 parent c4ccab0 commit c378c48

File tree

3 files changed

+2
-239
lines changed

3 files changed

+2
-239
lines changed
Lines changed: 0 additions & 218 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,4 @@
1-
from collections import Counter
2-
3-
import torch
41
from torch import topk
5-
from tqdm import tqdm
6-
7-
8-
class GreedyIterableDecoder:
9-
def __init__(self, blank_label=0, collapse_repeated=True):
10-
self.blank_label = blank_label
11-
self.collapse_repeated = collapse_repeated
12-
13-
def __call__(self, output):
14-
arg_maxes = torch.argmax(output, dim=-1)
15-
decodes = []
16-
for args in arg_maxes:
17-
decode = []
18-
for j, index in enumerate(args):
19-
if index != self.blank_label:
20-
if self.collapse_repeated and j != 0 and index == args[j - 1]:
21-
continue
22-
decode.append(index.item())
23-
decode = torch.tensor(decode)
24-
decodes.append(decode)
25-
# decodes = torch.tensor(decodes)
26-
decodes = torch.nn.utils.rnn.pad_sequence(decodes, batch_first=True)
27-
return decodes
282

293

304
class GreedyDecoder:
@@ -39,195 +13,3 @@ def __call__(self, outputs):
3913
"""
4014
_, indices = topk(outputs, k=1, dim=-1)
4115
return indices[..., 0]
42-
43-
44-
def zeros_like(m):
45-
return zeros(len(m), len(m[0]))
46-
47-
48-
def zeros(d1, d2):
49-
return list(list(0 for _ in range(d2)) for _ in range(d1))
50-
51-
52-
def apply_transpose(f, m):
53-
return list(map(f, zip(*m)))
54-
55-
56-
def argmax(l):
57-
return max(range(len(l)), key=lambda i: l[i])
58-
59-
60-
def add1d2d(m1, m2):
61-
return [[v2 + v1 for v2 in m2_row] for m2_row, v1 in zip(m2, m1)]
62-
63-
64-
def add1d1d(v1, v2):
65-
return [e + s for e, s in zip(v1, v2)]
66-
67-
68-
class ListViterbiDecoder:
69-
def __init__(self, data_loader, vocab_size, n=2, progress_bar=False):
70-
self._transitions = self._build_transitions(
71-
data_loader, vocab_size, n, progress_bar
72-
)
73-
74-
def __call__(self, emissions):
75-
return torch.tensor([self._decode(emissions[i].tolist(), self._transitions)[0] for i in range(len(emissions))])
76-
77-
@staticmethod
78-
def _build_transitions(data_loader, vocab_size, n=2, progress_bar=False):
79-
80-
# Count n-grams
81-
count = Counter()
82-
for _, label in tqdm(data_loader, disable=not progress_bar):
83-
count += Counter(a for a in zip(*(label[i:] for i in range(n))))
84-
85-
# Write as matrix
86-
transitions = zeros(vocab_size, vocab_size)
87-
for (k1, k2), v in count.items():
88-
transitions[k1][k2] = v
89-
90-
return transitions
91-
92-
@staticmethod
93-
def _decode(emissions, transitions):
94-
scores = zeros_like(emissions)
95-
back_pointers = zeros_like(emissions)
96-
scores = emissions[0]
97-
98-
# Generate most likely scores and paths for each step in sequence
99-
for i in range(1, len(emissions)):
100-
score_with_transition = add1d2d(scores, transitions)
101-
max_score_with_transition = apply_transpose(max, score_with_transition)
102-
scores = add1d1d(emissions[i], max_score_with_transition)
103-
back_pointers[i] = apply_transpose(argmax, score_with_transition)
104-
105-
# Generate the most likely path
106-
viterbi = [argmax(scores)]
107-
for bp in reversed(back_pointers[1:]):
108-
viterbi.append(bp[viterbi[-1]])
109-
viterbi.reverse()
110-
viterbi_score = max(scores)
111-
112-
return viterbi, viterbi_score
113-
114-
115-
class ViterbiDecoder:
116-
def __init__(self, data_loader, vocab_size, n=2, progress_bar=False):
117-
self.vocab_size = vocab_size
118-
self.n = n
119-
self.top_k = 1
120-
self.progress_bar = progress_bar
121-
122-
self._build_transitions(data_loader)
123-
124-
def _build_transitions(self, data_loader):
125-
126-
# Count n-grams
127-
128-
c = Counter()
129-
for _, label in tqdm(data_loader, disable=not self.progress_bar):
130-
count = Counter(
131-
tuple(b.item() for b in a)
132-
for a in zip(*(label[i:] for i in range(self.n)))
133-
)
134-
c += count
135-
136-
# Encode as transition matrix
137-
138-
ind = torch.tensor([a for (a, _) in c.items()]).t()
139-
val = torch.tensor([b for (_, b) in c.items()], dtype=torch.float)
140-
141-
transitions = (
142-
torch.sparse_coo_tensor(
143-
indices=ind, values=val, size=[self.vocab_size, self.vocab_size]
144-
)
145-
.coalesce()
146-
.to_dense()
147-
)
148-
transitions = transitions / torch.max(
149-
torch.tensor(1.0), transitions.max(dim=1)[0]
150-
).unsqueeze(1)
151-
152-
self.transitions = transitions
153-
154-
def _viterbi_decode(self, tag_sequence: torch.Tensor):
155-
"""
156-
Perform Viterbi decoding in log space over a sequence given a transition matrix
157-
specifying pairwise (transition) potentials between tags and a matrix of shape
158-
(sequence_length, num_tags) specifying unary potentials for possible tags per
159-
timestep.
160-
161-
Parameters
162-
----------
163-
tag_sequence : torch.Tensor, required.
164-
A tensor of shape (sequence_length, num_tags) representing scores for
165-
a set of tags over a given sequence.
166-
167-
Returns
168-
-------
169-
viterbi_path : List[int]
170-
The tag indices of the maximum likelihood tag sequence.
171-
viterbi_score : float
172-
The score of the viterbi path.
173-
"""
174-
sequence_length, num_tags = tag_sequence.size()
175-
176-
path_scores = []
177-
path_indices = []
178-
# At the beginning, the maximum number of permutations is 1; therefore, we unsqueeze(0)
179-
# to allow for 1 permutation.
180-
path_scores.append(tag_sequence[0, :].unsqueeze(0))
181-
# assert path_scores[0].size() == (n_permutations, num_tags)
182-
183-
# Evaluate the scores for all possible paths.
184-
for timestep in range(1, sequence_length):
185-
# Add pairwise potentials to current scores.
186-
# assert path_scores[timestep - 1].size() == (n_permutations, num_tags)
187-
summed_potentials = (
188-
path_scores[timestep - 1].unsqueeze(2) + self.transitions
189-
)
190-
summed_potentials = summed_potentials.view(-1, num_tags)
191-
192-
# Best pairwise potential path score from the previous timestep.
193-
max_k = min(summed_potentials.size()[0], self.top_k)
194-
scores, paths = torch.topk(summed_potentials, k=max_k, dim=0)
195-
# assert scores.size() == (n_permutations, num_tags)
196-
# assert paths.size() == (n_permutations, num_tags)
197-
198-
scores = tag_sequence[timestep, :] + scores
199-
# assert scores.size() == (n_permutations, num_tags)
200-
path_scores.append(scores)
201-
path_indices.append(paths.squeeze())
202-
203-
# Construct the most likely sequence backwards.
204-
path_scores = path_scores[-1].view(-1)
205-
max_k = min(path_scores.size()[0], self.top_k)
206-
viterbi_scores, best_paths = torch.topk(path_scores, k=max_k, dim=0)
207-
208-
viterbi_paths = []
209-
for i in range(max_k):
210-
211-
viterbi_path = [best_paths[i].item()]
212-
for backward_timestep in reversed(path_indices):
213-
viterbi_path.append(int(backward_timestep.view(-1)[viterbi_path[-1]]))
214-
215-
# Reverse the backward path.
216-
viterbi_path.reverse()
217-
218-
# Viterbi paths uses (num_tags * n_permutations) nodes; therefore, we need to modulo.
219-
viterbi_path = [j % num_tags for j in viterbi_path]
220-
viterbi_paths.append(viterbi_path)
221-
222-
return viterbi_paths, viterbi_scores
223-
224-
def __call__(self, tag_sequence: torch.Tensor):
225-
226-
outputs = []
227-
scores = []
228-
for i in range(tag_sequence.shape[1]):
229-
paths, score = self._viterbi_decode(tag_sequence[:, i, :])
230-
outputs.append(paths)
231-
scores.append(score)
232-
233-
return torch.tensor(outputs).transpose(0, -1), torch.cat(scores)[:, 0, :]

examples/pipeline_wav2letter/engine.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,7 @@
1313
from torchaudio.transforms import MFCC, Resample
1414
from torchaudio.models.wav2letter import Wav2Letter
1515

16-
from ctc_decoders import (
17-
GreedyDecoder,
18-
GreedyIterableDecoder,
19-
ListViterbiDecoder,
20-
ViterbiDecoder,
21-
)
16+
from ctc_decoders import GreedyDecoder
2217
from datasets import collate_factory, split_process_librispeech
2318
from languagemodels import LanguageModel
2419
from metrics import levenshtein_distance
@@ -248,10 +243,6 @@ def evaluate(
248243
outputs, targets, decoder, language_model, loss_value, metric
249244
)
250245

251-
# TODO Remove before merge pull request
252-
if SIGNAL_RECEIVED:
253-
break
254-
255246
metric.flush()
256247

257248
return avg_loss
@@ -338,12 +329,6 @@ def main(rank, args):
338329

339330
if args.decoder == "greedy":
340331
decoder = GreedyDecoder()
341-
elif args.decoder == "greedyiter":
342-
decoder = GreedyIterableDecoder()
343-
elif args.decoder == "viterbi":
344-
decoder = ListViterbiDecoder(
345-
training, len(language_model), progress_bar=args.progress_bar
346-
)
347332
else:
348333
raise ValueError("Selected decoder not supported")
349334

examples/pipeline_wav2letter/main.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ def parse_args():
1313
choices=["waveform", "mfcc"],
1414
help="input type for model",
1515
)
16-
parser.add_argument(
1716
parser.add_argument(
1817
"--freq-mask", default=0, type=int, help="maximal width of frequency mask",
1918
)
@@ -118,10 +117,7 @@ def parse_args():
118117
help="rho parameter for Adadelta",
119118
)
120119
parser.add_argument(
121-
"--clip-grad",
122-
metavar="NORM",
123-
type=float,
124-
help="value to clip gradient at",
120+
"--clip-grad", metavar="NORM", type=float, help="value to clip gradient at",
125121
)
126122
parser.add_argument(
127123
"--dataset-root", type=str, help="specify dataset root folder",

0 commit comments

Comments
 (0)