From 3340265471c1b0015fd061cc7b86d8e815b7f13e Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Fri, 7 May 2021 17:38:56 -0400 Subject: [PATCH 1/9] unit testing for experimental vocab --- test/experimental/test_vocab.py | 9 +++++++++ torchtext/csrc/vocab.cpp | 18 ++++++++++++++++-- torchtext/csrc/vocab.h | 2 +- torchtext/experimental/vocab.py | 3 +++ 4 files changed, 29 insertions(+), 3 deletions(-) diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index d091d7796b..83b19a8fa3 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -76,6 +76,10 @@ def test_vocab_insert_token(self): self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) + with self.assertRaises(RuntimeError) as context: + v.insert_token('b', 0) + + self.assertTrue("Token b already exists in the Vocab with index: 0" in str(context.exception)) def test_vocab_append_token(self): c = OrderedDict({'a': 2}) @@ -88,6 +92,11 @@ def test_vocab_append_token(self): self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) + with self.assertRaises(RuntimeError) as context: + v.append_token('b') + + self.assertTrue("Token b already exists in the Vocab with index: 2" in str(context.exception)) + def test_vocab_len(self): token_to_freq = {'a': 2, 'b': 2, 'c': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 1831d46f39..fec59b094d 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -38,7 +38,6 @@ bool Vocab::__contains__(const c10::string_view &token) const { return false; } - int64_t Vocab::__getitem__(const c10::string_view &token) const { int64_t id = _find(token); if (stoi_[id] != -1) { @@ -47,7 +46,22 @@ int64_t Vocab::__getitem__(const c10::string_view &token) const { return unk_index_; } -void Vocab::append_token(const std::string &token) { _add(token); } +void Vocab::append_token(const std::string &token) { + // if item already in stoi we throw an error + auto token_position = _find(c10::string_view{token.data(), token.size()}); + if (stoi_[token_position] != -1) { +#ifdef _MSC_VER + std::cerr << "[RuntimeError] Token " << token + << " already exists in the Vocab with index: " + << stoi_[token_position] << std::endl; +#endif + throw std::runtime_error("Token " + token + + " already exists in the Vocab with index: " + + std::to_string(stoi_[token_position]) + "."); + } + + _add(token); +} void Vocab::insert_token(const std::string &token, const int64_t &index) { if (index < 0 || index > itos_.size()) { diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index d915c7de27..40fc38e0e2 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -44,7 +44,7 @@ struct Vocab : torch::CustomClassHolder { uint32_t _find(const c10::string_view &w) const { uint32_t stoi_size = stoi_.size(); uint32_t id = _hash(w) % stoi_size; - while (stoi_[id] != -1 && itos_[stoi_[id]]!= w) { + while (stoi_[id] != -1 && itos_[stoi_[id]] != w) { id = (id + 1) % stoi_size; } return id; diff --git a/torchtext/experimental/vocab.py b/torchtext/experimental/vocab.py index 26f393ce36..acc6d96a4d 100644 --- a/torchtext/experimental/vocab.py +++ b/torchtext/experimental/vocab.py @@ -214,6 +214,9 @@ def append_token(self, token: str) -> None: r""" Args: token (str): the token used to lookup the corresponding index. + + Raises: + RuntimeError: if token already exists in the vocab """ self.vocab.append_token(token) From ea8bbbe1ed2fb44178dc84cd016a2107622d3ec0 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Sun, 9 May 2021 20:03:06 -0400 Subject: [PATCH 2/9] added APIs related to default index --- test/experimental/test_vocab.py | 53 +++++++++-- torchtext/csrc/register_bindings.cpp | 29 +++--- torchtext/csrc/vocab.cpp | 132 +++++++++++++++++---------- torchtext/csrc/vocab.h | 28 ++++-- torchtext/experimental/vocab.py | 38 +++++++- 5 files changed, 200 insertions(+), 80 deletions(-) diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index 83b19a8fa3..2dd87ee8e6 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -20,18 +20,12 @@ def tearDown(self): def test_has_unk(self): c = OrderedDict() v = vocab(c) - - # check if unk is mapped to the first index - self.assertEqual(v['not_in_it'], 0) self.assertEqual(v[''], 0) def test_new_unk(self): c = OrderedDict() v = vocab(c, unk_token="") - - # check if new_unk is mapped to the first index self.assertEqual(v[''], 0) - self.assertEqual(v['not_in_it'], 0) def test_vocab_membership(self): token_to_freq = {'': 2, 'a': 2, 'b': 2} @@ -54,6 +48,50 @@ def test_vocab_get_item(self): self.assertEqual(v['a'], 1) self.assertEqual(v['b'], 2) + def test_reassign_token(self): + token_to_freq = {'': 1, 'a': 2, 'b': 2} + sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) + c = OrderedDict(sorted_by_freq_tuples) + v = vocab(c, min_freq=1) + + self.assertEqual(v[''], 2) + self.assertEqual(v['a'], 0) + self.assertEqual(v['b'], 1) + v.reassign_token('', 0) + self.assertEqual(v[''], 0) + self.assertEqual(v['a'], 1) + self.assertEqual(v['b'], 2) + + self.assertEqual(v.get_itos(), ['', 'a', 'b']) + + with self.assertRaises(RuntimeError): + v.reassign_token('not in vocab', 0) + + with self.assertRaises(RuntimeError): + v.reassign_token('', 3) + + def test_default_index(self): + token_to_freq = {'': 2, 'a': 2, 'b': 2} + sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) + c = OrderedDict(sorted_by_freq_tuples) + v = vocab(c, min_freq=2) + + self.assertTrue(v.get_default_index() is None) + with self.assertRaises(RuntimeError): + v['not in vocab'] + + v.set_default_index(0) + self.assertEqual(v['not in vocab'], 0) + + def test_default_index_jit(self): + token_to_freq = {'': 2, 'a': 2, 'b': 2} + sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) + c = OrderedDict(sorted_by_freq_tuples) + v = vocab(c, min_freq=2) + v.set_default_index(0) + v_jit = torch.jit.script(v) + self.assertEqual(v_jit['not in vocab'], 0) + def test_vocab_insert_token(self): c = OrderedDict({'': 2, 'a': 2}) @@ -181,9 +219,6 @@ def test_vocab_lookup_indices(self): self.assertEqual(v.lookup_indices(tokens), expected_indices) - # we separate out these errors because Windows runs into seg faults when propagating - # exceptions from C++ using pybind11 - @unittest.skipIf(platform.system() == "Windows", "Test is known to fail on Windows.") def test_errors_vocab_cpp(self): token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index 0b325bcda6..83e3573f01 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -15,12 +15,10 @@ namespace py = pybind11; namespace { Vocab build_vocab_from_text_file(const std::string &file_path, - const std::string &unk_token, const int64_t min_freq, const int64_t num_cpus, py::object fn) { torch::jit::script::Module module(*torch::jit::as_module(fn)); - return _build_vocab_from_text_file(file_path, unk_token, min_freq, num_cpus, - module); + return _build_vocab_from_text_file(file_path, min_freq, num_cpus, module); } } // namespace @@ -104,21 +102,25 @@ PYBIND11_MODULE(_torchtext, m) { })); py::class_>(m, "Vocab") - .def(py::init, std::string>()) + .def(py::init>()) .def_readonly("itos_", &Vocab::itos_) - .def_readonly("unk_token_", &Vocab::unk_token_) - .def("__contains__", - [](c10::intrusive_ptr &self, const py::str &item) -> bool { - ssize_t length; - const char *buffer = PyUnicode_AsUTF8AndSize(item.ptr(), &length); - return self->__contains__(c10::string_view{buffer, (size_t)length}); - }) + .def_readonly("default_index_", &Vocab::default_index_) + .def( + "__contains__", + [](c10::intrusive_ptr &self, const py::str &item) -> bool { + ssize_t length; + const char *buffer = PyUnicode_AsUTF8AndSize(item.ptr(), &length); + return self->__contains__(c10::string_view{buffer, (size_t)length}); + }) .def("__getitem__", [](c10::intrusive_ptr &self, const py::str &item) -> int64_t { ssize_t length; const char *buffer = PyUnicode_AsUTF8AndSize(item.ptr(), &length); return self->__getitem__(c10::string_view{buffer, (size_t)length}); }) + .def("set_default_index", &Vocab::set_default_index) + .def("get_default_index", &Vocab::get_default_index) + .def("reassign_token", &Vocab::reassign_token) .def("__len__", &Vocab::__len__) .def("insert_token", &Vocab::insert_token) .def("append_token", &Vocab::append_token) @@ -234,7 +236,7 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) { }); m.class_("Vocab") - .def(torch::init()) + .def(torch::init>()) .def("__contains__", [](const c10::intrusive_ptr &self, const std::string &item) -> bool { return self->__contains__(c10::string_view{item}); }) @@ -242,6 +244,9 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) { [](const c10::intrusive_ptr &self, const std::string &item) -> int64_t { return self->__getitem__(c10::string_view{item}); }) .def("__len__", &Vocab::__len__) + .def("set_default_index", &Vocab::set_default_index) + .def("get_default_index", &Vocab::get_default_index) + .def("reassign_token", &Vocab::reassign_token) .def("insert_token", &Vocab::insert_token) .def("append_token", &Vocab::append_token) .def("lookup_token", &Vocab::lookup_token) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index fec59b094d..0f857ef907 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -7,8 +7,24 @@ #include // @manual namespace torchtext { -Vocab::Vocab(const StringList &tokens, const std::string &unk_token) - : stoi_(MAX_VOCAB_SIZE, -1), unk_token_(std::move(unk_token)) { +Vocab::Vocab(const StringList &tokens) : stoi_(MAX_VOCAB_SIZE, -1) { + for (std::size_t i = 0; i < tokens.size(); i++) { + // tokens should not have any duplicates + auto token_position = + _find(c10::string_view{tokens[i].data(), tokens[i].size()}); + if (stoi_[token_position] != -1) { +#ifdef _MSC_VER + std::cerr << "[RuntimeError] Duplicate token found in tokens list: " + << tokens[i] << std::endl; +#endif + throw std::runtime_error("Duplicate token found in tokens list: " + + tokens[i]); + } + _add(tokens[i]); + } +} +Vocab::Vocab(const StringList &tokens, c10::optional default_index) + : stoi_(MAX_VOCAB_SIZE, -1), default_index_{default_index} { for (std::size_t i = 0; i < tokens.size(); i++) { // tokens should not have any duplicates auto token_position = @@ -23,9 +39,6 @@ Vocab::Vocab(const StringList &tokens, const std::string &unk_token) } _add(tokens[i]); } - - unk_index_ = - stoi_[_find(c10::string_view{unk_token.data(), unk_token.size()})]; } int64_t Vocab::__len__() const { return itos_.size(); } @@ -43,7 +56,51 @@ int64_t Vocab::__getitem__(const c10::string_view &token) const { if (stoi_[id] != -1) { return stoi_[id]; } - return unk_index_; + + if (default_index_.has_value()) { + return default_index_.value(); + } +#ifdef _MSC_VER + std::cerr << "[RuntimeError] Token " << std::string(token) + << "not found and default index is not set." << std::endl; +#endif + throw std::runtime_error("[RuntimError] Token " + std::string(token) + + "not found and default index is not set."); +} + +void Vocab::set_default_index(int64_t index) { default_index_ = index; } + +c10::optional Vocab::get_default_index() const { + return default_index_; +} + +void Vocab::reassign_token(const std::string &token, const int64_t &index) { + // throw error if index is not valid + if (index < 0 || index >= itos_.size()) { +#ifdef _MSC_VER + std::cerr << "[RuntimeError] Specified index " << index + << " is out of bounds of the size of stoi dictionary: " + << stoi_.size() << std::endl; +#endif + throw std::runtime_error( + "[RuntimeError] Specified index " + std::to_string(index) + + " is out of bounds of the size of stoi dictionary: " + + std::to_string(stoi_.size()) + "."); + } + + + // throw error if token not found + auto id = _find(c10::string_view{token.data(), token.size()}); + if (stoi_[id] == -1) { +#ifdef _MSC_VER + std::cerr << "[RuntimeEror] Token " << token << " not found in Vocab.\n"; +#endif + throw std::runtime_error("[RuntimeError] Token " + token + + " not found in Vocab."); + } + + _remove(token); + insert_token(token, index); } void Vocab::append_token(const std::string &token) { @@ -71,7 +128,7 @@ void Vocab::insert_token(const std::string &token, const int64_t &index) { << stoi_.size() << std::endl; #endif throw std::runtime_error( - "Specified index " + std::to_string(index) + + "[RuntimeError] Specified index " + std::to_string(index) + " is out of bounds of the size of stoi dictionary: " + std::to_string(stoi_.size()) + "."); } @@ -84,7 +141,7 @@ void Vocab::insert_token(const std::string &token, const int64_t &index) { << " already exists in the Vocab with index: " << stoi_[token_position] << std::endl; #endif - throw std::runtime_error("Token " + token + + throw std::runtime_error("[RuntimeError] Token " + token + " already exists in the Vocab with index: " + std::to_string(stoi_[token_position]) + "."); } @@ -96,11 +153,6 @@ void Vocab::insert_token(const std::string &token, const int64_t &index) { itos_.insert(itos_.begin() + index, token); stoi_[_find(c10::string_view{token.data(), token.size()})] = index; - - // need to update unk_index in case token equals unk_token or token - // inserted before unk_token - unk_index_ = - stoi_[_find(c10::string_view{unk_token_.data(), unk_token_.size()})]; } std::string Vocab::lookup_token(const int64_t &index) { @@ -224,8 +276,8 @@ struct CompareTokens { StringList _concat_tokens(std::vector> chunk_counters, - const std::string &unk_token, const int64_t min_freq, - const int64_t num_lines, const bool sort_tokens) { + const int64_t min_freq, const int64_t num_lines, + const bool sort_tokens) { TORCH_CHECK(chunk_counters.size() > 0, "There must be at least 1 chunk to concatenate!"); @@ -271,24 +323,12 @@ _concat_tokens(std::vector> chunk_counters, unique_tokens.push_back(token_freq_pair.first); } - // insert unk_token if not present - if (tokens_freq.find(unk_token) == tokens_freq.end()) { - std::cerr << "The `unk_token` " << unk_token - << " wasn't found in the `ordered_dict`. Adding the `unk_token` " - "to the beginning of the Vocab." - << std::endl; - - unique_tokens.insert(unique_tokens.begin(), unk_token); - } - return unique_tokens; } constexpr int64_t GRAIN_SIZE = 13107; Vocab _load_vocab_from_file(const std::string &file_path, - const std::string &unk_token, const int64_t min_freq, const int64_t num_cpus) { - std::cerr << "[INFO] Reading file " << file_path << std::endl; int64_t num_lines = _infer_lines(file_path); int64_t chunk_size = impl::divup(num_lines, num_cpus); @@ -327,13 +367,12 @@ Vocab _load_vocab_from_file(const std::string &file_path, cv.wait(lock, [&thread_count] { return thread_count == 0; }); StringList tokens = - _concat_tokens(chunk_counters, unk_token, min_freq, num_lines, false); + _concat_tokens(chunk_counters, min_freq, num_lines, false); - return Vocab(std::move(tokens), unk_token); + return Vocab(std::move(tokens), {}); } Vocab _build_vocab_from_text_file(const std::string &file_path, - const std::string &unk_token, const int64_t min_freq, const int64_t num_cpus, torch::jit::script::Module tokenizer) { @@ -373,40 +412,40 @@ Vocab _build_vocab_from_text_file(const std::string &file_path, std::unique_lock lock(m); cv.wait(lock, [&thread_count] { return thread_count == 0; }); - StringList tokens = - _concat_tokens(chunk_counters, unk_token, min_freq, num_lines, true); + StringList tokens = _concat_tokens(chunk_counters, min_freq, num_lines, true); - return Vocab(std::move(tokens), unk_token); + return Vocab(std::move(tokens), {}); } VocabStates _serialize_vocab(const c10::intrusive_ptr &self) { std::vector integers; StringList strings = self->itos_; - strings.push_back(self->unk_token_); std::vector tensors; - VocabStates states = std::make_tuple(self->version_str_, std::move(integers), - std::move(strings), std::move(tensors)); + VocabStates states = std::make_tuple(self->version_str_, self->default_index_, + std::move(integers), std::move(strings), + std::move(tensors)); return states; } c10::intrusive_ptr _deserialize_vocab(VocabStates states) { auto state_size = std::tuple_size::value; - if (state_size != 4) { + if (state_size != 5) { #ifdef _MSC_VER - std::cerr << "[RuntimeError] Expected deserialized Vocab to have 4 states " + std::cerr << "[RuntimeError] Expected deserialized Vocab to have 5 states " "but found " << state_size << " states." << std::endl; #endif throw std::runtime_error( - "Expected deserialized Vocab to have 4 states but found " + + "Expected deserialized Vocab to have 5 states but found " + std::to_string(state_size) + " states."); } auto &version_str = std::get<0>(states); - auto &integers = std::get<1>(states); - auto &strings = std::get<2>(states); - auto &tensors = std::get<3>(states); + auto &default_index = std::get<1>(states); + auto &integers = std::get<2>(states); + auto &strings = std::get<3>(states); + auto &tensors = std::get<4>(states); // check integers and tensors are empty if (integers.size() != 0 || tensors.size() != 0) { @@ -419,11 +458,10 @@ c10::intrusive_ptr _deserialize_vocab(VocabStates states) { "Expected `integers` and `tensors` states to be empty."); } - if (version_str.compare("0.0.1") >= 0) { - std::string unk_token = strings.back(); - strings.pop_back(); // remove last element which is unk_token - - return c10::make_intrusive(std::move(strings), std::move(unk_token)); + if (version_str.compare("0.0.2") >= 0) { + auto deserialized_vocab = c10::make_intrusive(std::move(strings)); + deserialized_vocab->default_index_ = default_index; + return deserialized_vocab; } #ifdef _MSC_VER std::cerr << "[RuntimeError] Found unexpected version for serialized Vocab: " diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index 40fc38e0e2..b3f35aef46 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -1,3 +1,4 @@ +#include #include #include namespace torchtext { @@ -5,23 +6,28 @@ namespace torchtext { typedef std::vector StringList; typedef ska_ordered::order_preserving_flat_hash_map IndexDict; -typedef std::tuple, std::vector, - std::vector> +typedef std::tuple, std::vector, + std::vector, std::vector> VocabStates; struct Vocab : torch::CustomClassHolder { static const int32_t MAX_VOCAB_SIZE = 30000000; int64_t unk_index_; std::vector stoi_; - const std::string version_str_ = "0.0.1"; + const std::string version_str_ = "0.0.2"; StringList itos_; - std::string unk_token_; + c10::optional default_index_ = {}; - explicit Vocab(const std::vector &tokens, - const std::string &unk_token); + // TODO: [can we remove this?] we need to keep this constructor, otherwise torch binding gets + // compilation error: no matching constructor for initialization of 'torchtext::Vocab' + explicit Vocab(const StringList &tokens); + explicit Vocab(const StringList &tokens, c10::optional default_index); int64_t __len__() const; int64_t __getitem__(const c10::string_view &token) const; bool __contains__(const c10::string_view &token) const; + void set_default_index(int64_t index); + c10::optional get_default_index() const; + void reassign_token(const std::string &token,const int64_t &index); void append_token(const std::string &token); void insert_token(const std::string &token, const int64_t &index); std::string lookup_token(const int64_t &index); @@ -57,16 +63,22 @@ struct Vocab : torch::CustomClassHolder { stoi_[h] = itos_.size() - 1; } } + + void _remove(const std::string &w) { + uint32_t h = _find(c10::string_view{w.data(), w.size()}); + if (stoi_[h] != -1) { + stoi_[h] = -1; + itos_.erase(std::find(itos_.begin(), itos_.end(), w)); + } + } }; VocabStates _serialize_vocab(const c10::intrusive_ptr &self); c10::intrusive_ptr _deserialize_vocab(VocabStates states); Vocab _load_vocab_from_file(const std::string &file_path, - const std::string &unk_token, const int64_t min_freq, const int64_t num_cpus); Vocab _build_vocab_from_text_file(const std::string &file_path, - const std::string &unk_token, const int64_t min_freq, const int64_t num_cpus, torch::jit::script::Module tokenizer); diff --git a/torchtext/experimental/vocab.py b/torchtext/experimental/vocab.py index acc6d96a4d..3d0ebf3ec8 100644 --- a/torchtext/experimental/vocab.py +++ b/torchtext/experimental/vocab.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, List +from typing import Dict, List, Optional import warnings from collections import Counter, OrderedDict import torch @@ -136,7 +136,7 @@ def vocab(ordered_dict, min_freq=1, unk_token=''): tokens.insert(0, unk_token) warnings.warn("The `unk_token` '{}' wasn't found in the `ordered_dict`. Adding the `unk_token` " "to the beginning of the Vocab.".format(unk_token), RuntimeWarning) - return Vocab(VocabPybind(tokens, unk_token)) + return Vocab(VocabPybind(tokens, None)) class Vocab(nn.Module): @@ -197,6 +197,34 @@ def __getitem__(self, token: str) -> int: """ return self.vocab[token] + @torch.jit.export + def set_default_index(self, index: int) -> None: + r""" + Args: + index: Value of default index. This index will be returned when OOV token is queried + """ + self.vocab.set_default_index(index) + + @torch.jit.export + def get_default_index(self) -> Optional[int]: + r""" + Returns: + index (optional[int]): Value of default index if it is set. + """ + return self.vocab.get_default_index() + + @torch.jit.export + def reassign_token(self, token: str, index: int) -> None: + r""" + Args: + token (str): the token used to lookup the corresponding index. + index (int): the index corresponding to the associated token. + + Raises: + RuntimeError: If token is not present in Vocab + """ + self.vocab.reassign_token(token, index) + @torch.jit.export def insert_token(self, token: str, index: int) -> None: r""" @@ -278,5 +306,7 @@ def get_itos(self) -> List[str]: def __prepare_scriptable__(self): r"""Return a JITable Vocab. """ - cpp_vocab = torch.classes.torchtext.Vocab(self.vocab.itos_, self.vocab.unk_token_) - return Vocab(cpp_vocab) + if not self.is_jitable: + cpp_vocab = torch.classes.torchtext.Vocab(self.vocab.itos_, self.vocab.default_index_) + return Vocab(cpp_vocab) + return self From 41ecdf326cad0d40da683ce747d675056d111d5a Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Sun, 9 May 2021 23:14:58 -0400 Subject: [PATCH 3/9] fixing tests and stylecheck --- test/experimental/test_vocab.py | 2 -- test/experimental/test_with_asset.py | 8 ++++---- torchtext/csrc/vocab.cpp | 2 +- torchtext/csrc/vocab.h | 2 +- torchtext/experimental/vocab.py | 8 ++++---- 5 files changed, 10 insertions(+), 12 deletions(-) diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index 2dd87ee8e6..2701ddebaa 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -1,9 +1,7 @@ # -*- coding: utf-8 -*- from collections import OrderedDict import os -import platform import torch -import unittest from test.common.torchtext_test_case import TorchtextTestCase from torchtext.experimental.vocab import ( vocab, diff --git a/test/experimental/test_with_asset.py b/test/experimental/test_with_asset.py index 2055e0ab57..7bd39e773e 100644 --- a/test/experimental/test_with_asset.py +++ b/test/experimental/test_with_asset.py @@ -178,8 +178,8 @@ def test_glove_different_dims(self): def test_vocab_from_file(self): asset_name = 'vocab_test.txt' asset_path = get_asset_path(asset_name) - v = load_vocab_from_file(asset_path, unk_token='') - expected_itos = ['', 'b', 'a', 'c'] + v = load_vocab_from_file(asset_path) + expected_itos = ['b', 'a', 'c'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) @@ -189,8 +189,8 @@ def test_vocab_from_raw_text_file(self): asset_path = get_asset_path(asset_name) tokenizer = basic_english_normalize() jit_tokenizer = torch.jit.script(tokenizer) - v = build_vocab_from_text_file(asset_path, jit_tokenizer, unk_token='') - expected_itos = ['', "'", 'after', 'talks', '.', 'are', 'at', 'disappointed', + v = build_vocab_from_text_file(asset_path, jit_tokenizer) + expected_itos = ["'", 'after', 'talks', '.', 'are', 'at', 'disappointed', 'fears', 'federal', 'firm', 'for', 'mogul', 'n', 'newall', 'parent', 'pension', 'representing', 'say', 'stricken', 't', 'they', 'turner', 'unions', 'with', 'workers'] diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 0f857ef907..ef13f77b7f 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -23,7 +23,7 @@ Vocab::Vocab(const StringList &tokens) : stoi_(MAX_VOCAB_SIZE, -1) { _add(tokens[i]); } } -Vocab::Vocab(const StringList &tokens, c10::optional default_index) +Vocab::Vocab(const StringList &tokens,const c10::optional &default_index) : stoi_(MAX_VOCAB_SIZE, -1), default_index_{default_index} { for (std::size_t i = 0; i < tokens.size(); i++) { // tokens should not have any duplicates diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index b3f35aef46..33e8d3bb85 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -21,7 +21,7 @@ struct Vocab : torch::CustomClassHolder { // TODO: [can we remove this?] we need to keep this constructor, otherwise torch binding gets // compilation error: no matching constructor for initialization of 'torchtext::Vocab' explicit Vocab(const StringList &tokens); - explicit Vocab(const StringList &tokens, c10::optional default_index); + explicit Vocab(const StringList &tokens,const c10::optional &default_index); int64_t __len__() const; int64_t __getitem__(const c10::string_view &token) const; bool __contains__(const c10::string_view &token) const; diff --git a/torchtext/experimental/vocab.py b/torchtext/experimental/vocab.py index 3d0ebf3ec8..fdc569e9f2 100644 --- a/torchtext/experimental/vocab.py +++ b/torchtext/experimental/vocab.py @@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) -def build_vocab_from_text_file(file_path, jited_tokenizer, min_freq=1, unk_token='', num_cpus=4): +def build_vocab_from_text_file(file_path, jited_tokenizer, min_freq=1, num_cpus=4): r"""Create a `Vocab` object from a raw text file. The `file_path` can contain any raw text. This function applies a generic JITed tokenizer in @@ -44,11 +44,11 @@ def build_vocab_from_text_file(file_path, jited_tokenizer, min_freq=1, unk_token >>> jit_tokenizer = torch.jit.script(tokenizer) >>> v = build_vocab_from_text_file('vocab.txt', jit_tokenizer) """ - vocab_obj = _build_vocab_from_text_file(file_path, unk_token, min_freq, num_cpus, jited_tokenizer) + vocab_obj = _build_vocab_from_text_file(file_path, min_freq, num_cpus, jited_tokenizer) return Vocab(vocab_obj) -def load_vocab_from_file(file_path, min_freq=1, unk_token='', num_cpus=4): +def load_vocab_from_file(file_path, min_freq=1, num_cpus=4): r"""Create a `Vocab` object from a text file. The `file_path` should contain tokens separated by new lines. Format for txt file: @@ -73,7 +73,7 @@ def load_vocab_from_file(file_path, min_freq=1, unk_token='', num_cpus=4): >>> v = load_vocab_from_file('vocab.txt') """ - vocab_obj = _load_vocab_from_file(file_path, unk_token, min_freq, num_cpus) + vocab_obj = _load_vocab_from_file(file_path, min_freq, num_cpus) return Vocab(vocab_obj) From 29fd747f833fc620dc0caf3013775fcdf0c76a66 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Mon, 10 May 2021 11:33:25 -0400 Subject: [PATCH 4/9] removed std::cerr for windows --- torchtext/csrc/vocab.cpp | 66 +++++----------------------------------- 1 file changed, 7 insertions(+), 59 deletions(-) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index ef13f77b7f..4a4e6cf501 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -13,27 +13,20 @@ Vocab::Vocab(const StringList &tokens) : stoi_(MAX_VOCAB_SIZE, -1) { auto token_position = _find(c10::string_view{tokens[i].data(), tokens[i].size()}); if (stoi_[token_position] != -1) { -#ifdef _MSC_VER - std::cerr << "[RuntimeError] Duplicate token found in tokens list: " - << tokens[i] << std::endl; -#endif throw std::runtime_error("Duplicate token found in tokens list: " + tokens[i]); } _add(tokens[i]); } } -Vocab::Vocab(const StringList &tokens,const c10::optional &default_index) +Vocab::Vocab(const StringList &tokens, + const c10::optional &default_index) : stoi_(MAX_VOCAB_SIZE, -1), default_index_{default_index} { for (std::size_t i = 0; i < tokens.size(); i++) { // tokens should not have any duplicates auto token_position = _find(c10::string_view{tokens[i].data(), tokens[i].size()}); if (stoi_[token_position] != -1) { -#ifdef _MSC_VER - std::cerr << "[RuntimeError] Duplicate token found in tokens list: " - << tokens[i] << std::endl; -#endif throw std::runtime_error("Duplicate token found in tokens list: " + tokens[i]); } @@ -60,10 +53,6 @@ int64_t Vocab::__getitem__(const c10::string_view &token) const { if (default_index_.has_value()) { return default_index_.value(); } -#ifdef _MSC_VER - std::cerr << "[RuntimeError] Token " << std::string(token) - << "not found and default index is not set." << std::endl; -#endif throw std::runtime_error("[RuntimError] Token " + std::string(token) + "not found and default index is not set."); } @@ -77,24 +66,15 @@ c10::optional Vocab::get_default_index() const { void Vocab::reassign_token(const std::string &token, const int64_t &index) { // throw error if index is not valid if (index < 0 || index >= itos_.size()) { -#ifdef _MSC_VER - std::cerr << "[RuntimeError] Specified index " << index - << " is out of bounds of the size of stoi dictionary: " - << stoi_.size() << std::endl; -#endif throw std::runtime_error( "[RuntimeError] Specified index " + std::to_string(index) + " is out of bounds of the size of stoi dictionary: " + std::to_string(stoi_.size()) + "."); } - // throw error if token not found auto id = _find(c10::string_view{token.data(), token.size()}); if (stoi_[id] == -1) { -#ifdef _MSC_VER - std::cerr << "[RuntimeEror] Token " << token << " not found in Vocab.\n"; -#endif throw std::runtime_error("[RuntimeError] Token " + token + " not found in Vocab."); } @@ -107,11 +87,6 @@ void Vocab::append_token(const std::string &token) { // if item already in stoi we throw an error auto token_position = _find(c10::string_view{token.data(), token.size()}); if (stoi_[token_position] != -1) { -#ifdef _MSC_VER - std::cerr << "[RuntimeError] Token " << token - << " already exists in the Vocab with index: " - << stoi_[token_position] << std::endl; -#endif throw std::runtime_error("Token " + token + " already exists in the Vocab with index: " + std::to_string(stoi_[token_position]) + "."); @@ -122,11 +97,6 @@ void Vocab::append_token(const std::string &token) { void Vocab::insert_token(const std::string &token, const int64_t &index) { if (index < 0 || index > itos_.size()) { -#ifdef _MSC_VER - std::cerr << "[RuntimeError] Specified index " << index - << " is out of bounds of the size of stoi dictionary: " - << stoi_.size() << std::endl; -#endif throw std::runtime_error( "[RuntimeError] Specified index " + std::to_string(index) + " is out of bounds of the size of stoi dictionary: " + @@ -136,11 +106,6 @@ void Vocab::insert_token(const std::string &token, const int64_t &index) { // if item already in stoi we throw an error auto token_position = _find(c10::string_view{token.data(), token.size()}); if (stoi_[token_position] != -1) { -#ifdef _MSC_VER - std::cerr << "[RuntimeError] Token " << token - << " already exists in the Vocab with index: " - << stoi_[token_position] << std::endl; -#endif throw std::runtime_error("[RuntimeError] Token " + token + " already exists in the Vocab with index: " + std::to_string(stoi_[token_position]) + "."); @@ -157,11 +122,6 @@ void Vocab::insert_token(const std::string &token, const int64_t &index) { std::string Vocab::lookup_token(const int64_t &index) { if (index < 0 || index > static_cast(itos_.size())) { -#ifdef _MSC_VER - std::cerr << "[RuntimeError] Specified index " << index - << " is out of bounds of the size of itos dictionary: " - << itos_.size() << std::endl; -#endif throw std::runtime_error( "Specified index " + std::to_string(index) + " is out of bounds of the size of itos dictionary: " + @@ -278,8 +238,11 @@ StringList _concat_tokens(std::vector> chunk_counters, const int64_t min_freq, const int64_t num_lines, const bool sort_tokens) { - TORCH_CHECK(chunk_counters.size() > 0, - "There must be at least 1 chunk to concatenate!"); + + if (chunk_counters.size() < 1) { + throw ::std::runtime_error( + "There must be at least 1 chunk to concatenate!"); + } IndexDict tokens_freq; StringList unique_tokens; @@ -376,7 +339,6 @@ Vocab _build_vocab_from_text_file(const std::string &file_path, const int64_t min_freq, const int64_t num_cpus, torch::jit::script::Module tokenizer) { - std::cerr << "[INFO] Reading file " << file_path << std::endl; int64_t num_lines = _infer_lines(file_path); int64_t chunk_size = impl::divup(num_lines, num_cpus); // Launching a thread on less lines than this likely has too much overhead. @@ -431,11 +393,6 @@ VocabStates _serialize_vocab(const c10::intrusive_ptr &self) { c10::intrusive_ptr _deserialize_vocab(VocabStates states) { auto state_size = std::tuple_size::value; if (state_size != 5) { -#ifdef _MSC_VER - std::cerr << "[RuntimeError] Expected deserialized Vocab to have 5 states " - "but found " - << state_size << " states." << std::endl; -#endif throw std::runtime_error( "Expected deserialized Vocab to have 5 states but found " + std::to_string(state_size) + " states."); @@ -449,11 +406,6 @@ c10::intrusive_ptr _deserialize_vocab(VocabStates states) { // check integers and tensors are empty if (integers.size() != 0 || tensors.size() != 0) { -#ifdef _MSC_VER - std::cerr << "[RuntimeError] Expected `integers` and `tensors` states to " - "be empty." - << std::endl; -#endif throw std::runtime_error( "Expected `integers` and `tensors` states to be empty."); } @@ -463,10 +415,6 @@ c10::intrusive_ptr _deserialize_vocab(VocabStates states) { deserialized_vocab->default_index_ = default_index; return deserialized_vocab; } -#ifdef _MSC_VER - std::cerr << "[RuntimeError] Found unexpected version for serialized Vocab: " - << version_str << std::endl; -#endif throw std::runtime_error( "Found unexpected version for serialized Vocab: " + version_str + "."); } From 54d60fcd3d27cda142b136f410510c02c2eb0187 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Mon, 10 May 2021 14:34:16 -0400 Subject: [PATCH 5/9] using TORCH_CHECK macro for error handling --- test/experimental/test_vocab.py | 12 +-- torchtext/csrc/vocab.cpp | 162 ++++++++++++++------------------ 2 files changed, 79 insertions(+), 95 deletions(-) diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index 2701ddebaa..27ff389cb9 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -62,9 +62,11 @@ def test_reassign_token(self): self.assertEqual(v.get_itos(), ['', 'a', 'b']) + # token must exist for rassignment with self.assertRaises(RuntimeError): v.reassign_token('not in vocab', 0) + # index should be valid for reassignment with self.assertRaises(RuntimeError): v.reassign_token('', 3) @@ -112,11 +114,10 @@ def test_vocab_insert_token(self): self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) - with self.assertRaises(RuntimeError) as context: + # token must not exist to be inserted + with self.assertRaises(RuntimeError): v.insert_token('b', 0) - self.assertTrue("Token b already exists in the Vocab with index: 0" in str(context.exception)) - def test_vocab_append_token(self): c = OrderedDict({'a': 2}) v = vocab(c) @@ -128,11 +129,10 @@ def test_vocab_append_token(self): self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) - with self.assertRaises(RuntimeError) as context: + # token must not exist to be appended + with self.assertRaises(RuntimeError): v.append_token('b') - self.assertTrue("Token b already exists in the Vocab with index: 2" in str(context.exception)) - def test_vocab_len(self): token_to_freq = {'a': 2, 'b': 2, 'c': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 4a4e6cf501..a6ae6cafd8 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -8,28 +8,25 @@ namespace torchtext { Vocab::Vocab(const StringList &tokens) : stoi_(MAX_VOCAB_SIZE, -1) { - for (std::size_t i = 0; i < tokens.size(); i++) { - // tokens should not have any duplicates - auto token_position = - _find(c10::string_view{tokens[i].data(), tokens[i].size()}); - if (stoi_[token_position] != -1) { - throw std::runtime_error("Duplicate token found in tokens list: " + - tokens[i]); - } + for (size_t i = 0; i < tokens.size(); i++) { + // throw error if duplicate token is found + auto id = _find(c10::string_view{tokens[i].data(), tokens[i].size()}); + TORCH_CHECK(stoi_[id] == -1, + "Duplicate token found in tokens list: " + tokens[i]); + _add(tokens[i]); } } + Vocab::Vocab(const StringList &tokens, const c10::optional &default_index) : stoi_(MAX_VOCAB_SIZE, -1), default_index_{default_index} { - for (std::size_t i = 0; i < tokens.size(); i++) { - // tokens should not have any duplicates - auto token_position = - _find(c10::string_view{tokens[i].data(), tokens[i].size()}); - if (stoi_[token_position] != -1) { - throw std::runtime_error("Duplicate token found in tokens list: " + - tokens[i]); - } + for (size_t i = 0; i < tokens.size(); i++) { + // throw error if duplicate token is found + auto id = _find(c10::string_view{tokens[i].data(), tokens[i].size()}); + TORCH_CHECK(stoi_[id] == -1, + "Duplicate token found in tokens list: " + tokens[i]); + _add(tokens[i]); } } @@ -46,15 +43,16 @@ bool Vocab::__contains__(const c10::string_view &token) const { int64_t Vocab::__getitem__(const c10::string_view &token) const { int64_t id = _find(token); - if (stoi_[id] != -1) { + if (stoi_[id] != -1) return stoi_[id]; - } - if (default_index_.has_value()) { - return default_index_.value(); - } - throw std::runtime_error("[RuntimError] Token " + std::string(token) + - "not found and default index is not set."); + // throw error if default_index_ is not set + TORCH_CHECK(default_index_.has_value(), + "Token " + std::string(token) + + " not found and default index is not set"); + + // return default index if token is OOV + return default_index_.value(); } void Vocab::set_default_index(int64_t index) { default_index_ = index; } @@ -65,54 +63,42 @@ c10::optional Vocab::get_default_index() const { void Vocab::reassign_token(const std::string &token, const int64_t &index) { // throw error if index is not valid - if (index < 0 || index >= itos_.size()) { - throw std::runtime_error( - "[RuntimeError] Specified index " + std::to_string(index) + - " is out of bounds of the size of stoi dictionary: " + - std::to_string(stoi_.size()) + "."); - } + TORCH_CHECK(index >= 0 && index < __len__(), + "Specified index " + std::to_string(index) + + " is out of bounds for vocab of size " + + std::to_string(__len__())); // throw error if token not found auto id = _find(c10::string_view{token.data(), token.size()}); - if (stoi_[id] == -1) { - throw std::runtime_error("[RuntimeError] Token " + token + - " not found in Vocab."); - } + TORCH_CHECK(stoi_[id] != -1, "Token " + token + " not found in Vocab"); _remove(token); insert_token(token, index); } void Vocab::append_token(const std::string &token) { - // if item already in stoi we throw an error - auto token_position = _find(c10::string_view{token.data(), token.size()}); - if (stoi_[token_position] != -1) { - throw std::runtime_error("Token " + token + - " already exists in the Vocab with index: " + - std::to_string(stoi_[token_position]) + "."); - } + // throw error if token already exist in vocab + auto id = _find(c10::string_view{token.data(), token.size()}); + TORCH_CHECK(stoi_[id] == -1, "Token " + token + + " already exists in the Vocab with index: " + + std::to_string(stoi_[id])); _add(token); } void Vocab::insert_token(const std::string &token, const int64_t &index) { - if (index < 0 || index > itos_.size()) { - throw std::runtime_error( - "[RuntimeError] Specified index " + std::to_string(index) + - " is out of bounds of the size of stoi dictionary: " + - std::to_string(stoi_.size()) + "."); - } + // throw error if index is not valid + TORCH_CHECK(index >= 0 && index <= __len__(), + "Specified index " + std::to_string(index) + + " is out of bounds for vocab of size " + + std::to_string(__len__())); - // if item already in stoi we throw an error - auto token_position = _find(c10::string_view{token.data(), token.size()}); - if (stoi_[token_position] != -1) { - throw std::runtime_error("[RuntimeError] Token " + token + - " already exists in the Vocab with index: " + - std::to_string(stoi_[token_position]) + "."); - } + // throw error if token not found + auto id = _find(c10::string_view{token.data(), token.size()}); + TORCH_CHECK(stoi_[id] == -1, "Token " + token + " not found in Vocab"); // need to offset all tokens greater than or equal index by 1 - for (size_t i = index; i < itos_.size(); i++) { + for (size_t i = index; i < __len__(); i++) { stoi_[_find(c10::string_view{itos_[i].data(), itos_[i].size()})] = i + 1; } @@ -121,20 +107,28 @@ void Vocab::insert_token(const std::string &token, const int64_t &index) { } std::string Vocab::lookup_token(const int64_t &index) { - if (index < 0 || index > static_cast(itos_.size())) { - throw std::runtime_error( - "Specified index " + std::to_string(index) + - " is out of bounds of the size of itos dictionary: " + - std::to_string(itos_.size()) + "."); - } + // throw error if index is not valid + TORCH_CHECK(index >= 0 && index < __len__(), + "Specified index " + std::to_string(index) + + " is out of bounds for vocab of size " + + std::to_string(__len__())); return itos_[index]; } StringList Vocab::lookup_tokens(const std::vector &indices) { + // throw error if indices are not valid + for (size_t i = 0; i < indices.size(); i++) { + TORCH_CHECK(indices[i] >= 0 && indices[i] < __len__(), + "Specified index " + std::to_string(indices[i]) + + " at position " + std::to_string(i) + + " is out of bounds for vocab of size " + + std::to_string(__len__())); + } + std::vector tokens(indices.size()); - for (int64_t i = 0; i < static_cast(indices.size()); i++) { - tokens[i] = lookup_token(indices[i]); + for (size_t i = 0; i < indices.size(); i++) { + tokens[i] = itos_[indices[i]]; } return tokens; } @@ -142,7 +136,7 @@ StringList Vocab::lookup_tokens(const std::vector &indices) { std::vector Vocab::lookup_indices(const std::vector &tokens) { std::vector indices(tokens.size()); - for (int64_t i = 0; i < static_cast(tokens.size()); i++) { + for (size_t i = 0; i < tokens.size(); i++) { indices[i] = __getitem__(tokens[i]); } return indices; @@ -174,9 +168,7 @@ void parse_vocab_file_chunk(const std::string &file_path, size_t offset, const int64_t start_line, const int64_t end_line, std::shared_ptr counter) { std::ifstream fin(file_path, std::ios::in); - if (!fin.is_open()) { - throw std::runtime_error("Cannot open input file " + file_path + "\n"); - } + TORCH_CHECK(fin.is_open(), "Cannot open input file " + file_path); fin.seekg(offset); @@ -198,9 +190,7 @@ void parse_raw_text_file_chunk(const std::string &file_path, size_t offset, std::shared_ptr counter, torch::jit::script::Module &module) { std::ifstream fin(file_path, std::ios::in); - if (!fin.is_open()) { - throw std::runtime_error("Cannot open input file " + file_path + "\n"); - } + TORCH_CHECK(fin.is_open(), "Cannot open input file " + file_path); fin.seekg(offset); @@ -239,10 +229,8 @@ _concat_tokens(std::vector> chunk_counters, const int64_t min_freq, const int64_t num_lines, const bool sort_tokens) { - if (chunk_counters.size() < 1) { - throw ::std::runtime_error( - "There must be at least 1 chunk to concatenate!"); - } + TORCH_CHECK(chunk_counters.size() > 0, + "There must be at least 1 chunk to concatenate!"); IndexDict tokens_freq; StringList unique_tokens; @@ -392,11 +380,9 @@ VocabStates _serialize_vocab(const c10::intrusive_ptr &self) { c10::intrusive_ptr _deserialize_vocab(VocabStates states) { auto state_size = std::tuple_size::value; - if (state_size != 5) { - throw std::runtime_error( - "Expected deserialized Vocab to have 5 states but found " + - std::to_string(state_size) + " states."); - } + TORCH_CHECK(state_size == 5, + "Expected deserialized Vocab to have 5 states but found " + + std::to_string(state_size) + " states"); auto &version_str = std::get<0>(states); auto &default_index = std::get<1>(states); @@ -405,18 +391,16 @@ c10::intrusive_ptr _deserialize_vocab(VocabStates states) { auto &tensors = std::get<4>(states); // check integers and tensors are empty - if (integers.size() != 0 || tensors.size() != 0) { - throw std::runtime_error( - "Expected `integers` and `tensors` states to be empty."); - } + TORCH_CHECK(integers.size() == 0 || tensors.size() == 0, + "Expected `integers` and `tensors` states to be empty"); - if (version_str.compare("0.0.2") >= 0) { - auto deserialized_vocab = c10::make_intrusive(std::move(strings)); - deserialized_vocab->default_index_ = default_index; - return deserialized_vocab; - } - throw std::runtime_error( - "Found unexpected version for serialized Vocab: " + version_str + "."); + // throw error if version is not compatible + TORCH_CHECK(version_str.compare("0.0.2") >= 0, + "Found unexpected version for serialized Vocab: " + version_str); + + auto deserialized_vocab = c10::make_intrusive(std::move(strings)); + deserialized_vocab->default_index_ = default_index; + return deserialized_vocab; } } // namespace torchtext From 95136feb676118831f632c0e94fab274815b41bd Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Mon, 10 May 2021 23:30:56 -0400 Subject: [PATCH 6/9] adding back support for unk token in factory functions --- test/experimental/test_with_asset.py | 8 ++++---- torchtext/experimental/vocab.py | 16 +++++++++------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/test/experimental/test_with_asset.py b/test/experimental/test_with_asset.py index 7bd39e773e..2055e0ab57 100644 --- a/test/experimental/test_with_asset.py +++ b/test/experimental/test_with_asset.py @@ -178,8 +178,8 @@ def test_glove_different_dims(self): def test_vocab_from_file(self): asset_name = 'vocab_test.txt' asset_path = get_asset_path(asset_name) - v = load_vocab_from_file(asset_path) - expected_itos = ['b', 'a', 'c'] + v = load_vocab_from_file(asset_path, unk_token='') + expected_itos = ['', 'b', 'a', 'c'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) @@ -189,8 +189,8 @@ def test_vocab_from_raw_text_file(self): asset_path = get_asset_path(asset_name) tokenizer = basic_english_normalize() jit_tokenizer = torch.jit.script(tokenizer) - v = build_vocab_from_text_file(asset_path, jit_tokenizer) - expected_itos = ["'", 'after', 'talks', '.', 'are', 'at', 'disappointed', + v = build_vocab_from_text_file(asset_path, jit_tokenizer, unk_token='') + expected_itos = ['', "'", 'after', 'talks', '.', 'are', 'at', 'disappointed', 'fears', 'federal', 'firm', 'for', 'mogul', 'n', 'newall', 'parent', 'pension', 'representing', 'say', 'stricken', 't', 'they', 'turner', 'unions', 'with', 'workers'] diff --git a/torchtext/experimental/vocab.py b/torchtext/experimental/vocab.py index fdc569e9f2..544a7a47b8 100644 --- a/torchtext/experimental/vocab.py +++ b/torchtext/experimental/vocab.py @@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) -def build_vocab_from_text_file(file_path, jited_tokenizer, min_freq=1, num_cpus=4): +def build_vocab_from_text_file(file_path, jited_tokenizer, min_freq=1, unk_token='', num_cpus=4): r"""Create a `Vocab` object from a raw text file. The `file_path` can contain any raw text. This function applies a generic JITed tokenizer in @@ -30,7 +30,7 @@ def build_vocab_from_text_file(file_path, jited_tokenizer, min_freq=1, num_cpus= jited_tokenizer (ScriptModule): a tokenizer that has been JITed using `torch.jit.script` min_freq: The minimum frequency needed to include a token in the vocabulary. Values less than 1 will be set to 1. Default: 1. - unk_token: The default unknown token to use. Default: ''. + unk_token: The default unknown token to use. Default: ''. If not found in text file, it will be inserted to index 0. num_cpus (int): the number of cpus to use when loading the vectors from file. Default: 4. Returns: @@ -45,10 +45,12 @@ def build_vocab_from_text_file(file_path, jited_tokenizer, min_freq=1, num_cpus= >>> v = build_vocab_from_text_file('vocab.txt', jit_tokenizer) """ vocab_obj = _build_vocab_from_text_file(file_path, min_freq, num_cpus, jited_tokenizer) + if unk_token not in vocab_obj: + vocab_obj.insert_token(unk_token, 0) return Vocab(vocab_obj) -def load_vocab_from_file(file_path, min_freq=1, num_cpus=4): +def load_vocab_from_file(file_path, min_freq=1, unk_token='', num_cpus=4): r"""Create a `Vocab` object from a text file. The `file_path` should contain tokens separated by new lines. Format for txt file: @@ -62,7 +64,7 @@ def load_vocab_from_file(file_path, min_freq=1, num_cpus=4): file_object (FileObject): a file like object to read data from. min_freq: The minimum frequency needed to include a token in the vocabulary. Values less than 1 will be set to 1. Default: 1. - unk_token: The default unknown token to use. Default: ''. + unk_token: The default unknown token to use. Default: ''. If not found in vocab file, it will be inserted to index 0. num_cpus (int): the number of cpus to use when loading the vectors from file. Default: 4. Returns: @@ -74,6 +76,8 @@ def load_vocab_from_file(file_path, min_freq=1, num_cpus=4): """ vocab_obj = _load_vocab_from_file(file_path, min_freq, num_cpus) + if unk_token not in vocab_obj: + vocab_obj.insert_token(unk_token, 0) return Vocab(vocab_obj) @@ -108,7 +112,7 @@ def vocab(ordered_dict, min_freq=1, unk_token=''): ordered_dict (collections.OrderedDict): object holding the frequencies of each token found in the data. min_freq: The minimum frequency needed to include a token in the vocabulary. Values less than 1 will be set to 1. Default: 1. - unk_token: The default unknown token to use. Default: ''. + unk_token: The default unknown token to use. Default: ''. If not found in ordered_dict, it will be inserted at index 0. Raises: ValueError: if a default `unk_token` isn't provided. @@ -134,8 +138,6 @@ def vocab(ordered_dict, min_freq=1, unk_token=''): if unk_token not in tokens: tokens.insert(0, unk_token) - warnings.warn("The `unk_token` '{}' wasn't found in the `ordered_dict`. Adding the `unk_token` " - "to the beginning of the Vocab.".format(unk_token), RuntimeWarning) return Vocab(VocabPybind(tokens, None)) From 763bb02dfa7179e355ba80f260fa1bf32f42d3a3 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Mon, 10 May 2021 23:41:01 -0400 Subject: [PATCH 7/9] fixing style issues --- torchtext/experimental/vocab.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchtext/experimental/vocab.py b/torchtext/experimental/vocab.py index 544a7a47b8..5226e49705 100644 --- a/torchtext/experimental/vocab.py +++ b/torchtext/experimental/vocab.py @@ -1,6 +1,5 @@ import logging from typing import Dict, List, Optional -import warnings from collections import Counter, OrderedDict import torch import torch.nn as nn From 9ffddbc65c4ae97bfb521cd54a94451a7f7ca59f Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Tue, 11 May 2021 16:47:15 -0400 Subject: [PATCH 8/9] resolving comments --- test/experimental/test_vocab.py | 20 +++--- torchtext/csrc/register_bindings.cpp | 6 +- torchtext/csrc/vocab.cpp | 95 ++++++++++++---------------- torchtext/csrc/vocab.h | 15 +++-- torchtext/experimental/vocab.py | 24 ++----- 5 files changed, 64 insertions(+), 96 deletions(-) diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index 27ff389cb9..9cebce9668 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -55,20 +55,16 @@ def test_reassign_token(self): self.assertEqual(v[''], 2) self.assertEqual(v['a'], 0) self.assertEqual(v['b'], 1) - v.reassign_token('', 0) + v[''] = 0 self.assertEqual(v[''], 0) self.assertEqual(v['a'], 1) self.assertEqual(v['b'], 2) self.assertEqual(v.get_itos(), ['', 'a', 'b']) - # token must exist for rassignment - with self.assertRaises(RuntimeError): - v.reassign_token('not in vocab', 0) - # index should be valid for reassignment with self.assertRaises(RuntimeError): - v.reassign_token('', 3) + v[''] = 3 def test_default_index(self): token_to_freq = {'': 2, 'a': 2, 'b': 2} @@ -97,7 +93,7 @@ def test_vocab_insert_token(self): # add item to end v = vocab(c) - v.insert_token('b', 2) + v['b'] = 2 expected_itos = ['', 'a', 'b'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} @@ -107,16 +103,13 @@ def test_vocab_insert_token(self): # add item to middle v = vocab(c) - v.insert_token('b', 0) + v['b'] = 0 expected_itos = ['b', '', 'a'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) - # token must not exist to be inserted - with self.assertRaises(RuntimeError): - v.insert_token('b', 0) def test_vocab_append_token(self): c = OrderedDict({'a': 2}) @@ -225,7 +218,7 @@ def test_errors_vocab_cpp(self): with self.assertRaises(RuntimeError): # Test proper error raised when setting a token out of bounds v = vocab(c, min_freq=3) - v.insert_token('new_token', 100) + v['new_token'] = 100 with self.assertRaises(RuntimeError): # Test proper error raised when looking up a token out of bounds @@ -247,6 +240,7 @@ def test_vocab_load_and_save(self): c = OrderedDict(sorted_by_freq_tuples) v = vocab(c, min_freq=3) + v.set_default_index(0) expected_itos = ['', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} @@ -260,6 +254,7 @@ def test_vocab_load_and_save(self): loaded_v = torch.load(vocab_path) self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi) + self.assertEqual(v['not in vocab'], 0) with self.subTest('torchscript'): vocab_path = os.path.join(self.test_dir, 'vocab_torchscript.pt') @@ -269,6 +264,7 @@ def test_vocab_load_and_save(self): loaded_v = torch.load(vocab_path) self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi) + self.assertEqual(v['not in vocab'], 0) def test_build_vocab_iterator(self): iterator = [['hello', 'hello', 'hello', 'freq_low', 'hello', 'world', 'world', 'world', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index 83e3573f01..3b6611156a 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -112,6 +112,7 @@ PYBIND11_MODULE(_torchtext, m) { const char *buffer = PyUnicode_AsUTF8AndSize(item.ptr(), &length); return self->__contains__(c10::string_view{buffer, (size_t)length}); }) + .def("__setitem__", &Vocab::__setitem__) .def("__getitem__", [](c10::intrusive_ptr &self, const py::str &item) -> int64_t { ssize_t length; @@ -120,9 +121,7 @@ PYBIND11_MODULE(_torchtext, m) { }) .def("set_default_index", &Vocab::set_default_index) .def("get_default_index", &Vocab::get_default_index) - .def("reassign_token", &Vocab::reassign_token) .def("__len__", &Vocab::__len__) - .def("insert_token", &Vocab::insert_token) .def("append_token", &Vocab::append_token) .def("lookup_token", &Vocab::lookup_token) .def("lookup_tokens", &Vocab::lookup_tokens) @@ -240,14 +239,13 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) { .def("__contains__", [](const c10::intrusive_ptr &self, const std::string &item) -> bool { return self->__contains__(c10::string_view{item}); }) + .def("__setitem__", &Vocab::__setitem__) .def("__getitem__", [](const c10::intrusive_ptr &self, const std::string &item) -> int64_t { return self->__getitem__(c10::string_view{item}); }) .def("__len__", &Vocab::__len__) .def("set_default_index", &Vocab::set_default_index) .def("get_default_index", &Vocab::get_default_index) - .def("reassign_token", &Vocab::reassign_token) - .def("insert_token", &Vocab::insert_token) .def("append_token", &Vocab::append_token) .def("lookup_token", &Vocab::lookup_token) .def("lookup_tokens", &Vocab::lookup_tokens) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index a6ae6cafd8..96843f45e5 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -7,17 +7,6 @@ #include // @manual namespace torchtext { -Vocab::Vocab(const StringList &tokens) : stoi_(MAX_VOCAB_SIZE, -1) { - for (size_t i = 0; i < tokens.size(); i++) { - // throw error if duplicate token is found - auto id = _find(c10::string_view{tokens[i].data(), tokens[i].size()}); - TORCH_CHECK(stoi_[id] == -1, - "Duplicate token found in tokens list: " + tokens[i]); - - _add(tokens[i]); - } -} - Vocab::Vocab(const StringList &tokens, const c10::optional &default_index) : stoi_(MAX_VOCAB_SIZE, -1), default_index_{default_index} { @@ -31,6 +20,8 @@ Vocab::Vocab(const StringList &tokens, } } +Vocab::Vocab(const StringList &tokens) : Vocab(tokens, {}) {} + int64_t Vocab::__len__() const { return itos_.size(); } bool Vocab::__contains__(const c10::string_view &token) const { @@ -61,19 +52,30 @@ c10::optional Vocab::get_default_index() const { return default_index_; } -void Vocab::reassign_token(const std::string &token, const int64_t &index) { +void Vocab::__setitem__(const std::string &token, const int64_t &index) { + // reassignment scenario + if (__contains__(c10::string_view{token.data(),token.size()})) { + // throw error if index is not valid + TORCH_CHECK(index >= 0 && index < __len__(), + "Specified index " + std::to_string(index) + + " is out of bounds for vocab of size " + + std::to_string(__len__())); + + _remove(token); + } // throw error if index is not valid - TORCH_CHECK(index >= 0 && index < __len__(), + TORCH_CHECK(index >= 0 && index <= __len__(), "Specified index " + std::to_string(index) + " is out of bounds for vocab of size " + std::to_string(__len__())); - // throw error if token not found - auto id = _find(c10::string_view{token.data(), token.size()}); - TORCH_CHECK(stoi_[id] != -1, "Token " + token + " not found in Vocab"); + // need to offset all tokens greater than or equal index by 1 + for (size_t i = index; i < __len__(); i++) { + stoi_[_find(c10::string_view{itos_[i].data(),itos_[i].size()})] = i + 1; + } - _remove(token); - insert_token(token, index); + itos_.insert(itos_.begin() + index, token); + stoi_[_find(c10::string_view{token.data(),token.size()})] = index; } void Vocab::append_token(const std::string &token) { @@ -86,26 +88,6 @@ void Vocab::append_token(const std::string &token) { _add(token); } -void Vocab::insert_token(const std::string &token, const int64_t &index) { - // throw error if index is not valid - TORCH_CHECK(index >= 0 && index <= __len__(), - "Specified index " + std::to_string(index) + - " is out of bounds for vocab of size " + - std::to_string(__len__())); - - // throw error if token not found - auto id = _find(c10::string_view{token.data(), token.size()}); - TORCH_CHECK(stoi_[id] == -1, "Token " + token + " not found in Vocab"); - - // need to offset all tokens greater than or equal index by 1 - for (size_t i = index; i < __len__(); i++) { - stoi_[_find(c10::string_view{itos_[i].data(), itos_[i].size()})] = i + 1; - } - - itos_.insert(itos_.begin() + index, token); - stoi_[_find(c10::string_view{token.data(), token.size()})] = index; -} - std::string Vocab::lookup_token(const int64_t &index) { // throw error if index is not valid TORCH_CHECK(index >= 0 && index < __len__(), @@ -320,7 +302,7 @@ Vocab _load_vocab_from_file(const std::string &file_path, StringList tokens = _concat_tokens(chunk_counters, min_freq, num_lines, false); - return Vocab(std::move(tokens), {}); + return Vocab(std::move(tokens)); } Vocab _build_vocab_from_text_file(const std::string &file_path, @@ -364,7 +346,7 @@ Vocab _build_vocab_from_text_file(const std::string &file_path, StringList tokens = _concat_tokens(chunk_counters, min_freq, num_lines, true); - return Vocab(std::move(tokens), {}); + return Vocab(std::move(tokens)); } VocabStates _serialize_vocab(const c10::intrusive_ptr &self) { @@ -372,35 +354,38 @@ VocabStates _serialize_vocab(const c10::intrusive_ptr &self) { StringList strings = self->itos_; std::vector tensors; - VocabStates states = std::make_tuple(self->version_str_, self->default_index_, - std::move(integers), std::move(strings), - std::move(tensors)); + if (self->default_index_.has_value()) { + integers.push_back(self->default_index_.value()); + } + + VocabStates states = std::make_tuple(self->version_str_, std::move(integers), + std::move(strings), std::move(tensors)); return states; } c10::intrusive_ptr _deserialize_vocab(VocabStates states) { auto state_size = std::tuple_size::value; - TORCH_CHECK(state_size == 5, - "Expected deserialized Vocab to have 5 states but found " + + TORCH_CHECK(state_size == 4, + "Expected deserialized Vocab to have 4 states but found " + std::to_string(state_size) + " states"); auto &version_str = std::get<0>(states); - auto &default_index = std::get<1>(states); - auto &integers = std::get<2>(states); - auto &strings = std::get<3>(states); - auto &tensors = std::get<4>(states); + auto &integers = std::get<1>(states); + auto &strings = std::get<2>(states); + auto &tensors = std::get<3>(states); - // check integers and tensors are empty - TORCH_CHECK(integers.size() == 0 || tensors.size() == 0, - "Expected `integers` and `tensors` states to be empty"); + // check tensors are empty + TORCH_CHECK(tensors.size() == 0, "Expected `tensors` states to be empty"); // throw error if version is not compatible TORCH_CHECK(version_str.compare("0.0.2") >= 0, "Found unexpected version for serialized Vocab: " + version_str); - auto deserialized_vocab = c10::make_intrusive(std::move(strings)); - deserialized_vocab->default_index_ = default_index; - return deserialized_vocab; + c10::optional default_index = {}; + if (integers.size() > 0) { + default_index = integers[0]; + } + return c10::make_intrusive(std::move(strings), default_index); } } // namespace torchtext diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index 33e8d3bb85..c8c4f2fc48 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -6,8 +6,8 @@ namespace torchtext { typedef std::vector StringList; typedef ska_ordered::order_preserving_flat_hash_map IndexDict; -typedef std::tuple, std::vector, - std::vector, std::vector> +typedef std::tuple, std::vector, + std::vector> VocabStates; struct Vocab : torch::CustomClassHolder { @@ -18,18 +18,19 @@ struct Vocab : torch::CustomClassHolder { StringList itos_; c10::optional default_index_ = {}; - // TODO: [can we remove this?] we need to keep this constructor, otherwise torch binding gets - // compilation error: no matching constructor for initialization of 'torchtext::Vocab' + // TODO: [can we remove this?] we need to keep this constructor, otherwise + // torch binding gets compilation error: no matching constructor for + // initialization of 'torchtext::Vocab' explicit Vocab(const StringList &tokens); - explicit Vocab(const StringList &tokens,const c10::optional &default_index); + explicit Vocab(const StringList &tokens, + const c10::optional &default_index); int64_t __len__() const; int64_t __getitem__(const c10::string_view &token) const; bool __contains__(const c10::string_view &token) const; void set_default_index(int64_t index); c10::optional get_default_index() const; - void reassign_token(const std::string &token,const int64_t &index); + void __setitem__(const std::string &token, const int64_t &index); void append_token(const std::string &token); - void insert_token(const std::string &token, const int64_t &index); std::string lookup_token(const int64_t &index); std::vector lookup_tokens(const std::vector &indices); std::vector diff --git a/torchtext/experimental/vocab.py b/torchtext/experimental/vocab.py index 5226e49705..63848ede0b 100644 --- a/torchtext/experimental/vocab.py +++ b/torchtext/experimental/vocab.py @@ -45,7 +45,7 @@ def build_vocab_from_text_file(file_path, jited_tokenizer, min_freq=1, unk_token """ vocab_obj = _build_vocab_from_text_file(file_path, min_freq, num_cpus, jited_tokenizer) if unk_token not in vocab_obj: - vocab_obj.insert_token(unk_token, 0) + vocab_obj[unk_token] = 0 return Vocab(vocab_obj) @@ -76,7 +76,7 @@ def load_vocab_from_file(file_path, min_freq=1, unk_token='', num_cpus=4): vocab_obj = _load_vocab_from_file(file_path, min_freq, num_cpus) if unk_token not in vocab_obj: - vocab_obj.insert_token(unk_token, 0) + vocab_obj[unk_token] = 0 return Vocab(vocab_obj) @@ -215,28 +215,16 @@ def get_default_index(self) -> Optional[int]: return self.vocab.get_default_index() @torch.jit.export - def reassign_token(self, token: str, index: int) -> None: + def __setitem__(self, token: str, index: int) -> None: r""" Args: token (str): the token used to lookup the corresponding index. index (int): the index corresponding to the associated token. - - Raises: - RuntimeError: If token is not present in Vocab - """ - self.vocab.reassign_token(token, index) - - @torch.jit.export - def insert_token(self, token: str, index: int) -> None: - r""" - Args: - token (str): the token used to lookup the corresponding index. - index (int): the index corresponding to the associated token. - Raises: - RuntimeError: if `index` not between [0, Vocab.size()] or if token already exists in the vocab. + RuntimeError: If `index` is out or range [0,Vocab.size()) in + case token exsist or out of range [0,Vocab.size()] in case token does not exist in Vocab. """ - self.vocab.insert_token(token, index) + self.vocab.__setitem__(token, index) @torch.jit.export def append_token(self, token: str) -> None: From 8aed597fe3fa4dc58ee2cfdcc160d1b9c5b3baba Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Tue, 11 May 2021 21:59:06 -0400 Subject: [PATCH 9/9] removing __setitem__ --- test/experimental/test_vocab.py | 29 +++++----------- torchtext/csrc/register_bindings.cpp | 6 ++-- torchtext/csrc/vocab.cpp | 51 ++++++++++++++++------------ torchtext/csrc/vocab.h | 3 +- torchtext/experimental/vocab.py | 22 ++++++++---- 5 files changed, 60 insertions(+), 51 deletions(-) diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index 9cebce9668..afb140c71e 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -55,16 +55,18 @@ def test_reassign_token(self): self.assertEqual(v[''], 2) self.assertEqual(v['a'], 0) self.assertEqual(v['b'], 1) - v[''] = 0 + v.reassign_token('', 0) self.assertEqual(v[''], 0) self.assertEqual(v['a'], 1) self.assertEqual(v['b'], 2) self.assertEqual(v.get_itos(), ['', 'a', 'b']) - # index should be valid for reassignment with self.assertRaises(RuntimeError): - v[''] = 3 + v.reassign_token('not in vocab', 0) + + with self.assertRaises(RuntimeError): + v.reassign_token('', 3) def test_default_index(self): token_to_freq = {'': 2, 'a': 2, 'b': 2} @@ -93,7 +95,7 @@ def test_vocab_insert_token(self): # add item to end v = vocab(c) - v['b'] = 2 + v.insert_token('b', 2) expected_itos = ['', 'a', 'b'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} @@ -103,7 +105,7 @@ def test_vocab_insert_token(self): # add item to middle v = vocab(c) - v['b'] = 0 + v.insert_token('b', 0) expected_itos = ['b', '', 'a'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} @@ -187,6 +189,8 @@ def test_vocab_lookup_token(self): v = vocab(c) self.assertEqual(v.lookup_token(1), 'a') + with self.assertRaises(RuntimeError): + v.lookup_token(100) def test_vocab_lookup_tokens(self): token_to_freq = {'a': 2, 'b': 2, 'c': 2} @@ -210,21 +214,6 @@ def test_vocab_lookup_indices(self): self.assertEqual(v.lookup_indices(tokens), expected_indices) - def test_errors_vocab_cpp(self): - token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2} - sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) - c = OrderedDict(sorted_by_freq_tuples) - - with self.assertRaises(RuntimeError): - # Test proper error raised when setting a token out of bounds - v = vocab(c, min_freq=3) - v['new_token'] = 100 - - with self.assertRaises(RuntimeError): - # Test proper error raised when looking up a token out of bounds - v = vocab(c) - v.lookup_token(100) - def test_errors_vocab_python(self): token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index 3b6611156a..cf6656d12a 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -112,13 +112,14 @@ PYBIND11_MODULE(_torchtext, m) { const char *buffer = PyUnicode_AsUTF8AndSize(item.ptr(), &length); return self->__contains__(c10::string_view{buffer, (size_t)length}); }) - .def("__setitem__", &Vocab::__setitem__) .def("__getitem__", [](c10::intrusive_ptr &self, const py::str &item) -> int64_t { ssize_t length; const char *buffer = PyUnicode_AsUTF8AndSize(item.ptr(), &length); return self->__getitem__(c10::string_view{buffer, (size_t)length}); }) + .def("reassign_token", &Vocab::reassign_token) + .def("insert_token", &Vocab::insert_token) .def("set_default_index", &Vocab::set_default_index) .def("get_default_index", &Vocab::get_default_index) .def("__len__", &Vocab::__len__) @@ -239,10 +240,11 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) { .def("__contains__", [](const c10::intrusive_ptr &self, const std::string &item) -> bool { return self->__contains__(c10::string_view{item}); }) - .def("__setitem__", &Vocab::__setitem__) .def("__getitem__", [](const c10::intrusive_ptr &self, const std::string &item) -> int64_t { return self->__getitem__(c10::string_view{item}); }) + .def("reassign_token", &Vocab::reassign_token) + .def("insert_token", &Vocab::insert_token) .def("__len__", &Vocab::__len__) .def("set_default_index", &Vocab::set_default_index) .def("get_default_index", &Vocab::get_default_index) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 96843f45e5..e659fbdb70 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -52,40 +52,47 @@ c10::optional Vocab::get_default_index() const { return default_index_; } -void Vocab::__setitem__(const std::string &token, const int64_t &index) { - // reassignment scenario - if (__contains__(c10::string_view{token.data(),token.size()})) { - // throw error if index is not valid - TORCH_CHECK(index >= 0 && index < __len__(), - "Specified index " + std::to_string(index) + - " is out of bounds for vocab of size " + - std::to_string(__len__())); +void Vocab::append_token(const std::string &token) { + // throw error if token already exist in vocab + auto id = _find(c10::string_view{token.data(), token.size()}); + TORCH_CHECK(stoi_[id] == -1, "Token " + token + + " already exists in the Vocab with index: " + + std::to_string(stoi_[id])); - _remove(token); - } + _add(token); +} + +void Vocab::reassign_token(const std::string &token, const int64_t &index) { + // throw error if index is not valid + TORCH_CHECK(index >= 0 && index < __len__(), + "Specified index " + std::to_string(index) + + " is out of bounds for vocab of size " + + std::to_string(__len__())); + + // throw error if token not found + TORCH_CHECK(__contains__(token), "Token " + token + " not found in Vocab"); + + _remove(token); + insert_token(token, index); +} + +void Vocab::insert_token(const std::string &token, const int64_t &index) { // throw error if index is not valid TORCH_CHECK(index >= 0 && index <= __len__(), "Specified index " + std::to_string(index) + " is out of bounds for vocab of size " + std::to_string(__len__())); + // throw error if token already present + TORCH_CHECK(!__contains__(token), "Token " + token + " not found in Vocab"); + // need to offset all tokens greater than or equal index by 1 for (size_t i = index; i < __len__(); i++) { - stoi_[_find(c10::string_view{itos_[i].data(),itos_[i].size()})] = i + 1; + stoi_[_find(c10::string_view{itos_[i].data(), itos_[i].size()})] = i + 1; } itos_.insert(itos_.begin() + index, token); - stoi_[_find(c10::string_view{token.data(),token.size()})] = index; -} - -void Vocab::append_token(const std::string &token) { - // throw error if token already exist in vocab - auto id = _find(c10::string_view{token.data(), token.size()}); - TORCH_CHECK(stoi_[id] == -1, "Token " + token + - " already exists in the Vocab with index: " + - std::to_string(stoi_[id])); - - _add(token); + stoi_[_find(c10::string_view{token.data(), token.size()})] = index; } std::string Vocab::lookup_token(const int64_t &index) { diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index c8c4f2fc48..06f98865d3 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -29,7 +29,8 @@ struct Vocab : torch::CustomClassHolder { bool __contains__(const c10::string_view &token) const; void set_default_index(int64_t index); c10::optional get_default_index() const; - void __setitem__(const std::string &token, const int64_t &index); + void reassign_token(const std::string &token, const int64_t &index); + void insert_token(const std::string &token, const int64_t &index); void append_token(const std::string &token); std::string lookup_token(const int64_t &index); std::vector lookup_tokens(const std::vector &indices); diff --git a/torchtext/experimental/vocab.py b/torchtext/experimental/vocab.py index 63848ede0b..a1aa2290d9 100644 --- a/torchtext/experimental/vocab.py +++ b/torchtext/experimental/vocab.py @@ -45,7 +45,7 @@ def build_vocab_from_text_file(file_path, jited_tokenizer, min_freq=1, unk_token """ vocab_obj = _build_vocab_from_text_file(file_path, min_freq, num_cpus, jited_tokenizer) if unk_token not in vocab_obj: - vocab_obj[unk_token] = 0 + vocab_obj.insert_token(unk_token, 0) return Vocab(vocab_obj) @@ -76,7 +76,7 @@ def load_vocab_from_file(file_path, min_freq=1, unk_token='', num_cpus=4): vocab_obj = _load_vocab_from_file(file_path, min_freq, num_cpus) if unk_token not in vocab_obj: - vocab_obj[unk_token] = 0 + vocab_obj.insert_token(unk_token, 0) return Vocab(vocab_obj) @@ -215,16 +215,26 @@ def get_default_index(self) -> Optional[int]: return self.vocab.get_default_index() @torch.jit.export - def __setitem__(self, token: str, index: int) -> None: + def reassign_token(self, token: str, index: int) -> None: r""" Args: token (str): the token used to lookup the corresponding index. index (int): the index corresponding to the associated token. Raises: - RuntimeError: If `index` is out or range [0,Vocab.size()) in - case token exsist or out of range [0,Vocab.size()] in case token does not exist in Vocab. + RuntimeError: If `index` is not range [0,Vocab.size()) or if token is not present in Vocab """ - self.vocab.__setitem__(token, index) + self.vocab.reassign_token(token, index) + + @torch.jit.export + def insert_token(self, token: str, index: int) -> None: + r""" + Args: + token (str): the token used to lookup the corresponding index. + index (int): the index corresponding to the associated token. + Raises: + RuntimeError: if `index` not between [0, Vocab.size()] or if token already exists in the vocab. + """ + self.vocab.insert_token(token, index) @torch.jit.export def append_token(self, token: str) -> None: