From e46e2f83689177f6f57327014742210ddd458793 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Fri, 9 Oct 2020 06:47:46 -0700 Subject: [PATCH 01/45] checkpoint --- torchtext/csrc/register_bindings.cpp | 1 - torchtext/csrc/vocab.cpp | 61 ++++++++++++++-------------- torchtext/csrc/vocab.h | 10 +---- 3 files changed, 32 insertions(+), 40 deletions(-) diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index 920167af5a..15c7f99b63 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -53,7 +53,6 @@ PYBIND11_MODULE(_torchtext, m) { py::class_(m, "Vocab") .def(py::init, std::string>()) .def_readonly("itos_", &Vocab::itos_) - .def_readonly("unk_token_", &Vocab::unk_token_) .def("__getitem__", &Vocab::__getitem__) .def("__len__", &Vocab::__len__) .def("insert_token", &Vocab::insert_token) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 1b462a8967..772abe7d0a 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -8,13 +8,11 @@ namespace torchtext { -Vocab::Vocab(const StringList &tokens, const IndexDict &stoi, - const std::string &unk_token, const int64_t unk_index) - : unk_index_(std::move(unk_index)), stoi_(std::move(stoi)), - itos_(std::move(tokens)), unk_token_(std::move(unk_token)) {} +Vocab::Vocab(const StringList &tokens, const IndexDict &stoi) + : stoi_(std::move(stoi)), itos_(std::move(tokens)) {} -Vocab::Vocab(const StringList &tokens, const std::string &unk_token) - : itos_(std::move(tokens)), unk_token_(std::move(unk_token)) { +Vocab::Vocab(const StringList &tokens) + : itos_(std::move(tokens)) { stoi_.reserve(tokens.size()); for (std::size_t i = 0; i < tokens.size(); i++) { // tokens should not have any duplicates @@ -28,7 +26,7 @@ Vocab::Vocab(const StringList &tokens, const std::string &unk_token) } stoi_[std::move(tokens[i])] = i; } - unk_index_ = stoi_.find(unk_token)->second; + //unk_index_ = stoi_.find(unk_token)->second; } int64_t Vocab::__len__() const { return stoi_.size(); } @@ -38,7 +36,12 @@ int64_t Vocab::__getitem__(const std::string &token) const { if (item != stoi_.end()) { return item->second; } - return unk_index_; + else if (unk_index_ != -1) { + return unk_index_; + } + else + throw std::runtime_error("UNK index has not been set up yet. Call set_unk_index() function to set up the UNK index"); + } void Vocab::append_token(const std::string &token) { @@ -89,7 +92,7 @@ void Vocab::insert_token(const std::string &token, const int64_t &index) { // need to update unk_index in case token equals unk_token or token // inserted before unk_token - unk_index_ = stoi_.find(unk_token_)->second; + // unk_index_ = stoi_.find(unk_token_)->second; } std::string Vocab::lookup_token(const int64_t &index) { @@ -207,8 +210,7 @@ struct CompareTokens { }; std::tuple -_concat_tokens(std::vector> chunk_counters, - const std::string &unk_token, const int64_t min_freq, +_concat_tokens(std::vector> chunk_counters, const int64_t min_freq, const int64_t num_lines, const bool sort_tokens) { TORCH_CHECK(chunk_counters.size() > 0, "There must be at least 1 chunk to concatenate!"); @@ -256,13 +258,13 @@ _concat_tokens(std::vector> chunk_counters, } // insert unk_token if not present - if (tokens_freq.find(unk_token) == tokens_freq.end()) { - std::cerr << "The `unk_token` " << unk_token - << " wasn't found in the `ordered_dict`. Adding the `unk_token` " - "to the beginning of the Vocab." - << std::endl; - - unique_tokens.insert(unique_tokens.begin(), unk_token); + // if (tokens_freq.find(unk_token) == tokens_freq.end()) { + // std::cerr << "The `unk_token` " << unk_token + // << " wasn't found in the `ordered_dict`. Adding the `unk_token` " + // "to the beginning of the Vocab." + // << std::endl; + // + // unique_tokens.insert(unique_tokens.begin(), unk_token); } // create stoi @@ -280,7 +282,6 @@ _concat_tokens(std::vector> chunk_counters, constexpr int64_t GRAIN_SIZE = 13107; Vocab _load_vocab_from_file(const std::string &file_path, - const std::string &unk_token, const int64_t min_freq, const int64_t num_cpus) { std::cerr << "[INFO] Reading file " << file_path << std::endl; @@ -323,15 +324,14 @@ Vocab _load_vocab_from_file(const std::string &file_path, IndexDict stoi; StringList tokens; std::tie(stoi, tokens) = - _concat_tokens(chunk_counters, unk_token, min_freq, num_lines, false); + _concat_tokens(chunk_counters, min_freq, num_lines, false); - int64_t unk_index = stoi.find(unk_token)->second; + // int64_t unk_index = stoi.find(unk_token)->second; - return Vocab(std::move(tokens), std::move(stoi), unk_token, unk_index); + return Vocab(std::move(tokens), std::move(stoi)); } Vocab _load_vocab_from_raw_text_file(const std::string &file_path, - const std::string &unk_token, const int64_t min_freq, const int64_t num_cpus, py::object fn) { std::cerr << "[INFO] Reading file " << file_path << std::endl; @@ -375,16 +375,15 @@ Vocab _load_vocab_from_raw_text_file(const std::string &file_path, IndexDict stoi; StringList tokens; std::tie(stoi, tokens) = - _concat_tokens(chunk_counters, unk_token, min_freq, num_lines, true); - int64_t unk_index = stoi.find(unk_token)->second; + _concat_tokens(chunk_counters, min_freq, num_lines, true); + // int64_t unk_index = stoi.find(unk_token)->second; - return Vocab(std::move(tokens), std::move(stoi), unk_token, unk_index); + return Vocab(std::move(tokens), std::move(stoi)); } VocabStates _set_vocab_states(const c10::intrusive_ptr &self) { std::vector integers; StringList strings = self->itos_; - strings.push_back(self->unk_token_); std::vector tensors; VocabStates states = std::make_tuple(self->version_str_, std::move(integers), @@ -421,11 +420,11 @@ c10::intrusive_ptr _get_vocab_from_states(VocabStates states) { "Expected `integers` and `tensors` states to be empty."); } - if (version_str.compare("0.0.1") >= 0) { - std::string unk_token = strings.back(); - strings.pop_back(); // remove last element which is unk_token + //if (version_str.compare("0.0.1") >= 0) { + // std::string unk_token = strings.back(); + // strings.pop_back(); // remove last element which is unk_token - return c10::make_intrusive(std::move(strings), std::move(unk_token)); + return c10::make_intrusive(std::move(strings)); } #ifdef _MSC_VER std::cerr << "[RuntimeError] Found unexpected version for serialized Vocab: " diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index 005f833c02..6c02898aee 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -18,13 +18,9 @@ struct Vocab : torch::CustomClassHolder { public: const std::string version_str_ = "0.0.1"; StringList itos_; - std::string unk_token_; - explicit Vocab(const std::vector &tokens, - const std::string &unk_token); - explicit Vocab(const StringList &tokens, const IndexDict &stoi, - - const std::string &unk_token, const int64_t unk_index); + explicit Vocab(const std::vector &tokens); + explicit Vocab(const StringList &tokens, const IndexDict &stoi) int64_t __len__() const; int64_t __getitem__(const std::string &token) const; void append_token(const std::string &token); @@ -39,10 +35,8 @@ struct Vocab : torch::CustomClassHolder { c10::intrusive_ptr _get_vocab_from_states(VocabStates states); VocabStates _set_vocab_states(const c10::intrusive_ptr &self); Vocab _load_vocab_from_file(const std::string &file_path, - const std::string &unk_token, const int64_t min_freq, const int64_t num_cpus); Vocab _load_vocab_from_raw_text_file(const std::string &file_path, - const std::string &unk_token, const int64_t min_freq, const int64_t num_cpus, py::object tokenizer); From 312903646183b7475464c32c59f6703e0583e669 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Fri, 9 Oct 2020 08:04:13 -0700 Subject: [PATCH 02/45] checkpoint --- torchtext/csrc/vocab.h | 4 ++-- torchtext/experimental/vocab.py | 34 +++++++++------------------------ 2 files changed, 11 insertions(+), 27 deletions(-) diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index 6c02898aee..fb6df3d3f4 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -12,7 +12,7 @@ typedef std::tuple, std::vector, struct Vocab : torch::CustomClassHolder { private: - int64_t unk_index_; + int64_t unk_index_ = -1; IndexDict stoi_; public: @@ -20,7 +20,7 @@ struct Vocab : torch::CustomClassHolder { StringList itos_; explicit Vocab(const std::vector &tokens); - explicit Vocab(const StringList &tokens, const IndexDict &stoi) + explicit Vocab(const StringList &tokens, const IndexDict &stoi); int64_t __len__() const; int64_t __getitem__(const std::string &token) const; void append_token(const std::string &token); diff --git a/torchtext/experimental/vocab.py b/torchtext/experimental/vocab.py index 5fdf5751ec..518cedb28d 100644 --- a/torchtext/experimental/vocab.py +++ b/torchtext/experimental/vocab.py @@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) -def vocab_from_raw_text_file(file_object, jited_tokenizer, min_freq=1, unk_token='', num_cpus=4): +def vocab_from_raw_text_file(file_object, jited_tokenizer, min_freq=1, num_cpus=4): r"""Create a `Vocab` object from a raw text file. The `file_object` can contain any raw text. This function applies a generic JITed tokenizer in @@ -31,7 +31,6 @@ def vocab_from_raw_text_file(file_object, jited_tokenizer, min_freq=1, unk_token jited_tokenizer (ScriptModule): a tokenizer that has been JITed using `torch.jit.script` min_freq: The minimum frequency needed to include a token in the vocabulary. Values less than 1 will be set to 1. Default: 1. - unk_token: The default unknown token to use. Default: ''. num_cpus (int): the number of cpus to use when loading the vectors from file. Default: 4. Returns: @@ -46,11 +45,11 @@ def vocab_from_raw_text_file(file_object, jited_tokenizer, min_freq=1, unk_token >>> jit_tokenizer = torch.jit.script(tokenizer.to_ivalue()) >>> v = vocab_from_raw_text_file(f, jit_tokenizer) """ - vocab_obj = _load_vocab_from_raw_text_file(file_object.name, unk_token, min_freq, num_cpus, jited_tokenizer) + vocab_obj = _load_vocab_from_raw_text_file(file_object.name, min_freq, num_cpus, jited_tokenizer) return Vocab(vocab_obj) -def vocab_from_file(file_object, min_freq=1, unk_token='', num_cpus=4): +def vocab_from_file(file_object, min_freq=1, num_cpus=4): r"""Create a `Vocab` object from a text file. The `file_object` should contain tokens separated by new lines. Note that the vocab will be created in the order that the tokens first appear in the file (and not by the frequency of tokens). @@ -63,7 +62,6 @@ def vocab_from_file(file_object, min_freq=1, unk_token='', num_cpus=4): file_object (FileObject): a file like object to read data from. min_freq: The minimum frequency needed to include a token in the vocabulary. Values less than 1 will be set to 1. Default: 1. - unk_token: The default unknown token to use. Default: ''. num_cpus (int): the number of cpus to use when loading the vectors from file. Default: 4. Returns: @@ -73,18 +71,17 @@ def vocab_from_file(file_object, min_freq=1, unk_token='', num_cpus=4): >>> f = open('vocab.txt', 'r') >>> v = vocab_from_file(f) """ - vocab_obj = _load_vocab_from_file(file_object.name, unk_token, min_freq, num_cpus) + vocab_obj = _load_vocab_from_file(file_object.name, min_freq, num_cpus) return Vocab(vocab_obj) -def build_vocab_from_iterator(iterator, min_freq=1, unk_token=''): +def build_vocab_from_iterator(iterator, min_freq=1): """ Build a Vocab from an iterator. Arguments: iterator: Iterator used to build Vocab. Must yield list or iterator of tokens. min_freq: The minimum frequency needed to include a token in the vocabulary. Values less than 1 will be set to 1. Default: 1. - unk_token: The default unknown token to use. Default: ''. """ counter = Counter() @@ -92,25 +89,20 @@ def build_vocab_from_iterator(iterator, min_freq=1, unk_token=''): counter.update(tokens) sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[1], reverse=True) ordered_dict = OrderedDict(sorted_by_freq_tuples) - word_vocab = vocab(ordered_dict, min_freq=min_freq, unk_token=unk_token) + word_vocab = vocab(ordered_dict, min_freq=min_freq) return word_vocab -def vocab(ordered_dict, min_freq=1, unk_token=''): +def vocab(ordered_dict, min_freq=1): r"""Factory method for creating a vocab object which maps tokens to indices. Note that the ordering in which key value pairs were inserted in the `ordered_dict` will be respected when building the vocab. Therefore if sorting by token frequency is important to the user, the `ordered_dict` should be created in a way to reflect this. - Additionally, the if the `unk_token` isn't found inside of the `ordered_dict`, it will be added to the end of the vocab. Arguments: ordered_dict (collections.OrderedDict): object holding the frequencies of each token found in the data. min_freq: The minimum frequency needed to include a token in the vocabulary. Values less than 1 will be set to 1. Default: 1. - unk_token: The default unknown token to use. Default: ''. - - Raises: - ValueError: if a default `unk_token` isn't provided. Examples: >>> from torchtext.experimental.vocab import vocab @@ -122,19 +114,11 @@ def vocab(ordered_dict, min_freq=1, unk_token=''): >>> tokens = ['e', 'd', 'c', 'b', 'a'] >>> v2 = vocab(OrderedDict([(token, 1) for token in tokens])) """ - if not unk_token: - raise ValueError("A default unk token wasn't provided.") - tokens = [] for token, freq in ordered_dict.items(): if freq >= min_freq: tokens.append(token) - - if unk_token not in tokens: - tokens.insert(0, unk_token) - warnings.warn("The `unk_token` '{}' wasn't found in the `ordered_dict`. Adding the `unk_token` " - "to the beginning of the Vocab.".format(unk_token), RuntimeWarning) - return Vocab(VocabPybind(tokens, unk_token)) + return Vocab(VocabPybind(tokens)) class Vocab(nn.Module): @@ -261,5 +245,5 @@ def get_itos(self) -> List[str]: def to_ivalue(self): r"""Return a JITable Vocab. """ - cpp_vocab = torch.classes.torchtext.Vocab(self.vocab.itos_, self.vocab.unk_token_) + cpp_vocab = torch.classes.torchtext.Vocab(self.vocab.itos_) return Vocab(cpp_vocab) From 441f3922187a019d4757ff8cefda6a64750e1b16 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Fri, 9 Oct 2020 08:38:28 -0700 Subject: [PATCH 03/45] checkpoint --- torchtext/csrc/register_bindings.cpp | 4 ++-- torchtext/csrc/vocab.cpp | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index 15c7f99b63..b13248ab3c 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -51,7 +51,7 @@ PYBIND11_MODULE(_torchtext, m) { .def("__len__", &Vectors::__len__); py::class_(m, "Vocab") - .def(py::init, std::string>()) + .def(py::init>()) .def_readonly("itos_", &Vocab::itos_) .def("__getitem__", &Vocab::__getitem__) .def("__len__", &Vocab::__len__) @@ -134,7 +134,7 @@ static auto sentencepiece = static auto vocab = torch::class_("torchtext", "Vocab") - .def(torch::init()) + .def(torch::init()) .def("__getitem__", &Vocab::__getitem__) .def("__len__", &Vocab::__len__) .def("insert_token", &Vocab::insert_token) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 772abe7d0a..6b5545ff74 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -265,7 +265,7 @@ _concat_tokens(std::vector> chunk_counters, const int // << std::endl; // // unique_tokens.insert(unique_tokens.begin(), unk_token); - } + // } // create stoi IndexDict stoi; @@ -420,9 +420,9 @@ c10::intrusive_ptr _get_vocab_from_states(VocabStates states) { "Expected `integers` and `tensors` states to be empty."); } - //if (version_str.compare("0.0.1") >= 0) { - // std::string unk_token = strings.back(); - // strings.pop_back(); // remove last element which is unk_token + if (version_str.compare("0.0.1") >= 0) { + std::string unk_token = strings.back(); + strings.pop_back(); // remove last element which is unk_token return c10::make_intrusive(std::move(strings)); } From bc87fb0591a2fbdb0c67cb1df317487dce619b78 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Fri, 9 Oct 2020 09:23:22 -0700 Subject: [PATCH 04/45] checkpoint --- torchtext/csrc/register_bindings.cpp | 4 ++++ torchtext/csrc/vocab.cpp | 11 +++++++++++ torchtext/csrc/vocab.h | 2 ++ torchtext/experimental/vocab.py | 18 ++++++++++++++++++ 4 files changed, 35 insertions(+) diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index b13248ab3c..e0208c4829 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -56,6 +56,8 @@ PYBIND11_MODULE(_torchtext, m) { .def("__getitem__", &Vocab::__getitem__) .def("__len__", &Vocab::__len__) .def("insert_token", &Vocab::insert_token) + .def("set_unk_index", &Vocab::set_unk_index) + .def("return_unk_index", &Vocab::return_unk_index) .def("append_token", &Vocab::append_token) .def("lookup_token", &Vocab::lookup_token) .def("lookup_tokens", &Vocab::lookup_tokens) @@ -138,6 +140,8 @@ static auto vocab = .def("__getitem__", &Vocab::__getitem__) .def("__len__", &Vocab::__len__) .def("insert_token", &Vocab::insert_token) + .def("set_unk_index", &Vocab::set_unk_index) + .def("return_unk_index", &Vocab::return_unk_index) .def("append_token", &Vocab::append_token) .def("lookup_token", &Vocab::lookup_token) .def("lookup_tokens", &Vocab::lookup_tokens) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 6b5545ff74..f7f1142f78 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -95,6 +95,17 @@ void Vocab::insert_token(const std::string &token, const int64_t &index) { // unk_index_ = stoi_.find(unk_token_)->second; } +void Vocab::set_unk_index(const int64_t index) { + if (unk_index_ != -1) + std::cerr << "UNK index has been assigned. You are resetting the UNK index here." + << index << std::endl; + unk_index_ = index; +} + +int64_t Vocab::return_unk_index() const { + return unk_index_; +} + std::string Vocab::lookup_token(const int64_t &index) { if (index < 0 || index > static_cast(itos_.size())) { #ifdef _MSC_VER diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index fb6df3d3f4..9f8ff9914e 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -25,6 +25,8 @@ struct Vocab : torch::CustomClassHolder { int64_t __getitem__(const std::string &token) const; void append_token(const std::string &token); void insert_token(const std::string &token, const int64_t &index); + void set_unk_index(const int64_t index); + int64_t return_unk_index() const; std::string lookup_token(const int64_t &index); std::vector lookup_tokens(const std::vector &indices); std::vector lookup_indices(const std::vector &tokens); diff --git a/torchtext/experimental/vocab.py b/torchtext/experimental/vocab.py index 518cedb28d..d9aba258c5 100644 --- a/torchtext/experimental/vocab.py +++ b/torchtext/experimental/vocab.py @@ -179,6 +179,24 @@ def insert_token(self, token: str, index: int) -> None: """ self.vocab.insert_token(token, index) + @torch.jit.export + def set_unk_index(self, index: int) -> None: + r""" + Args: + index (int): the unknown index. + + """ + self.vocab.set_unk_index(index) + + @torch.jit.export + def return_unk_index(self) -> int: + r""" + return: + index (int): the unknown index. + + """ + self.vocab.return_unk_index() + @torch.jit.export def append_token(self, token: str) -> None: r""" From 8ba6b2155bcaad8d33d3ef16a54cb15710b085c5 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Fri, 9 Oct 2020 09:28:30 -0700 Subject: [PATCH 05/45] checkpoint --- test/experimental/test_vocab.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index 626db2e726..c9b3434a5a 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -17,13 +17,20 @@ def tearDown(self): torch._C._jit_clear_class_registry() torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore() - def test_has_unk(self): + def test_has_no_unk(self): c = OrderedDict() v = vocab(c) + self.assertEqual(v.return_unk_index(), -1) # check if unk is mapped to the first index - self.assertEqual(v['not_in_it'], 0) - self.assertEqual(v[''], 0) + with self.assertRaises(RuntimeError): + v['not_in_it'] + with self.assertRaises(RuntimeError): + v[''] + + v.insert_token('not_in_it', 0) + v.set_unk_index(0) + self.assertEqual(v.return_unk_index(), 0) def test_new_unk(self): c = OrderedDict() From 6b9f015ccede93bc84fb3c9f768585bcd5dcab55 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Fri, 9 Oct 2020 09:48:52 -0700 Subject: [PATCH 06/45] update tests --- test/experimental/test_vocab.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index c9b3434a5a..99fb6a90dd 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -32,14 +32,6 @@ def test_has_no_unk(self): v.set_unk_index(0) self.assertEqual(v.return_unk_index(), 0) - def test_new_unk(self): - c = OrderedDict() - v = vocab(c, unk_token="") - - # check if new_unk is mapped to the first index - self.assertEqual(v[''], 0) - self.assertEqual(v['not_in_it'], 0) - def test_vocab_get_item(self): token_to_freq = {'': 2, 'a': 2, 'b': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) @@ -189,9 +181,10 @@ def test_errors_vocab_python(self): sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) c = OrderedDict(sorted_by_freq_tuples) - with self.assertRaises(ValueError): + with self.assertRaises(RuntimeError): # Test proper error raised when setting unk token to None - vocab(c, unk_token=None) + vocab(c) + vocab['not_in_vocab'] def test_vocab_load_and_save(self): token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2} From 3d21a18574c5486a7c23efa427ce5fc67fce18d4 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Fri, 9 Oct 2020 09:57:36 -0700 Subject: [PATCH 07/45] clang --- torchtext/csrc/vocab.cpp | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index f7f1142f78..c5f2003022 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -11,8 +11,7 @@ namespace torchtext { Vocab::Vocab(const StringList &tokens, const IndexDict &stoi) : stoi_(std::move(stoi)), itos_(std::move(tokens)) {} -Vocab::Vocab(const StringList &tokens) - : itos_(std::move(tokens)) { +Vocab::Vocab(const StringList &tokens) : itos_(std::move(tokens)) { stoi_.reserve(tokens.size()); for (std::size_t i = 0; i < tokens.size(); i++) { // tokens should not have any duplicates @@ -26,7 +25,7 @@ Vocab::Vocab(const StringList &tokens) } stoi_[std::move(tokens[i])] = i; } - //unk_index_ = stoi_.find(unk_token)->second; + // unk_index_ = stoi_.find(unk_token)->second; } int64_t Vocab::__len__() const { return stoi_.size(); } @@ -35,13 +34,12 @@ int64_t Vocab::__getitem__(const std::string &token) const { const auto &item = stoi_.find(token); if (item != stoi_.end()) { return item->second; - } - else if (unk_index_ != -1) { + } else if (unk_index_ != -1) { return unk_index_; - } - else - throw std::runtime_error("UNK index has not been set up yet. Call set_unk_index() function to set up the UNK index"); - + } else + throw std::runtime_error( + "UNK index has not been set up yet. Call set_unk_index() function to " + "set up the UNK index"); } void Vocab::append_token(const std::string &token) { @@ -97,14 +95,13 @@ void Vocab::insert_token(const std::string &token, const int64_t &index) { void Vocab::set_unk_index(const int64_t index) { if (unk_index_ != -1) - std::cerr << "UNK index has been assigned. You are resetting the UNK index here." - << index << std::endl; + std::cerr + << "UNK index has been assigned. You are resetting the UNK index here." + << index << std::endl; unk_index_ = index; } -int64_t Vocab::return_unk_index() const { - return unk_index_; -} +int64_t Vocab::return_unk_index() const { return unk_index_; } std::string Vocab::lookup_token(const int64_t &index) { if (index < 0 || index > static_cast(itos_.size())) { @@ -221,8 +218,9 @@ struct CompareTokens { }; std::tuple -_concat_tokens(std::vector> chunk_counters, const int64_t min_freq, - const int64_t num_lines, const bool sort_tokens) { +_concat_tokens(std::vector> chunk_counters, + const int64_t min_freq, const int64_t num_lines, + const bool sort_tokens) { TORCH_CHECK(chunk_counters.size() > 0, "There must be at least 1 chunk to concatenate!"); @@ -271,7 +269,8 @@ _concat_tokens(std::vector> chunk_counters, const int // insert unk_token if not present // if (tokens_freq.find(unk_token) == tokens_freq.end()) { // std::cerr << "The `unk_token` " << unk_token - // << " wasn't found in the `ordered_dict`. Adding the `unk_token` " + // << " wasn't found in the `ordered_dict`. Adding the `unk_token` + // " // "to the beginning of the Vocab." // << std::endl; // From 0efe7b2d5c71af2d7bb0988eba15ef17835aba87 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Fri, 9 Oct 2020 10:11:45 -0700 Subject: [PATCH 08/45] flake8 --- torchtext/experimental/vocab.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchtext/experimental/vocab.py b/torchtext/experimental/vocab.py index d9aba258c5..d9ecac3a41 100644 --- a/torchtext/experimental/vocab.py +++ b/torchtext/experimental/vocab.py @@ -1,6 +1,5 @@ import logging from typing import Dict, List -import warnings from collections import Counter, OrderedDict import torch import torch.nn as nn From dce7080f078097783526699596420fc37328db67 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Fri, 9 Oct 2020 10:36:22 -0700 Subject: [PATCH 09/45] checkpoint --- test/experimental/test_vocab.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index 99fb6a90dd..58c36a3f62 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -70,7 +70,7 @@ def test_vocab_append_token(self): v = vocab(c) v.append_token('b') - expected_itos = ['', 'a', 'b'] + expected_itos = ['a', 'b'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v.get_itos(), expected_itos) @@ -91,7 +91,7 @@ def test_vocab_basic(self): c = OrderedDict(sorted_by_freq_tuples) v = vocab(c, min_freq=3) - expected_itos = ['', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] + expected_itos = ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v.get_itos(), expected_itos) @@ -105,7 +105,7 @@ def test_vocab_jit(self): v = vocab(c, min_freq=3) jit_v = torch.jit.script(v.to_ivalue()) - expected_itos = ['', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] + expected_itos = ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} assert not v.is_jitable @@ -193,7 +193,7 @@ def test_vocab_load_and_save(self): c = OrderedDict(sorted_by_freq_tuples) v = vocab(c, min_freq=3) - expected_itos = ['', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] + expected_itos = ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v.get_itos(), expected_itos) @@ -210,7 +210,7 @@ def test_build_vocab_iterator(self): iterator = [['hello', 'hello', 'hello', 'freq_low', 'hello', 'world', 'world', 'world', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'freq_low', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T']] v = build_vocab_from_iterator(iterator) - expected_itos = ['', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world', 'freq_low'] + expected_itos = ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world', 'freq_low'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) From 660e05166e3b2b3760df9a9a3adeef08f9103948 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Fri, 9 Oct 2020 10:57:18 -0700 Subject: [PATCH 10/45] checkpoint --- test/experimental/test_transforms_with_asset.py | 7 ++++--- torchtext/experimental/vocab.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/test/experimental/test_transforms_with_asset.py b/test/experimental/test_transforms_with_asset.py index 610ff6beaf..37d5dae082 100644 --- a/test/experimental/test_transforms_with_asset.py +++ b/test/experimental/test_transforms_with_asset.py @@ -131,7 +131,8 @@ def test_vocab_from_file(self): asset_name = 'vocab_test.txt' asset_path = get_asset_path(asset_name) with open(asset_path, 'r') as f: - v = vocab_from_file(f, unk_token='') + v = vocab_from_file(f) + v.insert_token('', 0) expected_itos = ['', 'b', 'a', 'c'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v.get_itos(), expected_itos) @@ -143,8 +144,8 @@ def test_vocab_from_raw_text_file(self): with open(asset_path, 'r') as f: tokenizer = basic_english_normalize() jit_tokenizer = torch.jit.script(tokenizer.to_ivalue()) - v = vocab_from_raw_text_file(f, jit_tokenizer, unk_token='') - expected_itos = ['', "'", 'after', 'talks', '.', 'are', 'at', 'disappointed', + v = vocab_from_raw_text_file(f, jit_tokenizer) + expected_itos = ["'", 'after', 'talks', '.', 'are', 'at', 'disappointed', 'fears', 'federal', 'firm', 'for', 'mogul', 'n', 'newall', 'parent', 'pension', 'representing', 'say', 'stricken', 't', 'they', 'turner', 'unions', 'with', 'workers'] diff --git a/torchtext/experimental/vocab.py b/torchtext/experimental/vocab.py index d9ecac3a41..403f10830b 100644 --- a/torchtext/experimental/vocab.py +++ b/torchtext/experimental/vocab.py @@ -194,7 +194,7 @@ def return_unk_index(self) -> int: index (int): the unknown index. """ - self.vocab.return_unk_index() + return self.vocab.return_unk_index() @torch.jit.export def append_token(self, token: str) -> None: From 6b9368bda78cbf62195d1c19c935745403b09e1c Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Fri, 9 Oct 2020 11:14:22 -0700 Subject: [PATCH 11/45] CI --- test/experimental/test_vocab.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index 58c36a3f62..bbbc0c5cfe 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -82,7 +82,7 @@ def test_vocab_len(self): c = OrderedDict(sorted_by_freq_tuples) v = vocab(c) - self.assertEqual(len(v), 4) + self.assertEqual(len(v), 3) def test_vocab_basic(self): token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2} @@ -123,7 +123,7 @@ def test_vocab_forward(self): jit_v = torch.jit.script(v.to_ivalue()) tokens = ['b', 'a', 'c'] - expected_indices = [2, 1, 3] + expected_indices = [1, 0, 2] self.assertEqual(v(tokens), expected_indices) self.assertEqual(jit_v(tokens), expected_indices) @@ -142,7 +142,7 @@ def test_vocab_lookup_tokens(self): c = OrderedDict(sorted_by_freq_tuples) v = vocab(c) - indices = [2, 1, 3] + indices = [1, 0, 2] expected_tokens = ['b', 'a', 'c'] self.assertEqual(v.lookup_tokens(indices), expected_tokens) @@ -154,7 +154,7 @@ def test_vocab_lookup_indices(self): v = vocab(c) tokens = ['b', 'a', 'c'] - expected_indices = [2, 1, 3] + expected_indices = [1, 0, 2] self.assertEqual(v.lookup_indices(tokens), expected_indices) From ffeb7ab9a54743f88cebe0fd469db916ad617126 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Fri, 9 Oct 2020 11:57:34 -0700 Subject: [PATCH 12/45] checkpoint --- test/experimental/test_vocab.py | 6 +++--- torchtext/csrc/vocab.cpp | 4 ---- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index bbbc0c5cfe..2a71741239 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -134,7 +134,7 @@ def test_vocab_lookup_token(self): c = OrderedDict(sorted_by_freq_tuples) v = vocab(c) - self.assertEqual(v.lookup_token(1), 'a') + self.assertEqual(v.lookup_token(0), 'a') def test_vocab_lookup_tokens(self): token_to_freq = {'a': 2, 'b': 2, 'c': 2} @@ -180,11 +180,11 @@ def test_errors_vocab_python(self): token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) c = OrderedDict(sorted_by_freq_tuples) + vocab(c) with self.assertRaises(RuntimeError): # Test proper error raised when setting unk token to None - vocab(c) - vocab['not_in_vocab'] + vocab(['not_in_vocab']) def test_vocab_load_and_save(self): token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2} diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index c5f2003022..01de2d76b7 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -430,10 +430,6 @@ c10::intrusive_ptr _get_vocab_from_states(VocabStates states) { "Expected `integers` and `tensors` states to be empty."); } - if (version_str.compare("0.0.1") >= 0) { - std::string unk_token = strings.back(); - strings.pop_back(); // remove last element which is unk_token - return c10::make_intrusive(std::move(strings)); } #ifdef _MSC_VER From ca1dbbb9c020ec500a77497e3e03ff9472db906c Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Fri, 9 Oct 2020 12:02:19 -0700 Subject: [PATCH 13/45] checkpoint --- torchtext/csrc/vocab.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 01de2d76b7..b512550e2d 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -430,8 +430,6 @@ c10::intrusive_ptr _get_vocab_from_states(VocabStates states) { "Expected `integers` and `tensors` states to be empty."); } - return c10::make_intrusive(std::move(strings)); - } #ifdef _MSC_VER std::cerr << "[RuntimeError] Found unexpected version for serialized Vocab: " << version_str << std::endl; From 40d6c06adab786aada19d2601a406c70675dc00e Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Fri, 9 Oct 2020 12:30:08 -0700 Subject: [PATCH 14/45] checkpoint --- test/experimental/test_vocab.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index 2a71741239..4b62b5e11c 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -180,11 +180,11 @@ def test_errors_vocab_python(self): token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) c = OrderedDict(sorted_by_freq_tuples) - vocab(c) + v = vocab(c) with self.assertRaises(RuntimeError): # Test proper error raised when setting unk token to None - vocab(['not_in_vocab']) + v(['not_in_vocab']) def test_vocab_load_and_save(self): token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2} From 4246ba2498421388b01d927b649ba49f80533244 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Fri, 9 Oct 2020 12:33:08 -0700 Subject: [PATCH 15/45] checkpoint --- torchtext/csrc/vocab.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index b512550e2d..9588e578a5 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -430,6 +430,8 @@ c10::intrusive_ptr _get_vocab_from_states(VocabStates states) { "Expected `integers` and `tensors` states to be empty."); } + if (version_str.compare("0.0.1") >= 0) + return c10::make_intrusive(std::move(strings)); #ifdef _MSC_VER std::cerr << "[RuntimeError] Found unexpected version for serialized Vocab: " << version_str << std::endl; From 6f6cfad8e4342bad26b94de49316beceea99f872 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Fri, 9 Oct 2020 13:07:48 -0700 Subject: [PATCH 16/45] update save/load in vocab --- test/experimental/test_vocab.py | 3 ++- torchtext/csrc/vocab.cpp | 14 +++++++++----- torchtext/csrc/vocab.h | 2 +- torchtext/experimental/vocab.py | 1 + 4 files changed, 13 insertions(+), 7 deletions(-) diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index 4b62b5e11c..97a59b118f 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -192,7 +192,7 @@ def test_vocab_load_and_save(self): c = OrderedDict(sorted_by_freq_tuples) v = vocab(c, min_freq=3) - + v.set_unk_index(1) expected_itos = ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} @@ -205,6 +205,7 @@ def test_vocab_load_and_save(self): self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi) + self.assertEqual(loaded_v.return_unk_index(), v.return_unk_index()) def test_build_vocab_iterator(self): iterator = [['hello', 'hello', 'hello', 'freq_low', 'hello', 'world', 'world', 'world', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 9588e578a5..bdf6835b24 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -397,13 +397,13 @@ VocabStates _set_vocab_states(const c10::intrusive_ptr &self) { std::vector tensors; VocabStates states = std::make_tuple(self->version_str_, std::move(integers), - std::move(strings), std::move(tensors)); + std::move(strings), self->return_unk_index(), std::move(tensors)); return states; } c10::intrusive_ptr _get_vocab_from_states(VocabStates states) { auto state_size = std::tuple_size::value; - if (state_size != 4) { + if (state_size != 5) { #ifdef _MSC_VER std::cerr << "[RuntimeError] Expected deserialized Vocab to have 4 states " "but found " @@ -417,7 +417,8 @@ c10::intrusive_ptr _get_vocab_from_states(VocabStates states) { auto &version_str = std::get<0>(states); auto &integers = std::get<1>(states); auto &strings = std::get<2>(states); - auto &tensors = std::get<3>(states); + auto &integer = std::get<3>(states); + auto &tensors = std::get<4>(states); // check integers and tensors are empty if (integers.size() != 0 || tensors.size() != 0) { @@ -430,8 +431,11 @@ c10::intrusive_ptr _get_vocab_from_states(VocabStates states) { "Expected `integers` and `tensors` states to be empty."); } - if (version_str.compare("0.0.1") >= 0) - return c10::make_intrusive(std::move(strings)); + if (version_str.compare("0.0.1") >= 0) { + auto vocab_instance = c10::make_intrusive(std::move(strings)); + vocab_instance->set_unk_index(integer); + return vocab_instance; + } #ifdef _MSC_VER std::cerr << "[RuntimeError] Found unexpected version for serialized Vocab: " << version_str << std::endl; diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index 9f8ff9914e..a9d242b261 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -7,7 +7,7 @@ typedef std::vector StringList; typedef ska_ordered::order_preserving_flat_hash_map IndexDict; typedef std::tuple, std::vector, - std::vector> + int64_t, std::vector> VocabStates; struct Vocab : torch::CustomClassHolder { diff --git a/torchtext/experimental/vocab.py b/torchtext/experimental/vocab.py index 403f10830b..54ca933bb5 100644 --- a/torchtext/experimental/vocab.py +++ b/torchtext/experimental/vocab.py @@ -263,4 +263,5 @@ def to_ivalue(self): r"""Return a JITable Vocab. """ cpp_vocab = torch.classes.torchtext.Vocab(self.vocab.itos_) + cpp_vocab.set_unk_index(self.vocab.return_unk_index()) return Vocab(cpp_vocab) From a0d5fc2a3f616744be90a4ed82dbae6ad2a6a40c Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Fri, 9 Oct 2020 13:30:53 -0700 Subject: [PATCH 17/45] checkpooint --- test/experimental/test_vocab.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index 97a59b118f..b071aaa95c 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -176,6 +176,8 @@ def test_errors_vocab_cpp(self): v = vocab(c) v.lookup_token(100) + # we separate out these errors because Windows runs into seg faults when propagating + # exceptions from C++ using pybind11 def test_errors_vocab_python(self): token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) From a49c10dc31749bf4aa8b94725bd9d361e7fd4c6e Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Fri, 9 Oct 2020 14:02:25 -0700 Subject: [PATCH 18/45] checkpoint --- test/experimental/test_vocab.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index b071aaa95c..c32fc79875 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -178,6 +178,7 @@ def test_errors_vocab_cpp(self): # we separate out these errors because Windows runs into seg faults when propagating # exceptions from C++ using pybind11 + @unittest.skipIf(platform.system() == "Windows", "Test is known to fail on Windows.") def test_errors_vocab_python(self): token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) From f76d9b1b28d73734d13ad276f1251fa083a2bb98 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Fri, 9 Oct 2020 15:46:00 -0700 Subject: [PATCH 19/45] checkpoint --- torchtext/csrc/vocab.cpp | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index bdf6835b24..37666936ee 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -25,7 +25,6 @@ Vocab::Vocab(const StringList &tokens) : itos_(std::move(tokens)) { } stoi_[std::move(tokens[i])] = i; } - // unk_index_ = stoi_.find(unk_token)->second; } int64_t Vocab::__len__() const { return stoi_.size(); } @@ -266,17 +265,6 @@ _concat_tokens(std::vector> chunk_counters, unique_tokens.push_back(token_freq_pair.first); } - // insert unk_token if not present - // if (tokens_freq.find(unk_token) == tokens_freq.end()) { - // std::cerr << "The `unk_token` " << unk_token - // << " wasn't found in the `ordered_dict`. Adding the `unk_token` - // " - // "to the beginning of the Vocab." - // << std::endl; - // - // unique_tokens.insert(unique_tokens.begin(), unk_token); - // } - // create stoi IndexDict stoi; stoi.reserve(num_lines); @@ -396,8 +384,9 @@ VocabStates _set_vocab_states(const c10::intrusive_ptr &self) { StringList strings = self->itos_; std::vector tensors; - VocabStates states = std::make_tuple(self->version_str_, std::move(integers), - std::move(strings), self->return_unk_index(), std::move(tensors)); + VocabStates states = std::make_tuple( + self->version_str_, std::move(integers), std::move(strings), + self->return_unk_index(), std::move(tensors)); return states; } From 7bec937860c954a0b246fbf17e514310b7ea74e0 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Fri, 9 Oct 2020 15:47:39 -0700 Subject: [PATCH 20/45] skip test for windows --- test/experimental/test_vocab.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index c32fc79875..2502a85d00 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -17,6 +17,9 @@ def tearDown(self): torch._C._jit_clear_class_registry() torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore() + # we separate out these errors because Windows runs into seg faults when propagating + # exceptions from C++ using pybind11 + @unittest.skipIf(platform.system() == "Windows", "Test is known to fail on Windows.") def test_has_no_unk(self): c = OrderedDict() v = vocab(c) From ba7e561db69c62f42bcf9022d6dbe87e69101129 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Mon, 12 Oct 2020 09:01:45 -0700 Subject: [PATCH 21/45] update unk_index with insert_token --- test/experimental/test_vocab.py | 4 ++++ torchtext/csrc/vocab.cpp | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index 2502a85d00..51772c424b 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -50,21 +50,25 @@ def test_vocab_insert_token(self): # add item to end v = vocab(c) + v.set_unk_index(0) v.insert_token('b', 2) expected_itos = ['', 'a', 'b'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} + self.assertEqual(v.return_unk_index(), 0) self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) # add item to middle v = vocab(c) + v.set_unk_index(0) v.insert_token('b', 0) expected_itos = ['b', '', 'a'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} + self.assertEqual(v.return_unk_index(), 1) self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 37666936ee..17e7e4518f 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -89,7 +89,8 @@ void Vocab::insert_token(const std::string &token, const int64_t &index) { // need to update unk_index in case token equals unk_token or token // inserted before unk_token - // unk_index_ = stoi_.find(unk_token_)->second; + if + index <= unk_index_ { unk_index_ = unk_index_ + 1; } } void Vocab::set_unk_index(const int64_t index) { From ed7be7ddcfca70e82dfa119edd26439e9a6f4682 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Mon, 12 Oct 2020 09:13:26 -0700 Subject: [PATCH 22/45] checkpoint --- torchtext/csrc/vocab.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 17e7e4518f..8733baf4d5 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -89,8 +89,9 @@ void Vocab::insert_token(const std::string &token, const int64_t &index) { // need to update unk_index in case token equals unk_token or token // inserted before unk_token - if - index <= unk_index_ { unk_index_ = unk_index_ + 1; } + if (index <= unk_index_) { + unk_index_ = unk_index_ + 1; + } } void Vocab::set_unk_index(const int64_t index) { From e4e1e05cdd7596e245b7c45379f480a53fa76a43 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Mon, 12 Oct 2020 13:11:27 -0700 Subject: [PATCH 23/45] change unk_index to fallback_index --- test/experimental/test_vocab.py | 18 +++++++++--------- torchtext/csrc/register_bindings.cpp | 8 ++++---- torchtext/csrc/vocab.cpp | 22 +++++++++++----------- torchtext/csrc/vocab.h | 6 +++--- torchtext/experimental/vocab.py | 10 +++++----- 5 files changed, 32 insertions(+), 32 deletions(-) diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index 51772c424b..9c7136c4dc 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -23,7 +23,7 @@ def tearDown(self): def test_has_no_unk(self): c = OrderedDict() v = vocab(c) - self.assertEqual(v.return_unk_index(), -1) + self.assertEqual(v.return_fallback_index(), -1) # check if unk is mapped to the first index with self.assertRaises(RuntimeError): @@ -32,8 +32,8 @@ def test_has_no_unk(self): v[''] v.insert_token('not_in_it', 0) - v.set_unk_index(0) - self.assertEqual(v.return_unk_index(), 0) + v.set_fallback_index(0) + self.assertEqual(v.return_fallback_index(), 0) def test_vocab_get_item(self): token_to_freq = {'': 2, 'a': 2, 'b': 2} @@ -50,25 +50,25 @@ def test_vocab_insert_token(self): # add item to end v = vocab(c) - v.set_unk_index(0) + v.set_fallback_index(0) v.insert_token('b', 2) expected_itos = ['', 'a', 'b'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} - self.assertEqual(v.return_unk_index(), 0) + self.assertEqual(v.return_fallback_index(), 0) self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) # add item to middle v = vocab(c) - v.set_unk_index(0) + v.set_fallback_index(0) v.insert_token('b', 0) expected_itos = ['b', '', 'a'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} - self.assertEqual(v.return_unk_index(), 1) + self.assertEqual(v.return_fallback_index(), 1) self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) @@ -202,7 +202,7 @@ def test_vocab_load_and_save(self): c = OrderedDict(sorted_by_freq_tuples) v = vocab(c, min_freq=3) - v.set_unk_index(1) + v.set_fallback_index(1) expected_itos = ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} @@ -215,7 +215,7 @@ def test_vocab_load_and_save(self): self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi) - self.assertEqual(loaded_v.return_unk_index(), v.return_unk_index()) + self.assertEqual(loaded_v.return_fallback_index(), v.return_fallback_index()) def test_build_vocab_iterator(self): iterator = [['hello', 'hello', 'hello', 'freq_low', 'hello', 'world', 'world', 'world', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index e0208c4829..465746a85b 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -56,8 +56,8 @@ PYBIND11_MODULE(_torchtext, m) { .def("__getitem__", &Vocab::__getitem__) .def("__len__", &Vocab::__len__) .def("insert_token", &Vocab::insert_token) - .def("set_unk_index", &Vocab::set_unk_index) - .def("return_unk_index", &Vocab::return_unk_index) + .def("set_fallback_index", &Vocab::set_fallback_index) + .def("return_fallback_index", &Vocab::return_fallback_index) .def("append_token", &Vocab::append_token) .def("lookup_token", &Vocab::lookup_token) .def("lookup_tokens", &Vocab::lookup_tokens) @@ -140,8 +140,8 @@ static auto vocab = .def("__getitem__", &Vocab::__getitem__) .def("__len__", &Vocab::__len__) .def("insert_token", &Vocab::insert_token) - .def("set_unk_index", &Vocab::set_unk_index) - .def("return_unk_index", &Vocab::return_unk_index) + .def("set_fallback_index", &Vocab::set_fallback_index) + .def("return_fallback_index", &Vocab::return_fallback_index) .def("append_token", &Vocab::append_token) .def("lookup_token", &Vocab::lookup_token) .def("lookup_tokens", &Vocab::lookup_tokens) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 8733baf4d5..0373de9fcf 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -33,11 +33,11 @@ int64_t Vocab::__getitem__(const std::string &token) const { const auto &item = stoi_.find(token); if (item != stoi_.end()) { return item->second; - } else if (unk_index_ != -1) { - return unk_index_; + } else if (fallback_index_ != -1) { + return fallback_index_; } else throw std::runtime_error( - "UNK index has not been set up yet. Call set_unk_index() function to " + "UNK index has not been set up yet. Call set_fallback_index() function to " "set up the UNK index"); } @@ -89,20 +89,20 @@ void Vocab::insert_token(const std::string &token, const int64_t &index) { // need to update unk_index in case token equals unk_token or token // inserted before unk_token - if (index <= unk_index_) { - unk_index_ = unk_index_ + 1; + if (fallback_index_ != -1 && index <= fallback_index_) { + fallback_index_ = fallback_index_ + 1; } } -void Vocab::set_unk_index(const int64_t index) { - if (unk_index_ != -1) +void Vocab::set_fallback_index(const int64_t index) { + if (fallback_index_ != -1) std::cerr << "UNK index has been assigned. You are resetting the UNK index here." << index << std::endl; - unk_index_ = index; + fallback_index_ = index; } -int64_t Vocab::return_unk_index() const { return unk_index_; } +int64_t Vocab::return_fallback_index() const { return fallback_index_; } std::string Vocab::lookup_token(const int64_t &index) { if (index < 0 || index > static_cast(itos_.size())) { @@ -388,7 +388,7 @@ VocabStates _set_vocab_states(const c10::intrusive_ptr &self) { VocabStates states = std::make_tuple( self->version_str_, std::move(integers), std::move(strings), - self->return_unk_index(), std::move(tensors)); + self->return_fallback_index(), std::move(tensors)); return states; } @@ -424,7 +424,7 @@ c10::intrusive_ptr _get_vocab_from_states(VocabStates states) { if (version_str.compare("0.0.1") >= 0) { auto vocab_instance = c10::make_intrusive(std::move(strings)); - vocab_instance->set_unk_index(integer); + vocab_instance->set_fallback_index(integer); return vocab_instance; } #ifdef _MSC_VER diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index a9d242b261..ef9247a19b 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -12,7 +12,7 @@ typedef std::tuple, std::vector, struct Vocab : torch::CustomClassHolder { private: - int64_t unk_index_ = -1; + int64_t fallback_index_ = -1; IndexDict stoi_; public: @@ -25,8 +25,8 @@ struct Vocab : torch::CustomClassHolder { int64_t __getitem__(const std::string &token) const; void append_token(const std::string &token); void insert_token(const std::string &token, const int64_t &index); - void set_unk_index(const int64_t index); - int64_t return_unk_index() const; + void set_fallback_index(const int64_t index); + int64_t return_fallback_index() const; std::string lookup_token(const int64_t &index); std::vector lookup_tokens(const std::vector &indices); std::vector lookup_indices(const std::vector &tokens); diff --git a/torchtext/experimental/vocab.py b/torchtext/experimental/vocab.py index 54ca933bb5..6598e0a9f1 100644 --- a/torchtext/experimental/vocab.py +++ b/torchtext/experimental/vocab.py @@ -179,22 +179,22 @@ def insert_token(self, token: str, index: int) -> None: self.vocab.insert_token(token, index) @torch.jit.export - def set_unk_index(self, index: int) -> None: + def set_fallback_index(self, index: int) -> None: r""" Args: index (int): the unknown index. """ - self.vocab.set_unk_index(index) + self.vocab.set_fallback_index(index) @torch.jit.export - def return_unk_index(self) -> int: + def return_fallback_index(self) -> int: r""" return: index (int): the unknown index. """ - return self.vocab.return_unk_index() + return self.vocab.return_fallback_index() @torch.jit.export def append_token(self, token: str) -> None: @@ -263,5 +263,5 @@ def to_ivalue(self): r"""Return a JITable Vocab. """ cpp_vocab = torch.classes.torchtext.Vocab(self.vocab.itos_) - cpp_vocab.set_unk_index(self.vocab.return_unk_index()) + cpp_vocab.set_fallback_index(self.vocab.return_fallback_index()) return Vocab(cpp_vocab) From 279ba95296633e378dccc02996b57b2e18880491 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Mon, 12 Oct 2020 13:13:44 -0700 Subject: [PATCH 24/45] checkpoint --- torchtext/csrc/vocab.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 0373de9fcf..159a676b14 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -37,8 +37,8 @@ int64_t Vocab::__getitem__(const std::string &token) const { return fallback_index_; } else throw std::runtime_error( - "UNK index has not been set up yet. Call set_fallback_index() function to " - "set up the UNK index"); + "The fallback index has not been set up yet. Call set_fallback_index() function to " + "set up the fallback index"); } void Vocab::append_token(const std::string &token) { From 1cdb82e41c5a38f46aa7bbd9f62269b7b0e42dc9 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Mon, 12 Oct 2020 13:44:07 -0700 Subject: [PATCH 25/45] checkpoint --- torchtext/csrc/vocab.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 159a676b14..a11495458b 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -36,9 +36,9 @@ int64_t Vocab::__getitem__(const std::string &token) const { } else if (fallback_index_ != -1) { return fallback_index_; } else - throw std::runtime_error( - "The fallback index has not been set up yet. Call set_fallback_index() function to " - "set up the fallback index"); + throw std::runtime_error("The fallback index has not been set up yet. Call " + "set_fallback_index() function to " + "set up the fallback index"); } void Vocab::append_token(const std::string &token) { From ccd3166dc37f46eb1f31397408ffa6242fc7e160 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Sun, 18 Oct 2020 15:27:46 -0700 Subject: [PATCH 26/45] switch to default --- test/experimental/test_vocab.py | 18 +++++++++--------- torchtext/csrc/register_bindings.cpp | 8 ++++---- torchtext/csrc/vocab.cpp | 26 +++++++++++++------------- torchtext/csrc/vocab.h | 6 +++--- torchtext/experimental/vocab.py | 10 +++++----- 5 files changed, 34 insertions(+), 34 deletions(-) diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index 9c7136c4dc..a7a263d703 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -23,7 +23,7 @@ def tearDown(self): def test_has_no_unk(self): c = OrderedDict() v = vocab(c) - self.assertEqual(v.return_fallback_index(), -1) + self.assertEqual(v.return_default_index(), -1) # check if unk is mapped to the first index with self.assertRaises(RuntimeError): @@ -32,8 +32,8 @@ def test_has_no_unk(self): v[''] v.insert_token('not_in_it', 0) - v.set_fallback_index(0) - self.assertEqual(v.return_fallback_index(), 0) + v.set_default_index(0) + self.assertEqual(v.return_default_index(), 0) def test_vocab_get_item(self): token_to_freq = {'': 2, 'a': 2, 'b': 2} @@ -50,25 +50,25 @@ def test_vocab_insert_token(self): # add item to end v = vocab(c) - v.set_fallback_index(0) + v.set_default_index(0) v.insert_token('b', 2) expected_itos = ['', 'a', 'b'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} - self.assertEqual(v.return_fallback_index(), 0) + self.assertEqual(v.return_default_index(), 0) self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) # add item to middle v = vocab(c) - v.set_fallback_index(0) + v.set_default_index(0) v.insert_token('b', 0) expected_itos = ['b', '', 'a'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} - self.assertEqual(v.return_fallback_index(), 1) + self.assertEqual(v.return_default_index(), 1) self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) @@ -202,7 +202,7 @@ def test_vocab_load_and_save(self): c = OrderedDict(sorted_by_freq_tuples) v = vocab(c, min_freq=3) - v.set_fallback_index(1) + v.set_default_index(1) expected_itos = ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} @@ -215,7 +215,7 @@ def test_vocab_load_and_save(self): self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi) - self.assertEqual(loaded_v.return_fallback_index(), v.return_fallback_index()) + self.assertEqual(loaded_v.return_default_index(), v.return_default_index()) def test_build_vocab_iterator(self): iterator = [['hello', 'hello', 'hello', 'freq_low', 'hello', 'world', 'world', 'world', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index 465746a85b..cf1d8f66c2 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -56,8 +56,8 @@ PYBIND11_MODULE(_torchtext, m) { .def("__getitem__", &Vocab::__getitem__) .def("__len__", &Vocab::__len__) .def("insert_token", &Vocab::insert_token) - .def("set_fallback_index", &Vocab::set_fallback_index) - .def("return_fallback_index", &Vocab::return_fallback_index) + .def("set_default_index", &Vocab::set_default_index) + .def("return_default_index", &Vocab::return_default_index) .def("append_token", &Vocab::append_token) .def("lookup_token", &Vocab::lookup_token) .def("lookup_tokens", &Vocab::lookup_tokens) @@ -140,8 +140,8 @@ static auto vocab = .def("__getitem__", &Vocab::__getitem__) .def("__len__", &Vocab::__len__) .def("insert_token", &Vocab::insert_token) - .def("set_fallback_index", &Vocab::set_fallback_index) - .def("return_fallback_index", &Vocab::return_fallback_index) + .def("set_default_index", &Vocab::set_default_index) + .def("return_default_index", &Vocab::return_default_index) .def("append_token", &Vocab::append_token) .def("lookup_token", &Vocab::lookup_token) .def("lookup_tokens", &Vocab::lookup_tokens) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index a11495458b..b57dfac691 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -33,12 +33,12 @@ int64_t Vocab::__getitem__(const std::string &token) const { const auto &item = stoi_.find(token); if (item != stoi_.end()) { return item->second; - } else if (fallback_index_ != -1) { - return fallback_index_; + } else if (default_index_ != -1) { + return default_index_; } else - throw std::runtime_error("The fallback index has not been set up yet. Call " - "set_fallback_index() function to " - "set up the fallback index"); + throw std::runtime_error("The default index has not been set up yet. Call " + "set_default_index() function to " + "set up the default index"); } void Vocab::append_token(const std::string &token) { @@ -89,20 +89,20 @@ void Vocab::insert_token(const std::string &token, const int64_t &index) { // need to update unk_index in case token equals unk_token or token // inserted before unk_token - if (fallback_index_ != -1 && index <= fallback_index_) { - fallback_index_ = fallback_index_ + 1; + if (default_index_ != -1 && index <= default_index_) { + default_index_ = default_index_ + 1; } } -void Vocab::set_fallback_index(const int64_t index) { - if (fallback_index_ != -1) +void Vocab::set_default_index(const int64_t index) { + if (default_index_ != -1) std::cerr << "UNK index has been assigned. You are resetting the UNK index here." << index << std::endl; - fallback_index_ = index; + default_index_ = index; } -int64_t Vocab::return_fallback_index() const { return fallback_index_; } +int64_t Vocab::return_default_index() const { return default_index_; } std::string Vocab::lookup_token(const int64_t &index) { if (index < 0 || index > static_cast(itos_.size())) { @@ -388,7 +388,7 @@ VocabStates _set_vocab_states(const c10::intrusive_ptr &self) { VocabStates states = std::make_tuple( self->version_str_, std::move(integers), std::move(strings), - self->return_fallback_index(), std::move(tensors)); + self->return_default_index(), std::move(tensors)); return states; } @@ -424,7 +424,7 @@ c10::intrusive_ptr _get_vocab_from_states(VocabStates states) { if (version_str.compare("0.0.1") >= 0) { auto vocab_instance = c10::make_intrusive(std::move(strings)); - vocab_instance->set_fallback_index(integer); + vocab_instance->set_default_index(integer); return vocab_instance; } #ifdef _MSC_VER diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index ef9247a19b..ae745ed7a4 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -12,7 +12,7 @@ typedef std::tuple, std::vector, struct Vocab : torch::CustomClassHolder { private: - int64_t fallback_index_ = -1; + int64_t default_index_ = -1; IndexDict stoi_; public: @@ -25,8 +25,8 @@ struct Vocab : torch::CustomClassHolder { int64_t __getitem__(const std::string &token) const; void append_token(const std::string &token); void insert_token(const std::string &token, const int64_t &index); - void set_fallback_index(const int64_t index); - int64_t return_fallback_index() const; + void set_default_index(const int64_t index); + int64_t return_default_index() const; std::string lookup_token(const int64_t &index); std::vector lookup_tokens(const std::vector &indices); std::vector lookup_indices(const std::vector &tokens); diff --git a/torchtext/experimental/vocab.py b/torchtext/experimental/vocab.py index 6598e0a9f1..1489b11f84 100644 --- a/torchtext/experimental/vocab.py +++ b/torchtext/experimental/vocab.py @@ -179,22 +179,22 @@ def insert_token(self, token: str, index: int) -> None: self.vocab.insert_token(token, index) @torch.jit.export - def set_fallback_index(self, index: int) -> None: + def set_default_index(self, index: int) -> None: r""" Args: index (int): the unknown index. """ - self.vocab.set_fallback_index(index) + self.vocab.set_default_index(index) @torch.jit.export - def return_fallback_index(self) -> int: + def return_default_index(self) -> int: r""" return: index (int): the unknown index. """ - return self.vocab.return_fallback_index() + return self.vocab.return_default_index() @torch.jit.export def append_token(self, token: str) -> None: @@ -263,5 +263,5 @@ def to_ivalue(self): r"""Return a JITable Vocab. """ cpp_vocab = torch.classes.torchtext.Vocab(self.vocab.itos_) - cpp_vocab.set_fallback_index(self.vocab.return_fallback_index()) + cpp_vocab.set_default_index(self.vocab.return_default_index()) return Vocab(cpp_vocab) From c61c12788f716c2af4da869ffe32a03b0eb5ff12 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Mon, 19 Oct 2020 13:21:36 -0700 Subject: [PATCH 27/45] checkpoint --- test/experimental/test_vocab.py | 10 +++++----- torchtext/csrc/register_bindings.cpp | 4 ++-- torchtext/csrc/vocab.cpp | 4 ++-- torchtext/csrc/vocab.h | 2 +- torchtext/experimental/vocab.py | 6 +++--- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index a7a263d703..a741693502 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -23,7 +23,7 @@ def tearDown(self): def test_has_no_unk(self): c = OrderedDict() v = vocab(c) - self.assertEqual(v.return_default_index(), -1) + self.assertEqual(v.get_default_index(), -1) # check if unk is mapped to the first index with self.assertRaises(RuntimeError): @@ -33,7 +33,7 @@ def test_has_no_unk(self): v.insert_token('not_in_it', 0) v.set_default_index(0) - self.assertEqual(v.return_default_index(), 0) + self.assertEqual(v.get_default_index(), 0) def test_vocab_get_item(self): token_to_freq = {'': 2, 'a': 2, 'b': 2} @@ -56,7 +56,7 @@ def test_vocab_insert_token(self): expected_itos = ['', 'a', 'b'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} - self.assertEqual(v.return_default_index(), 0) + self.assertEqual(v.get_default_index(), 0) self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) @@ -68,7 +68,7 @@ def test_vocab_insert_token(self): expected_itos = ['b', '', 'a'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} - self.assertEqual(v.return_default_index(), 1) + self.assertEqual(v.get_default_index(), 1) self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) @@ -215,7 +215,7 @@ def test_vocab_load_and_save(self): self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi) - self.assertEqual(loaded_v.return_default_index(), v.return_default_index()) + self.assertEqual(loaded_v.get_default_index(), v.get_default_index()) def test_build_vocab_iterator(self): iterator = [['hello', 'hello', 'hello', 'freq_low', 'hello', 'world', 'world', 'world', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index cf1d8f66c2..3c17cbe4d5 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -57,7 +57,7 @@ PYBIND11_MODULE(_torchtext, m) { .def("__len__", &Vocab::__len__) .def("insert_token", &Vocab::insert_token) .def("set_default_index", &Vocab::set_default_index) - .def("return_default_index", &Vocab::return_default_index) + .def("get_default_index", &Vocab::get_default_index) .def("append_token", &Vocab::append_token) .def("lookup_token", &Vocab::lookup_token) .def("lookup_tokens", &Vocab::lookup_tokens) @@ -141,7 +141,7 @@ static auto vocab = .def("__len__", &Vocab::__len__) .def("insert_token", &Vocab::insert_token) .def("set_default_index", &Vocab::set_default_index) - .def("return_default_index", &Vocab::return_default_index) + .def("get_default_index", &Vocab::get_default_index) .def("append_token", &Vocab::append_token) .def("lookup_token", &Vocab::lookup_token) .def("lookup_tokens", &Vocab::lookup_tokens) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index b57dfac691..e8eec0b6c6 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -102,7 +102,7 @@ void Vocab::set_default_index(const int64_t index) { default_index_ = index; } -int64_t Vocab::return_default_index() const { return default_index_; } +int64_t Vocab::get_default_index() const { return default_index_; } std::string Vocab::lookup_token(const int64_t &index) { if (index < 0 || index > static_cast(itos_.size())) { @@ -388,7 +388,7 @@ VocabStates _set_vocab_states(const c10::intrusive_ptr &self) { VocabStates states = std::make_tuple( self->version_str_, std::move(integers), std::move(strings), - self->return_default_index(), std::move(tensors)); + self->get_default_index(), std::move(tensors)); return states; } diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index ae745ed7a4..3d37285734 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -26,7 +26,7 @@ struct Vocab : torch::CustomClassHolder { void append_token(const std::string &token); void insert_token(const std::string &token, const int64_t &index); void set_default_index(const int64_t index); - int64_t return_default_index() const; + int64_t get_default_index() const; std::string lookup_token(const int64_t &index); std::vector lookup_tokens(const std::vector &indices); std::vector lookup_indices(const std::vector &tokens); diff --git a/torchtext/experimental/vocab.py b/torchtext/experimental/vocab.py index 1489b11f84..6c690aa918 100644 --- a/torchtext/experimental/vocab.py +++ b/torchtext/experimental/vocab.py @@ -188,13 +188,13 @@ def set_default_index(self, index: int) -> None: self.vocab.set_default_index(index) @torch.jit.export - def return_default_index(self) -> int: + def get_default_index(self) -> int: r""" return: index (int): the unknown index. """ - return self.vocab.return_default_index() + return self.vocab.get_default_index() @torch.jit.export def append_token(self, token: str) -> None: @@ -263,5 +263,5 @@ def to_ivalue(self): r"""Return a JITable Vocab. """ cpp_vocab = torch.classes.torchtext.Vocab(self.vocab.itos_) - cpp_vocab.set_default_index(self.vocab.return_default_index()) + cpp_vocab.set_default_index(self.vocab.get_default_index()) return Vocab(cpp_vocab) From 6b7e5ada5750317bf695fb97c557c38759349c40 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Tue, 20 Oct 2020 15:10:11 -0700 Subject: [PATCH 28/45] update test --- test/experimental/test_vocab.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index a741693502..fdf1638422 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -34,6 +34,8 @@ def test_has_no_unk(self): v.insert_token('not_in_it', 0) v.set_default_index(0) self.assertEqual(v.get_default_index(), 0) + self.assertEqual(v['not_in_it'], 0) + self.assertEqual(v[''], 0) def test_vocab_get_item(self): token_to_freq = {'': 2, 'a': 2, 'b': 2} From 685822746aaa77603a8ff30a95d1cd21ea5915b9 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Tue, 20 Oct 2020 15:24:03 -0700 Subject: [PATCH 29/45] add one more test for inserting existing token --- test/experimental/test_vocab.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index fdf1638422..ad66434c11 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -74,6 +74,24 @@ def test_vocab_insert_token(self): self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) + # we separate out these errors because Windows runs into seg faults when propagating + # exceptions from C++ using pybind11 + @unittest.skipIf(platform.system() == "Windows", "Test is known to fail on Windows.") + def test_insert_existing_token(self): + c = OrderedDict({'a': 2, 'b': 2, 'c': 2}) + + # add item to end + v = vocab(c) + v.insert_token('', 2) + v.set_default_index(2) + + with self.assertRaises(RuntimeError): + # Test proper error raised when setting a token out of bounds + v.insert_token('', 1) + + v.insert_token('d', 1) + self.assertEqual(v['not_in_it'], 3) + def test_vocab_append_token(self): c = OrderedDict({'a': 2}) v = vocab(c) From b40d8dd1cb9b7811b16e873ab50028a9443f6b1c Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Wed, 21 Oct 2020 12:08:12 -0700 Subject: [PATCH 30/45] use c10::optional for default index --- test/experimental/test_transforms_with_asset.py | 8 ++++++++ test/experimental/test_vocab.py | 9 ++++++++- torchtext/csrc/vocab.cpp | 14 ++++++-------- torchtext/csrc/vocab.h | 2 +- torchtext/experimental/vocab.py | 7 +++++-- 5 files changed, 28 insertions(+), 12 deletions(-) diff --git a/test/experimental/test_transforms_with_asset.py b/test/experimental/test_transforms_with_asset.py index 37d5dae082..2c006f3331 100644 --- a/test/experimental/test_transforms_with_asset.py +++ b/test/experimental/test_transforms_with_asset.py @@ -13,6 +13,8 @@ vocab_from_file, vocab_from_raw_text_file, ) +import unittest +import platform import shutil import tempfile import os @@ -26,6 +28,9 @@ class TestTransformsWithAsset(TorchtextTestCase): + # we separate out these errors because Windows runs into seg faults when propagating + # exceptions from C++ using pybind11 + @unittest.skipIf(platform.system() == "Windows", "Test is known to fail on Windows.") def test_vocab_transform(self): asset_name = 'vocab_test2.txt' asset_path = get_asset_path(asset_name) @@ -170,6 +175,9 @@ def test_builtin_pretrained_sentencepiece_processor(self): ref_results = [13, 1465, 12824, 304, 24935, 5771, 3776] self.assertEqual(spm_transform(test_sample), ref_results) + # we separate out these errors because Windows runs into seg faults when propagating + # exceptions from C++ using pybind11 + @unittest.skipIf(platform.system() == "Windows", "Test is known to fail on Windows.") def test_text_sequential_transform(self): asset_name = 'vocab_test2.txt' asset_path = get_asset_path(asset_name) diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index ad66434c11..76733239b9 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -23,7 +23,8 @@ def tearDown(self): def test_has_no_unk(self): c = OrderedDict() v = vocab(c) - self.assertEqual(v.get_default_index(), -1) + with self.assertRaisesRegex(RuntimeError, 'bad optional access'): + v.get_default_index() # check if unk is mapped to the first index with self.assertRaises(RuntimeError): @@ -124,6 +125,9 @@ def test_vocab_basic(self): self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) + # we separate out these errors because Windows runs into seg faults when propagating + # exceptions from C++ using pybind11 + @unittest.skipIf(platform.system() == "Windows", "Test is known to fail on Windows.") def test_vocab_jit(self): token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) @@ -141,6 +145,9 @@ def test_vocab_jit(self): self.assertEqual(jit_v.get_itos(), expected_itos) self.assertEqual(dict(jit_v.get_stoi()), expected_stoi) + # we separate out these errors because Windows runs into seg faults when propagating + # exceptions from C++ using pybind11 + @unittest.skipIf(platform.system() == "Windows", "Test is known to fail on Windows.") def test_vocab_forward(self): token_to_freq = {'a': 2, 'b': 2, 'c': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index e8eec0b6c6..d5812285b3 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -33,8 +33,8 @@ int64_t Vocab::__getitem__(const std::string &token) const { const auto &item = stoi_.find(token); if (item != stoi_.end()) { return item->second; - } else if (default_index_ != -1) { - return default_index_; + } else if (default_index_.has_value()) { + return default_index_.value(); } else throw std::runtime_error("The default index has not been set up yet. Call " "set_default_index() function to " @@ -89,20 +89,20 @@ void Vocab::insert_token(const std::string &token, const int64_t &index) { // need to update unk_index in case token equals unk_token or token // inserted before unk_token - if (default_index_ != -1 && index <= default_index_) { - default_index_ = default_index_ + 1; + if (default_index_.has_value() && index <= *default_index_) { + default_index_ = *default_index_ + 1; } } void Vocab::set_default_index(const int64_t index) { - if (default_index_ != -1) + if (default_index_.has_value()) std::cerr << "UNK index has been assigned. You are resetting the UNK index here." << index << std::endl; default_index_ = index; } -int64_t Vocab::get_default_index() const { return default_index_; } +int64_t Vocab::get_default_index() const { return default_index_.value(); } std::string Vocab::lookup_token(const int64_t &index) { if (index < 0 || index > static_cast(itos_.size())) { @@ -326,8 +326,6 @@ Vocab _load_vocab_from_file(const std::string &file_path, std::tie(stoi, tokens) = _concat_tokens(chunk_counters, min_freq, num_lines, false); - // int64_t unk_index = stoi.find(unk_token)->second; - return Vocab(std::move(tokens), std::move(stoi)); } diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index 3d37285734..ca06b52841 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -12,7 +12,7 @@ typedef std::tuple, std::vector, struct Vocab : torch::CustomClassHolder { private: - int64_t default_index_ = -1; + c10::optional default_index_ = {}; IndexDict stoi_; public: diff --git a/torchtext/experimental/vocab.py b/torchtext/experimental/vocab.py index 6c690aa918..093109583e 100644 --- a/torchtext/experimental/vocab.py +++ b/torchtext/experimental/vocab.py @@ -263,5 +263,8 @@ def to_ivalue(self): r"""Return a JITable Vocab. """ cpp_vocab = torch.classes.torchtext.Vocab(self.vocab.itos_) - cpp_vocab.set_default_index(self.vocab.get_default_index()) - return Vocab(cpp_vocab) + try: + cpp_vocab.set_default_index(self.vocab.get_default_index()) + return Vocab(cpp_vocab) + except RuntimeError: + return Vocab(cpp_vocab) From 4ec66e68392c46f84359a9cad9f95329d4761166 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Thu, 22 Oct 2020 07:23:28 -0700 Subject: [PATCH 31/45] checkpoint --- torchtext/csrc/vocab.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index d5812285b3..eb86e1d9a9 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -90,7 +90,7 @@ void Vocab::insert_token(const std::string &token, const int64_t &index) { // need to update unk_index in case token equals unk_token or token // inserted before unk_token if (default_index_.has_value() && index <= *default_index_) { - default_index_ = *default_index_ + 1; + default_index_ = default_index_.value() + 1; } } From c5e3773b9dfee04797f3439582fb12fed1023c5d Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Mon, 2 Nov 2020 07:44:18 -0800 Subject: [PATCH 32/45] Update docs --- torchtext/experimental/vocab.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/torchtext/experimental/vocab.py b/torchtext/experimental/vocab.py index 566d68d509..df82e53c84 100644 --- a/torchtext/experimental/vocab.py +++ b/torchtext/experimental/vocab.py @@ -43,6 +43,9 @@ def vocab_from_raw_text_file(file_object, jited_tokenizer, min_freq=1, num_cpus= >>> tokenizer = basic_english_normalize() >>> jit_tokenizer = torch.jit.script(tokenizer.to_ivalue()) >>> v = vocab_from_raw_text_file(f, jit_tokenizer) + >>> v.insert_token('', 0) + >>> v.set_default_index(0) + >>> v.get_default_index() """ vocab_obj = _load_vocab_from_raw_text_file(file_object.name, min_freq, num_cpus, jited_tokenizer) return Vocab(vocab_obj) @@ -72,6 +75,9 @@ def vocab_from_file(file_object, min_freq=1, num_cpus=4): >>> from torchtext.experimental.vocab import vocab_from_file >>> f = open('vocab.txt', 'r') >>> v = vocab_from_file(f) + >>> v.insert_token('', 0) + >>> v.set_default_index(0) + >>> v.get_default_index() """ vocab_obj = _load_vocab_from_file(file_object.name, min_freq, num_cpus) return Vocab(vocab_obj) @@ -85,6 +91,16 @@ def build_vocab_from_iterator(iterator, min_freq=1): iterator: Iterator used to build Vocab. Must yield list or iterator of tokens. min_freq: The minimum frequency needed to include a token in the vocabulary. Values less than 1 will be set to 1. Default: 1. + + Examples: + >>> from torchtext.experimental.vocab import build_vocab_from_iterator + >>> tokens = [['this', 'is', 'an', 'example', 'for', 'vocab']] + >>> v = build_vocab_from_iterator(tokens) + >>> v.insert_token('', 0) + >>> v.set_default_index(0) + >>> v.get_default_index() + >>> tokens_iter = iter([['this', 'is', 'an'], ['example', 'for', 'vocab']]) + >>> v1 = build_vocab_from_iterator(tokens_iter) """ counter = Counter() @@ -114,6 +130,9 @@ def vocab(ordered_dict, min_freq=1): >>> sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[1], reverse=True) >>> ordered_dict = OrderedDict(sorted_by_freq_tuples) >>> v1 = vocab(ordered_dict) + >>> v1.insert_token('', 0) + >>> v1.set_default_index(0) + >>> v1.get_default_index() >>> tokens = ['e', 'd', 'c', 'b', 'a'] >>> v2 = vocab(OrderedDict([(token, 1) for token in tokens])) """ From f3ed7675b6d22d5d367e63b8a8c5a6c52df4f38e Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Mon, 2 Nov 2020 15:22:29 -0800 Subject: [PATCH 33/45] checkpoint --- torchtext/csrc/vocab.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index eb86e1d9a9..b39f9e1b6c 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -374,7 +374,6 @@ Vocab _load_vocab_from_raw_text_file(const std::string &file_path, StringList tokens; std::tie(stoi, tokens) = _concat_tokens(chunk_counters, min_freq, num_lines, true); - // int64_t unk_index = stoi.find(unk_token)->second; return Vocab(std::move(tokens), std::move(stoi)); } From a9b27decfd094f0b907df0ecf305fb4784e2aeff Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Tue, 3 Nov 2020 13:43:06 -0800 Subject: [PATCH 34/45] set_default_index if the saved vocab has a default index --- torchtext/csrc/vocab.cpp | 12 +++++++++--- torchtext/csrc/vocab.h | 2 +- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index b39f9e1b6c..937e9b479c 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -383,9 +383,13 @@ VocabStates _set_vocab_states(const c10::intrusive_ptr &self) { StringList strings = self->itos_; std::vector tensors; + c10::optional default_index = {}; + if (self->default_index_.has_value()) + default_index = self->default_index_.value(); + VocabStates states = std::make_tuple( self->version_str_, std::move(integers), std::move(strings), - self->get_default_index(), std::move(tensors)); + default_index, std::move(tensors)); return states; } @@ -405,7 +409,7 @@ c10::intrusive_ptr _get_vocab_from_states(VocabStates states) { auto &version_str = std::get<0>(states); auto &integers = std::get<1>(states); auto &strings = std::get<2>(states); - auto &integer = std::get<3>(states); + auto &default_index = std::get<3>(states); auto &tensors = std::get<4>(states); // check integers and tensors are empty @@ -421,7 +425,9 @@ c10::intrusive_ptr _get_vocab_from_states(VocabStates states) { if (version_str.compare("0.0.1") >= 0) { auto vocab_instance = c10::make_intrusive(std::move(strings)); - vocab_instance->set_default_index(integer); + if (default_index.has_value()) + vocab_instance->set_default_index(default_index.value()); + return vocab_instance; } #ifdef _MSC_VER diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index ca06b52841..7e83a0d3e2 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -12,12 +12,12 @@ typedef std::tuple, std::vector, struct Vocab : torch::CustomClassHolder { private: - c10::optional default_index_ = {}; IndexDict stoi_; public: const std::string version_str_ = "0.0.1"; StringList itos_; + c10::optional default_index_ = {}; explicit Vocab(const std::vector &tokens); explicit Vocab(const StringList &tokens, const IndexDict &stoi); From 53a353f6d60f5662a94562e34254032cb37147e5 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Tue, 3 Nov 2020 14:10:19 -0800 Subject: [PATCH 35/45] checkpoint --- torchtext/csrc/vocab.cpp | 2 +- torchtext/csrc/vocab.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 937e9b479c..ce365f3a7e 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -383,7 +383,7 @@ VocabStates _set_vocab_states(const c10::intrusive_ptr &self) { StringList strings = self->itos_; std::vector tensors; - c10::optional default_index = {}; + c10::optional default_index = {}; if (self->default_index_.has_value()) default_index = self->default_index_.value(); diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index 7e83a0d3e2..f3f0fd62de 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -7,7 +7,7 @@ typedef std::vector StringList; typedef ska_ordered::order_preserving_flat_hash_map IndexDict; typedef std::tuple, std::vector, - int64_t, std::vector> + c10::optional, std::vector> VocabStates; struct Vocab : torch::CustomClassHolder { From aeb9995bf4a8bc7be7f1a0295a99ea75ac31cd16 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Wed, 23 Dec 2020 08:47:33 -0800 Subject: [PATCH 36/45] checkpoint --- torchtext/csrc/register_bindings.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index 3b71dc073a..5797898914 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -15,12 +15,11 @@ namespace py = pybind11; namespace { Vocab build_vocab_from_text_file(const std::string &file_path, - const std::string &unk_token, const int64_t min_freq, const int64_t num_cpus, py::object fn) { torch::jit::script::Module module(*torch::jit::as_module(fn)); - return _build_vocab_from_text_file(file_path, unk_token, min_freq, num_cpus, module); + return _build_vocab_from_text_file(file_path, min_freq, num_cpus, module); } } // namespace @@ -100,7 +99,7 @@ PYBIND11_MODULE(_torchtext, m) { })); py::class_>(m, "Vocab") - .def(py::init, std::string>()) + .def(py::init>()) .def_readonly("itos_", &Vocab::itos_) .def("__getitem__", &Vocab::__getitem__) .def("__len__", &Vocab::__len__) @@ -203,7 +202,7 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) { }); m.class_("Vocab") - .def(torch::init()) + .def(torch::init()) .def("__getitem__", &Vocab::__getitem__) .def("__len__", &Vocab::__len__) .def("insert_token", &Vocab::insert_token) From 588cce45d1caeefbf8d430c0f87ac92bdcac5eb8 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Wed, 23 Dec 2020 09:01:34 -0800 Subject: [PATCH 37/45] checkpoint --- torchtext/csrc/register_bindings.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index 5797898914..db4dcf47e5 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -206,6 +206,8 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) { .def("__getitem__", &Vocab::__getitem__) .def("__len__", &Vocab::__len__) .def("insert_token", &Vocab::insert_token) + .def("set_default_index", &Vocab::set_default_index) + .def("get_default_index", &Vocab::get_default_index) .def("append_token", &Vocab::append_token) .def("lookup_token", &Vocab::lookup_token) .def("lookup_tokens", &Vocab::lookup_tokens) From 67ec4669c4cf9195b300cc4c3190576bf9c21371 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Wed, 23 Dec 2020 09:18:02 -0800 Subject: [PATCH 38/45] checkpoint --- test/experimental/test_with_asset.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/experimental/test_with_asset.py b/test/experimental/test_with_asset.py index 03df6b97a0..a107168f49 100644 --- a/test/experimental/test_with_asset.py +++ b/test/experimental/test_with_asset.py @@ -20,8 +20,6 @@ import shutil import tempfile import os -import unittest -import platform from torchtext.experimental.vectors import ( GloVe, build_vectors, @@ -185,7 +183,7 @@ def test_vocab_from_file(self): asset_name = 'vocab_test.txt' asset_path = get_asset_path(asset_name) with open(asset_path, 'r') as f: - v = load_vocab_from_file(f, unk_token='') + v = load_vocab_from_file(f) v.insert_token('', 0) expected_itos = ['', 'b', 'a', 'c'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} From 4567d7c28f155607a4c7c659a5f96d8b41bf0aef Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Wed, 23 Dec 2020 11:08:02 -0800 Subject: [PATCH 39/45] add setitem func for vocab --- torchtext/csrc/register_bindings.cpp | 2 ++ torchtext/csrc/vocab.cpp | 15 +++++++++++++++ torchtext/csrc/vocab.h | 1 + torchtext/experimental/vocab.py | 12 ++++++++++++ 4 files changed, 30 insertions(+) diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index db4dcf47e5..a7f6f054c9 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -102,6 +102,7 @@ PYBIND11_MODULE(_torchtext, m) { .def(py::init>()) .def_readonly("itos_", &Vocab::itos_) .def("__getitem__", &Vocab::__getitem__) + .def("__setitem__", &Vocab::__setitem__) .def("__len__", &Vocab::__len__) .def("insert_token", &Vocab::insert_token) .def("set_default_index", &Vocab::set_default_index) @@ -204,6 +205,7 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) { m.class_("Vocab") .def(torch::init()) .def("__getitem__", &Vocab::__getitem__) + .def("__setitem__", &Vocab::__setitem__) .def("__len__", &Vocab::__len__) .def("insert_token", &Vocab::insert_token) .def("set_default_index", &Vocab::set_default_index) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 99b3c11699..695e7a468d 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -40,6 +40,21 @@ int64_t Vocab::__getitem__(const std::string &token) const { "set up the default index"); } +void __setitem__(const std::string &token, const int64_t &index) { + if (index < 0 || index > static_cast(stoi_.size())) { +#ifdef _MSC_VER + std::cerr << "[RuntimeError] Specified index " << index + << " is out of bounds of the size of stoi dictionary: " + << stoi_.size() << std::endl; +#endif + throw std::runtime_error( + "Specified index " + std::to_string(index) + + " is out of bounds of the size of stoi dictionary: " + + std::to_string(stoi_.size()) + "."); + } + +} + void Vocab::append_token(const std::string &token) { if (stoi_.find(token) == stoi_.end()) { // Note: we can't do `stoi_[token] = stoi_.size()` because of a bug diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index cfa169a8a1..c8575a4a13 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -22,6 +22,7 @@ struct Vocab : torch::CustomClassHolder { explicit Vocab(const StringList &tokens, const IndexDict &stoi); int64_t __len__() const; int64_t __getitem__(const std::string &token) const; + void __setitem__(const std::string &token, const int64_t &index); void append_token(const std::string &token); void insert_token(const std::string &token, const int64_t &index); void set_default_index(const int64_t index); diff --git a/torchtext/experimental/vocab.py b/torchtext/experimental/vocab.py index 01c73ceacf..09140ffebd 100644 --- a/torchtext/experimental/vocab.py +++ b/torchtext/experimental/vocab.py @@ -189,6 +189,18 @@ def __getitem__(self, token: str) -> int: """ return self.vocab[token] + @torch.jit.export + def __setitem__(self, token: str, index: int) -> None: + r""" + Args: + token (str): the token used to lookup the corresponding index. + index (int): the index corresponding to the associated token. + + Raises: + RuntimeError: if `index` not between [0, Vocab.size()] or if token already exists in the vocab. + """ + self.vocab[token] = index + @torch.jit.export def insert_token(self, token: str, index: int) -> None: r""" From 15f29be1c50c2a7b62a1eb2eb88db12a5acee937 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Wed, 23 Dec 2020 11:15:20 -0800 Subject: [PATCH 40/45] checkpoint --- torchtext/csrc/vocab.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 695e7a468d..40b7b4d91d 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -40,7 +40,7 @@ int64_t Vocab::__getitem__(const std::string &token) const { "set up the default index"); } -void __setitem__(const std::string &token, const int64_t &index) { +void Vocab::__setitem__(const std::string &token, const int64_t &index) { if (index < 0 || index > static_cast(stoi_.size())) { #ifdef _MSC_VER std::cerr << "[RuntimeError] Specified index " << index From 33f31d9e9127d850aed6d2c2f31f842158d8887d Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Mon, 28 Dec 2020 10:34:53 -0800 Subject: [PATCH 41/45] add delete_token func --- torchtext/csrc/vocab.cpp | 25 +++++++++++++++++++++++++ torchtext/csrc/vocab.h | 1 + 2 files changed, 26 insertions(+) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 40b7b4d91d..466a5e5103 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -108,6 +108,31 @@ void Vocab::insert_token(const std::string &token, const int64_t &index) { } } +void Vocab::_delete_token(const std::string &token) { + const auto &item = stoi_.find(token); + // if item already in stoi we throw an error + if (item == stoi_.end()) { +#ifdef _MSC_VER + std::cerr << "[RuntimeError] Token " << token + << " doesn't exist in the Vocab" + << std::endl; +#endif + throw std::runtime_error("Token " + token + + " doesn't exist in the Vocab" + "."); + } + for (size_t i = item->second + 1; i < itos_.size(); i++) { + stoi_[itos_[i]] = i - 1; + } + stoi_.erase(token); + itos_.erase(itos_.begin() + item->second); + + // need to update unk_index in case token equals unk_token or token + // inserted before unk_token + if (default_index_.has_value() && item->second < *default_index_) { + default_index_ = default_index_.value() - 1; + } +} + void Vocab::set_default_index(const int64_t index) { if (default_index_.has_value()) std::cerr diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index c8575a4a13..79b0987253 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -25,6 +25,7 @@ struct Vocab : torch::CustomClassHolder { void __setitem__(const std::string &token, const int64_t &index); void append_token(const std::string &token); void insert_token(const std::string &token, const int64_t &index); + void _delete_token(const std::string &token); void set_default_index(const int64_t index); int64_t get_default_index() const; std::string lookup_token(const int64_t &index); From 6beae11ca66992f800dad33874124f5a921e3e2d Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Mon, 28 Dec 2020 11:41:44 -0800 Subject: [PATCH 42/45] implement setitem func and add a test --- test/experimental/test_vocab.py | 20 ++++++++++++++++++++ torchtext/csrc/vocab.cpp | 6 ++++++ 2 files changed, 26 insertions(+) diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index f072231fbd..3e0a494164 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -48,6 +48,26 @@ def test_vocab_get_item(self): self.assertEqual(v['a'], 1) self.assertEqual(v['b'], 2) + def test_vocab_set_item(self): + token_to_freq = {'': 2, 'a': 2, 'b': 2} + sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) + c = OrderedDict(sorted_by_freq_tuples) + v = vocab(c, min_freq=2) + + v.set_default_index(0) + v['b'] = 1 + self.assertEqual(v[''], 0) + self.assertEqual(v['a'], 2) + self.assertEqual(v['b'], 1) + self.assertEqual(v['not_in_it'], 0) + + v.set_default_index(1) + v['a'] = 0 + self.assertEqual(v[''], 1) + self.assertEqual(v['a'], 0) + self.assertEqual(v['b'], 2) + self.assertEqual(v['not_in_it'], 2) + def test_vocab_insert_token(self): c = OrderedDict({'': 2, 'a': 2}) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 466a5e5103..ae1ab3a68a 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -53,6 +53,12 @@ void Vocab::__setitem__(const std::string &token, const int64_t &index) { std::to_string(stoi_.size()) + "."); } + const auto &item = stoi_.find(token); + if (item != stoi_.end()) { + _delete_token(token); + } + + insert_token(token, index); } void Vocab::append_token(const std::string &token) { From 82fce2b0d60d7828820800da98cf591d34ebcda2 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Wed, 6 Jan 2021 13:44:00 -0800 Subject: [PATCH 43/45] add __delete__ func to Vocab and remind users not to reassign an existing token. They have to delete it before reassignment. --- test/experimental/test_vocab.py | 18 ++++++++++-------- torchtext/csrc/register_bindings.cpp | 2 ++ torchtext/csrc/vocab.cpp | 12 +++++++----- torchtext/csrc/vocab.h | 2 +- torchtext/experimental/vocab.py | 12 ++++++++++++ 5 files changed, 32 insertions(+), 14 deletions(-) diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index 3e0a494164..1a47caf7cc 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -48,6 +48,9 @@ def test_vocab_get_item(self): self.assertEqual(v['a'], 1) self.assertEqual(v['b'], 2) + # we separate out these errors because Windows runs into seg faults when propagating + # exceptions from C++ using pybind11 + @unittest.skipIf(platform.system() == "Windows", "Test is known to fail on Windows.") def test_vocab_set_item(self): token_to_freq = {'': 2, 'a': 2, 'b': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) @@ -55,19 +58,18 @@ def test_vocab_set_item(self): v = vocab(c, min_freq=2) v.set_default_index(0) + with self.assertRaises(RuntimeError): + v['b'] = 1 + del v['b'] + self.assertEqual(v[''], 0) + self.assertEqual(v['a'], 1) + self.assertEqual(v['not_in_it'], 0) + v['b'] = 1 self.assertEqual(v[''], 0) - self.assertEqual(v['a'], 2) self.assertEqual(v['b'], 1) self.assertEqual(v['not_in_it'], 0) - v.set_default_index(1) - v['a'] = 0 - self.assertEqual(v[''], 1) - self.assertEqual(v['a'], 0) - self.assertEqual(v['b'], 2) - self.assertEqual(v['not_in_it'], 2) - def test_vocab_insert_token(self): c = OrderedDict({'': 2, 'a': 2}) diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index a7f6f054c9..b1c17dd79b 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -103,6 +103,7 @@ PYBIND11_MODULE(_torchtext, m) { .def_readonly("itos_", &Vocab::itos_) .def("__getitem__", &Vocab::__getitem__) .def("__setitem__", &Vocab::__setitem__) + .def("__delitem__", &Vocab::__delitem__) .def("__len__", &Vocab::__len__) .def("insert_token", &Vocab::insert_token) .def("set_default_index", &Vocab::set_default_index) @@ -206,6 +207,7 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) { .def(torch::init()) .def("__getitem__", &Vocab::__getitem__) .def("__setitem__", &Vocab::__setitem__) + .def("__delitem__", &Vocab::__delitem__) .def("__len__", &Vocab::__len__) .def("insert_token", &Vocab::insert_token) .def("set_default_index", &Vocab::set_default_index) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index ae1ab3a68a..229d27eaa9 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -53,12 +53,14 @@ void Vocab::__setitem__(const std::string &token, const int64_t &index) { std::to_string(stoi_.size()) + "."); } - const auto &item = stoi_.find(token); + auto item = stoi_.find(token); if (item != stoi_.end()) { - _delete_token(token); + throw std::runtime_error( + "Token " + token + + " has already been in the Vocab. Please delete it first by call del func."); } - - insert_token(token, index); + item->first = token; + itos_[index] = token; } void Vocab::append_token(const std::string &token) { @@ -114,7 +116,7 @@ void Vocab::insert_token(const std::string &token, const int64_t &index) { } } -void Vocab::_delete_token(const std::string &token) { +void Vocab::__delitem__(const std::string &token) { const auto &item = stoi_.find(token); // if item already in stoi we throw an error if (item == stoi_.end()) { diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index 79b0987253..609549cd18 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -25,7 +25,7 @@ struct Vocab : torch::CustomClassHolder { void __setitem__(const std::string &token, const int64_t &index); void append_token(const std::string &token); void insert_token(const std::string &token, const int64_t &index); - void _delete_token(const std::string &token); + void __delitem__(const std::string &token); void set_default_index(const int64_t index); int64_t get_default_index() const; std::string lookup_token(const int64_t &index); diff --git a/torchtext/experimental/vocab.py b/torchtext/experimental/vocab.py index a54d760ca9..b65b8dd423 100644 --- a/torchtext/experimental/vocab.py +++ b/torchtext/experimental/vocab.py @@ -201,6 +201,18 @@ def __setitem__(self, token: str, index: int) -> None: """ self.vocab[token] = index + @torch.jit.export + def __delitem__(self, token: str) -> None: + r"""Delete token from vocab and shift all the following tokens to left by 1. + + Args: + token (str): the token to be deleted. + + Raises: + RuntimeError: if `token` is not in the vocab. + """ + del self.vocab[token] + @torch.jit.export def insert_token(self, token: str, index: int) -> None: r""" From 0c2dab7c511fb04299213a2a2541bf50e5271148 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Thu, 7 Jan 2021 20:17:20 -0500 Subject: [PATCH 44/45] checkpoint --- test/experimental/test_vocab.py | 2 ++ torchtext/csrc/vocab.cpp | 8 +++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index 1a47caf7cc..5a4d1bb77b 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -64,11 +64,13 @@ def test_vocab_set_item(self): self.assertEqual(v[''], 0) self.assertEqual(v['a'], 1) self.assertEqual(v['not_in_it'], 0) + self.assertEqual(v['b'], 0) v['b'] = 1 self.assertEqual(v[''], 0) self.assertEqual(v['b'], 1) self.assertEqual(v['not_in_it'], 0) + self.assertEqual(v['a'], 0) def test_vocab_insert_token(self): c = OrderedDict({'': 2, 'a': 2}) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 229d27eaa9..264ae6f139 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -59,8 +59,14 @@ void Vocab::__setitem__(const std::string &token, const int64_t &index) { "Token " + token + " has already been in the Vocab. Please delete it first by call del func."); } - item->first = token; + + if (index == static_cast(stoi_.size())) append_token(token); + else { + auto it = stoi_.find(itos_[index]); + stoi_.erase(it); + stoi_[token] = index; itos_[index] = token; + } } void Vocab::append_token(const std::string &token) { From 2f9ac0ea58cb53c35c7fe9e4e878d9f86d5572e6 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Fri, 8 Jan 2021 09:55:55 -0500 Subject: [PATCH 45/45] checkpoint --- torchtext/experimental/vocab.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchtext/experimental/vocab.py b/torchtext/experimental/vocab.py index b65b8dd423..2186aefc99 100644 --- a/torchtext/experimental/vocab.py +++ b/torchtext/experimental/vocab.py @@ -191,13 +191,15 @@ def __getitem__(self, token: str) -> int: @torch.jit.export def __setitem__(self, token: str, index: int) -> None: - r""" + r"""Set token to a specific index. The original token assigned to index is + replaced by the new token. + Args: token (str): the token used to lookup the corresponding index. index (int): the index corresponding to the associated token. Raises: - RuntimeError: if `index` not between [0, Vocab.size()] or if token already exists in the vocab. + RuntimeError: if `index` not between [0, Vocab.size()] """ self.vocab[token] = index