From 86bee5074f8fcac353356d0c3374833ea2a1551b Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Sun, 16 Jan 2022 09:16:49 -0500 Subject: [PATCH 1/5] add max_tokens kwarg to vocab and vocab factory. --- test/test_vocab.py | 17 +++++++++++++++++ torchtext/vocab/vocab_factory.py | 17 +++++++++++++---- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/test/test_vocab.py b/test/test_vocab.py index 1f938731b7..21a6b427cc 100644 --- a/test/test_vocab.py +++ b/test/test_vocab.py @@ -258,3 +258,20 @@ def test_vocab_specials(self): expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v2.get_itos(), expected_itos) self.assertEqual(dict(v2.get_stoi()), expected_stoi) + + def test_vocab_max_tokens(self): + token_to_freq = {'': 2, 'a': 2, 'b': 2} + sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) + c = OrderedDict(sorted_by_freq_tuples) + max_tokens = 1 + v = vocab(c, min_freq=2, max_tokens=max_tokens) + + self.assertEqual(len(v), max_tokens) + self.assertEqual(v[''], 0) + + max_tokens = 2 + v = vocab(c, min_freq=2, max_tokens=max_tokens) + + self.assertEqual(len(v), max_tokens) + self.assertEqual(v[''], 0) + self.assertEqual(v['a'], 1) diff --git a/torchtext/vocab/vocab_factory.py b/torchtext/vocab/vocab_factory.py index 835886e96f..9b3371e3c7 100644 --- a/torchtext/vocab/vocab_factory.py +++ b/torchtext/vocab/vocab_factory.py @@ -8,7 +8,8 @@ def vocab(ordered_dict: Dict, min_freq: int = 1, specials: Optional[List[str]] = None, - special_first: bool = True) -> Vocab: + special_first: bool = True, + max_tokens: Optional[int] = None) -> Vocab: r"""Factory method for creating a vocab object which maps tokens to indices. Note that the ordering in which key value pairs were inserted in the `ordered_dict` will be respected when building the vocab. @@ -19,6 +20,7 @@ def vocab(ordered_dict: Dict, min_freq: int = 1, min_freq: The minimum frequency needed to include a token in the vocabulary. specials: Special symbols to add. The order of supplied tokens will be preserved. special_first: Indicates whether to insert symbols at the beginning or at the end. + max_tokens: If provided, creates the vocab from the `max_tokens` most frequent tokens. Returns: torchtext.vocab.Vocab: A `Vocab` object @@ -49,7 +51,12 @@ def vocab(ordered_dict: Dict, min_freq: int = 1, ordered_dict.pop(token, None) tokens = [] - for token, freq in ordered_dict.items(): + # Save room for special tokens + max_tokens = (max_tokens or float('inf')) - len(specials) + for i, (token, freq) in enumerate(ordered_dict.items()): + # Save room + if i >= max_tokens: + break if freq >= min_freq: tokens.append(token) @@ -61,7 +68,7 @@ def vocab(ordered_dict: Dict, min_freq: int = 1, return Vocab(VocabPybind(tokens, None)) -def build_vocab_from_iterator(iterator: Iterable, min_freq: int = 1, specials: Optional[List[str]] = None, special_first: bool = True) -> Vocab: +def build_vocab_from_iterator(iterator: Iterable, min_freq: int = 1, specials: Optional[List[str]] = None, special_first: bool = True, max_tokens: Optional[int] = None) -> Vocab: """ Build a Vocab from an iterator. @@ -70,6 +77,7 @@ def build_vocab_from_iterator(iterator: Iterable, min_freq: int = 1, specials: O min_freq: The minimum frequency needed to include a token in the vocabulary. specials: Special symbols to add. The order of supplied tokens will be preserved. special_first: Indicates whether to insert symbols at the beginning or at the end. + max_tokens: If provided, creates the vocab from the `max_tokens` most frequent tokens. Returns: @@ -94,6 +102,7 @@ def build_vocab_from_iterator(iterator: Iterable, min_freq: int = 1, specials: O sorted_by_freq_tuples.sort(key=lambda x: x[1], reverse=True) ordered_dict = OrderedDict(sorted_by_freq_tuples) + word_vocab = vocab(ordered_dict, min_freq=min_freq, specials=specials or [], - special_first=special_first) + special_first=special_first, max_tokens=max_tokens) return word_vocab From e41074b7435a4e28a1b13dc2bab34201ba80ee9f Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Sun, 16 Jan 2022 09:34:48 -0500 Subject: [PATCH 2/5] fix flake. --- torchtext/vocab/vocab_factory.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchtext/vocab/vocab_factory.py b/torchtext/vocab/vocab_factory.py index 9b3371e3c7..eba7df707c 100644 --- a/torchtext/vocab/vocab_factory.py +++ b/torchtext/vocab/vocab_factory.py @@ -102,7 +102,6 @@ def build_vocab_from_iterator(iterator: Iterable, min_freq: int = 1, specials: O sorted_by_freq_tuples.sort(key=lambda x: x[1], reverse=True) ordered_dict = OrderedDict(sorted_by_freq_tuples) - word_vocab = vocab(ordered_dict, min_freq=min_freq, specials=specials or [], special_first=special_first, max_tokens=max_tokens) return word_vocab From dc2119fae011c610bfd85f5a960d7f87446ba1da Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Sun, 16 Jan 2022 09:41:25 -0500 Subject: [PATCH 3/5] make docstring more clear about the interaction between special and non-special tokens when max_tokens is provided. --- torchtext/vocab/vocab_factory.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtext/vocab/vocab_factory.py b/torchtext/vocab/vocab_factory.py index eba7df707c..72ce5e18f0 100644 --- a/torchtext/vocab/vocab_factory.py +++ b/torchtext/vocab/vocab_factory.py @@ -20,7 +20,7 @@ def vocab(ordered_dict: Dict, min_freq: int = 1, min_freq: The minimum frequency needed to include a token in the vocabulary. specials: Special symbols to add. The order of supplied tokens will be preserved. special_first: Indicates whether to insert symbols at the beginning or at the end. - max_tokens: If provided, creates the vocab from the `max_tokens` most frequent tokens. + max_tokens: If provided, creates the vocab from the `max_tokens - len(specials)` most frequent tokens. Returns: torchtext.vocab.Vocab: A `Vocab` object @@ -77,7 +77,7 @@ def build_vocab_from_iterator(iterator: Iterable, min_freq: int = 1, specials: O min_freq: The minimum frequency needed to include a token in the vocabulary. specials: Special symbols to add. The order of supplied tokens will be preserved. special_first: Indicates whether to insert symbols at the beginning or at the end. - max_tokens: If provided, creates the vocab from the `max_tokens` most frequent tokens. + max_tokens: If provided, creates the vocab from the `max_tokens - len(specials)` most frequent tokens. Returns: From 41231223c5f1bb0187187f4d7931e906a229bcef Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Thu, 20 Jan 2022 06:30:22 -0500 Subject: [PATCH 4/5] remove max_tokens from vocab and force builders to handle that logic. --- test/test_vocab.py | 35 ++++++++++++++++++-------------- torchtext/vocab/vocab_factory.py | 22 ++++++++------------ 2 files changed, 29 insertions(+), 28 deletions(-) diff --git a/test/test_vocab.py b/test/test_vocab.py index 21a6b427cc..dd1dea0b89 100644 --- a/test/test_vocab.py +++ b/test/test_vocab.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- from collections import OrderedDict import os + +import pytest import torch from test.common.torchtext_test_case import TorchtextTestCase from torchtext.vocab import ( @@ -259,19 +261,22 @@ def test_vocab_specials(self): self.assertEqual(v2.get_itos(), expected_itos) self.assertEqual(dict(v2.get_stoi()), expected_stoi) - def test_vocab_max_tokens(self): - token_to_freq = {'': 2, 'a': 2, 'b': 2} - sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) - c = OrderedDict(sorted_by_freq_tuples) + def test_build_vocab_from_iterator_max_tokens(self): + it = [["hello", "world"], ["hello"]] max_tokens = 1 - v = vocab(c, min_freq=2, max_tokens=max_tokens) - - self.assertEqual(len(v), max_tokens) - self.assertEqual(v[''], 0) - - max_tokens = 2 - v = vocab(c, min_freq=2, max_tokens=max_tokens) - - self.assertEqual(len(v), max_tokens) - self.assertEqual(v[''], 0) - self.assertEqual(v['a'], 1) + specials = ["", ""] + self.assertLess(max_tokens, len(specials)) + with pytest.raises(AssertionError): + build_vocab_from_iterator(it, specials=specials, max_tokens=max_tokens) + + max_tokens = 3 + vocab = build_vocab_from_iterator(it, specials=specials, special_first=True, max_tokens=max_tokens) + self.assertEqual(vocab[""], 0) + self.assertEqual(vocab[""], 1) + self.assertEqual(vocab["hello"], 2) + + max_tokens = 3 + vocab = build_vocab_from_iterator(it, specials=specials, special_first=False, max_tokens=max_tokens) + self.assertEqual(vocab["hello"], 0) + self.assertEqual(vocab[""], 1) + self.assertEqual(vocab[""], 2) diff --git a/torchtext/vocab/vocab_factory.py b/torchtext/vocab/vocab_factory.py index 72ce5e18f0..af35f1c1dd 100644 --- a/torchtext/vocab/vocab_factory.py +++ b/torchtext/vocab/vocab_factory.py @@ -8,8 +8,7 @@ def vocab(ordered_dict: Dict, min_freq: int = 1, specials: Optional[List[str]] = None, - special_first: bool = True, - max_tokens: Optional[int] = None) -> Vocab: + special_first: bool = True) -> Vocab: r"""Factory method for creating a vocab object which maps tokens to indices. Note that the ordering in which key value pairs were inserted in the `ordered_dict` will be respected when building the vocab. @@ -20,7 +19,6 @@ def vocab(ordered_dict: Dict, min_freq: int = 1, min_freq: The minimum frequency needed to include a token in the vocabulary. specials: Special symbols to add. The order of supplied tokens will be preserved. special_first: Indicates whether to insert symbols at the beginning or at the end. - max_tokens: If provided, creates the vocab from the `max_tokens - len(specials)` most frequent tokens. Returns: torchtext.vocab.Vocab: A `Vocab` object @@ -52,11 +50,7 @@ def vocab(ordered_dict: Dict, min_freq: int = 1, tokens = [] # Save room for special tokens - max_tokens = (max_tokens or float('inf')) - len(specials) - for i, (token, freq) in enumerate(ordered_dict.items()): - # Save room - if i >= max_tokens: - break + for token, freq in ordered_dict.items(): if freq >= min_freq: tokens.append(token) @@ -98,10 +92,12 @@ def build_vocab_from_iterator(iterator: Iterable, min_freq: int = 1, specials: O for tokens in iterator: counter.update(tokens) - sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[0]) - sorted_by_freq_tuples.sort(key=lambda x: x[1], reverse=True) - ordered_dict = OrderedDict(sorted_by_freq_tuples) + specials = specials or [] + if max_tokens is None: + ordered_dict = OrderedDict(counter.most_common()) + else: + assert len(specials) < max_tokens, "len(specials) >= max_tokens, so the vocab will be entirely special tokens." + ordered_dict = OrderedDict(counter.most_common(max_tokens - len(specials))) - word_vocab = vocab(ordered_dict, min_freq=min_freq, specials=specials or [], - special_first=special_first, max_tokens=max_tokens) + word_vocab = vocab(ordered_dict, min_freq=min_freq, specials=specials, special_first=special_first) return word_vocab From 457747f43ea3ce649a6a61bbf9b7fee9a08802ef Mon Sep 17 00:00:00 2001 From: Elijah Rippeth Date: Thu, 20 Jan 2022 15:48:56 -0500 Subject: [PATCH 5/5] revert potential breakage where lexicographical sorting is done. Simplify logic for sorting. --- test/test_vocab.py | 11 +++++++++++ torchtext/vocab/vocab_factory.py | 8 ++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/test/test_vocab.py b/test/test_vocab.py index dd1dea0b89..6b04746e42 100644 --- a/test/test_vocab.py +++ b/test/test_vocab.py @@ -261,6 +261,17 @@ def test_vocab_specials(self): self.assertEqual(v2.get_itos(), expected_itos) self.assertEqual(dict(v2.get_stoi()), expected_stoi) + def test_build_vocab_sorts_descending_frequency_then_lexigraphically(self): + it = [["a", "b"], ["a", "b"]] + vocab = build_vocab_from_iterator(it) + self.assertEqual(vocab["a"], 0) + self.assertEqual(vocab["b"], 1) + + it = [["a", "b"], ["b"]] + vocab = build_vocab_from_iterator(it) + self.assertEqual(vocab["b"], 0) + self.assertEqual(vocab["a"], 1) + def test_build_vocab_from_iterator_max_tokens(self): it = [["hello", "world"], ["hello"]] max_tokens = 1 diff --git a/torchtext/vocab/vocab_factory.py b/torchtext/vocab/vocab_factory.py index af35f1c1dd..6a7880c8a3 100644 --- a/torchtext/vocab/vocab_factory.py +++ b/torchtext/vocab/vocab_factory.py @@ -93,11 +93,15 @@ def build_vocab_from_iterator(iterator: Iterable, min_freq: int = 1, specials: O counter.update(tokens) specials = specials or [] + + # First sort by descending frequency, then lexicographically + sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: (-x[1], x[0])) + if max_tokens is None: - ordered_dict = OrderedDict(counter.most_common()) + ordered_dict = OrderedDict(sorted_by_freq_tuples) else: assert len(specials) < max_tokens, "len(specials) >= max_tokens, so the vocab will be entirely special tokens." - ordered_dict = OrderedDict(counter.most_common(max_tokens - len(specials))) + ordered_dict = OrderedDict(sorted_by_freq_tuples[:max_tokens - len(specials)]) word_vocab = vocab(ordered_dict, min_freq=min_freq, specials=specials, special_first=special_first) return word_vocab