diff --git a/docs/api/translate.rst b/docs/api/translate.rst index 093fbc38a..89b4537ac 100644 --- a/docs/api/translate.rst +++ b/docs/api/translate.rst @@ -16,3 +16,5 @@ Modules :members: translate .. autoclass:: ZhThTranslator :members: translate +.. autoclass:: Translate + :members: diff --git a/pythainlp/translate/__init__.py b/pythainlp/translate/__init__.py index 86004eaba..93a473277 100644 --- a/pythainlp/translate/__init__.py +++ b/pythainlp/translate/__init__.py @@ -8,9 +8,12 @@ "ThEnTranslator", "download_model_all", "ThZhTranslator", - "ZhThTranslator" + "ZhThTranslator", + "Translate" ] +from pythainlp.translate.core import Translate + from pythainlp.translate.en_th import ( EnThTranslator, ThEnTranslator, diff --git a/pythainlp/translate/core.py b/pythainlp/translate/core.py new file mode 100644 index 000000000..27bbac581 --- /dev/null +++ b/pythainlp/translate/core.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- + + +class Translate: + def __init__(self, src_lang: str, target_lang: str) -> None: + """ + :param str src_lang: source language + :param str target_lang: target language + + **Options for source & target language** + * *th* - *en* - Thai to English + * *en* - *th* - English to Thai + * *th* - *zh* - Thai to Chinese + * *zh* - *th* - Chinese to Thai + """ + self.model = None + self.load_model(src_lang, target_lang) + + def load_model(self, src_lang: str, target_lang: str): + if src_lang == "th" and target_lang == "en": + from pythainlp.translate.en_th import ThEnTranslator + self.model = ThEnTranslator() + elif src_lang == "en" and target_lang == "th": + from pythainlp.translate.en_th import EnThTranslator + self.model = EnThTranslator() + elif src_lang == "th" and target_lang == "zh": + from pythainlp.translate.zh_th import ThZhTranslator + self.model = ThZhTranslator() + elif src_lang == "zh" and target_lang == "th": + from pythainlp.translate.zh_th import ZhThTranslator + self.model = ZhThTranslator() + else: + raise ValueError("Not support language!") + + def translate(self, text) -> str: + """ + Translate text + + :param str text: input text in source language + :return: translated text in target language + :rtype: str + """ + return self.model.translate(text) diff --git a/tests/test_translate.py b/tests/test_translate.py index b9a78c352..7da13e7f5 100644 --- a/tests/test_translate.py +++ b/tests/test_translate.py @@ -7,7 +7,8 @@ ThEnTranslator, ThZhTranslator, ZhThTranslator, - download_model_all + download_model_all, + Translate ) @@ -38,3 +39,29 @@ def test_translate(self): "我爱你", ) ) + self.th_en_translator = Translate('th', 'en') + self.assertIsNotNone( + self.th_en_translator.translate( + "แมวกินปลา", + ) + ) + self.en_th_translator = Translate('en', 'th') + self.assertIsNotNone( + self.en_th_translator.translate( + "the cat eats fish.", + ) + ) + self.th_zh_translator = Translate('th', 'zh') + self.assertIsNotNone( + self.th_zh_translator.translate( + "ผมรักคุณ", + ) + ) + self.zh_th_translator = Translate('zh', 'th') + self.assertIsNotNone( + self.zh_th_translator.translate( + "我爱你", + ) + ) + with self.assertRaises(ValueError): + self.th_cat_translator = Translate('th', 'cat')