From 1b347dc0897cb62bc5b7b308ef442f6fccff5ec7 Mon Sep 17 00:00:00 2001 From: Abhinav Arora Date: Tue, 9 Nov 2021 16:24:33 -0800 Subject: [PATCH 1/2] [Vocab] Refactor vocab factory method to accept special tokens as a keyword argument --- test/test_vocab.py | 17 +++++++++++++++ torchtext/vocab/vocab_factory.py | 36 ++++++++++++++++---------------- 2 files changed, 35 insertions(+), 18 deletions(-) diff --git a/test/test_vocab.py b/test/test_vocab.py index c78cd5c708..1f938731b7 100644 --- a/test/test_vocab.py +++ b/test/test_vocab.py @@ -241,3 +241,20 @@ def test_build_vocab_iterator(self): expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) + + def test_vocab_specials(self): + token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2} + sorted_by_freq_tuples = OrderedDict(sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)) + specials = ["", "", "", "pad"] + + v1 = vocab(sorted_by_freq_tuples, min_freq=3, specials=specials) + expected_itos = specials + ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] + expected_stoi = {x: index for index, x in enumerate(expected_itos)} + self.assertEqual(v1.get_itos(), expected_itos) + self.assertEqual(dict(v1.get_stoi()), expected_stoi) + + v2 = vocab(sorted_by_freq_tuples, min_freq=3, specials=specials, special_first=False) + expected_itos = ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] + specials + expected_stoi = {x: index for index, x in enumerate(expected_itos)} + self.assertEqual(v2.get_itos(), expected_itos) + self.assertEqual(dict(v2.get_stoi()), expected_stoi) diff --git a/torchtext/vocab/vocab_factory.py b/torchtext/vocab/vocab_factory.py index bdea76f0a6..4549ac29b9 100644 --- a/torchtext/vocab/vocab_factory.py +++ b/torchtext/vocab/vocab_factory.py @@ -6,7 +6,7 @@ ) -def vocab(ordered_dict: Dict, min_freq: int = 1) -> Vocab: +def vocab(ordered_dict: Dict, min_freq: int = 1, specials: List[str] = [], special_first: bool = True) -> Vocab: r"""Factory method for creating a vocab object which maps tokens to indices. Note that the ordering in which key value pairs were inserted in the `ordered_dict` will be respected when building the vocab. @@ -15,6 +15,8 @@ def vocab(ordered_dict: Dict, min_freq: int = 1) -> Vocab: Args: ordered_dict: Ordered Dictionary mapping tokens to their corresponding occurance frequencies. min_freq: The minimum frequency needed to include a token in the vocabulary. + specials: Special symbols to add. The order of supplied tokens will be preserved. + special_first: Indicates whether to insert symbols at the beginning or at the end. Returns: torchtext.vocab.Vocab: A `Vocab` object @@ -34,19 +36,27 @@ def vocab(ordered_dict: Dict, min_freq: int = 1) -> Vocab: >>> unk_token = '' >>> default_index = -1 >>> if unk_token not in v2: v2.insert_token(unk_token, 0) - >>> v2.set_default_index(default_index) - >>> print(v2['']) #prints 0 - >>> print(v2['out of vocab']) #prints -1 + >>> v3 = vocab(OrderedDict([(token, 1) for token in tokens]), specials=[unk_token]) + >>> v3.set_default_index(default_index) + >>> print(v3['']) #prints 0 + >>> print(v3['out of vocab']) #prints -1 >>> #make default index same as index of unk_token - >>> v2.set_default_index(v2[unk_token]) - >>> v2['out of vocab'] is v2[unk_token] #prints True + >>> v3.set_default_index(v3[unk_token]) + >>> v3['out of vocab'] is v3[unk_token] #prints True """ + for token in specials: + ordered_dict.pop(token, None) tokens = [] for token, freq in ordered_dict.items(): if freq >= min_freq: tokens.append(token) + if special_first: + tokens[0:0] = specials + else: + tokens.extend(specials) + return Vocab(VocabPybind(tokens, None)) @@ -79,20 +89,10 @@ def build_vocab_from_iterator(iterator: Iterable, min_freq: int = 1, specials: O for tokens in iterator: counter.update(tokens) - if specials is not None: - for tok in specials: - del counter[tok] - sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[0]) sorted_by_freq_tuples.sort(key=lambda x: x[1], reverse=True) ordered_dict = OrderedDict(sorted_by_freq_tuples) - if specials is not None: - if special_first: - specials = specials[::-1] - for symbol in specials: - ordered_dict.update({symbol: min_freq}) - ordered_dict.move_to_end(symbol, last=not special_first) - - word_vocab = vocab(ordered_dict, min_freq=min_freq) + word_vocab = vocab(ordered_dict, min_freq=min_freq, specials=specials or [], + special_first=special_first) return word_vocab From 61a7f746bbafad793c0bfafcf9756d3620cea1d7 Mon Sep 17 00:00:00 2001 From: Abhinav Arora Date: Tue, 9 Nov 2021 16:24:33 -0800 Subject: [PATCH 2/2] [Vocab] Refactor vocab factory method to accept special tokens as a keyword argument --- test/test_vocab.py | 17 +++++++++++++++++ torchtext/vocab/vocab_factory.py | 31 ++++++++++++++++--------------- 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/test/test_vocab.py b/test/test_vocab.py index c78cd5c708..1f938731b7 100644 --- a/test/test_vocab.py +++ b/test/test_vocab.py @@ -241,3 +241,20 @@ def test_build_vocab_iterator(self): expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) + + def test_vocab_specials(self): + token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2} + sorted_by_freq_tuples = OrderedDict(sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True)) + specials = ["", "", "", "pad"] + + v1 = vocab(sorted_by_freq_tuples, min_freq=3, specials=specials) + expected_itos = specials + ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] + expected_stoi = {x: index for index, x in enumerate(expected_itos)} + self.assertEqual(v1.get_itos(), expected_itos) + self.assertEqual(dict(v1.get_stoi()), expected_stoi) + + v2 = vocab(sorted_by_freq_tuples, min_freq=3, specials=specials, special_first=False) + expected_itos = ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] + specials + expected_stoi = {x: index for index, x in enumerate(expected_itos)} + self.assertEqual(v2.get_itos(), expected_itos) + self.assertEqual(dict(v2.get_stoi()), expected_stoi) diff --git a/torchtext/vocab/vocab_factory.py b/torchtext/vocab/vocab_factory.py index bdea76f0a6..835886e96f 100644 --- a/torchtext/vocab/vocab_factory.py +++ b/torchtext/vocab/vocab_factory.py @@ -6,7 +6,9 @@ ) -def vocab(ordered_dict: Dict, min_freq: int = 1) -> Vocab: +def vocab(ordered_dict: Dict, min_freq: int = 1, + specials: Optional[List[str]] = None, + special_first: bool = True) -> Vocab: r"""Factory method for creating a vocab object which maps tokens to indices. Note that the ordering in which key value pairs were inserted in the `ordered_dict` will be respected when building the vocab. @@ -15,6 +17,8 @@ def vocab(ordered_dict: Dict, min_freq: int = 1) -> Vocab: Args: ordered_dict: Ordered Dictionary mapping tokens to their corresponding occurance frequencies. min_freq: The minimum frequency needed to include a token in the vocabulary. + specials: Special symbols to add. The order of supplied tokens will be preserved. + special_first: Indicates whether to insert symbols at the beginning or at the end. Returns: torchtext.vocab.Vocab: A `Vocab` object @@ -29,11 +33,10 @@ def vocab(ordered_dict: Dict, min_freq: int = 1) -> Vocab: >>> print(v1['a']) #prints 1 >>> print(v1['out of vocab']) #raise RuntimeError since default index is not set >>> tokens = ['e', 'd', 'c', 'b', 'a'] - >>> v2 = vocab(OrderedDict([(token, 1) for token in tokens])) >>> #adding token and default index >>> unk_token = '' >>> default_index = -1 - >>> if unk_token not in v2: v2.insert_token(unk_token, 0) + >>> v2 = vocab(OrderedDict([(token, 1) for token in tokens]), specials=[unk_token]) >>> v2.set_default_index(default_index) >>> print(v2['']) #prints 0 >>> print(v2['out of vocab']) #prints -1 @@ -41,12 +44,20 @@ def vocab(ordered_dict: Dict, min_freq: int = 1) -> Vocab: >>> v2.set_default_index(v2[unk_token]) >>> v2['out of vocab'] is v2[unk_token] #prints True """ + specials = specials or [] + for token in specials: + ordered_dict.pop(token, None) tokens = [] for token, freq in ordered_dict.items(): if freq >= min_freq: tokens.append(token) + if special_first: + tokens[0:0] = specials + else: + tokens.extend(specials) + return Vocab(VocabPybind(tokens, None)) @@ -79,20 +90,10 @@ def build_vocab_from_iterator(iterator: Iterable, min_freq: int = 1, specials: O for tokens in iterator: counter.update(tokens) - if specials is not None: - for tok in specials: - del counter[tok] - sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[0]) sorted_by_freq_tuples.sort(key=lambda x: x[1], reverse=True) ordered_dict = OrderedDict(sorted_by_freq_tuples) - if specials is not None: - if special_first: - specials = specials[::-1] - for symbol in specials: - ordered_dict.update({symbol: min_freq}) - ordered_dict.move_to_end(symbol, last=not special_first) - - word_vocab = vocab(ordered_dict, min_freq=min_freq) + word_vocab = vocab(ordered_dict, min_freq=min_freq, specials=specials or [], + special_first=special_first) return word_vocab