diff --git a/pythainlp/util/trie.py b/pythainlp/util/trie.py index fd06842e0..a1648a929 100644 --- a/pythainlp/util/trie.py +++ b/pythainlp/util/trie.py @@ -40,6 +40,32 @@ def add(self, word: str) -> None: cur = child cur.end = True + def remove(self, word: str) -> None: + """ + Remove a word from the trie. + If the word is not found, do nothing. + + :param str text: a word + """ + # remove from set first + if word not in self.words: + return + self.words.remove(word) + # then remove from nodes + parent = self.root + data = [] # track path to leaf + for ch in word: + child = parent.children[ch] + data.append((parent, child, ch)) + parent = child + # remove the last one + child.end = False + # prune up the tree + for parent, child, ch in reversed(data): + if child.end or child.children: + break + del parent.children[ch] # remove from parent dict + def prefixes(self, text: str) -> List[str]: """ List all possible words from first sequence of characters in a word. @@ -71,11 +97,11 @@ def __len__(self) -> int: def dict_trie(dict_source: Union[str, Iterable[str], Trie]) -> Trie: """ - Create a dictionary trie from a string or an iterable. + Create a dictionary trie from a file or an iterable. :param str|Iterable[str]|pythainlp.util.Trie dict_source: a path to dictionary file or a list of words or a pythainlp.util.Trie object - :return: a trie object created from a dictionary input + :return: a trie object :rtype: pythainlp.util.Trie """ trie = None diff --git a/tests/test_util.py b/tests/test_util.py index 6ab88674c..aca0db883 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -388,6 +388,15 @@ def test_trie(self): self.assertEqual(len(trie), 4) self.assertEqual(len(trie.prefixes("ทดสอบ")), 2) + trie.remove("ทบ") + trie.remove("ทด") + self.assertEqual(len(trie), 2) + + trie = Trie([]) + self.assertEqual(len(trie), 0) + trie.remove("หมด") + self.assertEqual(len(trie), 0) + self.assertIsNotNone(dict_trie(Trie(["ลอง", "ลาก"]))) self.assertIsNotNone(dict_trie(("ลอง", "สร้าง", "Trie", "ลน"))) self.assertIsNotNone(dict_trie(["ลอง", "สร้าง", "Trie", "ลน"]))