Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions test/test_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,20 @@ def test_build_vocab_iterator(self):
expected_stoi = {x: index for index, x in enumerate(expected_itos)}
self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(v.get_stoi()), expected_stoi)

def test_vocab_specials(self):
token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}
sorted_by_freq_tuples = OrderedDict(sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True))
specials = ["<unk>", "<bos>", "<eos>", "pad"]

v1 = vocab(sorted_by_freq_tuples, min_freq=3, specials=specials)
expected_itos = specials + ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']
expected_stoi = {x: index for index, x in enumerate(expected_itos)}
self.assertEqual(v1.get_itos(), expected_itos)
self.assertEqual(dict(v1.get_stoi()), expected_stoi)

v2 = vocab(sorted_by_freq_tuples, min_freq=3, specials=specials, special_first=False)
expected_itos = ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] + specials
expected_stoi = {x: index for index, x in enumerate(expected_itos)}
self.assertEqual(v2.get_itos(), expected_itos)
self.assertEqual(dict(v2.get_stoi()), expected_stoi)
31 changes: 16 additions & 15 deletions torchtext/vocab/vocab_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
)


def vocab(ordered_dict: Dict, min_freq: int = 1) -> Vocab:
def vocab(ordered_dict: Dict, min_freq: int = 1,
specials: Optional[List[str]] = None,
special_first: bool = True) -> Vocab:
r"""Factory method for creating a vocab object which maps tokens to indices.

Note that the ordering in which key value pairs were inserted in the `ordered_dict` will be respected when building the vocab.
Expand All @@ -15,6 +17,8 @@ def vocab(ordered_dict: Dict, min_freq: int = 1) -> Vocab:
Args:
ordered_dict: Ordered Dictionary mapping tokens to their corresponding occurance frequencies.
min_freq: The minimum frequency needed to include a token in the vocabulary.
specials: Special symbols to add. The order of supplied tokens will be preserved.
special_first: Indicates whether to insert symbols at the beginning or at the end.

Returns:
torchtext.vocab.Vocab: A `Vocab` object
Expand All @@ -29,24 +33,31 @@ def vocab(ordered_dict: Dict, min_freq: int = 1) -> Vocab:
>>> print(v1['a']) #prints 1
>>> print(v1['out of vocab']) #raise RuntimeError since default index is not set
>>> tokens = ['e', 'd', 'c', 'b', 'a']
>>> v2 = vocab(OrderedDict([(token, 1) for token in tokens]))
>>> #adding <unk> token and default index
>>> unk_token = '<unk>'
>>> default_index = -1
>>> if unk_token not in v2: v2.insert_token(unk_token, 0)
>>> v2 = vocab(OrderedDict([(token, 1) for token in tokens]), specials=[unk_token])
>>> v2.set_default_index(default_index)
>>> print(v2['<unk>']) #prints 0
>>> print(v2['out of vocab']) #prints -1
>>> #make default index same as index of unk_token
>>> v2.set_default_index(v2[unk_token])
>>> v2['out of vocab'] is v2[unk_token] #prints True
"""
specials = specials or []
for token in specials:
ordered_dict.pop(token, None)

tokens = []
for token, freq in ordered_dict.items():
if freq >= min_freq:
tokens.append(token)

if special_first:
tokens[0:0] = specials
else:
tokens.extend(specials)

return Vocab(VocabPybind(tokens, None))


Expand Down Expand Up @@ -79,20 +90,10 @@ def build_vocab_from_iterator(iterator: Iterable, min_freq: int = 1, specials: O
for tokens in iterator:
counter.update(tokens)

if specials is not None:
for tok in specials:
del counter[tok]

sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[0])
sorted_by_freq_tuples.sort(key=lambda x: x[1], reverse=True)
ordered_dict = OrderedDict(sorted_by_freq_tuples)

if specials is not None:
if special_first:
specials = specials[::-1]
for symbol in specials:
ordered_dict.update({symbol: min_freq})
ordered_dict.move_to_end(symbol, last=not special_first)

word_vocab = vocab(ordered_dict, min_freq=min_freq)
word_vocab = vocab(ordered_dict, min_freq=min_freq, specials=specials or [],
special_first=special_first)
return word_vocab