@@ -110,17 +110,19 @@ def _grucell(self, x, h, w_ih, w_hh, b_ih, b_hh):
110110
111111 return h
112112
113- def _gru (self , x , steps , w_ih , w_hh , b_ih , b_hh , h0 = None ):
113+ def _gru (self , x , steps , w_ih , w_hh , b_ih , b_hh , h0 = None ) -> np . ndarray :
114114 if h0 is None :
115115 h0 = np .zeros ((x .shape [0 ], w_hh .shape [1 ]), np .float32 )
116116 h = h0 # initial hidden state
117+
117118 outputs = np .zeros ((x .shape [0 ], steps , w_hh .shape [1 ]), np .float32 )
118119 for t in range (steps ):
119120 h = self ._grucell (x [:, t , :], h , w_ih , w_hh , b_ih , b_hh ) # (b, h)
120121 outputs [:, t , ::] = h
122+
121123 return outputs
122124
123- def _encode (self , word : str ):
125+ def _encode (self , word : str ) -> np . ndarray :
124126 chars = list (word ) + ["</s>" ]
125127 x = [self .g2idx .get (char , self .g2idx ["<unk>" ]) for char in chars ]
126128 x = np .take (self .enc_emb , np .expand_dims (x , 0 ), axis = 0 )
@@ -132,7 +134,8 @@ def _short_word(self, word: str) -> str:
132134 if self .word .endswith ("." ):
133135 self .word = self .word .replace ("." , "" )
134136 self .word = "-" .join ([_j + "อ" for _j in list (self .word )])
135- return self .word
137+
138+ return self .word
136139
137140 def _predict (self , word : str ) -> str :
138141 short_word = self ._short_word (word )
@@ -156,7 +159,7 @@ def _predict(self, word: str) -> str:
156159 h = last_hidden
157160
158161 preds = []
159- for _i in range (20 ):
162+ for _ in range (20 ):
160163 h = self ._grucell (
161164 dec ,
162165 h ,
@@ -173,6 +176,7 @@ def _predict(self, word: str) -> str:
173176 dec = np .take (self .dec_emb , [pred ], axis = 0 )
174177
175178 preds = [self .idx2p .get (idx , "<unk>" ) for idx in preds ]
179+
176180 return preds
177181
178182 def __call__ (self , word : str ) -> str :
0 commit comments