diff --git a/pythainlp/translate/core.py b/pythainlp/translate/core.py index 80225a307..f7d616b31 100644 --- a/pythainlp/translate/core.py +++ b/pythainlp/translate/core.py @@ -41,19 +41,26 @@ class Translate: """ def __init__( - self, src_lang: str, target_lang: str, use_gpu: bool = False + self, src_lang: str, target_lang: str, engine: str="default", use_gpu: bool = False ) -> None: """ :param str src_lang: source language :param str target_lang: target language + :param str engine: Machine Translation engine :param bool use_gpu: load model to gpu (Default is False) + **Options for engine* + * *default* - The engine default by each a language. + * *small100* - A multilingual machine translation model (covering 100 languages) + **Options for source & target language** * *th* - *en* - Thai to English * *en* - *th* - English to Thai * *th* - *zh* - Thai to Chinese * *zh* - *th* - Chinese to Thai * *th* - *fr* - Thai to French + * *th* - *xx* - Thai to xx (xx is language code). It uses small100 model. + * *xx* - *th* - xx to Thai (xx is language code). It uses small100 model. :Example: @@ -66,10 +73,21 @@ def __init__( # output: I love cat. """ self.model = None - self.load_model(src_lang, target_lang, use_gpu) - - def load_model(self, src_lang: str, target_lang: str, use_gpu: bool): - if src_lang == "th" and target_lang == "en": + self.engine = engine + self.src_lang = src_lang + self.use_gpu = use_gpu + self.target_lang = target_lang + self.load_model() + + def load_model(self): + src_lang = self.src_lang + target_lang = self.target_lang + use_gpu = self.use_gpu + if self.engine == "small100": + from .small100 import Small100Translator + + self.model = Small100Translator(use_gpu) + elif src_lang == "th" and target_lang == "en": from pythainlp.translate.en_th import ThEnTranslator self.model = ThEnTranslator(use_gpu) @@ -100,4 +118,6 @@ def translate(self, text) -> str: :return: translated text in target language :rtype: str """ + if self.engine == "small100": + return self.model.translate(text, tgt_lang=self.target_lang) return self.model.translate(text) diff --git a/pythainlp/translate/small100.py b/pythainlp/translate/small100.py new file mode 100644 index 000000000..9a4fb5bd2 --- /dev/null +++ b/pythainlp/translate/small100.py @@ -0,0 +1,60 @@ +from transformers import M2M100ForConditionalGeneration +from .tokenization_small100 import SMALL100Tokenizer + +class Small100Translator: + """ + Machine Translation with small100 model + + - Huggingface https://huggingface.co/alirezamsh/small100 + + :param bool use_gpu : load model to gpu (Default is False) + """ + + def __init__( + self, + use_gpu: bool = False, + pretrained: str = "alirezamsh/small100", + ) -> None: + self.pretrained = pretrained + self.model = M2M100ForConditionalGeneration.from_pretrained(self.pretrained) + self.tgt_lang = None + if use_gpu: + self.model = self.model.cuda() + + def translate(self, text: str, tgt_lang: str="en") -> str: + """ + Translate text from X to X + + :param str text: input text in source language + :param str tgt_lang: target language + :return: translated text in target language + :rtype: str + + :Example: + + :: + + from pythainlp.translate.small100 import Small100Translator + + mt = Small100Translator() + + # Translate text from Thai to English + mt.translate("ทดสอบระบบ", tgt_lang="en") + # output: 'Testing system' + + # Translate text from Thai to Chinese + mt.translate("ทดสอบระบบ", tgt_lang="zh") + # output: '系统测试' + + # Translate text from Thai to French + mt.translate("ทดสอบระบบ", tgt_lang="fr") + # output: 'Test du système' + + """ + if tgt_lang!=self.tgt_lang: + self.tokenizer = SMALL100Tokenizer.from_pretrained(self.pretrained, tgt_lang=tgt_lang) + self.tgt_lang = tgt_lang + self.translated = self.model.generate( + **self.tokenizer(text, return_tensors="pt") + ) + return self.tokenizer.batch_decode(self.translated, skip_special_tokens=True)[0] diff --git a/pythainlp/translate/tokenization_small100.py b/pythainlp/translate/tokenization_small100.py new file mode 100644 index 000000000..7168076a5 --- /dev/null +++ b/pythainlp/translate/tokenization_small100.py @@ -0,0 +1,362 @@ +# Copyright (c) 2022 Idiap Research Institute, http://www.idiap.ch/ +# Written by Alireza Mohammadshahi +# This is a modified version of https://github.com/huggingface/transformers/blob/main/src/transformers/models/m2m_100/tokenization_m2m_100.py +# which owns by Fariseq Authors and The HuggingFace Inc. team. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes for SMALL100.""" +import json +import os +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import sentencepiece + +from transformers.tokenization_utils import BatchEncoding, PreTrainedTokenizer +from transformers.utils import logging + + +SPIECE_UNDERLINE = "▁" + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.json", + "spm_file": "sentencepiece.bpe.model", + "tokenizer_config_file": "tokenizer_config.json", +} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "alirezamsh/small100": "https://huggingface.co/alirezamsh/small100/resolve/main/vocab.json", + }, + "spm_file": { + "alirezamsh/small100": "https://huggingface.co/alirezamsh/small100/resolve/main/sentencepiece.bpe.model", + }, + "tokenizer_config_file": { + "alirezamsh/small100": "https://huggingface.co/alirezamsh/small100/resolve/main/tokenizer_config.json", + }, +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "alirezamsh/small100": 1024, +} + +# fmt: off +FAIRSEQ_LANGUAGE_CODES = { + "m2m100": ["af", "am", "ar", "ast", "az", "ba", "be", "bg", "bn", "br", "bs", "ca", "ceb", "cs", "cy", "da", "de", "el", "en", "es", "et", "fa", "ff", "fi", "fr", "fy", "ga", "gd", "gl", "gu", "ha", "he", "hi", "hr", "ht", "hu", "hy", "id", "ig", "ilo", "is", "it", "ja", "jv", "ka", "kk", "km", "kn", "ko", "lb", "lg", "ln", "lo", "lt", "lv", "mg", "mk", "ml", "mn", "mr", "ms", "my", "ne", "nl", "no", "ns", "oc", "or", "pa", "pl", "ps", "pt", "ro", "ru", "sd", "si", "sk", "sl", "so", "sq", "sr", "ss", "su", "sv", "sw", "ta", "th", "tl", "tn", "tr", "uk", "ur", "uz", "vi", "wo", "xh", "yi", "yo", "zh", "zu"] +} +# fmt: on + + +class SMALL100Tokenizer(PreTrainedTokenizer): + """ + Construct an SMALL100 tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece). + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + Args: + vocab_file (`str`): + Path to the vocabulary file. + spm_file (`str`): + Path to [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .spm extension) that + contains the vocabulary. + tgt_lang (`str`, *optional*): + A string representing the target language. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + language_codes (`str`, *optional*): + What language codes to use. Should be `"m2m100"`. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + Examples: + ```python + >>> from tokenization_small100 import SMALL100Tokenizer + >>> tokenizer = SMALL100Tokenizer.from_pretrained("alirezamsh/small100", tgt_lang="ro") + >>> src_text = " UN Chief Says There Is No Military Solution in Syria" + >>> tgt_text = "Şeful ONU declară că nu există o soluţie militară în Siria" + >>> model_inputs = tokenizer(src_text, text_target=tgt_text, return_tensors="pt") + >>> model(**model_inputs) # should work + ```""" + + vocab_files_names = VOCAB_FILES_NAMES + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + model_input_names = ["input_ids", "attention_mask"] + + prefix_tokens: List[int] = [] + suffix_tokens: List[int] = [] + + def __init__( + self, + vocab_file, + spm_file, + tgt_lang=None, + bos_token="", + eos_token="", + sep_token="", + pad_token="", + unk_token="", + language_codes="m2m100", + sp_model_kwargs: Optional[Dict[str, Any]] = None, + num_madeup_words=8, + **kwargs, + ) -> None: + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + + self.language_codes = language_codes + fairseq_language_code = FAIRSEQ_LANGUAGE_CODES[language_codes] + self.lang_code_to_token = {lang_code: f"__{lang_code}__" for lang_code in fairseq_language_code} + + kwargs["additional_special_tokens"] = kwargs.get("additional_special_tokens", []) + kwargs["additional_special_tokens"] += [ + self.get_lang_token(lang_code) + for lang_code in fairseq_language_code + if self.get_lang_token(lang_code) not in kwargs["additional_special_tokens"] + ] + + super().__init__( + tgt_lang=tgt_lang, + bos_token=bos_token, + eos_token=eos_token, + sep_token=sep_token, + unk_token=unk_token, + pad_token=pad_token, + language_codes=language_codes, + sp_model_kwargs=self.sp_model_kwargs, + num_madeup_words=num_madeup_words, + **kwargs, + ) + + self.vocab_file = vocab_file + self.encoder = load_json(vocab_file) + self.decoder = {v: k for k, v in self.encoder.items()} + self.spm_file = spm_file + self.sp_model = load_spm(spm_file, self.sp_model_kwargs) + + self.encoder_size = len(self.encoder) + + self.lang_token_to_id = { + self.get_lang_token(lang_code): self.encoder_size + i for i, lang_code in enumerate(fairseq_language_code) + } + self.lang_code_to_id = {lang_code: self.encoder_size + i for i, lang_code in enumerate(fairseq_language_code)} + self.id_to_lang_token = {v: k for k, v in self.lang_token_to_id.items()} + + self._tgt_lang = tgt_lang if tgt_lang is not None else "en" + self.cur_lang_id = self.get_lang_id(self._tgt_lang) + self.set_lang_special_tokens(self._tgt_lang) + + self.num_madeup_words = num_madeup_words + + @property + def vocab_size(self) -> int: + return len(self.encoder) + len(self.lang_token_to_id) + self.num_madeup_words + + @property + def tgt_lang(self) -> str: + return self._tgt_lang + + @tgt_lang.setter + def tgt_lang(self, new_tgt_lang: str) -> None: + self._tgt_lang = new_tgt_lang + self.set_lang_special_tokens(self._tgt_lang) + + def _tokenize(self, text: str) -> List[str]: + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token): + if token in self.lang_token_to_id: + return self.lang_token_to_id[token] + return self.encoder.get(token, self.encoder[self.unk_token]) + + def _convert_id_to_token(self, index: int) -> str: + """Converts an index (integer) in a token (str) using the decoder.""" + if index in self.id_to_lang_token: + return self.id_to_lang_token[index] + return self.decoder.get(index, self.unk_token) + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + """Converts a sequence of tokens (strings for sub-words) in a single string.""" + return self.sp_model.decode(tokens) + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + prefix_ones = [1] * len(self.prefix_tokens) + suffix_ones = [1] * len(self.suffix_tokens) + if token_ids_1 is None: + return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones + return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. An MBART sequence has the following format, where `X` represents the sequence: + - `input_ids` (for encoder) `X [eos, src_lang_code]` + - `decoder_input_ids`: (for decoder) `X [eos, tgt_lang_code]` + BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a + separator. + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + if token_ids_1 is None: + if self.prefix_tokens is None: + return token_ids_0 + self.suffix_tokens + else: + return self.prefix_tokens + token_ids_0 + self.suffix_tokens + # We don't expect to process pairs, but leave the pair logic for API consistency + if self.prefix_tokens is None: + return token_ids_0 + token_ids_1 + self.suffix_tokens + else: + return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens + + def get_vocab(self) -> Dict: + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def __getstate__(self) -> Dict: + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d: Dict) -> None: + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = load_spm(self.spm_file, self.sp_model_kwargs) + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + save_dir = Path(save_directory) + if not save_dir.is_dir(): + raise OSError(f"{save_directory} should be a directory") + vocab_save_path = save_dir / ( + (filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["vocab_file"] + ) + spm_save_path = save_dir / ( + (filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["spm_file"] + ) + + save_json(self.encoder, vocab_save_path) + + if os.path.abspath(self.spm_file) != os.path.abspath(spm_save_path) and os.path.isfile(self.spm_file): + copyfile(self.spm_file, spm_save_path) + elif not os.path.isfile(self.spm_file): + with open(spm_save_path, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (str(vocab_save_path), str(spm_save_path)) + + def prepare_seq2seq_batch( + self, + src_texts: List[str], + tgt_texts: Optional[List[str]] = None, + tgt_lang: str = "ro", + **kwargs, + ) -> BatchEncoding: + self.tgt_lang = tgt_lang + self.set_lang_special_tokens(self.tgt_lang) + return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs) + + def _build_translation_inputs(self, raw_inputs, tgt_lang: Optional[str], **extra_kwargs): + """Used by translation pipeline, to prepare inputs for the generate function""" + if tgt_lang is None: + raise ValueError("Translation requires a `tgt_lang` for this model") + self.tgt_lang = tgt_lang + inputs = self(raw_inputs, add_special_tokens=True, **extra_kwargs) + return inputs + + def _switch_to_input_mode(self): + self.set_lang_special_tokens(self.tgt_lang) + + def _switch_to_target_mode(self): + self.prefix_tokens = None + self.suffix_tokens = [self.eos_token_id] + + def set_lang_special_tokens(self, src_lang: str) -> None: + """Reset the special tokens to the tgt lang setting. No prefix and suffix=[eos, tgt_lang_code].""" + lang_token = self.get_lang_token(src_lang) + self.cur_lang_id = self.lang_token_to_id[lang_token] + self.prefix_tokens = [self.cur_lang_id] + self.suffix_tokens = [self.eos_token_id] + + def get_lang_token(self, lang: str) -> str: + return self.lang_code_to_token[lang] + + def get_lang_id(self, lang: str) -> int: + lang_token = self.get_lang_token(lang) + return self.lang_token_to_id[lang_token] + + +def load_spm(path: str, sp_model_kwargs: Dict[str, Any]) -> sentencepiece.SentencePieceProcessor: + spm = sentencepiece.SentencePieceProcessor(**sp_model_kwargs) + spm.Load(str(path)) + return spm + + +def load_json(path: str) -> Union[Dict, List]: + with open(path, "r") as f: + return json.load(f) + + +def save_json(data, path: str) -> None: + with open(path, "w") as f: + json.dump(data, f, indent=2) diff --git a/tests/test_translate.py b/tests/test_translate.py index fafbd12f8..9ae4fe44f 100644 --- a/tests/test_translate.py +++ b/tests/test_translate.py @@ -73,5 +73,17 @@ def test_translate(self): "ทดสอบระบบ", ) ) + self.th_fr_translator = Translate('th', 'fr', engine="small100") + self.assertIsNotNone( + self.th_fr_translator.translate( + "ทดสอบระบบ", + ) + ) + self.th_ja_translator = Translate('th', 'ja', engine="small100") + self.assertIsNotNone( + self.th_fr_translator.translate( + "ทดสอบระบบ", + ) + ) with self.assertRaises(ValueError): - self.th_cat_translator = Translate('th', 'cat') + self.th_cat_translator = Translate('th', 'cat', engine="fkfj")