diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index 879c03e72d..1ce15b41d3 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -17,21 +17,26 @@ def tearDown(self): torch._C._jit_clear_class_registry() torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore() - def test_has_unk(self): + # we separate out these errors because Windows runs into seg faults when propagating + # exceptions from C++ using pybind11 + @unittest.skipIf(platform.system() == "Windows", "Test is known to fail on Windows.") + def test_has_no_unk(self): c = OrderedDict() v = vocab(c) + with self.assertRaisesRegex(RuntimeError, 'bad optional access'): + v.get_default_index() # check if unk is mapped to the first index - self.assertEqual(v['not_in_it'], 0) - self.assertEqual(v[''], 0) - - def test_new_unk(self): - c = OrderedDict() - v = vocab(c, unk_token="") + with self.assertRaises(RuntimeError): + v['not_in_it'] + with self.assertRaises(RuntimeError): + v[''] - # check if new_unk is mapped to the first index - self.assertEqual(v[''], 0) + v.insert_token('not_in_it', 0) + v.set_default_index(0) + self.assertEqual(v.get_default_index(), 0) self.assertEqual(v['not_in_it'], 0) + self.assertEqual(v[''], 0) def test_vocab_get_item(self): token_to_freq = {'': 2, 'a': 2, 'b': 2} @@ -43,35 +48,81 @@ def test_vocab_get_item(self): self.assertEqual(v['a'], 1) self.assertEqual(v['b'], 2) + # we separate out these errors because Windows runs into seg faults when propagating + # exceptions from C++ using pybind11 + @unittest.skipIf(platform.system() == "Windows", "Test is known to fail on Windows.") + def test_vocab_set_item(self): + token_to_freq = {'': 2, 'a': 2, 'b': 2} + sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) + c = OrderedDict(sorted_by_freq_tuples) + v = vocab(c, min_freq=2) + + v.set_default_index(0) + with self.assertRaises(RuntimeError): + v['b'] = 1 + del v['b'] + self.assertEqual(v[''], 0) + self.assertEqual(v['a'], 1) + self.assertEqual(v['not_in_it'], 0) + self.assertEqual(v['b'], 0) + + v['b'] = 1 + self.assertEqual(v[''], 0) + self.assertEqual(v['b'], 1) + self.assertEqual(v['not_in_it'], 0) + self.assertEqual(v['a'], 0) + def test_vocab_insert_token(self): c = OrderedDict({'': 2, 'a': 2}) # add item to end v = vocab(c) + v.set_default_index(0) v.insert_token('b', 2) expected_itos = ['', 'a', 'b'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} + self.assertEqual(v.get_default_index(), 0) self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) # add item to middle v = vocab(c) + v.set_default_index(0) v.insert_token('b', 0) expected_itos = ['b', '', 'a'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} + self.assertEqual(v.get_default_index(), 1) self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) + # we separate out these errors because Windows runs into seg faults when propagating + # exceptions from C++ using pybind11 + @unittest.skipIf(platform.system() == "Windows", "Test is known to fail on Windows.") + def test_insert_existing_token(self): + c = OrderedDict({'a': 2, 'b': 2, 'c': 2}) + + # add item to end + v = vocab(c) + v.insert_token('', 2) + v.set_default_index(2) + + with self.assertRaises(RuntimeError): + # Test proper error raised when setting a token out of bounds + v.insert_token('', 1) + + v.insert_token('d', 1) + self.assertEqual(v['not_in_it'], 3) + def test_vocab_append_token(self): c = OrderedDict({'a': 2}) v = vocab(c) v.append_token('b') - expected_itos = ['', 'a', 'b'] + expected_itos = ['a', 'b'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v.get_itos(), expected_itos) @@ -83,7 +134,7 @@ def test_vocab_len(self): c = OrderedDict(sorted_by_freq_tuples) v = vocab(c) - self.assertEqual(len(v), 4) + self.assertEqual(len(v), 3) def test_vocab_basic(self): token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2} @@ -92,12 +143,15 @@ def test_vocab_basic(self): c = OrderedDict(sorted_by_freq_tuples) v = vocab(c, min_freq=3) - expected_itos = ['', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] + expected_itos = ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) + # we separate out these errors because Windows runs into seg faults when propagating + # exceptions from C++ using pybind11 + @unittest.skipIf(platform.system() == "Windows", "Test is known to fail on Windows.") def test_vocab_jit(self): token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) @@ -106,7 +160,7 @@ def test_vocab_jit(self): v = vocab(c, min_freq=3) jit_v = torch.jit.script(v) - expected_itos = ['', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] + expected_itos = ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} assert not v.is_jitable @@ -117,6 +171,9 @@ def test_vocab_jit(self): self.assertEqual(jit_v.get_itos(), expected_itos) self.assertEqual(dict(jit_v.get_stoi()), expected_stoi) + # we separate out these errors because Windows runs into seg faults when propagating + # exceptions from C++ using pybind11 + @unittest.skipIf(platform.system() == "Windows", "Test is known to fail on Windows.") def test_vocab_forward(self): token_to_freq = {'a': 2, 'b': 2, 'c': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) @@ -126,7 +183,7 @@ def test_vocab_forward(self): jit_v = torch.jit.script(v) tokens = ['b', 'a', 'c'] - expected_indices = [2, 1, 3] + expected_indices = [1, 0, 2] self.assertEqual(v(tokens), expected_indices) self.assertEqual(jit_v(tokens), expected_indices) @@ -137,7 +194,7 @@ def test_vocab_lookup_token(self): c = OrderedDict(sorted_by_freq_tuples) v = vocab(c) - self.assertEqual(v.lookup_token(1), 'a') + self.assertEqual(v.lookup_token(0), 'a') def test_vocab_lookup_tokens(self): token_to_freq = {'a': 2, 'b': 2, 'c': 2} @@ -145,7 +202,7 @@ def test_vocab_lookup_tokens(self): c = OrderedDict(sorted_by_freq_tuples) v = vocab(c) - indices = [2, 1, 3] + indices = [1, 0, 2] expected_tokens = ['b', 'a', 'c'] self.assertEqual(v.lookup_tokens(indices), expected_tokens) @@ -157,7 +214,7 @@ def test_vocab_lookup_indices(self): v = vocab(c) tokens = ['b', 'a', 'c'] - expected_indices = [2, 1, 3] + expected_indices = [1, 0, 2] self.assertEqual(v.lookup_indices(tokens), expected_indices) @@ -179,14 +236,18 @@ def test_errors_vocab_cpp(self): v = vocab(c) v.lookup_token(100) + # we separate out these errors because Windows runs into seg faults when propagating + # exceptions from C++ using pybind11 + @unittest.skipIf(platform.system() == "Windows", "Test is known to fail on Windows.") def test_errors_vocab_python(self): token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) c = OrderedDict(sorted_by_freq_tuples) + v = vocab(c) - with self.assertRaises(ValueError): + with self.assertRaises(RuntimeError): # Test proper error raised when setting unk token to None - vocab(c, unk_token=None) + v(['not_in_vocab']) def test_vocab_load_and_save(self): token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2} @@ -194,8 +255,8 @@ def test_vocab_load_and_save(self): c = OrderedDict(sorted_by_freq_tuples) v = vocab(c, min_freq=3) - - expected_itos = ['', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] + v.set_default_index(1) + expected_itos = ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v.get_itos(), expected_itos) @@ -221,7 +282,7 @@ def test_build_vocab_iterator(self): iterator = [['hello', 'hello', 'hello', 'freq_low', 'hello', 'world', 'world', 'world', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'freq_low', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T']] v = build_vocab_from_iterator(iterator) - expected_itos = ['', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world', 'freq_low'] + expected_itos = ['ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world', 'freq_low'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) diff --git a/test/experimental/test_with_asset.py b/test/experimental/test_with_asset.py index f900fb6752..92bcb80579 100644 --- a/test/experimental/test_with_asset.py +++ b/test/experimental/test_with_asset.py @@ -15,11 +15,11 @@ load_vocab_from_file, build_vocab_from_text_file, ) +import unittest +import platform import shutil import tempfile import os -import unittest -import platform from torchtext.experimental.vectors import ( GloVe, build_vectors, @@ -75,6 +75,9 @@ def test_wikitext103(self): class TestTransformsWithAsset(TorchtextTestCase): + # we separate out these errors because Windows runs into seg faults when propagating + # exceptions from C++ using pybind11 + @unittest.skipIf(platform.system() == "Windows", "Test is known to fail on Windows.") def test_vocab_transform(self): asset_name = 'vocab_test2.txt' asset_path = get_asset_path(asset_name) @@ -180,7 +183,8 @@ def test_vocab_from_file(self): asset_name = 'vocab_test.txt' asset_path = get_asset_path(asset_name) with open(asset_path, 'r') as f: - v = load_vocab_from_file(f, unk_token='') + v = load_vocab_from_file(f) + v.insert_token('', 0) expected_itos = ['', 'b', 'a', 'c'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} self.assertEqual(v.get_itos(), expected_itos) @@ -192,8 +196,8 @@ def test_vocab_from_raw_text_file(self): with open(asset_path, 'r') as f: tokenizer = basic_english_normalize() jit_tokenizer = torch.jit.script(tokenizer) - v = build_vocab_from_text_file(f, jit_tokenizer, unk_token='') - expected_itos = ['', "'", 'after', 'talks', '.', 'are', 'at', 'disappointed', + v = build_vocab_from_text_file(f, jit_tokenizer) + expected_itos = ["'", 'after', 'talks', '.', 'are', 'at', 'disappointed', 'fears', 'federal', 'firm', 'for', 'mogul', 'n', 'newall', 'parent', 'pension', 'representing', 'say', 'stricken', 't', 'they', 'turner', 'unions', 'with', 'workers'] diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index 4c3ef76399..b1c17dd79b 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -15,12 +15,11 @@ namespace py = pybind11; namespace { Vocab build_vocab_from_text_file(const std::string &file_path, - const std::string &unk_token, const int64_t min_freq, const int64_t num_cpus, py::object fn) { torch::jit::script::Module module(*torch::jit::as_module(fn)); - return _build_vocab_from_text_file(file_path, unk_token, min_freq, num_cpus, module); + return _build_vocab_from_text_file(file_path, min_freq, num_cpus, module); } } // namespace @@ -100,12 +99,15 @@ PYBIND11_MODULE(_torchtext, m) { })); py::class_>(m, "Vocab") - .def(py::init, std::string>()) + .def(py::init>()) .def_readonly("itos_", &Vocab::itos_) - .def_readonly("unk_token_", &Vocab::unk_token_) .def("__getitem__", &Vocab::__getitem__) + .def("__setitem__", &Vocab::__setitem__) + .def("__delitem__", &Vocab::__delitem__) .def("__len__", &Vocab::__len__) .def("insert_token", &Vocab::insert_token) + .def("set_default_index", &Vocab::set_default_index) + .def("get_default_index", &Vocab::get_default_index) .def("append_token", &Vocab::append_token) .def("lookup_token", &Vocab::lookup_token) .def("lookup_tokens", &Vocab::lookup_tokens) @@ -202,10 +204,14 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) { }); m.class_("Vocab") - .def(torch::init()) + .def(torch::init()) .def("__getitem__", &Vocab::__getitem__) + .def("__setitem__", &Vocab::__setitem__) + .def("__delitem__", &Vocab::__delitem__) .def("__len__", &Vocab::__len__) .def("insert_token", &Vocab::insert_token) + .def("set_default_index", &Vocab::set_default_index) + .def("get_default_index", &Vocab::get_default_index) .def("append_token", &Vocab::append_token) .def("lookup_token", &Vocab::lookup_token) .def("lookup_tokens", &Vocab::lookup_tokens) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 0e324dbbf5..264ae6f139 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -7,13 +7,10 @@ namespace torchtext { -Vocab::Vocab(const StringList &tokens, const IndexDict &stoi, - const std::string &unk_token, const int64_t unk_index) - : unk_index_(std::move(unk_index)), stoi_(std::move(stoi)), - itos_(std::move(tokens)), unk_token_(std::move(unk_token)) {} +Vocab::Vocab(const StringList &tokens, const IndexDict &stoi) + : stoi_(std::move(stoi)), itos_(std::move(tokens)) {} -Vocab::Vocab(const StringList &tokens, const std::string &unk_token) - : itos_(std::move(tokens)), unk_token_(std::move(unk_token)) { +Vocab::Vocab(const StringList &tokens) : itos_(std::move(tokens)) { stoi_.reserve(tokens.size()); for (std::size_t i = 0; i < tokens.size(); i++) { // tokens should not have any duplicates @@ -27,7 +24,6 @@ Vocab::Vocab(const StringList &tokens, const std::string &unk_token) } stoi_[std::move(tokens[i])] = i; } - unk_index_ = stoi_.find(unk_token)->second; } int64_t Vocab::__len__() const { return stoi_.size(); } @@ -36,8 +32,41 @@ int64_t Vocab::__getitem__(const std::string &token) const { const auto &item = stoi_.find(token); if (item != stoi_.end()) { return item->second; + } else if (default_index_.has_value()) { + return default_index_.value(); + } else + throw std::runtime_error("The default index has not been set up yet. Call " + "set_default_index() function to " + "set up the default index"); +} + +void Vocab::__setitem__(const std::string &token, const int64_t &index) { + if (index < 0 || index > static_cast(stoi_.size())) { +#ifdef _MSC_VER + std::cerr << "[RuntimeError] Specified index " << index + << " is out of bounds of the size of stoi dictionary: " + << stoi_.size() << std::endl; +#endif + throw std::runtime_error( + "Specified index " + std::to_string(index) + + " is out of bounds of the size of stoi dictionary: " + + std::to_string(stoi_.size()) + "."); + } + + auto item = stoi_.find(token); + if (item != stoi_.end()) { + throw std::runtime_error( + "Token " + token + + " has already been in the Vocab. Please delete it first by call del func."); + } + + if (index == static_cast(stoi_.size())) append_token(token); + else { + auto it = stoi_.find(itos_[index]); + stoi_.erase(it); + stoi_[token] = index; + itos_[index] = token; } - return unk_index_; } void Vocab::append_token(const std::string &token) { @@ -88,9 +117,46 @@ void Vocab::insert_token(const std::string &token, const int64_t &index) { // need to update unk_index in case token equals unk_token or token // inserted before unk_token - unk_index_ = stoi_.find(unk_token_)->second; + if (default_index_.has_value() && index <= *default_index_) { + default_index_ = default_index_.value() + 1; + } +} + +void Vocab::__delitem__(const std::string &token) { + const auto &item = stoi_.find(token); + // if item already in stoi we throw an error + if (item == stoi_.end()) { +#ifdef _MSC_VER + std::cerr << "[RuntimeError] Token " << token + << " doesn't exist in the Vocab" + << std::endl; +#endif + throw std::runtime_error("Token " + token + + " doesn't exist in the Vocab" + "."); + } + for (size_t i = item->second + 1; i < itos_.size(); i++) { + stoi_[itos_[i]] = i - 1; + } + stoi_.erase(token); + itos_.erase(itos_.begin() + item->second); + + // need to update unk_index in case token equals unk_token or token + // inserted before unk_token + if (default_index_.has_value() && item->second < *default_index_) { + default_index_ = default_index_.value() - 1; + } +} + +void Vocab::set_default_index(const int64_t index) { + if (default_index_.has_value()) + std::cerr + << "UNK index has been assigned. You are resetting the UNK index here." + << index << std::endl; + default_index_ = index; } +int64_t Vocab::get_default_index() const { return default_index_.value(); } + std::string Vocab::lookup_token(const int64_t &index) { if (index < 0 || index > static_cast(itos_.size())) { #ifdef _MSC_VER @@ -207,8 +273,8 @@ struct CompareTokens { std::tuple _concat_tokens(std::vector> chunk_counters, - const std::string &unk_token, const int64_t min_freq, - const int64_t num_lines, const bool sort_tokens) { + const int64_t min_freq, const int64_t num_lines, + const bool sort_tokens) { TORCH_CHECK(chunk_counters.size() > 0, "There must be at least 1 chunk to concatenate!"); @@ -254,16 +320,6 @@ _concat_tokens(std::vector> chunk_counters, unique_tokens.push_back(token_freq_pair.first); } - // insert unk_token if not present - if (tokens_freq.find(unk_token) == tokens_freq.end()) { - std::cerr << "The `unk_token` " << unk_token - << " wasn't found in the `ordered_dict`. Adding the `unk_token` " - "to the beginning of the Vocab." - << std::endl; - - unique_tokens.insert(unique_tokens.begin(), unk_token); - } - // create stoi IndexDict stoi; stoi.reserve(num_lines); @@ -279,7 +335,6 @@ _concat_tokens(std::vector> chunk_counters, constexpr int64_t GRAIN_SIZE = 13107; Vocab _load_vocab_from_file(const std::string &file_path, - const std::string &unk_token, const int64_t min_freq, const int64_t num_cpus) { std::cerr << "[INFO] Reading file " << file_path << std::endl; @@ -322,15 +377,12 @@ Vocab _load_vocab_from_file(const std::string &file_path, IndexDict stoi; StringList tokens; std::tie(stoi, tokens) = - _concat_tokens(chunk_counters, unk_token, min_freq, num_lines, false); + _concat_tokens(chunk_counters, min_freq, num_lines, false); - int64_t unk_index = stoi.find(unk_token)->second; - - return Vocab(std::move(tokens), std::move(stoi), unk_token, unk_index); + return Vocab(std::move(tokens), std::move(stoi)); } Vocab _build_vocab_from_text_file(const std::string &file_path, - const std::string &unk_token, const int64_t min_freq, const int64_t num_cpus, torch::jit::script::Module tokenizer) { @@ -373,26 +425,29 @@ Vocab _build_vocab_from_text_file(const std::string &file_path, IndexDict stoi; StringList tokens; std::tie(stoi, tokens) = - _concat_tokens(chunk_counters, unk_token, min_freq, num_lines, true); - int64_t unk_index = stoi.find(unk_token)->second; + _concat_tokens(chunk_counters, min_freq, num_lines, true); - return Vocab(std::move(tokens), std::move(stoi), unk_token, unk_index); + return Vocab(std::move(tokens), std::move(stoi)); } VocabStates _serialize_vocab(const c10::intrusive_ptr &self) { std::vector integers; StringList strings = self->itos_; - strings.push_back(self->unk_token_); std::vector tensors; - VocabStates states = std::make_tuple(self->version_str_, std::move(integers), - std::move(strings), std::move(tensors)); + c10::optional default_index = {}; + if (self->default_index_.has_value()) + default_index = self->default_index_.value(); + + VocabStates states = std::make_tuple( + self->version_str_, std::move(integers), std::move(strings), + default_index, std::move(tensors)); return states; } c10::intrusive_ptr _deserialize_vocab(VocabStates states) { auto state_size = std::tuple_size::value; - if (state_size != 4) { + if (state_size != 5) { #ifdef _MSC_VER std::cerr << "[RuntimeError] Expected deserialized Vocab to have 4 states " "but found " @@ -406,7 +461,8 @@ c10::intrusive_ptr _deserialize_vocab(VocabStates states) { auto &version_str = std::get<0>(states); auto &integers = std::get<1>(states); auto &strings = std::get<2>(states); - auto &tensors = std::get<3>(states); + auto &default_index = std::get<3>(states); + auto &tensors = std::get<4>(states); // check integers and tensors are empty if (integers.size() != 0 || tensors.size() != 0) { @@ -420,10 +476,11 @@ c10::intrusive_ptr _deserialize_vocab(VocabStates states) { } if (version_str.compare("0.0.1") >= 0) { - std::string unk_token = strings.back(); - strings.pop_back(); // remove last element which is unk_token + auto vocab_instance = c10::make_intrusive(std::move(strings)); + if (default_index.has_value()) + vocab_instance->set_default_index(default_index.value()); - return c10::make_intrusive(std::move(strings), std::move(unk_token)); + return vocab_instance; } #ifdef _MSC_VER std::cerr << "[RuntimeError] Found unexpected version for serialized Vocab: " diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index 0da660a633..609549cd18 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -6,28 +6,28 @@ typedef std::vector StringList; typedef ska_ordered::order_preserving_flat_hash_map IndexDict; typedef std::tuple, std::vector, - std::vector> + c10::optional, std::vector> VocabStates; struct Vocab : torch::CustomClassHolder { private: - int64_t unk_index_; IndexDict stoi_; public: const std::string version_str_ = "0.0.1"; StringList itos_; - std::string unk_token_; + c10::optional default_index_ = {}; - explicit Vocab(const std::vector &tokens, - const std::string &unk_token); - explicit Vocab(const StringList &tokens, const IndexDict &stoi, - - const std::string &unk_token, const int64_t unk_index); + explicit Vocab(const std::vector &tokens); + explicit Vocab(const StringList &tokens, const IndexDict &stoi); int64_t __len__() const; int64_t __getitem__(const std::string &token) const; + void __setitem__(const std::string &token, const int64_t &index); void append_token(const std::string &token); void insert_token(const std::string &token, const int64_t &index); + void __delitem__(const std::string &token); + void set_default_index(const int64_t index); + int64_t get_default_index() const; std::string lookup_token(const int64_t &index); std::vector lookup_tokens(const std::vector &indices); std::vector lookup_indices(const std::vector &tokens); @@ -39,10 +39,8 @@ VocabStates _serialize_vocab(const c10::intrusive_ptr &self); c10::intrusive_ptr _deserialize_vocab(VocabStates states); Vocab _load_vocab_from_file(const std::string &file_path, - const std::string &unk_token, const int64_t min_freq, const int64_t num_cpus); Vocab _build_vocab_from_text_file(const std::string &file_path, - const std::string &unk_token, const int64_t min_freq, const int64_t num_cpus, torch::jit::script::Module tokenizer); diff --git a/torchtext/experimental/vocab.py b/torchtext/experimental/vocab.py index 606965daa9..6af1211301 100644 --- a/torchtext/experimental/vocab.py +++ b/torchtext/experimental/vocab.py @@ -1,6 +1,5 @@ import logging from typing import Dict, List -import warnings from collections import Counter, OrderedDict import torch import torch.nn as nn @@ -19,7 +18,7 @@ logger = logging.getLogger(__name__) -def build_vocab_from_text_file(file_object, jited_tokenizer, min_freq=1, unk_token='', num_cpus=4): +def build_vocab_from_text_file(file_object, jited_tokenizer, min_freq=1, num_cpus=4): r"""Create a `Vocab` object from a raw text file. The `file_object` can contain any raw text. This function applies a generic JITed tokenizer in @@ -31,7 +30,6 @@ def build_vocab_from_text_file(file_object, jited_tokenizer, min_freq=1, unk_tok jited_tokenizer (ScriptModule): a tokenizer that has been JITed using `torch.jit.script` min_freq: The minimum frequency needed to include a token in the vocabulary. Values less than 1 will be set to 1. Default: 1. - unk_token: The default unknown token to use. Default: ''. num_cpus (int): the number of cpus to use when loading the vectors from file. Default: 4. Returns: @@ -45,12 +43,15 @@ def build_vocab_from_text_file(file_object, jited_tokenizer, min_freq=1, unk_tok >>> tokenizer = basic_english_normalize() >>> jit_tokenizer = torch.jit.script(tokenizer) >>> v = build_vocab_from_text_file(f, jit_tokenizer) + >>> v.insert_token('', 0) + >>> v.set_default_index(0) + >>> v.get_default_index() """ - vocab_obj = _build_vocab_from_text_file(file_object.name, unk_token, min_freq, num_cpus, jited_tokenizer) + vocab_obj = _build_vocab_from_text_file(file_object.name, min_freq, num_cpus, jited_tokenizer) return Vocab(vocab_obj) -def load_vocab_from_file(file_object, min_freq=1, unk_token='', num_cpus=4): +def load_vocab_from_file(file_object, min_freq=1, num_cpus=4): r"""Create a `Vocab` object from a text file. The `file_object` should contain tokens separated by new lines. Note that the vocab will be created in the order that the tokens first appear in the file (and not by the frequency of tokens). @@ -65,7 +66,6 @@ def load_vocab_from_file(file_object, min_freq=1, unk_token='', num_cpus=4) file_object (FileObject): a file like object to read data from. min_freq: The minimum frequency needed to include a token in the vocabulary. Values less than 1 will be set to 1. Default: 1. - unk_token: The default unknown token to use. Default: ''. num_cpus (int): the number of cpus to use when loading the vectors from file. Default: 4. Returns: @@ -75,13 +75,15 @@ def load_vocab_from_file(file_object, min_freq=1, unk_token='', num_cpus=4) >>> from torchtext.experimental.vocab import load_vocab_from_file >>> f = open('vocab.txt', 'r') >>> v = load_vocab_from_file(f) + >>> v.insert_token('', 0) + >>> v.set_default_index(0) + >>> v.get_default_index() """ - - vocab_obj = _load_vocab_from_file(file_object.name, unk_token, min_freq, num_cpus) + vocab_obj = _load_vocab_from_file(file_object.name, min_freq, num_cpus) return Vocab(vocab_obj) -def build_vocab_from_iterator(iterator, min_freq=1, unk_token=''): +def build_vocab_from_iterator(iterator, min_freq=1): """ Build a Vocab from an iterator. @@ -89,7 +91,16 @@ def build_vocab_from_iterator(iterator, min_freq=1, unk_token=''): iterator: Iterator used to build Vocab. Must yield list or iterator of tokens. min_freq: The minimum frequency needed to include a token in the vocabulary. Values less than 1 will be set to 1. Default: 1. - unk_token: The default unknown token to use. Default: ''. + + Examples: + >>> from torchtext.experimental.vocab import build_vocab_from_iterator + >>> tokens = [['this', 'is', 'an', 'example', 'for', 'vocab']] + >>> v = build_vocab_from_iterator(tokens) + >>> v.insert_token('', 0) + >>> v.set_default_index(0) + >>> v.get_default_index() + >>> tokens_iter = iter([['this', 'is', 'an'], ['example', 'for', 'vocab']]) + >>> v1 = build_vocab_from_iterator(tokens_iter) """ counter = Counter() @@ -97,25 +108,20 @@ def build_vocab_from_iterator(iterator, min_freq=1, unk_token=''): counter.update(tokens) sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[1], reverse=True) ordered_dict = OrderedDict(sorted_by_freq_tuples) - word_vocab = vocab(ordered_dict, min_freq=min_freq, unk_token=unk_token) + word_vocab = vocab(ordered_dict, min_freq=min_freq) return word_vocab -def vocab(ordered_dict, min_freq=1, unk_token=''): +def vocab(ordered_dict, min_freq=1): r"""Factory method for creating a vocab object which maps tokens to indices. Note that the ordering in which key value pairs were inserted in the `ordered_dict` will be respected when building the vocab. Therefore if sorting by token frequency is important to the user, the `ordered_dict` should be created in a way to reflect this. - Additionally, the if the `unk_token` isn't found inside of the `ordered_dict`, it will be added to the end of the vocab. Args: ordered_dict (collections.OrderedDict): object holding the frequencies of each token found in the data. min_freq: The minimum frequency needed to include a token in the vocabulary. Values less than 1 will be set to 1. Default: 1. - unk_token: The default unknown token to use. Default: ''. - - Raises: - ValueError: if a default `unk_token` isn't provided. Examples: >>> from torchtext.experimental.vocab import vocab @@ -124,23 +130,17 @@ def vocab(ordered_dict, min_freq=1, unk_token=''): >>> sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[1], reverse=True) >>> ordered_dict = OrderedDict(sorted_by_freq_tuples) >>> v1 = vocab(ordered_dict) + >>> v1.insert_token('', 0) + >>> v1.set_default_index(0) + >>> v1.get_default_index() >>> tokens = ['e', 'd', 'c', 'b', 'a'] >>> v2 = vocab(OrderedDict([(token, 1) for token in tokens])) """ - - if not unk_token: - raise ValueError("A default unk token wasn't provided.") - tokens = [] for token, freq in ordered_dict.items(): if freq >= min_freq: tokens.append(token) - - if unk_token not in tokens: - tokens.insert(0, unk_token) - warnings.warn("The `unk_token` '{}' wasn't found in the `ordered_dict`. Adding the `unk_token` " - "to the beginning of the Vocab.".format(unk_token), RuntimeWarning) - return Vocab(VocabPybind(tokens, unk_token)) + return Vocab(VocabPybind(tokens)) class Vocab(nn.Module): @@ -189,6 +189,32 @@ def __getitem__(self, token: str) -> int: """ return self.vocab[token] + @torch.jit.export + def __setitem__(self, token: str, index: int) -> None: + r"""Set token to a specific index. The original token assigned to index is + replaced by the new token. + + Args: + token (str): the token used to lookup the corresponding index. + index (int): the index corresponding to the associated token. + + Raises: + RuntimeError: if `index` not between [0, Vocab.size()] + """ + self.vocab[token] = index + + @torch.jit.export + def __delitem__(self, token: str) -> None: + r"""Delete token from vocab and shift all the following tokens to left by 1. + + Args: + token (str): the token to be deleted. + + Raises: + RuntimeError: if `token` is not in the vocab. + """ + del self.vocab[token] + @torch.jit.export def insert_token(self, token: str, index: int) -> None: r""" @@ -201,6 +227,24 @@ def insert_token(self, token: str, index: int) -> None: """ self.vocab.insert_token(token, index) + @torch.jit.export + def set_default_index(self, index: int) -> None: + r""" + Args: + index (int): the unknown index. + + """ + self.vocab.set_default_index(index) + + @torch.jit.export + def get_default_index(self) -> int: + r""" + return: + index (int): the unknown index. + + """ + return self.vocab.get_default_index() + @torch.jit.export def append_token(self, token: str) -> None: r""" @@ -267,5 +311,9 @@ def get_itos(self) -> List[str]: def __prepare_scriptable__(self): r"""Return a JITable Vocab. """ - cpp_vocab = torch.classes.torchtext.Vocab(self.vocab.itos_, self.vocab.unk_token_) - return Vocab(cpp_vocab) + cpp_vocab = torch.classes.torchtext.Vocab(self.vocab.itos_) + try: + cpp_vocab.set_default_index(self.vocab.get_default_index()) + return Vocab(cpp_vocab) + except RuntimeError: + return Vocab(cpp_vocab)