Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 457747f

Browse files
committed
revert potential breakage where lexicographical sorting is done. Simplify logic for sorting.
1 parent 4123122 commit 457747f

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

test/test_vocab.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,17 @@ def test_vocab_specials(self):
261261
self.assertEqual(v2.get_itos(), expected_itos)
262262
self.assertEqual(dict(v2.get_stoi()), expected_stoi)
263263

264+
def test_build_vocab_sorts_descending_frequency_then_lexigraphically(self):
265+
it = [["a", "b"], ["a", "b"]]
266+
vocab = build_vocab_from_iterator(it)
267+
self.assertEqual(vocab["a"], 0)
268+
self.assertEqual(vocab["b"], 1)
269+
270+
it = [["a", "b"], ["b"]]
271+
vocab = build_vocab_from_iterator(it)
272+
self.assertEqual(vocab["b"], 0)
273+
self.assertEqual(vocab["a"], 1)
274+
264275
def test_build_vocab_from_iterator_max_tokens(self):
265276
it = [["hello", "world"], ["hello"]]
266277
max_tokens = 1

torchtext/vocab/vocab_factory.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,15 @@ def build_vocab_from_iterator(iterator: Iterable, min_freq: int = 1, specials: O
9393
counter.update(tokens)
9494

9595
specials = specials or []
96+
97+
# First sort by descending frequency, then lexicographically
98+
sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
99+
96100
if max_tokens is None:
97-
ordered_dict = OrderedDict(counter.most_common())
101+
ordered_dict = OrderedDict(sorted_by_freq_tuples)
98102
else:
99103
assert len(specials) < max_tokens, "len(specials) >= max_tokens, so the vocab will be entirely special tokens."
100-
ordered_dict = OrderedDict(counter.most_common(max_tokens - len(specials)))
104+
ordered_dict = OrderedDict(sorted_by_freq_tuples[:max_tokens - len(specials)])
101105

102106
word_vocab = vocab(ordered_dict, min_freq=min_freq, specials=specials, special_first=special_first)
103107
return word_vocab

0 commit comments

Comments
 (0)