diff --git a/test/test_vocab.py b/test/test_vocab.py index c78cd5c708..1f938731b7 100644 --- a/test/test_vocab.py +++ b/test/test_vocab.py @@ -241,3 +241,20 @@ def test_build_vocab_iterator(self): expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) + + def test_vocab_specials(self): + token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2} + sorted_by_freq_tuples = OrderedDict(sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)) + specials = ["", "", "", "pad"] + + v1 = vocab(sorted_by_freq_tuples, min_freq=3, specials=specials) + expected_itos = specials + ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] + expected_stoi = {x: index for index, x in enumerate(expected_itos)} + self.assertEqual(v1.get_itos(), expected_itos) + self.assertEqual(dict(v1.get_stoi()), expected_stoi) + + v2 = vocab(sorted_by_freq_tuples, min_freq=3, specials=specials, special_first=False) + expected_itos = ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] + specials + expected_stoi = {x: index for index, x in enumerate(expected_itos)} + self.assertEqual(v2.get_itos(), expected_itos) + self.assertEqual(dict(v2.get_stoi()), expected_stoi) diff --git a/torchtext/vocab/vocab_factory.py b/torchtext/vocab/vocab_factory.py index bdea76f0a6..835886e96f 100644 --- a/torchtext/vocab/vocab_factory.py +++ b/torchtext/vocab/vocab_factory.py @@ -6,7 +6,9 @@ ) -def vocab(ordered_dict: Dict, min_freq: int = 1) -> Vocab: +def vocab(ordered_dict: Dict, min_freq: int = 1, + specials: Optional[List[str]] = None, + special_first: bool = True) -> Vocab: r"""Factory method for creating a vocab object which maps tokens to indices. Note that the ordering in which key value pairs were inserted in the `ordered_dict` will be respected when building the vocab. @@ -15,6 +17,8 @@ def vocab(ordered_dict: Dict, min_freq: int = 1) -> Vocab: Args: ordered_dict: Ordered Dictionary mapping tokens to their corresponding occurance frequencies. min_freq: The minimum frequency needed to include a token in the vocabulary. + specials: Special symbols to add. The order of supplied tokens will be preserved. + special_first: Indicates whether to insert symbols at the beginning or at the end. Returns: torchtext.vocab.Vocab: A `Vocab` object @@ -29,11 +33,10 @@ def vocab(ordered_dict: Dict, min_freq: int = 1) -> Vocab: >>> print(v1['a']) #prints 1 >>> print(v1['out of vocab']) #raise RuntimeError since default index is not set >>> tokens = ['e', 'd', 'c', 'b', 'a'] - >>> v2 = vocab(OrderedDict([(token, 1) for token in tokens])) >>> #adding token and default index >>> unk_token = '' >>> default_index = -1 - >>> if unk_token not in v2: v2.insert_token(unk_token, 0) + >>> v2 = vocab(OrderedDict([(token, 1) for token in tokens]), specials=[unk_token]) >>> v2.set_default_index(default_index) >>> print(v2['']) #prints 0 >>> print(v2['out of vocab']) #prints -1 @@ -41,12 +44,20 @@ def vocab(ordered_dict: Dict, min_freq: int = 1) -> Vocab: >>> v2.set_default_index(v2[unk_token]) >>> v2['out of vocab'] is v2[unk_token] #prints True """ + specials = specials or [] + for token in specials: + ordered_dict.pop(token, None) tokens = [] for token, freq in ordered_dict.items(): if freq >= min_freq: tokens.append(token) + if special_first: + tokens[0:0] = specials + else: + tokens.extend(specials) + return Vocab(VocabPybind(tokens, None)) @@ -79,20 +90,10 @@ def build_vocab_from_iterator(iterator: Iterable, min_freq: int = 1, specials: O for tokens in iterator: counter.update(tokens) - if specials is not None: - for tok in specials: - del counter[tok] - sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[0]) sorted_by_freq_tuples.sort(key=lambda x: x[1], reverse=True) ordered_dict = OrderedDict(sorted_by_freq_tuples) - if specials is not None: - if special_first: - specials = specials[::-1] - for symbol in specials: - ordered_dict.update({symbol: min_freq}) - ordered_dict.move_to_end(symbol, last=not special_first) - - word_vocab = vocab(ordered_dict, min_freq=min_freq) + word_vocab = vocab(ordered_dict, min_freq=min_freq, specials=specials or [], + special_first=special_first) return word_vocab