diff --git a/docs/api/corpus.rst b/docs/api/corpus.rst
index 6db3d46ef..702a5d511 100644
--- a/docs/api/corpus.rst
+++ b/docs/api/corpus.rst
@@ -36,11 +36,21 @@ TNC
---
.. autofunction:: pythainlp.corpus.tnc.word_freqs
+.. autofunction:: pythainlp.corpus.tnc.unigram_word_freqs
+.. autofunction:: pythainlp.corpus.tnc.bigram_word_freqs
+.. autofunction:: pythainlp.corpus.tnc.trigram_word_freqs
TTC
---
.. autofunction:: pythainlp.corpus.ttc.word_freqs
+.. autofunction:: pythainlp.corpus.ttc.unigram_word_freqs
+
+OSCAR
+-----
+
+.. autofunction:: pythainlp.corpus.oscar.word_freqs
+.. autofunction:: pythainlp.corpus.oscar.unigram_word_freqs
Util
----
diff --git a/docs/api/generate.rst b/docs/api/generate.rst
new file mode 100644
index 000000000..02459dfc3
--- /dev/null
+++ b/docs/api/generate.rst
@@ -0,0 +1,16 @@
+.. currentmodule:: pythainlp.generate
+
+pythainlp.generate
+==================
+The :class:`pythainlp.generate` is Thai text generate with PyThaiNLP.
+
+Modules
+-------
+
+.. autoclass:: Unigram
+ :members:
+.. autoclass:: Bigram
+ :members:
+.. autoclass:: Trigram
+ :members:
+.. autofunction:: pythainlp.generate.thai2fit.gen_sentence
\ No newline at end of file
diff --git a/pythainlp/corpus/oscar.py b/pythainlp/corpus/oscar.py
new file mode 100644
index 000000000..085f5bc41
--- /dev/null
+++ b/pythainlp/corpus/oscar.py
@@ -0,0 +1,58 @@
+# -*- coding: utf-8 -*-
+"""
+Thai unigram word frequency from OSCAR Corpus (icu word tokenize)
+
+Credit: Korakot Chaovavanich
+https://web.facebook.com/groups/colab.thailand/permalink/1524070061101680/
+"""
+
+__all__ = [
+ "word_freqs",
+ "unigram_word_freqs"
+]
+
+from collections import defaultdict
+from typing import List, Tuple
+
+from pythainlp.corpus import get_corpus_path
+
+_FILENAME = "oscar_icu"
+
+
+def word_freqs() -> List[Tuple[str, int]]:
+ """
+ Get word frequency from OSCAR Corpus (icu word tokenize)
+ """
+ word_freqs = []
+ _path = get_corpus_path(_FILENAME)
+ with open(_path, "r", encoding="utf-8") as f:
+ _data = [i for i in f.readlines()]
+ del _data[0]
+ for line in _data:
+ _temp = line.strip().split(",")
+ if len(_temp) >= 2:
+ if _temp[0] != " " and '"' not in _temp[0]:
+ word_freqs.append((_temp[0], int(_temp[1])))
+ elif _temp[0] == " ":
+ word_freqs.append(("", int(_temp[1])))
+
+ return word_freqs
+
+
+def unigram_word_freqs() -> defaultdict:
+ """
+ Get unigram word frequency from OSCAR Corpus (icu word tokenize)
+ """
+ _path = get_corpus_path(_FILENAME)
+ _word_freqs = defaultdict(int)
+ with open(_path, "r", encoding="utf-8-sig") as fh:
+ _data = [i for i in fh.readlines()]
+ del _data[0]
+ for i in _data:
+ _temp = i.strip().split(",")
+ if _temp[0] != " " and '"' not in _temp[0]:
+ _word_freqs[_temp[0]] = int(_temp[-1])
+ elif _temp[0] == " ":
+ _word_freqs[""] = int(_temp[-1])
+
+ return _word_freqs
diff --git a/pythainlp/corpus/tnc.py b/pythainlp/corpus/tnc.py
index db836ea17..5f80ab972 100644
--- a/pythainlp/corpus/tnc.py
+++ b/pythainlp/corpus/tnc.py
@@ -1,18 +1,25 @@
# -*- coding: utf-8 -*-
"""
Thai National Corpus word frequency
-
-Credit: Korakot Chaovavanich
-https://www.facebook.com/photo.php?fbid=363640477387469&set=gm.434330506948445&type=3&permPage=1
"""
-__all__ = ["word_freqs"]
+__all__ = [
+ "word_freqs",
+ "unigram_word_freqs",
+ "bigram_word_freqs",
+ "trigram_word_freqs"
+]
+from collections import defaultdict
from typing import List, Tuple
from pythainlp.corpus import get_corpus
+from pythainlp.corpus import get_corpus_path
+
_FILENAME = "tnc_freq.txt"
+_BIGRAM = "tnc_bigram_word_freqs"
+_TRIGRAM = "tnc_trigram_word_freqs"
def word_freqs() -> List[Tuple[str, int]]:
@@ -20,6 +27,8 @@ def word_freqs() -> List[Tuple[str, int]]:
Get word frequency from Thai National Corpus (TNC)
\n(See: `dev/pythainlp/corpus/tnc_freq.txt\
`_)
+
+ Credit: Korakot Chaovavanich https://bit.ly/3wSkZsF
"""
lines = list(get_corpus(_FILENAME))
word_freqs = []
@@ -29,3 +38,45 @@ def word_freqs() -> List[Tuple[str, int]]:
word_freqs.append((word_freq[0], int(word_freq[1])))
return word_freqs
+
+
+def unigram_word_freqs() -> defaultdict:
+ """
+ Get unigram word frequency from Thai National Corpus (TNC)
+ """
+ lines = list(get_corpus(_FILENAME))
+ _word_freqs = defaultdict(int)
+ for i in lines:
+ _temp = i.strip().split(" ")
+ if len(_temp) >= 2:
+ _word_freqs[_temp[0]] = int(_temp[-1])
+
+ return _word_freqs
+
+
+def bigram_word_freqs() -> defaultdict:
+ """
+ Get bigram word frequency from Thai National Corpus (TNC)
+ """
+ _path = get_corpus_path(_BIGRAM)
+ _word_freqs = defaultdict(int)
+ with open(_path, "r", encoding="utf-8-sig") as fh:
+ for i in fh.readlines():
+ _temp = i.strip().split(" ")
+ _word_freqs[(_temp[0], _temp[1])] = int(_temp[-1])
+
+ return _word_freqs
+
+
+def trigram_word_freqs() -> defaultdict:
+ """
+ Get trigram word frequency from Thai National Corpus (TNC)
+ """
+ _path = get_corpus_path(_TRIGRAM)
+ _word_freqs = defaultdict(int)
+ with open(_path, "r", encoding="utf-8-sig") as fh:
+ for i in fh.readlines():
+ _temp = i.strip().split(" ")
+ _word_freqs[(_temp[0], _temp[1], _temp[2])] = int(_temp[-1])
+
+ return _word_freqs
diff --git a/pythainlp/corpus/ttc.py b/pythainlp/corpus/ttc.py
index 0de0069c7..c3ffa0c0d 100644
--- a/pythainlp/corpus/ttc.py
+++ b/pythainlp/corpus/ttc.py
@@ -6,8 +6,12 @@
https://www.facebook.com/photo.php?fbid=363640477387469&set=gm.434330506948445&type=3&permPage=1
"""
-__all__ = ["word_freqs"]
+__all__ = [
+ "word_freqs",
+ "unigram_word_freqs"
+]
+from collections import defaultdict
from typing import List, Tuple
from pythainlp.corpus import get_corpus
@@ -29,3 +33,17 @@ def word_freqs() -> List[Tuple[str, int]]:
word_freqs.append((word_freq[0], int(word_freq[1])))
return word_freqs
+
+
+def unigram_word_freqs() -> defaultdict:
+ """
+ Get unigram word frequency from Thai Textbook Corpus (TTC)
+ """
+ lines = list(get_corpus(_FILENAME))
+ _word_freqs = defaultdict(int)
+ for i in lines:
+ _temp = i.strip().split(" ")
+ if len(_temp) >= 2:
+ _word_freqs[_temp[0]] = int(_temp[-1])
+
+ return _word_freqs
diff --git a/pythainlp/generate/__init__.py b/pythainlp/generate/__init__.py
new file mode 100644
index 000000000..fffac652c
--- /dev/null
+++ b/pythainlp/generate/__init__.py
@@ -0,0 +1,12 @@
+# -*- coding: utf-8 -*-
+"""
+Thai Text generate
+"""
+
+__all__ = [
+ "Unigram",
+ "Bigram",
+ "Trigram"
+]
+
+from pythainlp.generate.core import Unigram, Bigram, Trigram
diff --git a/pythainlp/generate/core.py b/pythainlp/generate/core.py
new file mode 100644
index 000000000..7e513c287
--- /dev/null
+++ b/pythainlp/generate/core.py
@@ -0,0 +1,295 @@
+# -*- coding: utf-8 -*-
+"""
+Text generator using n-gram language model
+
+code from
+https://towardsdatascience.com/understanding-word-n-grams-and-n-gram-probability-in-natural-language-processing-9d9eef0fa058
+"""
+import random
+from pythainlp.corpus.tnc import unigram_word_freqs as tnc_word_freqs_unigram
+from pythainlp.corpus.tnc import bigram_word_freqs as tnc_word_freqs_bigram
+from pythainlp.corpus.tnc import trigram_word_freqs as tnc_word_freqs_trigram
+from pythainlp.corpus.ttc import unigram_word_freqs as ttc_word_freqs_unigram
+from pythainlp.corpus.oscar import (
+ unigram_word_freqs as oscar_word_freqs_unigram
+)
+from typing import List, Union
+
+
+class Unigram:
+ """
+ Text generator using Unigram
+
+ :param str name: corpus name
+ * *tnc* - Thai National Corpus (default)
+ * *ttc* - Thai Textbook Corpus (TTC)
+ * *oscar* - OSCAR Corpus
+ """
+ def __init__(self, name: str = "tnc"):
+ if name == "tnc":
+ self.counts = tnc_word_freqs_unigram()
+ elif name == "ttc":
+ self.counts = ttc_word_freqs_unigram()
+ elif name == "oscar":
+ self.counts = oscar_word_freqs_unigram()
+ self.word = list(self.counts.keys())
+ self.n = 0
+ for i in self.word:
+ self.n += self.counts[i]
+ self.prob = {
+ i: self.counts[i] / self.n for i in self.word
+ }
+ self._word_prob = {}
+
+ def gen_sentence(
+ self,
+ start_seq: str = None,
+ N: int = 3,
+ prob: float = 0.001,
+ output_str: bool = True,
+ duplicate: bool = False
+ ) -> Union[List[str], str]:
+ """
+ :param str start_seq: word for begin word.
+ :param int N: number of word.
+ :param bool output_str: output is str
+ :param bool duplicate: duplicate word in sent
+
+ :return: list words or str words
+ :rtype: List[str], str
+
+ :Example:
+ ::
+
+ from pythainlp.generate import Unigram
+
+ gen = Unigram()
+
+ gen.gen_sentence("แมว")
+ # ouput: 'แมวเวลานะนั้น'
+ """
+ if start_seq is None:
+ start_seq = random.choice(self.word)
+ rand_text = start_seq.lower()
+ self._word_prob = {
+ i: self.counts[i] / self.n for i in self.word
+ if self.counts[i] / self.n >= prob
+ }
+ return self._next_word(
+ rand_text,
+ N,
+ output_str,
+ prob=prob,
+ duplicate=duplicate
+ )
+
+ def _next_word(
+ self,
+ text: str,
+ N: int,
+ output_str: str,
+ prob: float,
+ duplicate: bool = False
+ ):
+ self.words = []
+ self.words.append(text)
+ self._word_list = list(self._word_prob.keys())
+ if N > len(self._word_list):
+ N = len(self._word_list)
+ for i in range(N):
+ self._word = random.choice(self._word_list)
+ if duplicate is False:
+ while self._word in self.words:
+ self._word = random.choice(self._word_list)
+ self.words.append(self._word)
+
+ if output_str:
+ return "".join(self.words)
+ return self.words
+
+
+class Bigram:
+ """
+ Text generator using Bigram
+
+ :param str name: corpus name
+ * *tnc* - Thai National Corpus (default)
+ """
+ def __init__(self, name: str = "tnc"):
+ if name == "tnc":
+ self.uni = tnc_word_freqs_unigram()
+ self.bi = tnc_word_freqs_bigram()
+ self.uni_keys = list(self.uni.keys())
+ self.bi_keys = list(self.bi.keys())
+ self.words = [i[-1] for i in self.bi_keys]
+
+ def prob(self, t1: str, t2: str) -> float:
+ """
+ probability word
+
+ :param int t1: text 1
+ :param int t2: text 2
+
+ :return: probability value
+ :rtype: float
+ """
+ try:
+ v = self.bi[(t1, t2)] / self.uni[t1]
+ except:
+ v = 0.0
+ return v
+
+ def gen_sentence(
+ self,
+ start_seq: str = None,
+ N: int = 4,
+ prob: float = 0.001,
+ output_str: bool = True,
+ duplicate: bool = False
+ ) -> Union[List[str], str]:
+ """
+ :param str start_seq: word for begin word.
+ :param int N: number of word.
+ :param bool output_str: output is str
+ :param bool duplicate: duplicate word in sent
+
+ :return: list words or str words
+ :rtype: List[str], str
+
+ :Example:
+ ::
+
+ from pythainlp.generate import Bigram
+
+ gen = Bigram()
+
+ gen.gen_sentence("แมว")
+ # ouput: 'แมวไม่ได้รับเชื้อมัน'
+ """
+ if start_seq is None:
+ start_seq = random.choice(self.words)
+ self.late_word = start_seq
+ self.list_word = []
+ self.list_word.append(start_seq)
+
+ for i in range(N):
+ if duplicate:
+ self._temp = [
+ j for j in self.bi_keys if j[0] == self.late_word
+ ]
+ else:
+ self._temp = [
+ j for j in self.bi_keys
+ if j[0] == self.late_word and j[1] not in self.list_word
+ ]
+ self._probs = [
+ self.prob(
+ self.late_word, next_word[-1]
+ ) for next_word in self._temp
+ ]
+ self._p2 = [j for j in self._probs if j >= prob]
+ if len(self._p2) == 0:
+ break
+ self.items = self._temp[self._probs.index(random.choice(self._p2))]
+ self.late_word = self.items[-1]
+ self.list_word.append(self.late_word)
+ if output_str:
+ return ''.join(self.list_word)
+ return self.list_word
+
+
+class Trigram:
+ """
+ Text generator using Trigram
+
+ :param str name: corpus name
+ * *tnc* - Thai National Corpus (default)
+ """
+ def __init__(self, name: str = "tnc"):
+ if name == "tnc":
+ self.uni = tnc_word_freqs_unigram()
+ self.bi = tnc_word_freqs_bigram()
+ self.ti = tnc_word_freqs_trigram()
+ self.uni_keys = list(self.uni.keys())
+ self.bi_keys = list(self.bi.keys())
+ self.ti_keys = list(self.ti.keys())
+ self.words = [i[-1] for i in self.bi_keys]
+
+ def prob(self, t1: str, t2: str, t3: str) -> float:
+ """
+ probability word
+
+ :param int t1: text 1
+ :param int t2: text 2
+ :param int t3: text 3
+
+ :return: probability value
+ :rtype: float
+ """
+ try:
+ v = self.ti[(t1, t2, t3)] / self.bi[(t1, t2)]
+ except:
+ v = 0.0
+
+ return v
+
+ def gen_sentence(
+ self,
+ start_seq: str = None,
+ N: int = 4,
+ prob: float = 0.001,
+ output_str: bool = True,
+ duplicate: bool = False
+ ) -> Union[List[str], str]:
+ """
+ :param str start_seq: word for begin word.
+ :param int N: number of word.
+ :param bool output_str: output is str
+ :param bool duplicate: duplicate word in sent
+
+ :return: list words or str words
+ :rtype: List[str], str
+
+ :Example:
+ ::
+
+ from pythainlp.generate import Trigram
+
+ gen = Trigram()
+
+ gen.gen_sentence()
+ # ouput: 'ยังทำตัวเป็นเซิร์ฟเวอร์คือ'
+ """
+ if start_seq is None:
+ start_seq = random.choice(self.bi_keys)
+ self.late_word = start_seq
+ self.list_word = []
+ self.list_word.append(start_seq)
+
+ for i in range(N):
+ if duplicate:
+ self._temp = [
+ j for j in self.ti_keys if j[:2] == self.late_word
+ ]
+ else:
+ self._temp = [
+ j for j in self.ti_keys
+ if j[:2] == self.late_word and j[1:] not in self.list_word
+ ]
+ self._probs = [
+ self.prob(word[0], word[1], word[2]) for word in self._temp
+ ]
+ self._p2 = [j for j in self._probs if j >= prob]
+ if len(self._p2) == 0:
+ break
+ self.items = self._temp[self._probs.index(random.choice(self._p2))]
+ self.late_word = self.items[1:]
+ self.list_word.append(self.late_word)
+ self.listdata = []
+ for i in self.list_word:
+ for j in i:
+ if j not in self.listdata:
+ self.listdata.append(j)
+ if output_str:
+ return ''.join(self.listdata)
+ return self.listdata
diff --git a/pythainlp/generate/thai2fit.py b/pythainlp/generate/thai2fit.py
new file mode 100644
index 000000000..f299c6648
--- /dev/null
+++ b/pythainlp/generate/thai2fit.py
@@ -0,0 +1,126 @@
+# -*- coding: utf-8 -*-
+"""
+Thai2fit: Thai Wikipeida Language Model for Text Generation
+
+Code from
+https://github.com/PyThaiNLP/tutorials/blob/master/source/notebooks/text_generation.ipynb
+"""
+__all__ = [
+ "gen_sentence"
+]
+
+import pandas as pd
+import random
+import pickle
+from typing import List, Union
+
+# fastai
+import fastai
+from fastai.text import *
+
+# pythainlp
+from pythainlp.ulmfit import *
+
+# get dummy data
+imdb = untar_data(URLs.IMDB_SAMPLE)
+dummy_df = pd.read_csv(imdb/'texts.csv')
+
+# get vocab
+thwiki = ""
+try:
+ thwiki = _THWIKI_LSTM
+except:
+ thwiki = THWIKI_LSTM
+
+thwiki_itos = pickle.load(open(thwiki['itos_fname'], 'rb'))
+thwiki_vocab = fastai.text.transform.Vocab(thwiki_itos)
+
+# dummy databunch
+tt = Tokenizer(
+ tok_func=ThaiTokenizer,
+ lang='th',
+ pre_rules=pre_rules_th,
+ post_rules=post_rules_th
+)
+processor = [
+ TokenizeProcessor(tokenizer=tt, chunksize=10000, mark_fields=False),
+ NumericalizeProcessor(vocab=thwiki_vocab, max_vocab=60000, min_freq=3)
+]
+data_lm = (
+ TextList.from_df(dummy_df, imdb, cols=['text'], processor=processor)
+ .split_by_rand_pct(0.2)
+ .label_for_lm()
+ .databunch(bs=64)
+)
+
+
+data_lm.sanity_check()
+
+config = dict(
+ emb_sz=400,
+ n_hid=1550,
+ n_layers=4,
+ pad_token=1,
+ qrnn=False,
+ tie_weights=True,
+ out_bias=True,
+ output_p=0.25,
+ hidden_p=0.1,
+ input_p=0.2,
+ embed_p=0.02,
+ weight_p=0.15
+)
+trn_args = dict(drop_mult=0.9, clip=0.12, alpha=2, beta=1)
+
+learn = language_model_learner(
+ data_lm,
+ AWD_LSTM,
+ config=config,
+ pretrained=False,
+ **trn_args
+)
+
+# load pretrained models
+learn.load_pretrained(**thwiki)
+
+
+def gen_sentence(
+ start_seq: str = None,
+ N: int = 4,
+ prob: float = 0.001,
+ output_str: bool = True
+) -> Union[List[str], str]:
+ """
+ Text generator using Thai2fit
+
+ :param str start_seq: word for begin word.
+ :param int N: number of word.
+ :param bool output_str: output is str
+ :param bool duplicate: duplicate word in sent
+
+ :return: list words or str words
+ :rtype: List[str], str
+
+ :Example:
+ ::
+
+ from pythainlp.generate.thai2fit import gen_sentence
+
+ gen_sentence()
+ # output: 'แคทรียา อิงลิช (นักแสดง'
+
+ gen_sentence("แมว")
+ # output: 'แมว คุณหลวง '
+ """
+ if start_seq is None:
+ start_seq = random.choice(list(thwiki_itos))
+ list_word = learn.predict(
+ start_seq,
+ N,
+ temperature=0.8,
+ min_p=prob,
+ sep='-*-'
+ ).split('-*-')
+ if output_str:
+ return ''.join(list_word)
+ return list_word
diff --git a/setup.py b/setup.py
index 2f172be68..d6706f196 100644
--- a/setup.py
+++ b/setup.py
@@ -63,6 +63,7 @@
"wangchanberta": ["transformers", "sentencepiece"],
"mt5": ["transformers>=4.6.0", "sentencepiece>=0.1.91"],
"wordnet": ["nltk>=3.3.*"],
+ "generate": ["fastai<2.0"],
"sefr_cut": ["sefr_cut"],
"full": [
"PyYAML>=5.3.1",
@@ -79,6 +80,7 @@
"sentencepiece>=0.1.91",
"ssg>=0.0.6",
"torch>=1.0.0",
+ "fastai<2.0",
"bpemb",
"transformers>=4.6.0",
"sefr_cut"
diff --git a/tests/test_corpus.py b/tests/test_corpus.py
index 2f31f66fc..792f70fd3 100644
--- a/tests/test_corpus.py
+++ b/tests/test_corpus.py
@@ -11,6 +11,7 @@
get_corpus_db_detail,
get_corpus_default_db,
get_corpus_path,
+ oscar,
provinces,
remove,
thai_family_names,
@@ -104,11 +105,19 @@ def test_corpus(self):
self.assertIsNotNone(download(name="test", version="0.1"))
self.assertIsNotNone(remove("test"))
+ def test_oscar(self):
+ self.assertIsNotNone(oscar.word_freqs())
+ self.assertIsNotNone(oscar.unigram_word_freqs())
+
def test_tnc(self):
self.assertIsNotNone(tnc.word_freqs())
+ self.assertIsNotNone(tnc.unigram_word_freqs())
+ self.assertIsNotNone(tnc.bigram_word_freqs())
+ self.assertIsNotNone(tnc.trigram_word_freqs())
def test_ttc(self):
self.assertIsNotNone(ttc.word_freqs())
+ self.assertIsNotNone(ttc.unigram_word_freqs())
def test_wordnet(self):
self.assertIsInstance(wordnet.langs(), list)
diff --git a/tests/test_generate.py b/tests/test_generate.py
new file mode 100644
index 000000000..6405c679e
--- /dev/null
+++ b/tests/test_generate.py
@@ -0,0 +1,45 @@
+# -*- coding: utf-8 -*-
+
+import unittest
+
+from pythainlp.generate import Unigram, Bigram, Trigram
+from pythainlp.generate.thai2fit import gen_sentence
+
+
+class TestGeneratePackage(unittest.TestCase):
+ def test_unigram(self):
+ _tnc_unigram = Unigram("tnc")
+ self.assertIsNotNone(_tnc_unigram.gen_sentence("ผม"))
+ self.assertIsNotNone(_tnc_unigram.gen_sentence("ผม", output_str=False))
+ self.assertIsNotNone(_tnc_unigram.gen_sentence())
+ self.assertIsNotNone(_tnc_unigram.gen_sentence(duplicate=True))
+ _ttc_unigram = Unigram("ttc")
+ self.assertIsNotNone(_ttc_unigram.gen_sentence("ผม"))
+ self.assertIsNotNone(_ttc_unigram.gen_sentence("ผม", output_str=False))
+ self.assertIsNotNone(_ttc_unigram.gen_sentence())
+ self.assertIsNotNone(_ttc_unigram.gen_sentence(duplicate=True))
+ _oscar_unigram = Unigram("oscar")
+ self.assertIsNotNone(_oscar_unigram.gen_sentence("ผม"))
+ self.assertIsNotNone(
+ _oscar_unigram.gen_sentence("ผม", output_str=False)
+ )
+ self.assertIsNotNone(_oscar_unigram.gen_sentence())
+ self.assertIsNotNone(_oscar_unigram.gen_sentence(duplicate=True))
+
+ def test_bigram(self):
+ _bigram = Bigram()
+ self.assertIsNotNone(_bigram.gen_sentence("ผม"))
+ self.assertIsNotNone(_bigram.gen_sentence("ผม", output_str=False))
+ self.assertIsNotNone(_bigram.gen_sentence())
+ self.assertIsNotNone(_bigram.gen_sentence(duplicate=True))
+
+ def test_trigram(self):
+ _trigram = Trigram()
+ self.assertIsNotNone(_trigram.gen_sentence("ผม"))
+ self.assertIsNotNone(_trigram.gen_sentence("ผม", output_str=False))
+ self.assertIsNotNone(_trigram.gen_sentence())
+ self.assertIsNotNone(_trigram.gen_sentence(duplicate=True))
+
+ def test_thai2fit(self):
+ self.assertIsNotNone(gen_sentence("กาลครั้งหนึ่งนานมาแล้ว"))
+ self.assertIsNotNone(gen_sentence())
diff --git a/tests/test_ulmfit.py b/tests/test_ulmfit.py
index a713bd8b6..3aa807704 100644
--- a/tests/test_ulmfit.py
+++ b/tests/test_ulmfit.py
@@ -30,6 +30,15 @@
ungroup_emoji,
)
from pythainlp.ulmfit.tokenizer import BaseTokenizer
+import pandas as pd
+import random
+import pickle
+# fastai
+import fastai
+from fastai.text import *
+
+# pythainlp
+from pythainlp.ulmfit import *
class TestUlmfitPackage(unittest.TestCase):
@@ -198,3 +207,71 @@ def test_process_thai_dense(self):
]
self.assertEqual(actual, expect)
+
+ def test_document_vector(self):
+ imdb = untar_data(URLs.IMDB_SAMPLE)
+ dummy_df = pd.read_csv(imdb/'texts.csv')
+ thwiki = ""
+ try:
+ thwiki = _THWIKI_LSTM
+ except:
+ thwiki = THWIKI_LSTM
+ thwiki_itos = pickle.load(open(thwiki['itos_fname'], 'rb'))
+ thwiki_vocab = fastai.text.transform.Vocab(thwiki_itos)
+ tt = Tokenizer(
+ tok_func=ThaiTokenizer,
+ lang='th',
+ pre_rules=pre_rules_th,
+ post_rules=post_rules_th
+ )
+ processor = [
+ TokenizeProcessor(
+ tokenizer=tt, chunksize=10000, mark_fields=False
+ ),
+ NumericalizeProcessor(
+ vocab=thwiki_vocab, max_vocab=60000, min_freq=3
+ )
+ ]
+ data_lm = (
+ TextList.from_df(
+ dummy_df,
+ imdb,
+ cols=['text'],
+ processor=processor
+ )
+ .split_by_rand_pct(0.2)
+ .label_for_lm()
+ .databunch(bs=64)
+ )
+ data_lm.sanity_check()
+ config = dict(
+ emb_sz=400,
+ n_hid=1550,
+ n_layers=4,
+ pad_token=1,
+ qrnn=False,
+ tie_weights=True,
+ out_bias=True,
+ output_p=0.25,
+ hidden_p=0.1,
+ input_p=0.2,
+ embed_p=0.02,
+ weight_p=0.15
+ )
+ trn_args = dict(drop_mult=0.9, clip=0.12, alpha=2, beta=1)
+ learn = language_model_learner(
+ data_lm,
+ AWD_LSTM,
+ config=config,
+ pretrained=False,
+ **trn_args
+ )
+ learn.load_pretrained(**thwiki)
+ self.assertIsNotNone(
+ document_vector('วันนี้วันดีปีใหม่', learn, data_lm)
+ )
+ self.assertIsNotNone(
+ document_vector('วันนี้วันดีปีใหม่', learn, data_lm, agg="sum")
+ )
+ with self.assertRaises(ValueError):
+ document_vector('วันนี้วันดีปีใหม่', learn, data_lm, agg='abc')