diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index d091d7796b..afb140c71e 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -1,9 +1,7 @@ # -*- coding: utf-8 -*- from collections import OrderedDict import os -import platform import torch -import unittest from test.common.torchtext_test_case import TorchtextTestCase from torchtext.experimental.vocab import ( vocab, @@ -20,18 +18,12 @@ def tearDown(self): def test_has_unk(self): c = OrderedDict() v = vocab(c) - - # check if unk is mapped to the first index - self.assertEqual(v['not_in_it'], 0) self.assertEqual(v[''], 0) def test_new_unk(self): c = OrderedDict() v = vocab(c, unk_token="") - - # check if new_unk is mapped to the first index self.assertEqual(v[''], 0) - self.assertEqual(v['not_in_it'], 0) def test_vocab_membership(self): token_to_freq = {'': 2, 'a': 2, 'b': 2} @@ -54,6 +46,50 @@ def test_vocab_get_item(self): self.assertEqual(v['a'], 1) self.assertEqual(v['b'], 2) + def test_reassign_token(self): + token_to_freq = {'': 1, 'a': 2, 'b': 2} + sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) + c = OrderedDict(sorted_by_freq_tuples) + v = vocab(c, min_freq=1) + + self.assertEqual(v[''], 2) + self.assertEqual(v['a'], 0) + self.assertEqual(v['b'], 1) + v.reassign_token('', 0) + self.assertEqual(v[''], 0) + self.assertEqual(v['a'], 1) + self.assertEqual(v['b'], 2) + + self.assertEqual(v.get_itos(), ['', 'a', 'b']) + + with self.assertRaises(RuntimeError): + v.reassign_token('not in vocab', 0) + + with self.assertRaises(RuntimeError): + v.reassign_token('', 3) + + def test_default_index(self): + token_to_freq = {'': 2, 'a': 2, 'b': 2} + sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) + c = OrderedDict(sorted_by_freq_tuples) + v = vocab(c, min_freq=2) + + self.assertTrue(v.get_default_index() is None) + with self.assertRaises(RuntimeError): + v['not in vocab'] + + v.set_default_index(0) + self.assertEqual(v['not in vocab'], 0) + + def test_default_index_jit(self): + token_to_freq = {'': 2, 'a': 2, 'b': 2} + sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) + c = OrderedDict(sorted_by_freq_tuples) + v = vocab(c, min_freq=2) + v.set_default_index(0) + v_jit = torch.jit.script(v) + self.assertEqual(v_jit['not in vocab'], 0) + def test_vocab_insert_token(self): c = OrderedDict({'': 2, 'a': 2}) @@ -88,6 +124,10 @@ def test_vocab_append_token(self): self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(v.get_stoi()), expected_stoi) + # token must not exist to be appended + with self.assertRaises(RuntimeError): + v.append_token('b') + def test_vocab_len(self): token_to_freq = {'a': 2, 'b': 2, 'c': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) @@ -149,6 +189,8 @@ def test_vocab_lookup_token(self): v = vocab(c) self.assertEqual(v.lookup_token(1), 'a') + with self.assertRaises(RuntimeError): + v.lookup_token(100) def test_vocab_lookup_tokens(self): token_to_freq = {'a': 2, 'b': 2, 'c': 2} @@ -172,24 +214,6 @@ def test_vocab_lookup_indices(self): self.assertEqual(v.lookup_indices(tokens), expected_indices) - # we separate out these errors because Windows runs into seg faults when propagating - # exceptions from C++ using pybind11 - @unittest.skipIf(platform.system() == "Windows", "Test is known to fail on Windows.") - def test_errors_vocab_cpp(self): - token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2} - sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) - c = OrderedDict(sorted_by_freq_tuples) - - with self.assertRaises(RuntimeError): - # Test proper error raised when setting a token out of bounds - v = vocab(c, min_freq=3) - v.insert_token('new_token', 100) - - with self.assertRaises(RuntimeError): - # Test proper error raised when looking up a token out of bounds - v = vocab(c) - v.lookup_token(100) - def test_errors_vocab_python(self): token_to_freq = {'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2} sorted_by_freq_tuples = sorted(token_to_freq.items(), key=lambda x: x[1], reverse=True) @@ -205,6 +229,7 @@ def test_vocab_load_and_save(self): c = OrderedDict(sorted_by_freq_tuples) v = vocab(c, min_freq=3) + v.set_default_index(0) expected_itos = ['', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] expected_stoi = {x: index for index, x in enumerate(expected_itos)} @@ -218,6 +243,7 @@ def test_vocab_load_and_save(self): loaded_v = torch.load(vocab_path) self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi) + self.assertEqual(v['not in vocab'], 0) with self.subTest('torchscript'): vocab_path = os.path.join(self.test_dir, 'vocab_torchscript.pt') @@ -227,6 +253,7 @@ def test_vocab_load_and_save(self): loaded_v = torch.load(vocab_path) self.assertEqual(v.get_itos(), expected_itos) self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi) + self.assertEqual(v['not in vocab'], 0) def test_build_vocab_iterator(self): iterator = [['hello', 'hello', 'hello', 'freq_low', 'hello', 'world', 'world', 'world', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index 0b325bcda6..cf6656d12a 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -15,12 +15,10 @@ namespace py = pybind11; namespace { Vocab build_vocab_from_text_file(const std::string &file_path, - const std::string &unk_token, const int64_t min_freq, const int64_t num_cpus, py::object fn) { torch::jit::script::Module module(*torch::jit::as_module(fn)); - return _build_vocab_from_text_file(file_path, unk_token, min_freq, num_cpus, - module); + return _build_vocab_from_text_file(file_path, min_freq, num_cpus, module); } } // namespace @@ -104,23 +102,27 @@ PYBIND11_MODULE(_torchtext, m) { })); py::class_>(m, "Vocab") - .def(py::init, std::string>()) + .def(py::init>()) .def_readonly("itos_", &Vocab::itos_) - .def_readonly("unk_token_", &Vocab::unk_token_) - .def("__contains__", - [](c10::intrusive_ptr &self, const py::str &item) -> bool { - ssize_t length; - const char *buffer = PyUnicode_AsUTF8AndSize(item.ptr(), &length); - return self->__contains__(c10::string_view{buffer, (size_t)length}); - }) + .def_readonly("default_index_", &Vocab::default_index_) + .def( + "__contains__", + [](c10::intrusive_ptr &self, const py::str &item) -> bool { + ssize_t length; + const char *buffer = PyUnicode_AsUTF8AndSize(item.ptr(), &length); + return self->__contains__(c10::string_view{buffer, (size_t)length}); + }) .def("__getitem__", [](c10::intrusive_ptr &self, const py::str &item) -> int64_t { ssize_t length; const char *buffer = PyUnicode_AsUTF8AndSize(item.ptr(), &length); return self->__getitem__(c10::string_view{buffer, (size_t)length}); }) - .def("__len__", &Vocab::__len__) + .def("reassign_token", &Vocab::reassign_token) .def("insert_token", &Vocab::insert_token) + .def("set_default_index", &Vocab::set_default_index) + .def("get_default_index", &Vocab::get_default_index) + .def("__len__", &Vocab::__len__) .def("append_token", &Vocab::append_token) .def("lookup_token", &Vocab::lookup_token) .def("lookup_tokens", &Vocab::lookup_tokens) @@ -234,15 +236,18 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) { }); m.class_("Vocab") - .def(torch::init()) + .def(torch::init>()) .def("__contains__", [](const c10::intrusive_ptr &self, const std::string &item) -> bool { return self->__contains__(c10::string_view{item}); }) .def("__getitem__", [](const c10::intrusive_ptr &self, const std::string &item) -> int64_t { return self->__getitem__(c10::string_view{item}); }) - .def("__len__", &Vocab::__len__) + .def("reassign_token", &Vocab::reassign_token) .def("insert_token", &Vocab::insert_token) + .def("__len__", &Vocab::__len__) + .def("set_default_index", &Vocab::set_default_index) + .def("get_default_index", &Vocab::get_default_index) .def("append_token", &Vocab::append_token) .def("lookup_token", &Vocab::lookup_token) .def("lookup_tokens", &Vocab::lookup_tokens) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 1831d46f39..e659fbdb70 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -7,27 +7,21 @@ #include // @manual namespace torchtext { -Vocab::Vocab(const StringList &tokens, const std::string &unk_token) - : stoi_(MAX_VOCAB_SIZE, -1), unk_token_(std::move(unk_token)) { - for (std::size_t i = 0; i < tokens.size(); i++) { - // tokens should not have any duplicates - auto token_position = - _find(c10::string_view{tokens[i].data(), tokens[i].size()}); - if (stoi_[token_position] != -1) { -#ifdef _MSC_VER - std::cerr << "[RuntimeError] Duplicate token found in tokens list: " - << tokens[i] << std::endl; -#endif - throw std::runtime_error("Duplicate token found in tokens list: " + - tokens[i]); - } +Vocab::Vocab(const StringList &tokens, + const c10::optional &default_index) + : stoi_(MAX_VOCAB_SIZE, -1), default_index_{default_index} { + for (size_t i = 0; i < tokens.size(); i++) { + // throw error if duplicate token is found + auto id = _find(c10::string_view{tokens[i].data(), tokens[i].size()}); + TORCH_CHECK(stoi_[id] == -1, + "Duplicate token found in tokens list: " + tokens[i]); + _add(tokens[i]); } - - unk_index_ = - stoi_[_find(c10::string_view{unk_token.data(), unk_token.size()})]; } +Vocab::Vocab(const StringList &tokens) : Vocab(tokens, {}) {} + int64_t Vocab::__len__() const { return itos_.size(); } bool Vocab::__contains__(const c10::string_view &token) const { @@ -38,77 +32,92 @@ bool Vocab::__contains__(const c10::string_view &token) const { return false; } - int64_t Vocab::__getitem__(const c10::string_view &token) const { int64_t id = _find(token); - if (stoi_[id] != -1) { + if (stoi_[id] != -1) return stoi_[id]; - } - return unk_index_; + + // throw error if default_index_ is not set + TORCH_CHECK(default_index_.has_value(), + "Token " + std::string(token) + + " not found and default index is not set"); + + // return default index if token is OOV + return default_index_.value(); } -void Vocab::append_token(const std::string &token) { _add(token); } +void Vocab::set_default_index(int64_t index) { default_index_ = index; } + +c10::optional Vocab::get_default_index() const { + return default_index_; +} + +void Vocab::append_token(const std::string &token) { + // throw error if token already exist in vocab + auto id = _find(c10::string_view{token.data(), token.size()}); + TORCH_CHECK(stoi_[id] == -1, "Token " + token + + " already exists in the Vocab with index: " + + std::to_string(stoi_[id])); + + _add(token); +} + +void Vocab::reassign_token(const std::string &token, const int64_t &index) { + // throw error if index is not valid + TORCH_CHECK(index >= 0 && index < __len__(), + "Specified index " + std::to_string(index) + + " is out of bounds for vocab of size " + + std::to_string(__len__())); + + // throw error if token not found + TORCH_CHECK(__contains__(token), "Token " + token + " not found in Vocab"); + + _remove(token); + insert_token(token, index); +} void Vocab::insert_token(const std::string &token, const int64_t &index) { - if (index < 0 || index > itos_.size()) { -#ifdef _MSC_VER - std::cerr << "[RuntimeError] Specified index " << index - << " is out of bounds of the size of stoi dictionary: " - << stoi_.size() << std::endl; -#endif - throw std::runtime_error( - "Specified index " + std::to_string(index) + - " is out of bounds of the size of stoi dictionary: " + - std::to_string(stoi_.size()) + "."); - } + // throw error if index is not valid + TORCH_CHECK(index >= 0 && index <= __len__(), + "Specified index " + std::to_string(index) + + " is out of bounds for vocab of size " + + std::to_string(__len__())); - // if item already in stoi we throw an error - auto token_position = _find(c10::string_view{token.data(), token.size()}); - if (stoi_[token_position] != -1) { -#ifdef _MSC_VER - std::cerr << "[RuntimeError] Token " << token - << " already exists in the Vocab with index: " - << stoi_[token_position] << std::endl; -#endif - throw std::runtime_error("Token " + token + - " already exists in the Vocab with index: " + - std::to_string(stoi_[token_position]) + "."); - } + // throw error if token already present + TORCH_CHECK(!__contains__(token), "Token " + token + " not found in Vocab"); // need to offset all tokens greater than or equal index by 1 - for (size_t i = index; i < itos_.size(); i++) { + for (size_t i = index; i < __len__(); i++) { stoi_[_find(c10::string_view{itos_[i].data(), itos_[i].size()})] = i + 1; } itos_.insert(itos_.begin() + index, token); stoi_[_find(c10::string_view{token.data(), token.size()})] = index; - - // need to update unk_index in case token equals unk_token or token - // inserted before unk_token - unk_index_ = - stoi_[_find(c10::string_view{unk_token_.data(), unk_token_.size()})]; } std::string Vocab::lookup_token(const int64_t &index) { - if (index < 0 || index > static_cast(itos_.size())) { -#ifdef _MSC_VER - std::cerr << "[RuntimeError] Specified index " << index - << " is out of bounds of the size of itos dictionary: " - << itos_.size() << std::endl; -#endif - throw std::runtime_error( - "Specified index " + std::to_string(index) + - " is out of bounds of the size of itos dictionary: " + - std::to_string(itos_.size()) + "."); - } + // throw error if index is not valid + TORCH_CHECK(index >= 0 && index < __len__(), + "Specified index " + std::to_string(index) + + " is out of bounds for vocab of size " + + std::to_string(__len__())); return itos_[index]; } StringList Vocab::lookup_tokens(const std::vector &indices) { + // throw error if indices are not valid + for (size_t i = 0; i < indices.size(); i++) { + TORCH_CHECK(indices[i] >= 0 && indices[i] < __len__(), + "Specified index " + std::to_string(indices[i]) + + " at position " + std::to_string(i) + + " is out of bounds for vocab of size " + + std::to_string(__len__())); + } + std::vector tokens(indices.size()); - for (int64_t i = 0; i < static_cast(indices.size()); i++) { - tokens[i] = lookup_token(indices[i]); + for (size_t i = 0; i < indices.size(); i++) { + tokens[i] = itos_[indices[i]]; } return tokens; } @@ -116,7 +125,7 @@ StringList Vocab::lookup_tokens(const std::vector &indices) { std::vector Vocab::lookup_indices(const std::vector &tokens) { std::vector indices(tokens.size()); - for (int64_t i = 0; i < static_cast(tokens.size()); i++) { + for (size_t i = 0; i < tokens.size(); i++) { indices[i] = __getitem__(tokens[i]); } return indices; @@ -148,9 +157,7 @@ void parse_vocab_file_chunk(const std::string &file_path, size_t offset, const int64_t start_line, const int64_t end_line, std::shared_ptr counter) { std::ifstream fin(file_path, std::ios::in); - if (!fin.is_open()) { - throw std::runtime_error("Cannot open input file " + file_path + "\n"); - } + TORCH_CHECK(fin.is_open(), "Cannot open input file " + file_path); fin.seekg(offset); @@ -172,9 +179,7 @@ void parse_raw_text_file_chunk(const std::string &file_path, size_t offset, std::shared_ptr counter, torch::jit::script::Module &module) { std::ifstream fin(file_path, std::ios::in); - if (!fin.is_open()) { - throw std::runtime_error("Cannot open input file " + file_path + "\n"); - } + TORCH_CHECK(fin.is_open(), "Cannot open input file " + file_path); fin.seekg(offset); @@ -210,8 +215,9 @@ struct CompareTokens { StringList _concat_tokens(std::vector> chunk_counters, - const std::string &unk_token, const int64_t min_freq, - const int64_t num_lines, const bool sort_tokens) { + const int64_t min_freq, const int64_t num_lines, + const bool sort_tokens) { + TORCH_CHECK(chunk_counters.size() > 0, "There must be at least 1 chunk to concatenate!"); @@ -257,24 +263,12 @@ _concat_tokens(std::vector> chunk_counters, unique_tokens.push_back(token_freq_pair.first); } - // insert unk_token if not present - if (tokens_freq.find(unk_token) == tokens_freq.end()) { - std::cerr << "The `unk_token` " << unk_token - << " wasn't found in the `ordered_dict`. Adding the `unk_token` " - "to the beginning of the Vocab." - << std::endl; - - unique_tokens.insert(unique_tokens.begin(), unk_token); - } - return unique_tokens; } constexpr int64_t GRAIN_SIZE = 13107; Vocab _load_vocab_from_file(const std::string &file_path, - const std::string &unk_token, const int64_t min_freq, const int64_t num_cpus) { - std::cerr << "[INFO] Reading file " << file_path << std::endl; int64_t num_lines = _infer_lines(file_path); int64_t chunk_size = impl::divup(num_lines, num_cpus); @@ -313,17 +307,15 @@ Vocab _load_vocab_from_file(const std::string &file_path, cv.wait(lock, [&thread_count] { return thread_count == 0; }); StringList tokens = - _concat_tokens(chunk_counters, unk_token, min_freq, num_lines, false); + _concat_tokens(chunk_counters, min_freq, num_lines, false); - return Vocab(std::move(tokens), unk_token); + return Vocab(std::move(tokens)); } Vocab _build_vocab_from_text_file(const std::string &file_path, - const std::string &unk_token, const int64_t min_freq, const int64_t num_cpus, torch::jit::script::Module tokenizer) { - std::cerr << "[INFO] Reading file " << file_path << std::endl; int64_t num_lines = _infer_lines(file_path); int64_t chunk_size = impl::divup(num_lines, num_cpus); // Launching a thread on less lines than this likely has too much overhead. @@ -359,18 +351,20 @@ Vocab _build_vocab_from_text_file(const std::string &file_path, std::unique_lock lock(m); cv.wait(lock, [&thread_count] { return thread_count == 0; }); - StringList tokens = - _concat_tokens(chunk_counters, unk_token, min_freq, num_lines, true); + StringList tokens = _concat_tokens(chunk_counters, min_freq, num_lines, true); - return Vocab(std::move(tokens), unk_token); + return Vocab(std::move(tokens)); } VocabStates _serialize_vocab(const c10::intrusive_ptr &self) { std::vector integers; StringList strings = self->itos_; - strings.push_back(self->unk_token_); std::vector tensors; + if (self->default_index_.has_value()) { + integers.push_back(self->default_index_.value()); + } + VocabStates states = std::make_tuple(self->version_str_, std::move(integers), std::move(strings), std::move(tensors)); return states; @@ -378,45 +372,27 @@ VocabStates _serialize_vocab(const c10::intrusive_ptr &self) { c10::intrusive_ptr _deserialize_vocab(VocabStates states) { auto state_size = std::tuple_size::value; - if (state_size != 4) { -#ifdef _MSC_VER - std::cerr << "[RuntimeError] Expected deserialized Vocab to have 4 states " - "but found " - << state_size << " states." << std::endl; -#endif - throw std::runtime_error( - "Expected deserialized Vocab to have 4 states but found " + - std::to_string(state_size) + " states."); - } + TORCH_CHECK(state_size == 4, + "Expected deserialized Vocab to have 4 states but found " + + std::to_string(state_size) + " states"); auto &version_str = std::get<0>(states); auto &integers = std::get<1>(states); auto &strings = std::get<2>(states); auto &tensors = std::get<3>(states); - // check integers and tensors are empty - if (integers.size() != 0 || tensors.size() != 0) { -#ifdef _MSC_VER - std::cerr << "[RuntimeError] Expected `integers` and `tensors` states to " - "be empty." - << std::endl; -#endif - throw std::runtime_error( - "Expected `integers` and `tensors` states to be empty."); - } + // check tensors are empty + TORCH_CHECK(tensors.size() == 0, "Expected `tensors` states to be empty"); - if (version_str.compare("0.0.1") >= 0) { - std::string unk_token = strings.back(); - strings.pop_back(); // remove last element which is unk_token + // throw error if version is not compatible + TORCH_CHECK(version_str.compare("0.0.2") >= 0, + "Found unexpected version for serialized Vocab: " + version_str); - return c10::make_intrusive(std::move(strings), std::move(unk_token)); + c10::optional default_index = {}; + if (integers.size() > 0) { + default_index = integers[0]; } -#ifdef _MSC_VER - std::cerr << "[RuntimeError] Found unexpected version for serialized Vocab: " - << version_str << std::endl; -#endif - throw std::runtime_error( - "Found unexpected version for serialized Vocab: " + version_str + "."); + return c10::make_intrusive(std::move(strings), default_index); } } // namespace torchtext diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index d915c7de27..06f98865d3 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -1,3 +1,4 @@ +#include #include #include namespace torchtext { @@ -13,17 +14,24 @@ struct Vocab : torch::CustomClassHolder { static const int32_t MAX_VOCAB_SIZE = 30000000; int64_t unk_index_; std::vector stoi_; - const std::string version_str_ = "0.0.1"; + const std::string version_str_ = "0.0.2"; StringList itos_; - std::string unk_token_; + c10::optional default_index_ = {}; - explicit Vocab(const std::vector &tokens, - const std::string &unk_token); + // TODO: [can we remove this?] we need to keep this constructor, otherwise + // torch binding gets compilation error: no matching constructor for + // initialization of 'torchtext::Vocab' + explicit Vocab(const StringList &tokens); + explicit Vocab(const StringList &tokens, + const c10::optional &default_index); int64_t __len__() const; int64_t __getitem__(const c10::string_view &token) const; bool __contains__(const c10::string_view &token) const; - void append_token(const std::string &token); + void set_default_index(int64_t index); + c10::optional get_default_index() const; + void reassign_token(const std::string &token, const int64_t &index); void insert_token(const std::string &token, const int64_t &index); + void append_token(const std::string &token); std::string lookup_token(const int64_t &index); std::vector lookup_tokens(const std::vector &indices); std::vector @@ -44,7 +52,7 @@ struct Vocab : torch::CustomClassHolder { uint32_t _find(const c10::string_view &w) const { uint32_t stoi_size = stoi_.size(); uint32_t id = _hash(w) % stoi_size; - while (stoi_[id] != -1 && itos_[stoi_[id]]!= w) { + while (stoi_[id] != -1 && itos_[stoi_[id]] != w) { id = (id + 1) % stoi_size; } return id; @@ -57,16 +65,22 @@ struct Vocab : torch::CustomClassHolder { stoi_[h] = itos_.size() - 1; } } + + void _remove(const std::string &w) { + uint32_t h = _find(c10::string_view{w.data(), w.size()}); + if (stoi_[h] != -1) { + stoi_[h] = -1; + itos_.erase(std::find(itos_.begin(), itos_.end(), w)); + } + } }; VocabStates _serialize_vocab(const c10::intrusive_ptr &self); c10::intrusive_ptr _deserialize_vocab(VocabStates states); Vocab _load_vocab_from_file(const std::string &file_path, - const std::string &unk_token, const int64_t min_freq, const int64_t num_cpus); Vocab _build_vocab_from_text_file(const std::string &file_path, - const std::string &unk_token, const int64_t min_freq, const int64_t num_cpus, torch::jit::script::Module tokenizer); diff --git a/torchtext/experimental/vocab.py b/torchtext/experimental/vocab.py index 26f393ce36..a1aa2290d9 100644 --- a/torchtext/experimental/vocab.py +++ b/torchtext/experimental/vocab.py @@ -1,6 +1,5 @@ import logging -from typing import Dict, List -import warnings +from typing import Dict, List, Optional from collections import Counter, OrderedDict import torch import torch.nn as nn @@ -30,7 +29,7 @@ def build_vocab_from_text_file(file_path, jited_tokenizer, min_freq=1, unk_token jited_tokenizer (ScriptModule): a tokenizer that has been JITed using `torch.jit.script` min_freq: The minimum frequency needed to include a token in the vocabulary. Values less than 1 will be set to 1. Default: 1. - unk_token: The default unknown token to use. Default: ''. + unk_token: The default unknown token to use. Default: ''. If not found in text file, it will be inserted to index 0. num_cpus (int): the number of cpus to use when loading the vectors from file. Default: 4. Returns: @@ -44,7 +43,9 @@ def build_vocab_from_text_file(file_path, jited_tokenizer, min_freq=1, unk_token >>> jit_tokenizer = torch.jit.script(tokenizer) >>> v = build_vocab_from_text_file('vocab.txt', jit_tokenizer) """ - vocab_obj = _build_vocab_from_text_file(file_path, unk_token, min_freq, num_cpus, jited_tokenizer) + vocab_obj = _build_vocab_from_text_file(file_path, min_freq, num_cpus, jited_tokenizer) + if unk_token not in vocab_obj: + vocab_obj.insert_token(unk_token, 0) return Vocab(vocab_obj) @@ -62,7 +63,7 @@ def load_vocab_from_file(file_path, min_freq=1, unk_token='', num_cpus=4): file_object (FileObject): a file like object to read data from. min_freq: The minimum frequency needed to include a token in the vocabulary. Values less than 1 will be set to 1. Default: 1. - unk_token: The default unknown token to use. Default: ''. + unk_token: The default unknown token to use. Default: ''. If not found in vocab file, it will be inserted to index 0. num_cpus (int): the number of cpus to use when loading the vectors from file. Default: 4. Returns: @@ -73,7 +74,9 @@ def load_vocab_from_file(file_path, min_freq=1, unk_token='', num_cpus=4): >>> v = load_vocab_from_file('vocab.txt') """ - vocab_obj = _load_vocab_from_file(file_path, unk_token, min_freq, num_cpus) + vocab_obj = _load_vocab_from_file(file_path, min_freq, num_cpus) + if unk_token not in vocab_obj: + vocab_obj.insert_token(unk_token, 0) return Vocab(vocab_obj) @@ -108,7 +111,7 @@ def vocab(ordered_dict, min_freq=1, unk_token=''): ordered_dict (collections.OrderedDict): object holding the frequencies of each token found in the data. min_freq: The minimum frequency needed to include a token in the vocabulary. Values less than 1 will be set to 1. Default: 1. - unk_token: The default unknown token to use. Default: ''. + unk_token: The default unknown token to use. Default: ''. If not found in ordered_dict, it will be inserted at index 0. Raises: ValueError: if a default `unk_token` isn't provided. @@ -134,9 +137,7 @@ def vocab(ordered_dict, min_freq=1, unk_token=''): if unk_token not in tokens: tokens.insert(0, unk_token) - warnings.warn("The `unk_token` '{}' wasn't found in the `ordered_dict`. Adding the `unk_token` " - "to the beginning of the Vocab.".format(unk_token), RuntimeWarning) - return Vocab(VocabPybind(tokens, unk_token)) + return Vocab(VocabPybind(tokens, None)) class Vocab(nn.Module): @@ -198,12 +199,38 @@ def __getitem__(self, token: str) -> int: return self.vocab[token] @torch.jit.export - def insert_token(self, token: str, index: int) -> None: + def set_default_index(self, index: int) -> None: + r""" + Args: + index: Value of default index. This index will be returned when OOV token is queried + """ + self.vocab.set_default_index(index) + + @torch.jit.export + def get_default_index(self) -> Optional[int]: + r""" + Returns: + index (optional[int]): Value of default index if it is set. + """ + return self.vocab.get_default_index() + + @torch.jit.export + def reassign_token(self, token: str, index: int) -> None: r""" Args: token (str): the token used to lookup the corresponding index. index (int): the index corresponding to the associated token. + Raises: + RuntimeError: If `index` is not range [0,Vocab.size()) or if token is not present in Vocab + """ + self.vocab.reassign_token(token, index) + @torch.jit.export + def insert_token(self, token: str, index: int) -> None: + r""" + Args: + token (str): the token used to lookup the corresponding index. + index (int): the index corresponding to the associated token. Raises: RuntimeError: if `index` not between [0, Vocab.size()] or if token already exists in the vocab. """ @@ -214,6 +241,9 @@ def append_token(self, token: str) -> None: r""" Args: token (str): the token used to lookup the corresponding index. + + Raises: + RuntimeError: if token already exists in the vocab """ self.vocab.append_token(token) @@ -275,5 +305,7 @@ def get_itos(self) -> List[str]: def __prepare_scriptable__(self): r"""Return a JITable Vocab. """ - cpp_vocab = torch.classes.torchtext.Vocab(self.vocab.itos_, self.vocab.unk_token_) - return Vocab(cpp_vocab) + if not self.is_jitable: + cpp_vocab = torch.classes.torchtext.Vocab(self.vocab.itos_, self.vocab.default_index_) + return Vocab(cpp_vocab) + return self