Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit e1d66cf

Browse files
authored
add max_tokens kwarg to vocab factory. (#1525)
1 parent 03afb7e commit e1d66cf

File tree

2 files changed

+47
-6
lines changed

2 files changed

+47
-6
lines changed

test/test_vocab.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# -*- coding: utf-8 -*-
22
from collections import OrderedDict
33
import os
4+
5+
import pytest
46
import torch
57
from test.common.torchtext_test_case import TorchtextTestCase
68
from torchtext.vocab import (
@@ -258,3 +260,34 @@ def test_vocab_specials(self):
258260
expected_stoi = {x: index for index, x in enumerate(expected_itos)}
259261
self.assertEqual(v2.get_itos(), expected_itos)
260262
self.assertEqual(dict(v2.get_stoi()), expected_stoi)
263+
264+
def test_build_vocab_sorts_descending_frequency_then_lexigraphically(self):
265+
it = [["a", "b"], ["a", "b"]]
266+
vocab = build_vocab_from_iterator(it)
267+
self.assertEqual(vocab["a"], 0)
268+
self.assertEqual(vocab["b"], 1)
269+
270+
it = [["a", "b"], ["b"]]
271+
vocab = build_vocab_from_iterator(it)
272+
self.assertEqual(vocab["b"], 0)
273+
self.assertEqual(vocab["a"], 1)
274+
275+
def test_build_vocab_from_iterator_max_tokens(self):
276+
it = [["hello", "world"], ["hello"]]
277+
max_tokens = 1
278+
specials = ["<unk>", "<pad>"]
279+
self.assertLess(max_tokens, len(specials))
280+
with pytest.raises(AssertionError):
281+
build_vocab_from_iterator(it, specials=specials, max_tokens=max_tokens)
282+
283+
max_tokens = 3
284+
vocab = build_vocab_from_iterator(it, specials=specials, special_first=True, max_tokens=max_tokens)
285+
self.assertEqual(vocab["<unk>"], 0)
286+
self.assertEqual(vocab["<pad>"], 1)
287+
self.assertEqual(vocab["hello"], 2)
288+
289+
max_tokens = 3
290+
vocab = build_vocab_from_iterator(it, specials=specials, special_first=False, max_tokens=max_tokens)
291+
self.assertEqual(vocab["hello"], 0)
292+
self.assertEqual(vocab["<unk>"], 1)
293+
self.assertEqual(vocab["<pad>"], 2)

torchtext/vocab/vocab_factory.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def vocab(ordered_dict: Dict, min_freq: int = 1,
4949
ordered_dict.pop(token, None)
5050

5151
tokens = []
52+
# Save room for special tokens
5253
for token, freq in ordered_dict.items():
5354
if freq >= min_freq:
5455
tokens.append(token)
@@ -61,7 +62,7 @@ def vocab(ordered_dict: Dict, min_freq: int = 1,
6162
return Vocab(VocabPybind(tokens, None))
6263

6364

64-
def build_vocab_from_iterator(iterator: Iterable, min_freq: int = 1, specials: Optional[List[str]] = None, special_first: bool = True) -> Vocab:
65+
def build_vocab_from_iterator(iterator: Iterable, min_freq: int = 1, specials: Optional[List[str]] = None, special_first: bool = True, max_tokens: Optional[int] = None) -> Vocab:
6566
"""
6667
Build a Vocab from an iterator.
6768
@@ -70,6 +71,7 @@ def build_vocab_from_iterator(iterator: Iterable, min_freq: int = 1, specials: O
7071
min_freq: The minimum frequency needed to include a token in the vocabulary.
7172
specials: Special symbols to add. The order of supplied tokens will be preserved.
7273
special_first: Indicates whether to insert symbols at the beginning or at the end.
74+
max_tokens: If provided, creates the vocab from the `max_tokens - len(specials)` most frequent tokens.
7375
7476
7577
Returns:
@@ -90,10 +92,16 @@ def build_vocab_from_iterator(iterator: Iterable, min_freq: int = 1, specials: O
9092
for tokens in iterator:
9193
counter.update(tokens)
9294

93-
sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[0])
94-
sorted_by_freq_tuples.sort(key=lambda x: x[1], reverse=True)
95-
ordered_dict = OrderedDict(sorted_by_freq_tuples)
95+
specials = specials or []
96+
97+
# First sort by descending frequency, then lexicographically
98+
sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
99+
100+
if max_tokens is None:
101+
ordered_dict = OrderedDict(sorted_by_freq_tuples)
102+
else:
103+
assert len(specials) < max_tokens, "len(specials) >= max_tokens, so the vocab will be entirely special tokens."
104+
ordered_dict = OrderedDict(sorted_by_freq_tuples[:max_tokens - len(specials)])
96105

97-
word_vocab = vocab(ordered_dict, min_freq=min_freq, specials=specials or [],
98-
special_first=special_first)
106+
word_vocab = vocab(ordered_dict, min_freq=min_freq, specials=specials, special_first=special_first)
99107
return word_vocab

0 commit comments

Comments
 (0)