Skip to content

Commit 0a0a371

Browse files
committed
Fix _short_word() returns
1 parent e941968 commit 0a0a371

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

pythainlp/transliterate/w2p.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import codecs
88
import os
99
import re
10+
from typing import Union
1011

1112
import numpy as np
1213
from pythainlp.corpus import download, get_corpus_path
@@ -20,6 +21,8 @@
2021
+ " ณิฑชฉซทรํฬฏ–ัฃวก่ปผ์ฆบี๊ธฌญะไษ๋นโภ?"
2122
)
2223

24+
_MODEL_NAME = "thai_w2p"
25+
2326

2427
class _Hparams:
2528
batch_size = 256
@@ -52,10 +55,10 @@ def __init__(self):
5255
self.graphemes = hp.graphemes
5356
self.phonemes = hp.phonemes
5457
self.g2idx, self.idx2g, self.p2idx, self.idx2p = _load_vocab()
55-
self.checkpoint = get_corpus_path("thai_w2p")
58+
self.checkpoint = get_corpus_path(_MODEL_NAME)
5659
if self.checkpoint is None:
57-
download("thai_w2p")
58-
self.checkpoint = get_corpus_path("thai_w2p")
60+
download(_MODEL_NAME)
61+
self.checkpoint = get_corpus_path(_MODEL_NAME)
5962
self._load_variables()
6063

6164
def _load_variables(self):
@@ -129,18 +132,19 @@ def _encode(self, word: str) -> np.ndarray:
129132

130133
return x
131134

132-
def _short_word(self, word: str) -> str:
135+
def _short_word(self, word: str) -> Union[str, None]:
133136
self.word = word
134137
if self.word.endswith("."):
135138
self.word = self.word.replace(".", "")
136-
self.word = "-".join([_j + "อ" for _j in list(self.word)])
137-
138-
return self.word
139+
self.word = "-".join([i + "อ" for i in list(self.word)])
140+
return self.word
141+
return None
139142

140143
def _predict(self, word: str) -> str:
141144
short_word = self._short_word(word)
142145
if short_word is not None:
143146
return short_word
147+
144148
# encoder
145149
enc = self._encode(word)
146150
enc = self._gru(

0 commit comments

Comments
 (0)