Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions test/test_vocab.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# -*- coding: utf-8 -*-
from collections import OrderedDict
import os

import pytest
import torch
from test.common.torchtext_test_case import TorchtextTestCase
from torchtext.vocab import (
Expand Down Expand Up @@ -258,3 +260,34 @@ def test_vocab_specials(self):
expected_stoi = {x: index for index, x in enumerate(expected_itos)}
self.assertEqual(v2.get_itos(), expected_itos)
self.assertEqual(dict(v2.get_stoi()), expected_stoi)

def test_build_vocab_sorts_descending_frequency_then_lexigraphically(self):
it = [["a", "b"], ["a", "b"]]
vocab = build_vocab_from_iterator(it)
self.assertEqual(vocab["a"], 0)
self.assertEqual(vocab["b"], 1)

it = [["a", "b"], ["b"]]
vocab = build_vocab_from_iterator(it)
self.assertEqual(vocab["b"], 0)
self.assertEqual(vocab["a"], 1)

def test_build_vocab_from_iterator_max_tokens(self):
it = [["hello", "world"], ["hello"]]
max_tokens = 1
specials = ["<unk>", "<pad>"]
self.assertLess(max_tokens, len(specials))
with pytest.raises(AssertionError):
build_vocab_from_iterator(it, specials=specials, max_tokens=max_tokens)

max_tokens = 3
vocab = build_vocab_from_iterator(it, specials=specials, special_first=True, max_tokens=max_tokens)
self.assertEqual(vocab["<unk>"], 0)
self.assertEqual(vocab["<pad>"], 1)
self.assertEqual(vocab["hello"], 2)

max_tokens = 3
vocab = build_vocab_from_iterator(it, specials=specials, special_first=False, max_tokens=max_tokens)
self.assertEqual(vocab["hello"], 0)
self.assertEqual(vocab["<unk>"], 1)
self.assertEqual(vocab["<pad>"], 2)
20 changes: 14 additions & 6 deletions torchtext/vocab/vocab_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def vocab(ordered_dict: Dict, min_freq: int = 1,
ordered_dict.pop(token, None)

tokens = []
# Save room for special tokens
for token, freq in ordered_dict.items():
if freq >= min_freq:
tokens.append(token)
Expand All @@ -61,7 +62,7 @@ def vocab(ordered_dict: Dict, min_freq: int = 1,
return Vocab(VocabPybind(tokens, None))


def build_vocab_from_iterator(iterator: Iterable, min_freq: int = 1, specials: Optional[List[str]] = None, special_first: bool = True) -> Vocab:
def build_vocab_from_iterator(iterator: Iterable, min_freq: int = 1, specials: Optional[List[str]] = None, special_first: bool = True, max_tokens: Optional[int] = None) -> Vocab:
"""
Build a Vocab from an iterator.

Expand All @@ -70,6 +71,7 @@ def build_vocab_from_iterator(iterator: Iterable, min_freq: int = 1, specials: O
min_freq: The minimum frequency needed to include a token in the vocabulary.
specials: Special symbols to add. The order of supplied tokens will be preserved.
special_first: Indicates whether to insert symbols at the beginning or at the end.
max_tokens: If provided, creates the vocab from the `max_tokens - len(specials)` most frequent tokens.


Returns:
Expand All @@ -90,10 +92,16 @@ def build_vocab_from_iterator(iterator: Iterable, min_freq: int = 1, specials: O
for tokens in iterator:
counter.update(tokens)

sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[0])
sorted_by_freq_tuples.sort(key=lambda x: x[1], reverse=True)
ordered_dict = OrderedDict(sorted_by_freq_tuples)
specials = specials or []

# First sort by descending frequency, then lexicographically
sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, so you combined the two separate sorting operations in 1 :).


if max_tokens is None:
ordered_dict = OrderedDict(sorted_by_freq_tuples)
else:
assert len(specials) < max_tokens, "len(specials) >= max_tokens, so the vocab will be entirely special tokens."
ordered_dict = OrderedDict(sorted_by_freq_tuples[:max_tokens - len(specials)])

word_vocab = vocab(ordered_dict, min_freq=min_freq, specials=specials or [],
special_first=special_first)
word_vocab = vocab(ordered_dict, min_freq=min_freq, specials=specials, special_first=special_first)
return word_vocab