Skip to content

Commit e4f9c9b

Browse files
authored
Merge pull request #815 from PyThaiNLP/add-small100
Add small100 to pythainlp.translate
2 parents bf74de6 + b6212d9 commit e4f9c9b

File tree

4 files changed

+460
-6
lines changed

4 files changed

+460
-6
lines changed

pythainlp/translate/core.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,19 +41,26 @@ class Translate:
4141
"""
4242

4343
def __init__(
44-
self, src_lang: str, target_lang: str, use_gpu: bool = False
44+
self, src_lang: str, target_lang: str, engine: str="default", use_gpu: bool = False
4545
) -> None:
4646
"""
4747
:param str src_lang: source language
4848
:param str target_lang: target language
49+
:param str engine: Machine Translation engine
4950
:param bool use_gpu: load model to gpu (Default is False)
5051
52+
**Options for engine*
53+
* *default* - The engine default by each a language.
54+
* *small100* - A multilingual machine translation model (covering 100 languages)
55+
5156
**Options for source & target language**
5257
* *th* - *en* - Thai to English
5358
* *en* - *th* - English to Thai
5459
* *th* - *zh* - Thai to Chinese
5560
* *zh* - *th* - Chinese to Thai
5661
* *th* - *fr* - Thai to French
62+
* *th* - *xx* - Thai to xx (xx is language code). It uses small100 model.
63+
* *xx* - *th* - xx to Thai (xx is language code). It uses small100 model.
5764
5865
:Example:
5966
@@ -66,10 +73,21 @@ def __init__(
6673
# output: I love cat.
6774
"""
6875
self.model = None
69-
self.load_model(src_lang, target_lang, use_gpu)
70-
71-
def load_model(self, src_lang: str, target_lang: str, use_gpu: bool):
72-
if src_lang == "th" and target_lang == "en":
76+
self.engine = engine
77+
self.src_lang = src_lang
78+
self.use_gpu = use_gpu
79+
self.target_lang = target_lang
80+
self.load_model()
81+
82+
def load_model(self):
83+
src_lang = self.src_lang
84+
target_lang = self.target_lang
85+
use_gpu = self.use_gpu
86+
if self.engine == "small100":
87+
from .small100 import Small100Translator
88+
89+
self.model = Small100Translator(use_gpu)
90+
elif src_lang == "th" and target_lang == "en":
7391
from pythainlp.translate.en_th import ThEnTranslator
7492

7593
self.model = ThEnTranslator(use_gpu)
@@ -100,4 +118,6 @@ def translate(self, text) -> str:
100118
:return: translated text in target language
101119
:rtype: str
102120
"""
121+
if self.engine == "small100":
122+
return self.model.translate(text, tgt_lang=self.target_lang)
103123
return self.model.translate(text)

pythainlp/translate/small100.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from transformers import M2M100ForConditionalGeneration
2+
from .tokenization_small100 import SMALL100Tokenizer
3+
4+
class Small100Translator:
5+
"""
6+
Machine Translation with small100 model
7+
8+
- Huggingface https://huggingface.co/alirezamsh/small100
9+
10+
:param bool use_gpu : load model to gpu (Default is False)
11+
"""
12+
13+
def __init__(
14+
self,
15+
use_gpu: bool = False,
16+
pretrained: str = "alirezamsh/small100",
17+
) -> None:
18+
self.pretrained = pretrained
19+
self.model = M2M100ForConditionalGeneration.from_pretrained(self.pretrained)
20+
self.tgt_lang = None
21+
if use_gpu:
22+
self.model = self.model.cuda()
23+
24+
def translate(self, text: str, tgt_lang: str="en") -> str:
25+
"""
26+
Translate text from X to X
27+
28+
:param str text: input text in source language
29+
:param str tgt_lang: target language
30+
:return: translated text in target language
31+
:rtype: str
32+
33+
:Example:
34+
35+
::
36+
37+
from pythainlp.translate.small100 import Small100Translator
38+
39+
mt = Small100Translator()
40+
41+
# Translate text from Thai to English
42+
mt.translate("ทดสอบระบบ", tgt_lang="en")
43+
# output: 'Testing system'
44+
45+
# Translate text from Thai to Chinese
46+
mt.translate("ทดสอบระบบ", tgt_lang="zh")
47+
# output: '系统测试'
48+
49+
# Translate text from Thai to French
50+
mt.translate("ทดสอบระบบ", tgt_lang="fr")
51+
# output: 'Test du système'
52+
53+
"""
54+
if tgt_lang!=self.tgt_lang:
55+
self.tokenizer = SMALL100Tokenizer.from_pretrained(self.pretrained, tgt_lang=tgt_lang)
56+
self.tgt_lang = tgt_lang
57+
self.translated = self.model.generate(
58+
**self.tokenizer(text, return_tensors="pt")
59+
)
60+
return self.tokenizer.batch_decode(self.translated, skip_special_tokens=True)[0]

0 commit comments

Comments
 (0)