@@ -180,13 +180,36 @@ def pos_tag_sents(
180180
181181
182182def pos_tag_transformers (
183- words : str , engine : str = "bert-base-th-cased-blackboard"
184- ):
183+ sentence : str ,
184+ engine : str = "bert" ,
185+ corpus : str = "blackboard" ,
186+ )-> List [List [Tuple [str , str ]]]:
185187 """
186- "wangchanberta-ud-thai-pud-upos",
187- "mdeberta-v3-ud-thai-pud-upos",
188- "bert-base-th-cased-blackboard",
188+ Marks sentences with part-of-speech (POS) tags.
189+
190+ :param str sentence: a list of lists of tokenized words
191+ :param str engine:
192+ * *bert* - BERT: Bidirectional Encoder Representations from Transformers (default)
193+ * *wangchanberta* - fine-tuned version of airesearch/wangchanberta-base-att-spm-uncased on pud corpus (support PUD cotpus only)
194+ * *mdeberta* - mDeBERTa: Multilingual Decoding-enhanced BERT with disentangled attention (support PUD corpus only)
195+ :param str corpus: the corpus that is used to create the language model for tagger
196+ * *blackboard* - `blackboard treebank (support bert engine only) <https://bitbucket.org/kaamanita/blackboard-treebank/src/master/>`_
197+ * *pud* - `Parallel Universal Dependencies (PUD)\
198+ <https://github.com/UniversalDependencies/UD_Thai-PUD>`_ \
199+ treebanks, natively use Universal POS tags (support wangchanberta and mdeberta engine)
200+ :return: a list of lists of tuples (word, POS tag)
201+ :rtype: list[list[tuple[str, str]]]
189202
203+ :Example:
204+
205+ Labels POS for given sentence::
206+
207+ from pythainlp.tag import pos_tag_transformers
208+
209+ sentences = "แมวทำอะไรตอนห้าโมงเช้า"
210+ pos_tag_transformers(sentences, engine="bert", corpus='blackboard')
211+ # output:
212+ # [[('แมว', 'NOUN'), ('ทําอะไร', 'VERB'), ('ตอนห้าโมงเช้า', 'NOUN')]]
190213 """
191214
192215 try :
@@ -196,28 +219,35 @@ def pos_tag_transformers(
196219 raise ImportError (
197220 "Not found transformers! Please install transformers by pip install transformers" )
198221
199- if not words :
222+ if not sentence :
200223 return []
201224
202- if engine == "wangchanberta-ud-thai-pud-upos" :
203- model = AutoModelForTokenClassification .from_pretrained (
204- "Pavarissy/wangchanberta-ud-thai-pud-upos" )
205- tokenizer = AutoTokenizer .from_pretrained ("Pavarissy/wangchanberta-ud-thai-pud-upos" )
206- elif engine == "mdeberta-v3-ud-thai-pud-upos" :
207- model = AutoModelForTokenClassification .from_pretrained (
208- "Pavarissy/mdeberta-v3-ud-thai-pud-upos" )
209- tokenizer = AutoTokenizer .from_pretrained ("Pavarissy/mdeberta-v3-ud-thai-pud-upos" )
210- elif engine == "bert-base-th-cased-blackboard" :
211- model = AutoModelForTokenClassification .from_pretrained ("lunarlist/pos_thai" )
212- tokenizer = AutoTokenizer .from_pretrained ("lunarlist/pos_thai" )
225+ _blackboard_support_engine = {
226+ "bert" : "lunarlist/pos_thai" ,
227+ }
228+
229+ _pud_support_engine = {
230+ "wangchanberta" : "Pavarissy/wangchanberta-ud-thai-pud-upos" ,
231+ "mdeberta" : "Pavarissy/mdeberta-v3-ud-thai-pud-upos" ,
232+ }
233+
234+ if corpus == 'blackboard' and engine in _blackboard_support_engine .keys ():
235+ base_model = _blackboard_support_engine .get (engine )
236+ model = AutoModelForTokenClassification .from_pretrained (base_model )
237+ tokenizer = AutoTokenizer .from_pretrained (base_model )
238+ elif corpus == 'pud' and engine in _pud_support_engine .keys ():
239+ base_model = _pud_support_engine .get (engine )
240+ model = AutoModelForTokenClassification .from_pretrained (base_model )
241+ tokenizer = AutoTokenizer .from_pretrained (base_model )
213242 else :
214243 raise ValueError (
215- "pos_tag_transformers not support {0} engine." .format (
216- engine
244+ "pos_tag_transformers not support {0} engine or {1} corpus ." .format (
245+ engine , corpus
217246 )
218247 )
219248
220- pipeline = TokenClassificationPipeline (model = model , tokenizer = tokenizer , grouped_entities = True )
249+ pipeline = TokenClassificationPipeline (model = model , tokenizer = tokenizer , aggregation_strategy = "simple" )
221250
222- outputs = pipeline (words )
223- return outputs
251+ outputs = pipeline (sentence )
252+ word_tags = [[(tag ['word' ], tag ['entity_group' ]) for tag in outputs ]]
253+ return word_tags
0 commit comments