diff --git a/test/test_vocab.py b/test/test_vocab.py index 1f938731b7..6b04746e42 100644 --- a/test/test_vocab.py +++ b/test/test_vocab.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- from collections import OrderedDict import os + +import pytest import torch from test.common.torchtext_test_case import TorchtextTestCase from torchtext.vocab import ( @@ -258,3 +260,34 @@ def test_vocab_specials(self): expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v2.get_itos(), expected_itos) self.assertEqual(dict(v2.get_stoi()), expected_stoi) + + def test_build_vocab_sorts_descending_frequency_then_lexigraphically(self): + it = [["a", "b"], ["a", "b"]] + vocab = build_vocab_from_iterator(it) + self.assertEqual(vocab["a"], 0) + self.assertEqual(vocab["b"], 1) + + it = [["a", "b"], ["b"]] + vocab = build_vocab_from_iterator(it) + self.assertEqual(vocab["b"], 0) + self.assertEqual(vocab["a"], 1) + + def test_build_vocab_from_iterator_max_tokens(self): + it = [["hello", "world"], ["hello"]] + max_tokens = 1 + specials = ["", ""] + self.assertLess(max_tokens, len(specials)) + with pytest.raises(AssertionError): + build_vocab_from_iterator(it, specials=specials, max_tokens=max_tokens) + + max_tokens = 3 + vocab = build_vocab_from_iterator(it, specials=specials, special_first=True, max_tokens=max_tokens) + self.assertEqual(vocab[""], 0) + self.assertEqual(vocab[""], 1) + self.assertEqual(vocab["hello"], 2) + + max_tokens = 3 + vocab = build_vocab_from_iterator(it, specials=specials, special_first=False, max_tokens=max_tokens) + self.assertEqual(vocab["hello"], 0) + self.assertEqual(vocab[""], 1) + self.assertEqual(vocab[""], 2) diff --git a/torchtext/vocab/vocab_factory.py b/torchtext/vocab/vocab_factory.py index 835886e96f..6a7880c8a3 100644 --- a/torchtext/vocab/vocab_factory.py +++ b/torchtext/vocab/vocab_factory.py @@ -49,6 +49,7 @@ def vocab(ordered_dict: Dict, min_freq: int = 1, ordered_dict.pop(token, None) tokens = [] + # Save room for special tokens for token, freq in ordered_dict.items(): if freq >= min_freq: tokens.append(token) @@ -61,7 +62,7 @@ def vocab(ordered_dict: Dict, min_freq: int = 1, return Vocab(VocabPybind(tokens, None)) -def build_vocab_from_iterator(iterator: Iterable, min_freq: int = 1, specials: Optional[List[str]] = None, special_first: bool = True) -> Vocab: +def build_vocab_from_iterator(iterator: Iterable, min_freq: int = 1, specials: Optional[List[str]] = None, special_first: bool = True, max_tokens: Optional[int] = None) -> Vocab: """ Build a Vocab from an iterator. @@ -70,6 +71,7 @@ def build_vocab_from_iterator(iterator: Iterable, min_freq: int = 1, specials: O min_freq: The minimum frequency needed to include a token in the vocabulary. specials: Special symbols to add. The order of supplied tokens will be preserved. special_first: Indicates whether to insert symbols at the beginning or at the end. + max_tokens: If provided, creates the vocab from the `max_tokens - len(specials)` most frequent tokens. Returns: @@ -90,10 +92,16 @@ def build_vocab_from_iterator(iterator: Iterable, min_freq: int = 1, specials: O for tokens in iterator: counter.update(tokens) - sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[0]) - sorted_by_freq_tuples.sort(key=lambda x: x[1], reverse=True) - ordered_dict = OrderedDict(sorted_by_freq_tuples) + specials = specials or [] + + # First sort by descending frequency, then lexicographically + sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: (-x[1], x[0])) + + if max_tokens is None: + ordered_dict = OrderedDict(sorted_by_freq_tuples) + else: + assert len(specials) < max_tokens, "len(specials) >= max_tokens, so the vocab will be entirely special tokens." + ordered_dict = OrderedDict(sorted_by_freq_tuples[:max_tokens - len(specials)]) - word_vocab = vocab(ordered_dict, min_freq=min_freq, specials=specials or [], - special_first=special_first) + word_vocab = vocab(ordered_dict, min_freq=min_freq, specials=specials, special_first=special_first) return word_vocab