1- from collections import Counter
2-
3- import torch
41from torch import topk
5- from tqdm import tqdm
6-
7-
8- class GreedyIterableDecoder :
9- def __init__ (self , blank_label = 0 , collapse_repeated = True ):
10- self .blank_label = blank_label
11- self .collapse_repeated = collapse_repeated
12-
13- def __call__ (self , output ):
14- arg_maxes = torch .argmax (output , dim = - 1 )
15- decodes = []
16- for args in arg_maxes :
17- decode = []
18- for j , index in enumerate (args ):
19- if index != self .blank_label :
20- if self .collapse_repeated and j != 0 and index == args [j - 1 ]:
21- continue
22- decode .append (index .item ())
23- decode = torch .tensor (decode )
24- decodes .append (decode )
25- # decodes = torch.tensor(decodes)
26- decodes = torch .nn .utils .rnn .pad_sequence (decodes , batch_first = True )
27- return decodes
282
293
304class GreedyDecoder :
@@ -39,195 +13,3 @@ def __call__(self, outputs):
3913 """
4014 _ , indices = topk (outputs , k = 1 , dim = - 1 )
4115 return indices [..., 0 ]
42-
43-
44- def zeros_like (m ):
45- return zeros (len (m ), len (m [0 ]))
46-
47-
48- def zeros (d1 , d2 ):
49- return list (list (0 for _ in range (d2 )) for _ in range (d1 ))
50-
51-
52- def apply_transpose (f , m ):
53- return list (map (f , zip (* m )))
54-
55-
56- def argmax (l ):
57- return max (range (len (l )), key = lambda i : l [i ])
58-
59-
60- def add1d2d (m1 , m2 ):
61- return [[v2 + v1 for v2 in m2_row ] for m2_row , v1 in zip (m2 , m1 )]
62-
63-
64- def add1d1d (v1 , v2 ):
65- return [e + s for e , s in zip (v1 , v2 )]
66-
67-
68- class ListViterbiDecoder :
69- def __init__ (self , data_loader , vocab_size , n = 2 , progress_bar = False ):
70- self ._transitions = self ._build_transitions (
71- data_loader , vocab_size , n , progress_bar
72- )
73-
74- def __call__ (self , emissions ):
75- return torch .tensor ([self ._decode (emissions [i ].tolist (), self ._transitions )[0 ] for i in range (len (emissions ))])
76-
77- @staticmethod
78- def _build_transitions (data_loader , vocab_size , n = 2 , progress_bar = False ):
79-
80- # Count n-grams
81- count = Counter ()
82- for _ , label in tqdm (data_loader , disable = not progress_bar ):
83- count += Counter (a for a in zip (* (label [i :] for i in range (n ))))
84-
85- # Write as matrix
86- transitions = zeros (vocab_size , vocab_size )
87- for (k1 , k2 ), v in count .items ():
88- transitions [k1 ][k2 ] = v
89-
90- return transitions
91-
92- @staticmethod
93- def _decode (emissions , transitions ):
94- scores = zeros_like (emissions )
95- back_pointers = zeros_like (emissions )
96- scores = emissions [0 ]
97-
98- # Generate most likely scores and paths for each step in sequence
99- for i in range (1 , len (emissions )):
100- score_with_transition = add1d2d (scores , transitions )
101- max_score_with_transition = apply_transpose (max , score_with_transition )
102- scores = add1d1d (emissions [i ], max_score_with_transition )
103- back_pointers [i ] = apply_transpose (argmax , score_with_transition )
104-
105- # Generate the most likely path
106- viterbi = [argmax (scores )]
107- for bp in reversed (back_pointers [1 :]):
108- viterbi .append (bp [viterbi [- 1 ]])
109- viterbi .reverse ()
110- viterbi_score = max (scores )
111-
112- return viterbi , viterbi_score
113-
114-
115- class ViterbiDecoder :
116- def __init__ (self , data_loader , vocab_size , n = 2 , progress_bar = False ):
117- self .vocab_size = vocab_size
118- self .n = n
119- self .top_k = 1
120- self .progress_bar = progress_bar
121-
122- self ._build_transitions (data_loader )
123-
124- def _build_transitions (self , data_loader ):
125-
126- # Count n-grams
127-
128- c = Counter ()
129- for _ , label in tqdm (data_loader , disable = not self .progress_bar ):
130- count = Counter (
131- tuple (b .item () for b in a )
132- for a in zip (* (label [i :] for i in range (self .n )))
133- )
134- c += count
135-
136- # Encode as transition matrix
137-
138- ind = torch .tensor ([a for (a , _ ) in c .items ()]).t ()
139- val = torch .tensor ([b for (_ , b ) in c .items ()], dtype = torch .float )
140-
141- transitions = (
142- torch .sparse_coo_tensor (
143- indices = ind , values = val , size = [self .vocab_size , self .vocab_size ]
144- )
145- .coalesce ()
146- .to_dense ()
147- )
148- transitions = transitions / torch .max (
149- torch .tensor (1.0 ), transitions .max (dim = 1 )[0 ]
150- ).unsqueeze (1 )
151-
152- self .transitions = transitions
153-
154- def _viterbi_decode (self , tag_sequence : torch .Tensor ):
155- """
156- Perform Viterbi decoding in log space over a sequence given a transition matrix
157- specifying pairwise (transition) potentials between tags and a matrix of shape
158- (sequence_length, num_tags) specifying unary potentials for possible tags per
159- timestep.
160-
161- Parameters
162- ----------
163- tag_sequence : torch.Tensor, required.
164- A tensor of shape (sequence_length, num_tags) representing scores for
165- a set of tags over a given sequence.
166-
167- Returns
168- -------
169- viterbi_path : List[int]
170- The tag indices of the maximum likelihood tag sequence.
171- viterbi_score : float
172- The score of the viterbi path.
173- """
174- sequence_length , num_tags = tag_sequence .size ()
175-
176- path_scores = []
177- path_indices = []
178- # At the beginning, the maximum number of permutations is 1; therefore, we unsqueeze(0)
179- # to allow for 1 permutation.
180- path_scores .append (tag_sequence [0 , :].unsqueeze (0 ))
181- # assert path_scores[0].size() == (n_permutations, num_tags)
182-
183- # Evaluate the scores for all possible paths.
184- for timestep in range (1 , sequence_length ):
185- # Add pairwise potentials to current scores.
186- # assert path_scores[timestep - 1].size() == (n_permutations, num_tags)
187- summed_potentials = (
188- path_scores [timestep - 1 ].unsqueeze (2 ) + self .transitions
189- )
190- summed_potentials = summed_potentials .view (- 1 , num_tags )
191-
192- # Best pairwise potential path score from the previous timestep.
193- max_k = min (summed_potentials .size ()[0 ], self .top_k )
194- scores , paths = torch .topk (summed_potentials , k = max_k , dim = 0 )
195- # assert scores.size() == (n_permutations, num_tags)
196- # assert paths.size() == (n_permutations, num_tags)
197-
198- scores = tag_sequence [timestep , :] + scores
199- # assert scores.size() == (n_permutations, num_tags)
200- path_scores .append (scores )
201- path_indices .append (paths .squeeze ())
202-
203- # Construct the most likely sequence backwards.
204- path_scores = path_scores [- 1 ].view (- 1 )
205- max_k = min (path_scores .size ()[0 ], self .top_k )
206- viterbi_scores , best_paths = torch .topk (path_scores , k = max_k , dim = 0 )
207-
208- viterbi_paths = []
209- for i in range (max_k ):
210-
211- viterbi_path = [best_paths [i ].item ()]
212- for backward_timestep in reversed (path_indices ):
213- viterbi_path .append (int (backward_timestep .view (- 1 )[viterbi_path [- 1 ]]))
214-
215- # Reverse the backward path.
216- viterbi_path .reverse ()
217-
218- # Viterbi paths uses (num_tags * n_permutations) nodes; therefore, we need to modulo.
219- viterbi_path = [j % num_tags for j in viterbi_path ]
220- viterbi_paths .append (viterbi_path )
221-
222- return viterbi_paths , viterbi_scores
223-
224- def __call__ (self , tag_sequence : torch .Tensor ):
225-
226- outputs = []
227- scores = []
228- for i in range (tag_sequence .shape [1 ]):
229- paths , score = self ._viterbi_decode (tag_sequence [:, i , :])
230- outputs .append (paths )
231- scores .append (score )
232-
233- return torch .tensor (outputs ).transpose (0 , - 1 ), torch .cat (scores )[:, 0 , :]
0 commit comments