|
7 | 7 | import codecs |
8 | 8 | import os |
9 | 9 | import re |
| 10 | +from typing import Union |
10 | 11 |
|
11 | 12 | import numpy as np |
12 | 13 | from pythainlp.corpus import download, get_corpus_path |
|
20 | 21 | + " ณิฑชฉซทรํฬฏ–ัฃวก่ปผ์ฆบี๊ธฌญะไษ๋นโภ?" |
21 | 22 | ) |
22 | 23 |
|
| 24 | +_MODEL_NAME = "thai_w2p" |
| 25 | + |
23 | 26 |
|
24 | 27 | class _Hparams: |
25 | 28 | batch_size = 256 |
@@ -52,10 +55,10 @@ def __init__(self): |
52 | 55 | self.graphemes = hp.graphemes |
53 | 56 | self.phonemes = hp.phonemes |
54 | 57 | self.g2idx, self.idx2g, self.p2idx, self.idx2p = _load_vocab() |
55 | | - self.checkpoint = get_corpus_path("thai_w2p") |
| 58 | + self.checkpoint = get_corpus_path(_MODEL_NAME) |
56 | 59 | if self.checkpoint is None: |
57 | | - download("thai_w2p") |
58 | | - self.checkpoint = get_corpus_path("thai_w2p") |
| 60 | + download(_MODEL_NAME) |
| 61 | + self.checkpoint = get_corpus_path(_MODEL_NAME) |
59 | 62 | self._load_variables() |
60 | 63 |
|
61 | 64 | def _load_variables(self): |
@@ -129,18 +132,19 @@ def _encode(self, word: str) -> np.ndarray: |
129 | 132 |
|
130 | 133 | return x |
131 | 134 |
|
132 | | - def _short_word(self, word: str) -> str: |
| 135 | + def _short_word(self, word: str) -> Union[str, None]: |
133 | 136 | self.word = word |
134 | 137 | if self.word.endswith("."): |
135 | 138 | self.word = self.word.replace(".", "") |
136 | | - self.word = "-".join([_j + "อ" for _j in list(self.word)]) |
137 | | - |
138 | | - return self.word |
| 139 | + self.word = "-".join([i + "อ" for i in list(self.word)]) |
| 140 | + return self.word |
| 141 | + return None |
139 | 142 |
|
140 | 143 | def _predict(self, word: str) -> str: |
141 | 144 | short_word = self._short_word(word) |
142 | 145 | if short_word is not None: |
143 | 146 | return short_word |
| 147 | + |
144 | 148 | # encoder |
145 | 149 | enc = self._encode(word) |
146 | 150 | enc = self._gru( |
|
0 commit comments