From 8cbcb3cea6d7db1b744d90e70303c2a48304dd0a Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Mon, 8 Mar 2021 10:30:07 -0500 Subject: [PATCH 01/15] error element_type --- torchtext/csrc/vocab.cpp | 4 ++-- torchtext/csrc/vocab.h | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 0e324dbbf5..8716ae571d 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -32,8 +32,8 @@ Vocab::Vocab(const StringList &tokens, const std::string &unk_token) int64_t Vocab::__len__() const { return stoi_.size(); } -int64_t Vocab::__getitem__(const std::string &token) const { - const auto &item = stoi_.find(token); +int64_t Vocab::__getitem__(c10::string_view token) const { + const auto &item = stoi_.find(std::string{token}); if (item != stoi_.end()) { return item->second; } diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index 0da660a633..f9566f282e 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -1,5 +1,5 @@ #include - +#include namespace torchtext { typedef std::vector StringList; @@ -25,7 +25,7 @@ struct Vocab : torch::CustomClassHolder { const std::string &unk_token, const int64_t unk_index); int64_t __len__() const; - int64_t __getitem__(const std::string &token) const; + int64_t __getitem__(c10::string_view token) const; void append_token(const std::string &token); void insert_token(const std::string &token, const int64_t &index); std::string lookup_token(const int64_t &index); From ab77691a1836baf98c0e09b43a34a87f623d3975 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Tue, 9 Mar 2021 18:12:54 -0500 Subject: [PATCH 02/15] added py::str as argument --- torchtext/csrc/register_bindings.cpp | 3 ++- torchtext/csrc/vocab.cpp | 10 +++++++--- torchtext/csrc/vocab.h | 2 +- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index 4c3ef76399..81f8d66f4f 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -203,7 +203,8 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) { m.class_("Vocab") .def(torch::init()) - .def("__getitem__", &Vocab::__getitem__) + .def("__getitem__", [](const c10::intrusive_ptr &self, std::string item) -> int64_t { + return self->__getitem__(py::str{item});}) .def("__len__", &Vocab::__len__) .def("insert_token", &Vocab::insert_token) .def("append_token", &Vocab::append_token) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 8716ae571d..4feea1aa32 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -32,8 +32,12 @@ Vocab::Vocab(const StringList &tokens, const std::string &unk_token) int64_t Vocab::__len__() const { return stoi_.size(); } -int64_t Vocab::__getitem__(c10::string_view token) const { - const auto &item = stoi_.find(std::string{token}); +int64_t Vocab::__getitem__(const py::str &token) const { + py::bytes temp = py::reinterpret_borrow(PyUnicode_AsUTF8String(token.ptr())); + char *buffer; + ssize_t length; + PyBytes_AsStringAndSize(temp.ptr(),&buffer,&length); + const auto &item = stoi_.find(std::string{buffer, (size_t)length}); if (item != stoi_.end()) { return item->second; } @@ -118,7 +122,7 @@ StringList Vocab::lookup_tokens(const std::vector &indices) { std::vector Vocab::lookup_indices(const StringList &tokens) { std::vector indices(tokens.size()); for (int64_t i = 0; i < static_cast(tokens.size()); i++) { - indices[i] = __getitem__(tokens[i]); + indices[i] = __getitem__(py::str{tokens[i]}); } return indices; } diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index f9566f282e..e6c1782355 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -25,7 +25,7 @@ struct Vocab : torch::CustomClassHolder { const std::string &unk_token, const int64_t unk_index); int64_t __len__() const; - int64_t __getitem__(c10::string_view token) const; + int64_t __getitem__(const py::str &token) const; void append_token(const std::string &token); void insert_token(const std::string &token, const int64_t &index); std::string lookup_token(const int64_t &index); From a1c651af94b74ce38cbdc6653ff7b143c61fdecb Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Tue, 9 Mar 2021 21:37:24 -0500 Subject: [PATCH 03/15] using py::str as input arguments to avoid copying memory --- torchtext/csrc/register_bindings.cpp | 9 +++++++-- torchtext/csrc/vocab.cpp | 12 ++++++------ torchtext/csrc/vocab.h | 4 ++-- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index 81f8d66f4f..584faeea89 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -8,7 +8,7 @@ #include #include // @manual #include // @manual - +#include namespace torchtext { namespace py = pybind11; @@ -210,7 +210,12 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) { .def("append_token", &Vocab::append_token) .def("lookup_token", &Vocab::lookup_token) .def("lookup_tokens", &Vocab::lookup_tokens) - .def("lookup_indices", &Vocab::lookup_indices) + .def("lookup_indices",[](const c10::intrusive_ptr &self,const std::vector &item) -> std::vector { + std::vector temp(item.size()); + for(size_t i=0;ilookup_indices(temp);}) .def("get_stoi", &Vocab::get_stoi) .def("get_itos", &Vocab::get_itos) .def_pickle( diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 4feea1aa32..b5d59b4bc0 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -25,7 +25,7 @@ Vocab::Vocab(const StringList &tokens, const std::string &unk_token) throw std::runtime_error("Duplicate token found in tokens list: " + tokens[i]); } - stoi_[std::move(tokens[i])] = i; + stoi_[c10::string_view{tokens[i].data(),tokens[i].size()}] = i; } unk_index_ = stoi_.find(unk_token)->second; } @@ -37,7 +37,7 @@ int64_t Vocab::__getitem__(const py::str &token) const { char *buffer; ssize_t length; PyBytes_AsStringAndSize(temp.ptr(),&buffer,&length); - const auto &item = stoi_.find(std::string{buffer, (size_t)length}); + const auto &item = stoi_.find(c10::string_view{buffer, (size_t)length}); if (item != stoi_.end()) { return item->second; } @@ -119,10 +119,10 @@ StringList Vocab::lookup_tokens(const std::vector &indices) { return tokens; } -std::vector Vocab::lookup_indices(const StringList &tokens) { +std::vector Vocab::lookup_indices(const std::vector &tokens) { std::vector indices(tokens.size()); for (int64_t i = 0; i < static_cast(tokens.size()); i++) { - indices[i] = __getitem__(py::str{tokens[i]}); + indices[i] = __getitem__(tokens[i]); } return indices; } @@ -133,7 +133,7 @@ std::unordered_map Vocab::get_stoi() const { // construct tokens and index list for (const auto &item : stoi_) { - stoi[item.first] = item.second; + stoi[std::string{item.first}] = item.second; } return stoi; } @@ -234,7 +234,7 @@ _concat_tokens(std::vector> chunk_counters, // add to tokens list only if we exceed min_freq for the first time if (tokens_freq[item.first] - cur_token_freq < min_freq && tokens_freq[item.first] >= min_freq) { - unique_tokens.push_back(item.first); + unique_tokens.push_back(std::string{item.first}); } } } diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index e6c1782355..c9c36e14ea 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -3,7 +3,7 @@ namespace torchtext { typedef std::vector StringList; -typedef ska_ordered::order_preserving_flat_hash_map +typedef ska_ordered::order_preserving_flat_hash_map IndexDict; typedef std::tuple, std::vector, std::vector> @@ -30,7 +30,7 @@ struct Vocab : torch::CustomClassHolder { void insert_token(const std::string &token, const int64_t &index); std::string lookup_token(const int64_t &index); std::vector lookup_tokens(const std::vector &indices); - std::vector lookup_indices(const std::vector &tokens); + std::vector lookup_indices(const std::vector &tokens); std::unordered_map get_stoi() const; std::vector get_itos() const; }; From d635cf4d93081c3802b9bb33ce4aaa46eab7d4c0 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Tue, 9 Mar 2021 21:44:25 -0500 Subject: [PATCH 04/15] fixing append --- torchtext/csrc/vocab.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index b5d59b4bc0..925cfb0795 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -51,8 +51,8 @@ void Vocab::append_token(const std::string &token) { // For example if the size of `stoi_` is 2, doing // `stoi_["test"] = stoi_.size()` will set `stoi_["test"]` to a // value of 3 instead of 2 on Windows stoi_[token] = itos_.size(); - stoi_[token] = itos_.size(); itos_.push_back(token); + stoi_[c10::string_view{itos_.back().data(),itos_.back().size()}] = itos_.size()-1; } } From 1b52d953e037b4e42b2bd739bc8188c1ff169fa3 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Tue, 9 Mar 2021 21:49:21 -0500 Subject: [PATCH 05/15] chaging to py::rienterpret_steal --- torchtext/csrc/vocab.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 925cfb0795..c0d5ceb7ab 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -33,7 +33,7 @@ Vocab::Vocab(const StringList &tokens, const std::string &unk_token) int64_t Vocab::__len__() const { return stoi_.size(); } int64_t Vocab::__getitem__(const py::str &token) const { - py::bytes temp = py::reinterpret_borrow(PyUnicode_AsUTF8String(token.ptr())); + py::bytes temp = py::reinterpret_steal(PyUnicode_AsUTF8String(token.ptr())); char *buffer; ssize_t length; PyBytes_AsStringAndSize(temp.ptr(),&buffer,&length); From e56e087034e317383172c8be870c4542f0899044 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Wed, 10 Mar 2021 10:34:28 -0500 Subject: [PATCH 06/15] fixing string_view issue --- torchtext/csrc/vocab.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index c0d5ceb7ab..039d05ef48 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -25,7 +25,7 @@ Vocab::Vocab(const StringList &tokens, const std::string &unk_token) throw std::runtime_error("Duplicate token found in tokens list: " + tokens[i]); } - stoi_[c10::string_view{tokens[i].data(),tokens[i].size()}] = i; + stoi_[c10::string_view{itos_[i].data(),itos_[i].size()}] = i; } unk_index_ = stoi_.find(unk_token)->second; } @@ -33,10 +33,8 @@ Vocab::Vocab(const StringList &tokens, const std::string &unk_token) int64_t Vocab::__len__() const { return stoi_.size(); } int64_t Vocab::__getitem__(const py::str &token) const { - py::bytes temp = py::reinterpret_steal(PyUnicode_AsUTF8String(token.ptr())); - char *buffer; ssize_t length; - PyBytes_AsStringAndSize(temp.ptr(),&buffer,&length); + const char *buffer = PyUnicode_AsUTF8AndSize(token.ptr(),&length); const auto &item = stoi_.find(c10::string_view{buffer, (size_t)length}); if (item != stoi_.end()) { return item->second; From 19192044ff42e10fc2f284b93d73a9d1a7aeab84 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Wed, 10 Mar 2021 11:57:23 -0500 Subject: [PATCH 07/15] overloading functions for torchbind --- torchtext/csrc/register_bindings.cpp | 16 ++++++---------- torchtext/csrc/vocab.cpp | 27 ++++++++++++++++++++++++--- torchtext/csrc/vocab.h | 4 +++- 3 files changed, 33 insertions(+), 14 deletions(-) diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index 584faeea89..58484e5d08 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -103,13 +103,13 @@ PYBIND11_MODULE(_torchtext, m) { .def(py::init, std::string>()) .def_readonly("itos_", &Vocab::itos_) .def_readonly("unk_token_", &Vocab::unk_token_) - .def("__getitem__", &Vocab::__getitem__) + .def("__getitem__", py::overload_cast(&Vocab::__getitem__, py::const_)) .def("__len__", &Vocab::__len__) .def("insert_token", &Vocab::insert_token) .def("append_token", &Vocab::append_token) .def("lookup_token", &Vocab::lookup_token) .def("lookup_tokens", &Vocab::lookup_tokens) - .def("lookup_indices", &Vocab::lookup_indices) + .def("lookup_indices", py::overload_cast(&Vocab::lookup_indices)) .def("get_stoi", &Vocab::get_stoi) .def("get_itos", &Vocab::get_itos) .def(py::pickle( @@ -203,19 +203,15 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) { m.class_("Vocab") .def(torch::init()) - .def("__getitem__", [](const c10::intrusive_ptr &self, std::string item) -> int64_t { - return self->__getitem__(py::str{item});}) + .def("__getitem__", py::overload_cast(&Vocab::__getitem__, py::const_)) + // .def("__getitem__", [](const c10::intrusive_ptr &self, std::string item) -> int64_t { + // return self->__getitem__(py::str{item});}) .def("__len__", &Vocab::__len__) .def("insert_token", &Vocab::insert_token) .def("append_token", &Vocab::append_token) .def("lookup_token", &Vocab::lookup_token) .def("lookup_tokens", &Vocab::lookup_tokens) - .def("lookup_indices",[](const c10::intrusive_ptr &self,const std::vector &item) -> std::vector { - std::vector temp(item.size()); - for(size_t i=0;ilookup_indices(temp);}) + .def("lookup_indices", py::overload_cast &>(&Vocab::lookup_indices)) .def("get_stoi", &Vocab::get_stoi) .def("get_itos", &Vocab::get_itos) .def_pickle( diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 039d05ef48..34e6b59e66 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -42,6 +42,15 @@ int64_t Vocab::__getitem__(const py::str &token) const { return unk_index_; } +int64_t Vocab::__getitem__(const std::string &token) const { + const auto &item = stoi_.find(c10::string_view{token}); + if (item != stoi_.end()) { + return item->second; + } + return unk_index_; +} + + void Vocab::append_token(const std::string &token) { if (stoi_.find(token) == stoi_.end()) { // Note: we can't do `stoi_[token] = stoi_.size()` because of a bug @@ -117,7 +126,7 @@ StringList Vocab::lookup_tokens(const std::vector &indices) { return tokens; } -std::vector Vocab::lookup_indices(const std::vector &tokens) { +std::vector Vocab::lookup_indices(const py::list &tokens) { std::vector indices(tokens.size()); for (int64_t i = 0; i < static_cast(tokens.size()); i++) { indices[i] = __getitem__(tokens[i]); @@ -125,6 +134,15 @@ std::vector Vocab::lookup_indices(const std::vector &tokens) { return indices; } +std::vector Vocab::lookup_indices(const std::vector &tokens) { + std::vector indices(tokens.size()); + for (int64_t i = 0; i < static_cast(tokens.size()); i++) { + indices[i] = __getitem__(tokens[i]); + } + return indices; +} + + std::unordered_map Vocab::get_stoi() const { std::unordered_map stoi; stoi.reserve(stoi_.size()); @@ -152,8 +170,11 @@ int64_t _infer_lines(const std::string &file_path) { void parse_vocab_file_chunk(const std::string &file_path, size_t offset, const int64_t start_line, const int64_t end_line, std::shared_ptr counter) { - std::ifstream fin; - fin.open(file_path, std::ios::in); + std::ifstream fin(file_path, std::ios::in); + if(!fin.is_open()){ + throw std::runtime_error("Cannot open input file "+file_path+"\n"); + } + fin.seekg(offset); for (int64_t i = start_line; i < end_line; i++) { diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index c9c36e14ea..50842c273c 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -26,11 +26,13 @@ struct Vocab : torch::CustomClassHolder { const std::string &unk_token, const int64_t unk_index); int64_t __len__() const; int64_t __getitem__(const py::str &token) const; + int64_t __getitem__(const std::string &token) const; void append_token(const std::string &token); void insert_token(const std::string &token, const int64_t &index); std::string lookup_token(const int64_t &index); std::vector lookup_tokens(const std::vector &indices); - std::vector lookup_indices(const std::vector &tokens); + std::vector lookup_indices(const py::list &tokens); + std::vector lookup_indices(const std::vector &tokens); std::unordered_map get_stoi() const; std::vector get_itos() const; }; From 00e90d25f957dbf3ba8f028c4af4561845736942 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Thu, 11 Mar 2021 22:34:06 -0500 Subject: [PATCH 08/15] pushing py stuffs into lambda functions --- torchtext/csrc/register_bindings.cpp | 33 ++++++++++++++++++++++------ torchtext/csrc/vocab.cpp | 25 +++------------------ torchtext/csrc/vocab.h | 6 ++--- 3 files changed, 31 insertions(+), 33 deletions(-) diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index 58484e5d08..18376bb94b 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -103,13 +103,26 @@ PYBIND11_MODULE(_torchtext, m) { .def(py::init, std::string>()) .def_readonly("itos_", &Vocab::itos_) .def_readonly("unk_token_", &Vocab::unk_token_) - .def("__getitem__", py::overload_cast(&Vocab::__getitem__, py::const_)) + .def("__getitem__", [](c10::intrusive_ptr &self,const py::str &item){ + ssize_t length; + const char *buffer = PyUnicode_AsUTF8AndSize(item.ptr(),&length); + return self->__getitem__(c10::string_view{buffer,(size_t)length}); + }) .def("__len__", &Vocab::__len__) .def("insert_token", &Vocab::insert_token) .def("append_token", &Vocab::append_token) .def("lookup_token", &Vocab::lookup_token) .def("lookup_tokens", &Vocab::lookup_tokens) - .def("lookup_indices", py::overload_cast(&Vocab::lookup_indices)) + .def("lookup_indices", [](const c10::intrusive_ptr &self,const py::list &items){ + std::vector indices(items.size()); + int64_t counter = 0; + for(const auto &item: items){ + ssize_t length; + const char *buffer = PyUnicode_AsUTF8AndSize(item.ptr(),&length); + indices[counter++] = self->__getitem__(c10::string_view{buffer,(size_t)length}); + } + return indices; + }) .def("get_stoi", &Vocab::get_stoi) .def("get_itos", &Vocab::get_itos) .def(py::pickle( @@ -203,15 +216,21 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) { m.class_("Vocab") .def(torch::init()) - .def("__getitem__", py::overload_cast(&Vocab::__getitem__, py::const_)) - // .def("__getitem__", [](const c10::intrusive_ptr &self, std::string item) -> int64_t { - // return self->__getitem__(py::str{item});}) + .def("__getitem__", [](const c10::intrusive_ptr &self,const std::string &item) -> int64_t { + return self->__getitem__(c10::string_view{item});}) .def("__len__", &Vocab::__len__) .def("insert_token", &Vocab::insert_token) .def("append_token", &Vocab::append_token) .def("lookup_token", &Vocab::lookup_token) - .def("lookup_tokens", &Vocab::lookup_tokens) - .def("lookup_indices", py::overload_cast &>(&Vocab::lookup_indices)) + .def("lookup_tokens", &Vocab::lookup_tokens) + .def("lookup_indices", [](const c10::intrusive_ptr &self,const std::vector &items){ + std::vector temp(items.size()); + int64_t counter = 0; + for(const auto &item: items){ + temp[counter++] = c10::string_view{item}; + } + return self->lookup_indices(temp); + }) .def("get_stoi", &Vocab::get_stoi) .def("get_itos", &Vocab::get_itos) .def_pickle( diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 34e6b59e66..9eb6d9f6ba 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -32,25 +32,14 @@ Vocab::Vocab(const StringList &tokens, const std::string &unk_token) int64_t Vocab::__len__() const { return stoi_.size(); } -int64_t Vocab::__getitem__(const py::str &token) const { - ssize_t length; - const char *buffer = PyUnicode_AsUTF8AndSize(token.ptr(),&length); - const auto &item = stoi_.find(c10::string_view{buffer, (size_t)length}); +int64_t Vocab::__getitem__(const c10::string_view &token) const { + const auto &item = stoi_.find(token); if (item != stoi_.end()) { return item->second; } return unk_index_; } -int64_t Vocab::__getitem__(const std::string &token) const { - const auto &item = stoi_.find(c10::string_view{token}); - if (item != stoi_.end()) { - return item->second; - } - return unk_index_; -} - - void Vocab::append_token(const std::string &token) { if (stoi_.find(token) == stoi_.end()) { // Note: we can't do `stoi_[token] = stoi_.size()` because of a bug @@ -126,15 +115,7 @@ StringList Vocab::lookup_tokens(const std::vector &indices) { return tokens; } -std::vector Vocab::lookup_indices(const py::list &tokens) { - std::vector indices(tokens.size()); - for (int64_t i = 0; i < static_cast(tokens.size()); i++) { - indices[i] = __getitem__(tokens[i]); - } - return indices; -} - -std::vector Vocab::lookup_indices(const std::vector &tokens) { +std::vector Vocab::lookup_indices(const std::vector &tokens) { std::vector indices(tokens.size()); for (int64_t i = 0; i < static_cast(tokens.size()); i++) { indices[i] = __getitem__(tokens[i]); diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index 50842c273c..f508b36cd4 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -25,14 +25,12 @@ struct Vocab : torch::CustomClassHolder { const std::string &unk_token, const int64_t unk_index); int64_t __len__() const; - int64_t __getitem__(const py::str &token) const; - int64_t __getitem__(const std::string &token) const; + int64_t __getitem__(const c10::string_view &token) const; void append_token(const std::string &token); void insert_token(const std::string &token, const int64_t &index); std::string lookup_token(const int64_t &index); std::vector lookup_tokens(const std::vector &indices); - std::vector lookup_indices(const py::list &tokens); - std::vector lookup_indices(const std::vector &tokens); + std::vector lookup_indices(const std::vector &tokens); std::unordered_map get_stoi() const; std::vector get_itos() const; }; From 763d3900269576eef095f5ddbc874718b51a458e Mon Sep 17 00:00:00 2001 From: Parmeet Bhatia Date: Thu, 11 Mar 2021 20:33:14 -0800 Subject: [PATCH 09/15] fixing other vocab class functions --- torchtext/csrc/register_bindings.cpp | 265 ++++++++++++++------------- torchtext/csrc/vocab.cpp | 235 +++++++++++------------- torchtext/csrc/vocab.h | 8 +- 3 files changed, 254 insertions(+), 254 deletions(-) diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index 18376bb94b..7dbd157c80 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -1,14 +1,14 @@ +#include #include #include #include -#include // @manual -#include // @manual +#include // @manual +#include // @manual #include // @manual -#include // @manual +#include // @manual #include #include // @manual #include // @manual -#include namespace torchtext { namespace py = pybind11; @@ -16,11 +16,11 @@ namespace py = pybind11; namespace { Vocab build_vocab_from_text_file(const std::string &file_path, const std::string &unk_token, - const int64_t min_freq, - const int64_t num_cpus, + const int64_t min_freq, const int64_t num_cpus, py::object fn) { torch::jit::script::Module module(*torch::jit::as_module(fn)); - return _build_vocab_from_text_file(file_path, unk_token, min_freq, num_cpus, module); + return _build_vocab_from_text_file(file_path, unk_token, min_freq, num_cpus, + module); } } // namespace @@ -40,7 +40,8 @@ PYBIND11_MODULE(_torchtext, m) { return _deserialize_regex(std::move(state)); })); - py::class_>(m, "RegexTokenizer") + py::class_>( + m, "RegexTokenizer") .def_readonly("patterns_", &RegexTokenizer::patterns_) .def_readonly("replacements_", &RegexTokenizer::replacements_) .def_readonly("to_lower_", &RegexTokenizer::to_lower_) @@ -48,15 +49,18 @@ PYBIND11_MODULE(_torchtext, m) { .def("forward", &RegexTokenizer::forward) .def(py::pickle( // __getstate__ - [](const c10::intrusive_ptr &self) -> RegexTokenizerStates { + [](const c10::intrusive_ptr &self) + -> RegexTokenizerStates { return _serialize_regex_tokenizer(self); }, // __setstate__ - [](RegexTokenizerStates states) -> c10::intrusive_ptr { + [](RegexTokenizerStates states) + -> c10::intrusive_ptr { return _deserialize_regex_tokenizer(std::move(states)); })); - py::class_>(m, "SentencePiece") + py::class_>(m, + "SentencePiece") .def(py::init()) .def("_return_content", [](const SentencePiece &self) { return py::bytes(self.content_); }) @@ -70,14 +74,14 @@ PYBIND11_MODULE(_torchtext, m) { .def("PieceToId", &SentencePiece::PieceToId) .def("IdToPiece", &SentencePiece::IdToPiece) .def(py::pickle( - // __getstate__ - [](const c10::intrusive_ptr &self) -> py::bytes{ - return py::bytes(self->content_); - }, - // __setstate__ - [](py::bytes state) -> c10::intrusive_ptr { - return c10::make_intrusive(std::string(state)); - })); + // __getstate__ + [](const c10::intrusive_ptr &self) -> py::bytes { + return py::bytes(self->content_); + }, + // __setstate__ + [](py::bytes state) -> c10::intrusive_ptr { + return c10::make_intrusive(std::string(state)); + })); py::class_>(m, "Vectors") .def(py::init, std::vector, @@ -103,26 +107,30 @@ PYBIND11_MODULE(_torchtext, m) { .def(py::init, std::string>()) .def_readonly("itos_", &Vocab::itos_) .def_readonly("unk_token_", &Vocab::unk_token_) - .def("__getitem__", [](c10::intrusive_ptr &self,const py::str &item){ - ssize_t length; - const char *buffer = PyUnicode_AsUTF8AndSize(item.ptr(),&length); - return self->__getitem__(c10::string_view{buffer,(size_t)length}); - }) + .def("__getitem__", + [](c10::intrusive_ptr &self, const py::str &item) { + ssize_t length; + const char *buffer = PyUnicode_AsUTF8AndSize(item.ptr(), &length); + return self->__getitem__(c10::string_view{buffer, (size_t)length}); + }) .def("__len__", &Vocab::__len__) .def("insert_token", &Vocab::insert_token) .def("append_token", &Vocab::append_token) .def("lookup_token", &Vocab::lookup_token) .def("lookup_tokens", &Vocab::lookup_tokens) - .def("lookup_indices", [](const c10::intrusive_ptr &self,const py::list &items){ - std::vector indices(items.size()); - int64_t counter = 0; - for(const auto &item: items){ - ssize_t length; - const char *buffer = PyUnicode_AsUTF8AndSize(item.ptr(),&length); - indices[counter++] = self->__getitem__(c10::string_view{buffer,(size_t)length}); - } - return indices; - }) + .def("lookup_indices", + [](const c10::intrusive_ptr &self, const py::list &items) { + std::vector indices(items.size()); + int64_t counter = 0; + for (const auto &item : items) { + ssize_t length; + const char *buffer = + PyUnicode_AsUTF8AndSize(item.ptr(), &length); + indices[counter++] = + self->__getitem__(c10::string_view{buffer, (size_t)length}); + } + return indices; + }) .def("get_stoi", &Vocab::get_stoi) .def("get_itos", &Vocab::get_itos) .def(py::pickle( @@ -144,104 +152,113 @@ PYBIND11_MODULE(_torchtext, m) { TORCH_LIBRARY_FRAGMENT(torchtext, m) { m.class_("Regex") - .def(torch::init()) - .def("Sub", &Regex::Sub) - .def_pickle( - // __getstate__ - [](const c10::intrusive_ptr &self) -> std::string { - return _serialize_regex(self); - }, - // __setstate__ - [](std::string state) -> c10::intrusive_ptr { - return _deserialize_regex(std::move(state)); - }); + .def(torch::init()) + .def("Sub", &Regex::Sub) + .def_pickle( + // __getstate__ + [](const c10::intrusive_ptr &self) -> std::string { + return _serialize_regex(self); + }, + // __setstate__ + [](std::string state) -> c10::intrusive_ptr { + return _deserialize_regex(std::move(state)); + }); m.class_("RegexTokenizer") - .def(torch::init, std::vector, bool>()) - .def("forward", &RegexTokenizer::forward) - .def_pickle( - // __getstate__ - [](const c10::intrusive_ptr &self) -> RegexTokenizerStates { - return _serialize_regex_tokenizer(self); - }, - // __setstate__ - [](RegexTokenizerStates states) -> c10::intrusive_ptr { - return _deserialize_regex_tokenizer(std::move(states)); - }); + .def(torch::init, std::vector, + bool>()) + .def("forward", &RegexTokenizer::forward) + .def_pickle( + // __getstate__ + [](const c10::intrusive_ptr &self) + -> RegexTokenizerStates { + return _serialize_regex_tokenizer(self); + }, + // __setstate__ + [](RegexTokenizerStates states) + -> c10::intrusive_ptr { + return _deserialize_regex_tokenizer(std::move(states)); + }); m.class_("SentencePiece") - .def(torch::init()) - .def("Encode", &SentencePiece::Encode) - .def("EncodeAsIds", &SentencePiece::EncodeAsIds) - .def("DecodeIds", &SentencePiece::DecodeIds) - .def("EncodeAsPieces", &SentencePiece::EncodeAsPieces) - .def("DecodePieces", &SentencePiece::DecodePieces) - .def("GetPieceSize", &SentencePiece::GetPieceSize) - .def("unk_id", &SentencePiece::unk_id) - .def("PieceToId", &SentencePiece::PieceToId) - .def("IdToPiece", &SentencePiece::IdToPiece) - .def_pickle( - // The underlying content of SentencePiece contains byte string, - // and returing it as std::string cause UTF8 decoding error. - // Since TorchScript does not support byte string, we use byte Tensor to - // pass around the data. - // __getstate__ - [](const c10::intrusive_ptr &self) -> torch::Tensor { - auto *data = static_cast(const_cast(self->content_.data())); - auto numel = static_cast(self->content_.size()); - return torch::from_blob(data, {numel}, {torch::kUInt8}).clone(); - }, - // __setstate__ - [](torch::Tensor state) -> c10::intrusive_ptr { - auto *data = static_cast(state.data_ptr()); - auto numel = state.size(0); - return c10::make_intrusive(std::string(data, numel)); - }); + .def(torch::init()) + .def("Encode", &SentencePiece::Encode) + .def("EncodeAsIds", &SentencePiece::EncodeAsIds) + .def("DecodeIds", &SentencePiece::DecodeIds) + .def("EncodeAsPieces", &SentencePiece::EncodeAsPieces) + .def("DecodePieces", &SentencePiece::DecodePieces) + .def("GetPieceSize", &SentencePiece::GetPieceSize) + .def("unk_id", &SentencePiece::unk_id) + .def("PieceToId", &SentencePiece::PieceToId) + .def("IdToPiece", &SentencePiece::IdToPiece) + .def_pickle( + // The underlying content of SentencePiece contains byte string, + // and returing it as std::string cause UTF8 decoding error. + // Since TorchScript does not support byte string, we use byte Tensor + // to pass around the data. + // __getstate__ + [](const c10::intrusive_ptr &self) -> torch::Tensor { + auto *data = + static_cast(const_cast(self->content_.data())); + auto numel = static_cast(self->content_.size()); + return torch::from_blob(data, {numel}, {torch::kUInt8}).clone(); + }, + // __setstate__ + [](torch::Tensor state) -> c10::intrusive_ptr { + auto *data = static_cast(state.data_ptr()); + auto numel = state.size(0); + return c10::make_intrusive(std::string(data, numel)); + }); m.class_("Vectors") - .def(torch::init, std::vector, torch::Tensor, torch::Tensor>()) - .def("__getitem__", &Vectors::__getitem__) - .def("lookup_vectors", &Vectors::lookup_vectors) - .def("__setitem__", &Vectors::__setitem__) - .def("__len__", &Vectors::__len__) - .def_pickle( - // __getstate__ - [](const c10::intrusive_ptr &self) -> VectorsStates { - return _serialize_vectors(self); - }, - // __setstate__ - [](VectorsStates states) -> c10::intrusive_ptr { - return _deserialize_vectors(states); - }); + .def(torch::init, std::vector, + torch::Tensor, torch::Tensor>()) + .def("__getitem__", &Vectors::__getitem__) + .def("lookup_vectors", &Vectors::lookup_vectors) + .def("__setitem__", &Vectors::__setitem__) + .def("__len__", &Vectors::__len__) + .def_pickle( + // __getstate__ + [](const c10::intrusive_ptr &self) -> VectorsStates { + return _serialize_vectors(self); + }, + // __setstate__ + [](VectorsStates states) -> c10::intrusive_ptr { + return _deserialize_vectors(states); + }); m.class_("Vocab") - .def(torch::init()) - .def("__getitem__", [](const c10::intrusive_ptr &self,const std::string &item) -> int64_t { - return self->__getitem__(c10::string_view{item});}) - .def("__len__", &Vocab::__len__) - .def("insert_token", &Vocab::insert_token) - .def("append_token", &Vocab::append_token) - .def("lookup_token", &Vocab::lookup_token) - .def("lookup_tokens", &Vocab::lookup_tokens) - .def("lookup_indices", [](const c10::intrusive_ptr &self,const std::vector &items){ - std::vector temp(items.size()); - int64_t counter = 0; - for(const auto &item: items){ - temp[counter++] = c10::string_view{item}; - } - return self->lookup_indices(temp); - }) - .def("get_stoi", &Vocab::get_stoi) - .def("get_itos", &Vocab::get_itos) - .def_pickle( - // __getstate__ - [](const c10::intrusive_ptr &self) -> VocabStates { - return _serialize_vocab(self); - }, - // __setstate__ - [](VocabStates states) -> c10::intrusive_ptr { - return _deserialize_vocab(states); - }); + .def(torch::init()) + .def("__getitem__", + [](const c10::intrusive_ptr &self, const std::string &item) + -> int64_t { return self->__getitem__(c10::string_view{item}); }) + .def("__len__", &Vocab::__len__) + .def("insert_token", &Vocab::insert_token) + .def("append_token", &Vocab::append_token) + .def("lookup_token", &Vocab::lookup_token) + .def("lookup_tokens", &Vocab::lookup_tokens) + .def("lookup_indices", + [](const c10::intrusive_ptr &self, + const std::vector &items) { + std::vector indices(items.size()); + int64_t counter = 0; + for (const auto &item : items) { + indices[counter++] = + self->__getitem__(c10::string_view{item}); + } + return indices; + }) + .def("get_stoi", &Vocab::get_stoi) + .def("get_itos", &Vocab::get_itos) + .def_pickle( + // __getstate__ + [](const c10::intrusive_ptr &self) -> VocabStates { + return _serialize_vocab(self); + }, + // __setstate__ + [](VocabStates states) -> c10::intrusive_ptr { + return _deserialize_vocab(states); + }); m.def("torchtext::generate_sp_model", &generate_sp_model); m.def("torchtext::load_sp_model", &load_sp_model); diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 9eb6d9f6ba..dac5d1bef3 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -2,16 +2,11 @@ #include #include #include -#include // @manual -#include // @manual +#include // @manual +#include // @manual namespace torchtext { -Vocab::Vocab(const StringList &tokens, const IndexDict &stoi, - const std::string &unk_token, const int64_t unk_index) - : unk_index_(std::move(unk_index)), stoi_(std::move(stoi)), - itos_(std::move(tokens)), unk_token_(std::move(unk_token)) {} - Vocab::Vocab(const StringList &tokens, const std::string &unk_token) : itos_(std::move(tokens)), unk_token_(std::move(unk_token)) { stoi_.reserve(tokens.size()); @@ -25,7 +20,7 @@ Vocab::Vocab(const StringList &tokens, const std::string &unk_token) throw std::runtime_error("Duplicate token found in tokens list: " + tokens[i]); } - stoi_[c10::string_view{itos_[i].data(),itos_[i].size()}] = i; + stoi_[c10::string_view{itos_[i].data(), itos_[i].size()}] = i; } unk_index_ = stoi_.find(unk_token)->second; } @@ -33,7 +28,7 @@ Vocab::Vocab(const StringList &tokens, const std::string &unk_token) int64_t Vocab::__len__() const { return stoi_.size(); } int64_t Vocab::__getitem__(const c10::string_view &token) const { - const auto &item = stoi_.find(token); + const auto &item = stoi_.find(token); if (item != stoi_.end()) { return item->second; } @@ -42,13 +37,9 @@ int64_t Vocab::__getitem__(const c10::string_view &token) const { void Vocab::append_token(const std::string &token) { if (stoi_.find(token) == stoi_.end()) { - // Note: we can't do `stoi_[token] = stoi_.size()` because of a bug - // on Windows where the size gets updated before the assign occurs. - // For example if the size of `stoi_` is 2, doing - // `stoi_["test"] = stoi_.size()` will set `stoi_["test"]` to a - // value of 3 instead of 2 on Windows stoi_[token] = itos_.size(); itos_.push_back(token); - stoi_[c10::string_view{itos_.back().data(),itos_.back().size()}] = itos_.size()-1; + stoi_[c10::string_view{itos_.back().data(), itos_.back().size()}] = + itos_.size() - 1; } } @@ -65,7 +56,7 @@ void Vocab::insert_token(const std::string &token, const int64_t &index) { std::to_string(stoi_.size()) + "."); } - const auto &item = stoi_.find(token); + const auto &item = stoi_.find(c10::string_view{token}); // if item already in stoi we throw an error if (item != stoi_.end()) { #ifdef _MSC_VER @@ -80,122 +71,127 @@ void Vocab::insert_token(const std::string &token, const int64_t &index) { // need to offset all tokens greater than or equal index by 1 for (size_t i = index; i < itos_.size(); i++) { - stoi_[itos_[i]] = i + 1; + stoi_[c10::string_view{itos_[i].data(), itos_[i].size()}] = i + 1; } - stoi_[token] = index; itos_.insert(itos_.begin() + index, token); + stoi_[c10::string_view{itos_[index].data(),itos_[index].size()}] = index; - // need to update unk_index in case token equals unk_token or token - // inserted before unk_token - unk_index_ = stoi_.find(unk_token_)->second; + // need to update unk_index in case token equals unk_token or token + // inserted before unk_token + unk_index_ = stoi_.find(unk_token_)->second; } std::string Vocab::lookup_token(const int64_t &index) { - if (index < 0 || index > static_cast(itos_.size())) { + if (index < 0 || index > static_cast(itos_.size())) { #ifdef _MSC_VER - std::cerr << "[RuntimeError] Specified index " << index - << " is out of bounds of the size of itos dictionary: " - << stoi_.size() << std::endl; + std::cerr << "[RuntimeError] Specified index " << index + << " is out of bounds of the size of itos dictionary: " + << stoi_.size() << std::endl; #endif - throw std::runtime_error( - "Specified index " + std::to_string(index) + - " is out of bounds of the size of itos dictionary: " + - std::to_string(itos_.size()) + "."); - } + throw std::runtime_error( + "Specified index " + std::to_string(index) + + " is out of bounds of the size of itos dictionary: " + + std::to_string(itos_.size()) + "."); + } - return itos_[index]; + return itos_[index]; } StringList Vocab::lookup_tokens(const std::vector &indices) { - std::vector tokens(indices.size()); - for (int64_t i = 0; i < static_cast(indices.size()); i++) { - tokens[i] = lookup_token(indices[i]); - } - return tokens; + std::vector tokens(indices.size()); + for (int64_t i = 0; i < static_cast(indices.size()); i++) { + tokens[i] = lookup_token(indices[i]); + } + return tokens; } -std::vector Vocab::lookup_indices(const std::vector &tokens) { - std::vector indices(tokens.size()); - for (int64_t i = 0; i < static_cast(tokens.size()); i++) { - indices[i] = __getitem__(tokens[i]); - } - return indices; +std::vector +Vocab::lookup_indices(const std::vector &tokens) { + std::vector indices(tokens.size()); + for (int64_t i = 0; i < static_cast(tokens.size()); i++) { + indices[i] = __getitem__(tokens[i]); + } + return indices; } - std::unordered_map Vocab::get_stoi() const { - std::unordered_map stoi; - stoi.reserve(stoi_.size()); + std::unordered_map stoi; + stoi.reserve(stoi_.size()); - // construct tokens and index list - for (const auto &item : stoi_) { - stoi[std::string{item.first}] = item.second; - } - return stoi; + // construct tokens and index list + for (const auto &item : stoi_) { + stoi[std::string{item.first}] = item.second; + } + return stoi; } -StringList Vocab::get_itos() const { return itos_; } +StringList Vocab::get_itos() const { + return itos_; } int64_t _infer_lines(const std::string &file_path) { - int64_t num_lines = 0; - std::ifstream fin; - fin.open(file_path, std::ios::in); + int64_t num_lines = 0; + std::ifstream fin; + fin.open(file_path, std::ios::in); - while (fin.ignore(std::numeric_limits::max(), '\n')) { - num_lines++; - } - return num_lines; + while (fin.ignore(std::numeric_limits::max(), '\n')) { + num_lines++; + } + return num_lines; } void parse_vocab_file_chunk(const std::string &file_path, size_t offset, const int64_t start_line, const int64_t end_line, - std::shared_ptr counter) { - std::ifstream fin(file_path, std::ios::in); - if(!fin.is_open()){ - throw std::runtime_error("Cannot open input file "+file_path+"\n"); - } + std::shared_ptr> counter) { + std::ifstream fin(file_path, std::ios::in); + if (!fin.is_open()) { + throw std::runtime_error("Cannot open input file " + file_path + "\n"); + } - fin.seekg(offset); + fin.seekg(offset); - for (int64_t i = start_line; i < end_line; i++) { - std::string token; - fin >> token; - fin >> std::ws; + for (int64_t i = start_line; i < end_line; i++) { + std::string token; + fin >> token; + fin >> std::ws; - if ((*counter).find(token) == (*counter).end()) { - (*counter)[token] = 1; - } else { - (*counter)[token] += 1; + if ((*counter).find(token) == (*counter).end()) { + (*counter)[token] = 1; + } else { + (*counter)[token] += 1; + } } - } } void parse_raw_text_file_chunk(const std::string &file_path, size_t offset, const int64_t start_line, const int64_t end_line, - std::shared_ptr counter, + std::shared_ptr> counter, torch::jit::script::Module &module) { - std::ifstream fin; - fin.open(file_path, std::ios::in); - fin.seekg(offset); + std::ifstream fin(file_path, std::ios::in); + if (!fin.is_open()) { + throw std::runtime_error("Cannot open input file " + file_path + "\n"); + } - std::string line; - for (int64_t i = start_line; i < end_line; i++) { - std::getline(fin, line); - auto token_list = - module.forward(std::vector({c10::IValue(line)})).toList(); + fin.seekg(offset); - for (size_t i = 0; i < token_list.size(); i++) { - c10::IValue token_ref = token_list.get(i); - std::string token = token_ref.toStringRef(); + std::string line; + for (int64_t i = start_line; i < end_line; i++) { + std::getline(fin, line); + auto token_list = + module.forward(std::vector({c10::IValue(line)})) + .toList(); - if ((*counter).find(token) == (*counter).end()) { - (*counter)[token] = 1; - } else { - (*counter)[token] += 1; + for (size_t i = 0; i < token_list.size(); i++) { + c10::IValue token_ref = token_list.get(i); + std::string token = token_ref.toStringRef(); + + if ((*counter).find(token) == (*counter).end()) { + (*counter)[token] = 1; + } else { + (*counter)[token] += 1; + } } } - } } // sorting using a custom object @@ -204,19 +200,21 @@ struct CompareTokens { const std::pair &b) { if (a.second == b.second) { return a.first < b.first; - } - return a.second > b.second; - } -}; +} +return a.second > b.second; +} +} +; -std::tuple -_concat_tokens(std::vector> chunk_counters, - const std::string &unk_token, const int64_t min_freq, - const int64_t num_lines, const bool sort_tokens) { +StringList _concat_tokens( + std::vector>> + chunk_counters, + const std::string &unk_token, const int64_t min_freq, + const int64_t num_lines, const bool sort_tokens) { TORCH_CHECK(chunk_counters.size() > 0, "There must be at least 1 chunk to concatenate!"); - IndexDict tokens_freq; + std::unordered_map tokens_freq; StringList unique_tokens; unique_tokens.reserve(num_lines); @@ -268,17 +266,7 @@ _concat_tokens(std::vector> chunk_counters, unique_tokens.insert(unique_tokens.begin(), unk_token); } - // create stoi - IndexDict stoi; - stoi.reserve(num_lines); - int64_t index = 0; - - for (const auto &token : unique_tokens) { - stoi[token] = index; - index++; - } - - return std::make_tuple(std::move(stoi), std::move(unique_tokens)); + return unique_tokens; } constexpr int64_t GRAIN_SIZE = 13107; @@ -296,7 +284,8 @@ Vocab _load_vocab_from_file(const std::string &file_path, std::vector offsets; impl::infer_offsets(file_path, num_lines, chunk_size, offsets); - std::vector> chunk_counters; + std::vector>> + chunk_counters; std::mutex m; std::condition_variable cv; @@ -305,7 +294,8 @@ Vocab _load_vocab_from_file(const std::string &file_path, // create threads int64_t j = 0; for (int64_t i = 0; i < num_lines; i += chunk_size) { - auto counter_ptr = std::make_shared(); + auto counter_ptr = + std::make_shared>(); thread_count++; at::launch([&, file_path, num_lines, chunk_size, j, i, counter_ptr]() { @@ -323,14 +313,10 @@ Vocab _load_vocab_from_file(const std::string &file_path, std::unique_lock lock(m); cv.wait(lock, [&thread_count] { return thread_count == 0; }); - IndexDict stoi; - StringList tokens; - std::tie(stoi, tokens) = + StringList tokens = _concat_tokens(chunk_counters, unk_token, min_freq, num_lines, false); - int64_t unk_index = stoi.find(unk_token)->second; - - return Vocab(std::move(tokens), std::move(stoi), unk_token, unk_index); + return Vocab(std::move(tokens), unk_token); } Vocab _build_vocab_from_text_file(const std::string &file_path, @@ -347,7 +333,8 @@ Vocab _build_vocab_from_text_file(const std::string &file_path, std::vector offsets; impl::infer_offsets(file_path, num_lines, chunk_size, offsets); - std::vector> chunk_counters; + std::vector>> + chunk_counters; std::mutex m; std::condition_variable cv; @@ -356,7 +343,8 @@ Vocab _build_vocab_from_text_file(const std::string &file_path, // create threads int64_t j = 0; for (int64_t i = 0; i < num_lines; i += chunk_size) { - auto counter_ptr = std::make_shared(); + auto counter_ptr = + std::make_shared>(); thread_count++; at::launch([&, file_path, num_lines, chunk_size, j, i, counter_ptr]() { parse_raw_text_file_chunk(file_path, offsets[j], i, @@ -374,13 +362,10 @@ Vocab _build_vocab_from_text_file(const std::string &file_path, std::unique_lock lock(m); cv.wait(lock, [&thread_count] { return thread_count == 0; }); - IndexDict stoi; - StringList tokens; - std::tie(stoi, tokens) = + StringList tokens = _concat_tokens(chunk_counters, unk_token, min_freq, num_lines, true); - int64_t unk_index = stoi.find(unk_token)->second; - return Vocab(std::move(tokens), std::move(stoi), unk_token, unk_index); + return Vocab(std::move(tokens), unk_token); } VocabStates _serialize_vocab(const c10::intrusive_ptr &self) { diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index f508b36cd4..f58d3cc9aa 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -1,5 +1,5 @@ -#include #include +#include namespace torchtext { typedef std::vector StringList; @@ -21,16 +21,14 @@ struct Vocab : torch::CustomClassHolder { explicit Vocab(const std::vector &tokens, const std::string &unk_token); - explicit Vocab(const StringList &tokens, const IndexDict &stoi, - - const std::string &unk_token, const int64_t unk_index); int64_t __len__() const; int64_t __getitem__(const c10::string_view &token) const; void append_token(const std::string &token); void insert_token(const std::string &token, const int64_t &index); std::string lookup_token(const int64_t &index); std::vector lookup_tokens(const std::vector &indices); - std::vector lookup_indices(const std::vector &tokens); + std::vector + lookup_indices(const std::vector &tokens); std::unordered_map get_stoi() const; std::vector get_itos() const; }; From bc7b0d5c67bdb29ce3f8cdba9b90d82271654ebe Mon Sep 17 00:00:00 2001 From: Parmeet Bhatia Date: Fri, 12 Mar 2021 10:15:19 -0800 Subject: [PATCH 10/15] added fasttext hash --- torchtext/csrc/vocab.cpp | 201 ++++++++++++++++++++------------------- torchtext/csrc/vocab.h | 12 ++- 2 files changed, 112 insertions(+), 101 deletions(-) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index dac5d1bef3..c41b3f4e3a 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -12,7 +12,7 @@ Vocab::Vocab(const StringList &tokens, const std::string &unk_token) stoi_.reserve(tokens.size()); for (std::size_t i = 0; i < tokens.size(); i++) { // tokens should not have any duplicates - if (stoi_.find(tokens[i]) != stoi_.end()) { + if (stoi_.find(hash(c10::string_view{tokens[i]})) != stoi_.end()) { #ifdef _MSC_VER std::cerr << "[RuntimeError] Duplicate token found in tokens list: " << tokens[i] << std::endl; @@ -20,15 +20,16 @@ Vocab::Vocab(const StringList &tokens, const std::string &unk_token) throw std::runtime_error("Duplicate token found in tokens list: " + tokens[i]); } - stoi_[c10::string_view{itos_[i].data(), itos_[i].size()}] = i; + stoi_[hash(c10::string_view{itos_[i].data(), itos_[i].size()})] = i; } - unk_index_ = stoi_.find(unk_token)->second; + unk_index_ = stoi_.find(hash(c10::string_view{unk_token}))->second; } int64_t Vocab::__len__() const { return stoi_.size(); } int64_t Vocab::__getitem__(const c10::string_view &token) const { - const auto &item = stoi_.find(token); + uint32_t token_hash = hash(token); + const auto &item = stoi_.find(token_hash); if (item != stoi_.end()) { return item->second; } @@ -36,10 +37,10 @@ int64_t Vocab::__getitem__(const c10::string_view &token) const { } void Vocab::append_token(const std::string &token) { - if (stoi_.find(token) == stoi_.end()) { + uint32_t token_hash = hash(token); + if (stoi_.find(token_hash) == stoi_.end()) { itos_.push_back(token); - stoi_[c10::string_view{itos_.back().data(), itos_.back().size()}] = - itos_.size() - 1; + stoi_[token_hash] = itos_.size() - 1; } } @@ -56,7 +57,8 @@ void Vocab::insert_token(const std::string &token, const int64_t &index) { std::to_string(stoi_.size()) + "."); } - const auto &item = stoi_.find(c10::string_view{token}); + uint32_t token_hash = hash(token); + const auto &item = stoi_.find(token_hash); // if item already in stoi we throw an error if (item != stoi_.end()) { #ifdef _MSC_VER @@ -71,127 +73,127 @@ void Vocab::insert_token(const std::string &token, const int64_t &index) { // need to offset all tokens greater than or equal index by 1 for (size_t i = index; i < itos_.size(); i++) { - stoi_[c10::string_view{itos_[i].data(), itos_[i].size()}] = i + 1; + stoi_[hash(c10::string_view{itos_[i].data(), itos_[i].size()})] = i + 1; } itos_.insert(itos_.begin() + index, token); - stoi_[c10::string_view{itos_[index].data(),itos_[index].size()}] = index; + stoi_[token_hash] = index; - // need to update unk_index in case token equals unk_token or token - // inserted before unk_token - unk_index_ = stoi_.find(unk_token_)->second; + // need to update unk_index in case token equals unk_token or token + // inserted before unk_token + unk_index_ = stoi_.find(hash(c10::string_view(unk_token_)))->second; } std::string Vocab::lookup_token(const int64_t &index) { - if (index < 0 || index > static_cast(itos_.size())) { + if (index < 0 || index > static_cast(itos_.size())) { #ifdef _MSC_VER - std::cerr << "[RuntimeError] Specified index " << index - << " is out of bounds of the size of itos dictionary: " - << stoi_.size() << std::endl; + std::cerr << "[RuntimeError] Specified index " << index + << " is out of bounds of the size of itos dictionary: " + << stoi_.size() << std::endl; #endif - throw std::runtime_error( - "Specified index " + std::to_string(index) + - " is out of bounds of the size of itos dictionary: " + - std::to_string(itos_.size()) + "."); - } + throw std::runtime_error( + "Specified index " + std::to_string(index) + + " is out of bounds of the size of itos dictionary: " + + std::to_string(itos_.size()) + "."); + } - return itos_[index]; + return itos_[index]; } StringList Vocab::lookup_tokens(const std::vector &indices) { - std::vector tokens(indices.size()); - for (int64_t i = 0; i < static_cast(indices.size()); i++) { - tokens[i] = lookup_token(indices[i]); - } - return tokens; + std::vector tokens(indices.size()); + for (int64_t i = 0; i < static_cast(indices.size()); i++) { + tokens[i] = lookup_token(indices[i]); + } + return tokens; } std::vector Vocab::lookup_indices(const std::vector &tokens) { - std::vector indices(tokens.size()); - for (int64_t i = 0; i < static_cast(tokens.size()); i++) { - indices[i] = __getitem__(tokens[i]); - } - return indices; + std::vector indices(tokens.size()); + for (int64_t i = 0; i < static_cast(tokens.size()); i++) { + indices[i] = __getitem__(tokens[i]); + } + return indices; } std::unordered_map Vocab::get_stoi() const { - std::unordered_map stoi; - stoi.reserve(stoi_.size()); + std::unordered_map stoi; + stoi.reserve(stoi_.size()); - // construct tokens and index list - for (const auto &item : stoi_) { - stoi[std::string{item.first}] = item.second; - } - return stoi; + // construct tokens and index list + for (const auto &item : itos_) { + stoi[item] = stoi_.find(hash(c10::string_view{item}))->second; + } + return stoi; } -StringList Vocab::get_itos() const { - return itos_; } +StringList Vocab::get_itos() const { return itos_; } int64_t _infer_lines(const std::string &file_path) { - int64_t num_lines = 0; - std::ifstream fin; - fin.open(file_path, std::ios::in); + int64_t num_lines = 0; + std::ifstream fin; + fin.open(file_path, std::ios::in); - while (fin.ignore(std::numeric_limits::max(), '\n')) { - num_lines++; - } - return num_lines; + while (fin.ignore(std::numeric_limits::max(), '\n')) { + num_lines++; + } + return num_lines; } -void parse_vocab_file_chunk(const std::string &file_path, size_t offset, - const int64_t start_line, const int64_t end_line, - std::shared_ptr> counter) { - std::ifstream fin(file_path, std::ios::in); - if (!fin.is_open()) { - throw std::runtime_error("Cannot open input file " + file_path + "\n"); - } +void parse_vocab_file_chunk( + const std::string &file_path, size_t offset, const int64_t start_line, + const int64_t end_line, + std::shared_ptr> counter) { + std::ifstream fin(file_path, std::ios::in); + if (!fin.is_open()) { + throw std::runtime_error("Cannot open input file " + file_path + "\n"); + } - fin.seekg(offset); + fin.seekg(offset); - for (int64_t i = start_line; i < end_line; i++) { - std::string token; - fin >> token; - fin >> std::ws; + for (int64_t i = start_line; i < end_line; i++) { + std::string token; + fin >> token; + fin >> std::ws; - if ((*counter).find(token) == (*counter).end()) { - (*counter)[token] = 1; - } else { - (*counter)[token] += 1; - } + if ((*counter).find(token) == (*counter).end()) { + (*counter)[token] = 1; + } else { + (*counter)[token] += 1; } + } } -void parse_raw_text_file_chunk(const std::string &file_path, size_t offset, - const int64_t start_line, const int64_t end_line, - std::shared_ptr> counter, - torch::jit::script::Module &module) { - std::ifstream fin(file_path, std::ios::in); - if (!fin.is_open()) { - throw std::runtime_error("Cannot open input file " + file_path + "\n"); - } +void parse_raw_text_file_chunk( + const std::string &file_path, size_t offset, const int64_t start_line, + const int64_t end_line, + std::shared_ptr> counter, + torch::jit::script::Module &module) { + std::ifstream fin(file_path, std::ios::in); + if (!fin.is_open()) { + throw std::runtime_error("Cannot open input file " + file_path + "\n"); + } - fin.seekg(offset); + fin.seekg(offset); - std::string line; - for (int64_t i = start_line; i < end_line; i++) { - std::getline(fin, line); - auto token_list = - module.forward(std::vector({c10::IValue(line)})) - .toList(); + std::string line; + for (int64_t i = start_line; i < end_line; i++) { + std::getline(fin, line); + auto token_list = + module.forward(std::vector({c10::IValue(line)})).toList(); - for (size_t i = 0; i < token_list.size(); i++) { - c10::IValue token_ref = token_list.get(i); - std::string token = token_ref.toStringRef(); + for (size_t i = 0; i < token_list.size(); i++) { + c10::IValue token_ref = token_list.get(i); + std::string token = token_ref.toStringRef(); - if ((*counter).find(token) == (*counter).end()) { - (*counter)[token] = 1; - } else { - (*counter)[token] += 1; - } + if ((*counter).find(token) == (*counter).end()) { + (*counter)[token] = 1; + } else { + (*counter)[token] += 1; } } + } } // sorting using a custom object @@ -200,21 +202,20 @@ struct CompareTokens { const std::pair &b) { if (a.second == b.second) { return a.first < b.first; -} -return a.second > b.second; -} -} -; + } + return a.second > b.second; + } +}; StringList _concat_tokens( - std::vector>> + std::vector>> chunk_counters, const std::string &unk_token, const int64_t min_freq, const int64_t num_lines, const bool sort_tokens) { TORCH_CHECK(chunk_counters.size() > 0, "There must be at least 1 chunk to concatenate!"); - std::unordered_map tokens_freq; + ska_ordered::order_preserving_flat_hash_map tokens_freq; StringList unique_tokens; unique_tokens.reserve(num_lines); @@ -284,7 +285,7 @@ Vocab _load_vocab_from_file(const std::string &file_path, std::vector offsets; impl::infer_offsets(file_path, num_lines, chunk_size, offsets); - std::vector>> + std::vector>> chunk_counters; std::mutex m; @@ -295,7 +296,7 @@ Vocab _load_vocab_from_file(const std::string &file_path, int64_t j = 0; for (int64_t i = 0; i < num_lines; i += chunk_size) { auto counter_ptr = - std::make_shared>(); + std::make_shared>(); thread_count++; at::launch([&, file_path, num_lines, chunk_size, j, i, counter_ptr]() { @@ -333,7 +334,7 @@ Vocab _build_vocab_from_text_file(const std::string &file_path, std::vector offsets; impl::infer_offsets(file_path, num_lines, chunk_size, offsets); - std::vector>> + std::vector>> chunk_counters; std::mutex m; @@ -344,7 +345,7 @@ Vocab _build_vocab_from_text_file(const std::string &file_path, int64_t j = 0; for (int64_t i = 0; i < num_lines; i += chunk_size) { auto counter_ptr = - std::make_shared>(); + std::make_shared>(); thread_count++; at::launch([&, file_path, num_lines, chunk_size, j, i, counter_ptr]() { parse_raw_text_file_chunk(file_path, offsets[j], i, diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index f58d3cc9aa..36514b60e1 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -3,7 +3,7 @@ namespace torchtext { typedef std::vector StringList; -typedef ska_ordered::order_preserving_flat_hash_map +typedef ska_ordered::order_preserving_flat_hash_map IndexDict; typedef std::tuple, std::vector, std::vector> @@ -31,6 +31,16 @@ struct Vocab : torch::CustomClassHolder { lookup_indices(const std::vector &tokens); std::unordered_map get_stoi() const; std::vector get_itos() const; + +protected: + uint32_t hash(const c10::string_view &str) const { + uint32_t h = 2166136261; + for (size_t i = 0; i < str.size(); i++) { + h = h ^ uint32_t(uint8_t(str[i])); + h = h * 16777619; + } + return h; + } }; VocabStates _serialize_vocab(const c10::intrusive_ptr &self); From 8a9e2cea3da2e70dc79d8359bf3f9eac25ab6543 Mon Sep 17 00:00:00 2001 From: Parmeet Bhatia Date: Fri, 12 Mar 2021 10:32:55 -0800 Subject: [PATCH 11/15] clang-formatting --- torchtext/csrc/register_bindings.cpp | 3 +-- torchtext/csrc/vocab.cpp | 25 ++++++++++++++++--------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index 7dbd157c80..3820e0eb67 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -243,8 +243,7 @@ TORCH_LIBRARY_FRAGMENT(torchtext, m) { std::vector indices(items.size()); int64_t counter = 0; for (const auto &item : items) { - indices[counter++] = - self->__getitem__(c10::string_view{item}); + indices[counter++] = self->__getitem__(c10::string_view{item}); } return indices; }) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index c41b3f4e3a..fdc36e330f 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -144,7 +144,9 @@ int64_t _infer_lines(const std::string &file_path) { void parse_vocab_file_chunk( const std::string &file_path, size_t offset, const int64_t start_line, const int64_t end_line, - std::shared_ptr> counter) { + std::shared_ptr< + ska_ordered::order_preserving_flat_hash_map> + counter) { std::ifstream fin(file_path, std::ios::in); if (!fin.is_open()) { throw std::runtime_error("Cannot open input file " + file_path + "\n"); @@ -168,7 +170,9 @@ void parse_vocab_file_chunk( void parse_raw_text_file_chunk( const std::string &file_path, size_t offset, const int64_t start_line, const int64_t end_line, - std::shared_ptr> counter, + std::shared_ptr< + ska_ordered::order_preserving_flat_hash_map> + counter, torch::jit::script::Module &module) { std::ifstream fin(file_path, std::ios::in); if (!fin.is_open()) { @@ -208,7 +212,8 @@ struct CompareTokens { }; StringList _concat_tokens( - std::vector>> + std::vector>> chunk_counters, const std::string &unk_token, const int64_t min_freq, const int64_t num_lines, const bool sort_tokens) { @@ -285,7 +290,8 @@ Vocab _load_vocab_from_file(const std::string &file_path, std::vector offsets; impl::infer_offsets(file_path, num_lines, chunk_size, offsets); - std::vector>> + std::vector>> chunk_counters; std::mutex m; @@ -295,8 +301,8 @@ Vocab _load_vocab_from_file(const std::string &file_path, // create threads int64_t j = 0; for (int64_t i = 0; i < num_lines; i += chunk_size) { - auto counter_ptr = - std::make_shared>(); + auto counter_ptr = std::make_shared< + ska_ordered::order_preserving_flat_hash_map>(); thread_count++; at::launch([&, file_path, num_lines, chunk_size, j, i, counter_ptr]() { @@ -334,7 +340,8 @@ Vocab _build_vocab_from_text_file(const std::string &file_path, std::vector offsets; impl::infer_offsets(file_path, num_lines, chunk_size, offsets); - std::vector>> + std::vector>> chunk_counters; std::mutex m; @@ -344,8 +351,8 @@ Vocab _build_vocab_from_text_file(const std::string &file_path, // create threads int64_t j = 0; for (int64_t i = 0; i < num_lines; i += chunk_size) { - auto counter_ptr = - std::make_shared>(); + auto counter_ptr = std::make_shared< + ska_ordered::order_preserving_flat_hash_map>(); thread_count++; at::launch([&, file_path, num_lines, chunk_size, j, i, counter_ptr]() { parse_raw_text_file_chunk(file_path, offsets[j], i, From e6966af45e8c7e9bedc719d75c87b9d9d85c2f59 Mon Sep 17 00:00:00 2001 From: Parmeet Bhatia Date: Sat, 13 Mar 2021 23:15:34 -0800 Subject: [PATCH 12/15] fixing hashing issue --- benchmark/benchmark_experimental_vocab.py | 12 +- torchtext/csrc/register_bindings.cpp | 2 +- torchtext/csrc/vocab.cpp | 131 ++++++++++------------ torchtext/csrc/vocab.h | 34 ++++-- 4 files changed, 94 insertions(+), 85 deletions(-) diff --git a/benchmark/benchmark_experimental_vocab.py b/benchmark/benchmark_experimental_vocab.py index e63904934d..3cd4e49a1d 100644 --- a/benchmark/benchmark_experimental_vocab.py +++ b/benchmark/benchmark_experimental_vocab.py @@ -3,7 +3,7 @@ import time import torch -from torchtext.experimental.datasets import AG_NEWS +from torchtext.experimental.datasets import DATASETS from torchtext.experimental.vocab import ( vocab as VocabExperimental, load_vocab_from_file, @@ -76,7 +76,7 @@ def benchmark_experimental_vocab_construction(vocab_file_path, is_raw_text=True, print("Construction time:", time.monotonic() - t0) -def benchmark_experimental_vocab_lookup(vocab_file_path=None): +def benchmark_experimental_vocab_lookup(vocab_file_path=None, dataset = 'AG_NEWS'): def _run_benchmark_lookup(tokens, vocab): t0 = time.monotonic() # list lookup @@ -94,7 +94,7 @@ def _run_benchmark_lookup(tokens, vocab): tokens = [] tokens_lists = [] - train = AG_NEWS(split='train') + train = DATASETS[dataset](split='train') vocab = train.get_vocab() for (_, text) in train: cur_tokens = [] @@ -124,7 +124,7 @@ def token_iterator(file_path): v_experimental = load_vocab_from_file(f) print("Construction time:", time.monotonic() - t0) else: - print("Loading Vocab from AG News") + print("Loading Vocab from {}".format(dataset)) counter = Counter(tokens) sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[1], reverse=True) ordered_dict = OrderedDict(sorted_by_freq_tuples) @@ -174,6 +174,8 @@ def token_iterator(file_path): help='The name of vocab file used for construction') parser.add_argument('--vocab-filename-lookup', type=str, default=None, help='The name of vocab file used for lookup') + parser.add_argument('--dataset', type=str, default='AG_NEWS', + help='The name of vocab file used for lookup') args = parser.parse_args() if args.run_construction_benchmark: @@ -181,4 +183,4 @@ def token_iterator(file_path): benchmark_experimental_vocab_construction(args.vocab_filename_construction, is_raw_text=args.is_raw_text, is_legacy=args.is_legacy) else: - benchmark_experimental_vocab_lookup(args.vocab_filename_lookup) + benchmark_experimental_vocab_lookup(args.vocab_filename_lookup, args.dataset) diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index 3820e0eb67..e41783cef1 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -108,7 +108,7 @@ PYBIND11_MODULE(_torchtext, m) { .def_readonly("itos_", &Vocab::itos_) .def_readonly("unk_token_", &Vocab::unk_token_) .def("__getitem__", - [](c10::intrusive_ptr &self, const py::str &item) { + [](c10::intrusive_ptr &self, const py::str &item) -> int64_t { ssize_t length; const char *buffer = PyUnicode_AsUTF8AndSize(item.ptr(), &length); return self->__getitem__(c10::string_view{buffer, (size_t)length}); diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index fdc36e330f..731fd4995f 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -1,18 +1,19 @@ #include // @manual #include +#include #include #include #include // @manual #include // @manual - namespace torchtext { Vocab::Vocab(const StringList &tokens, const std::string &unk_token) - : itos_(std::move(tokens)), unk_token_(std::move(unk_token)) { - stoi_.reserve(tokens.size()); + : stoi_(MAX_VOCAB_SIZE, -1), unk_token_(std::move(unk_token)) { for (std::size_t i = 0; i < tokens.size(); i++) { // tokens should not have any duplicates - if (stoi_.find(hash(c10::string_view{tokens[i]})) != stoi_.end()) { + auto token_position = + _find(c10::string_view{tokens[i].data(), tokens[i].size()}); + if (stoi_[token_position] != -1) { #ifdef _MSC_VER std::cerr << "[RuntimeError] Duplicate token found in tokens list: " << tokens[i] << std::endl; @@ -20,32 +21,26 @@ Vocab::Vocab(const StringList &tokens, const std::string &unk_token) throw std::runtime_error("Duplicate token found in tokens list: " + tokens[i]); } - stoi_[hash(c10::string_view{itos_[i].data(), itos_[i].size()})] = i; + _add(tokens[i]); } - unk_index_ = stoi_.find(hash(c10::string_view{unk_token}))->second; + + unk_index_ = stoi_[_find(c10::string_view{unk_token})]; } -int64_t Vocab::__len__() const { return stoi_.size(); } +int64_t Vocab::__len__() const { return itos_.size(); } int64_t Vocab::__getitem__(const c10::string_view &token) const { - uint32_t token_hash = hash(token); - const auto &item = stoi_.find(token_hash); - if (item != stoi_.end()) { - return item->second; + uint32_t id = _find(token); + if (stoi_[id] != -1) { + return (int64_t)stoi_[id]; } - return unk_index_; + return (int64_t)unk_index_; } -void Vocab::append_token(const std::string &token) { - uint32_t token_hash = hash(token); - if (stoi_.find(token_hash) == stoi_.end()) { - itos_.push_back(token); - stoi_[token_hash] = itos_.size() - 1; - } -} +void Vocab::append_token(const std::string &token) { _add(token); } void Vocab::insert_token(const std::string &token, const int64_t &index) { - if (index < 0 || index > static_cast(stoi_.size())) { + if (index < 0 || index > static_cast(itos_.size())) { #ifdef _MSC_VER std::cerr << "[RuntimeError] Specified index " << index << " is out of bounds of the size of stoi dictionary: " @@ -57,31 +52,30 @@ void Vocab::insert_token(const std::string &token, const int64_t &index) { std::to_string(stoi_.size()) + "."); } - uint32_t token_hash = hash(token); - const auto &item = stoi_.find(token_hash); // if item already in stoi we throw an error - if (item != stoi_.end()) { + auto token_position = _find(c10::string_view{token}); + if (stoi_[token_position] != -1) { #ifdef _MSC_VER std::cerr << "[RuntimeError] Token " << token - << " already exists in the Vocab with index: " << item->second - << std::endl; + << " already exists in the Vocab with index: " + << stoi_[token_position] << std::endl; #endif throw std::runtime_error("Token " + token + " already exists in the Vocab with index: " + - std::to_string(item->second) + "."); + std::to_string(stoi_[token_position]) + "."); } // need to offset all tokens greater than or equal index by 1 for (size_t i = index; i < itos_.size(); i++) { - stoi_[hash(c10::string_view{itos_[i].data(), itos_[i].size()})] = i + 1; + stoi_[_find(c10::string_view{itos_[i]})] = i + 1; } itos_.insert(itos_.begin() + index, token); - stoi_[token_hash] = index; + stoi_[_find(c10::string_view{token})] = index; // need to update unk_index in case token equals unk_token or token // inserted before unk_token - unk_index_ = stoi_.find(hash(c10::string_view(unk_token_)))->second; + unk_index_ = stoi_[_find(c10::string_view(unk_token_))]; } std::string Vocab::lookup_token(const int64_t &index) { @@ -89,7 +83,7 @@ std::string Vocab::lookup_token(const int64_t &index) { #ifdef _MSC_VER std::cerr << "[RuntimeError] Specified index " << index << " is out of bounds of the size of itos dictionary: " - << stoi_.size() << std::endl; + << itos_.size() << std::endl; #endif throw std::runtime_error( "Specified index " + std::to_string(index) + @@ -102,7 +96,7 @@ std::string Vocab::lookup_token(const int64_t &index) { StringList Vocab::lookup_tokens(const std::vector &indices) { std::vector tokens(indices.size()); - for (int64_t i = 0; i < static_cast(indices.size()); i++) { + for (uint32_t i = 0; i < static_cast(indices.size()); i++) { tokens[i] = lookup_token(indices[i]); } return tokens; @@ -111,7 +105,7 @@ StringList Vocab::lookup_tokens(const std::vector &indices) { std::vector Vocab::lookup_indices(const std::vector &tokens) { std::vector indices(tokens.size()); - for (int64_t i = 0; i < static_cast(tokens.size()); i++) { + for (uint32_t i = 0; i < static_cast(tokens.size()); i++) { indices[i] = __getitem__(tokens[i]); } return indices; @@ -119,19 +113,17 @@ Vocab::lookup_indices(const std::vector &tokens) { std::unordered_map Vocab::get_stoi() const { std::unordered_map stoi; - stoi.reserve(stoi_.size()); - // construct tokens and index list for (const auto &item : itos_) { - stoi[item] = stoi_.find(hash(c10::string_view{item}))->second; + stoi[item] = __getitem__(c10::string_view{item}); } return stoi; } StringList Vocab::get_itos() const { return itos_; } -int64_t _infer_lines(const std::string &file_path) { - int64_t num_lines = 0; +uint32_t _infer_lines(const std::string &file_path) { + uint32_t num_lines = 0; std::ifstream fin; fin.open(file_path, std::ios::in); @@ -142,10 +134,10 @@ int64_t _infer_lines(const std::string &file_path) { } void parse_vocab_file_chunk( - const std::string &file_path, size_t offset, const int64_t start_line, - const int64_t end_line, + const std::string &file_path, size_t offset, const uint32_t start_line, + const uint32_t end_line, std::shared_ptr< - ska_ordered::order_preserving_flat_hash_map> + ska_ordered::order_preserving_flat_hash_map> counter) { std::ifstream fin(file_path, std::ios::in); if (!fin.is_open()) { @@ -154,7 +146,7 @@ void parse_vocab_file_chunk( fin.seekg(offset); - for (int64_t i = start_line; i < end_line; i++) { + for (uint32_t i = start_line; i < end_line; i++) { std::string token; fin >> token; fin >> std::ws; @@ -168,10 +160,10 @@ void parse_vocab_file_chunk( } void parse_raw_text_file_chunk( - const std::string &file_path, size_t offset, const int64_t start_line, - const int64_t end_line, + const std::string &file_path, size_t offset, const uint32_t start_line, + const uint32_t end_line, std::shared_ptr< - ska_ordered::order_preserving_flat_hash_map> + ska_ordered::order_preserving_flat_hash_map> counter, torch::jit::script::Module &module) { std::ifstream fin(file_path, std::ios::in); @@ -182,7 +174,7 @@ void parse_raw_text_file_chunk( fin.seekg(offset); std::string line; - for (int64_t i = start_line; i < end_line; i++) { + for (uint32_t i = start_line; i < end_line; i++) { std::getline(fin, line); auto token_list = module.forward(std::vector({c10::IValue(line)})).toList(); @@ -202,8 +194,8 @@ void parse_raw_text_file_chunk( // sorting using a custom object struct CompareTokens { - bool operator()(const std::pair &a, - const std::pair &b) { + bool operator()(const std::pair &a, + const std::pair &b) { if (a.second == b.second) { return a.first < b.first; } @@ -213,14 +205,15 @@ struct CompareTokens { StringList _concat_tokens( std::vector>> + ska_ordered::order_preserving_flat_hash_map>> chunk_counters, - const std::string &unk_token, const int64_t min_freq, - const int64_t num_lines, const bool sort_tokens) { + const std::string &unk_token, const uint32_t min_freq, + const uint32_t num_lines, const bool sort_tokens) { TORCH_CHECK(chunk_counters.size() > 0, "There must be at least 1 chunk to concatenate!"); - ska_ordered::order_preserving_flat_hash_map tokens_freq; + ska_ordered::order_preserving_flat_hash_map + tokens_freq; StringList unique_tokens; unique_tokens.reserve(num_lines); @@ -228,7 +221,7 @@ StringList _concat_tokens( for (size_t i = 0; i < chunk_counters.size(); i++) { auto &cur_counter = *chunk_counters[i]; for (const auto &item : cur_counter) { - int64_t cur_token_freq = item.second; + uint32_t cur_token_freq = item.second; if (tokens_freq.find(item.first) != tokens_freq.end()) { tokens_freq[item.first] += cur_token_freq; } else { @@ -244,7 +237,7 @@ StringList _concat_tokens( } // create token freq pairs - std::vector> token_freq_pairs; + std::vector> token_freq_pairs; for (std::string token : unique_tokens) { token_freq_pairs.push_back(std::make_pair(token, tokens_freq[token])); @@ -275,14 +268,14 @@ StringList _concat_tokens( return unique_tokens; } -constexpr int64_t GRAIN_SIZE = 13107; +constexpr uint32_t GRAIN_SIZE = 13107; Vocab _load_vocab_from_file(const std::string &file_path, const std::string &unk_token, - const int64_t min_freq, const int64_t num_cpus) { + const uint32_t min_freq, const uint32_t num_cpus) { std::cerr << "[INFO] Reading file " << file_path << std::endl; - int64_t num_lines = _infer_lines(file_path); - int64_t chunk_size = impl::divup(num_lines, num_cpus); + uint32_t num_lines = _infer_lines(file_path); + uint32_t chunk_size = impl::divup(num_lines, num_cpus); // Launching a thread on less lines than this likely has too much overhead. // TODO: Add explicit test beyond grain size to cover multithreading chunk_size = std::max(chunk_size, GRAIN_SIZE); @@ -291,7 +284,7 @@ Vocab _load_vocab_from_file(const std::string &file_path, impl::infer_offsets(file_path, num_lines, chunk_size, offsets); std::vector>> + ska_ordered::order_preserving_flat_hash_map>> chunk_counters; std::mutex m; @@ -299,10 +292,10 @@ Vocab _load_vocab_from_file(const std::string &file_path, std::atomic thread_count(0); // create threads - int64_t j = 0; - for (int64_t i = 0; i < num_lines; i += chunk_size) { + uint32_t j = 0; + for (uint32_t i = 0; i < num_lines; i += chunk_size) { auto counter_ptr = std::make_shared< - ska_ordered::order_preserving_flat_hash_map>(); + ska_ordered::order_preserving_flat_hash_map>(); thread_count++; at::launch([&, file_path, num_lines, chunk_size, j, i, counter_ptr]() { @@ -328,12 +321,12 @@ Vocab _load_vocab_from_file(const std::string &file_path, Vocab _build_vocab_from_text_file(const std::string &file_path, const std::string &unk_token, - const int64_t min_freq, - const int64_t num_cpus, + const uint32_t min_freq, + const uint32_t num_cpus, torch::jit::script::Module tokenizer) { std::cerr << "[INFO] Reading file " << file_path << std::endl; - int64_t num_lines = _infer_lines(file_path); - int64_t chunk_size = impl::divup(num_lines, num_cpus); + uint32_t num_lines = _infer_lines(file_path); + uint32_t chunk_size = impl::divup(num_lines, num_cpus); // Launching a thread on less lines than this likely has too much overhead. chunk_size = std::max(chunk_size, GRAIN_SIZE); @@ -341,7 +334,7 @@ Vocab _build_vocab_from_text_file(const std::string &file_path, impl::infer_offsets(file_path, num_lines, chunk_size, offsets); std::vector>> + ska_ordered::order_preserving_flat_hash_map>> chunk_counters; std::mutex m; @@ -349,10 +342,10 @@ Vocab _build_vocab_from_text_file(const std::string &file_path, std::atomic thread_count(0); // create threads - int64_t j = 0; - for (int64_t i = 0; i < num_lines; i += chunk_size) { + uint32_t j = 0; + for (uint32_t i = 0; i < num_lines; i += chunk_size) { auto counter_ptr = std::make_shared< - ska_ordered::order_preserving_flat_hash_map>(); + ska_ordered::order_preserving_flat_hash_map>(); thread_count++; at::launch([&, file_path, num_lines, chunk_size, j, i, counter_ptr]() { parse_raw_text_file_chunk(file_path, offsets[j], i, diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index 36514b60e1..6b4c571539 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -3,18 +3,14 @@ namespace torchtext { typedef std::vector StringList; -typedef ska_ordered::order_preserving_flat_hash_map - IndexDict; typedef std::tuple, std::vector, std::vector> VocabStates; struct Vocab : torch::CustomClassHolder { -private: + static const int32_t MAX_VOCAB_SIZE = 30000000; int64_t unk_index_; - IndexDict stoi_; - -public: + std::vector stoi_; const std::string version_str_ = "0.0.1"; StringList itos_; std::string unk_token_; @@ -33,7 +29,7 @@ struct Vocab : torch::CustomClassHolder { std::vector get_itos() const; protected: - uint32_t hash(const c10::string_view &str) const { + uint32_t _hash(const c10::string_view &str) const { uint32_t h = 2166136261; for (size_t i = 0; i < str.size(); i++) { h = h ^ uint32_t(uint8_t(str[i])); @@ -41,6 +37,24 @@ struct Vocab : torch::CustomClassHolder { } return h; } + + int32_t _find(const c10::string_view &w) const { + int32_t stoi_size = stoi_.size(); + int32_t id = _hash(w) % stoi_size; + while (stoi_[id] != -1 && c10::string_view{itos_[stoi_[id]].data(), + itos_[stoi_[id]].size()} != w) { + id = (id + 1) % stoi_size; + } + return id; + } + + void _add(const std::string &w) { + int32_t h = _find(c10::string_view{w.data(), w.size()}); + if (stoi_[h] == -1) { + itos_.push_back(w); + stoi_[h] = itos_.size() - 1; + } + } }; VocabStates _serialize_vocab(const c10::intrusive_ptr &self); @@ -48,11 +62,11 @@ c10::intrusive_ptr _deserialize_vocab(VocabStates states); Vocab _load_vocab_from_file(const std::string &file_path, const std::string &unk_token, - const int64_t min_freq, const int64_t num_cpus); + const uint32_t min_freq, const uint32_t num_cpus); Vocab _build_vocab_from_text_file(const std::string &file_path, const std::string &unk_token, - const int64_t min_freq, - const int64_t num_cpus, + const uint32_t min_freq, + const uint32_t num_cpus, torch::jit::script::Module tokenizer); } // namespace torchtext From 4216fcb290516df6eeb49a781ca20eed945a5387 Mon Sep 17 00:00:00 2001 From: Parmeet Bhatia Date: Sun, 14 Mar 2021 18:32:00 -0700 Subject: [PATCH 13/15] some cleanups --- torchtext/csrc/vocab.cpp | 16 ++++++++-------- torchtext/csrc/vocab.h | 8 ++++---- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 731fd4995f..1cadd8c5d0 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -24,7 +24,7 @@ Vocab::Vocab(const StringList &tokens, const std::string &unk_token) _add(tokens[i]); } - unk_index_ = stoi_[_find(c10::string_view{unk_token})]; + unk_index_ = stoi_[_find(c10::string_view{unk_token.data(),unk_token.size()})]; } int64_t Vocab::__len__() const { return itos_.size(); } @@ -32,15 +32,15 @@ int64_t Vocab::__len__() const { return itos_.size(); } int64_t Vocab::__getitem__(const c10::string_view &token) const { uint32_t id = _find(token); if (stoi_[id] != -1) { - return (int64_t)stoi_[id]; + return stoi_[id]; } - return (int64_t)unk_index_; + return unk_index_; } void Vocab::append_token(const std::string &token) { _add(token); } void Vocab::insert_token(const std::string &token, const int64_t &index) { - if (index < 0 || index > static_cast(itos_.size())) { + if (index < 0 || index > itos_.size()) { #ifdef _MSC_VER std::cerr << "[RuntimeError] Specified index " << index << " is out of bounds of the size of stoi dictionary: " @@ -53,7 +53,7 @@ void Vocab::insert_token(const std::string &token, const int64_t &index) { } // if item already in stoi we throw an error - auto token_position = _find(c10::string_view{token}); + auto token_position = _find(c10::string_view{token.data(),token.size()}); if (stoi_[token_position] != -1) { #ifdef _MSC_VER std::cerr << "[RuntimeError] Token " << token @@ -67,15 +67,15 @@ void Vocab::insert_token(const std::string &token, const int64_t &index) { // need to offset all tokens greater than or equal index by 1 for (size_t i = index; i < itos_.size(); i++) { - stoi_[_find(c10::string_view{itos_[i]})] = i + 1; + stoi_[_find(c10::string_view{itos_[i].data(),itos_[i].size()})] = i + 1; } itos_.insert(itos_.begin() + index, token); - stoi_[_find(c10::string_view{token})] = index; + stoi_[_find(c10::string_view{token.data(),token.size()})] = index; // need to update unk_index in case token equals unk_token or token // inserted before unk_token - unk_index_ = stoi_[_find(c10::string_view(unk_token_))]; + unk_index_ = stoi_[_find(c10::string_view{unk_token_.data(),unk_token_.size()})]; } std::string Vocab::lookup_token(const int64_t &index) { diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index 6b4c571539..18f6f4031b 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -38,9 +38,9 @@ struct Vocab : torch::CustomClassHolder { return h; } - int32_t _find(const c10::string_view &w) const { - int32_t stoi_size = stoi_.size(); - int32_t id = _hash(w) % stoi_size; + uint32_t _find(const c10::string_view &w) const { + uint32_t stoi_size = stoi_.size(); + uint32_t id = _hash(w) % stoi_size; while (stoi_[id] != -1 && c10::string_view{itos_[stoi_[id]].data(), itos_[stoi_[id]].size()} != w) { id = (id + 1) % stoi_size; @@ -49,7 +49,7 @@ struct Vocab : torch::CustomClassHolder { } void _add(const std::string &w) { - int32_t h = _find(c10::string_view{w.data(), w.size()}); + uint32_t h = _find(c10::string_view{w.data(), w.size()}); if (stoi_[h] == -1) { itos_.push_back(w); stoi_[h] = itos_.size() - 1; From f00ec120910b8a00ca0d3c6fceb34bb7b3ad0680 Mon Sep 17 00:00:00 2001 From: Parmeet Bhatia Date: Wed, 17 Mar 2021 19:05:21 -0700 Subject: [PATCH 14/15] changing uint32_t back to int64_t as per original state --- torchtext/csrc/vocab.cpp | 109 +++++++++++++++++---------------------- torchtext/csrc/vocab.h | 8 +-- 2 files changed, 53 insertions(+), 64 deletions(-) diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index 1cadd8c5d0..b6fa2099d7 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -23,14 +23,15 @@ Vocab::Vocab(const StringList &tokens, const std::string &unk_token) } _add(tokens[i]); } - - unk_index_ = stoi_[_find(c10::string_view{unk_token.data(),unk_token.size()})]; + + unk_index_ = + stoi_[_find(c10::string_view{unk_token.data(), unk_token.size()})]; } int64_t Vocab::__len__() const { return itos_.size(); } int64_t Vocab::__getitem__(const c10::string_view &token) const { - uint32_t id = _find(token); + int64_t id = _find(token); if (stoi_[id] != -1) { return stoi_[id]; } @@ -53,7 +54,7 @@ void Vocab::insert_token(const std::string &token, const int64_t &index) { } // if item already in stoi we throw an error - auto token_position = _find(c10::string_view{token.data(),token.size()}); + auto token_position = _find(c10::string_view{token.data(), token.size()}); if (stoi_[token_position] != -1) { #ifdef _MSC_VER std::cerr << "[RuntimeError] Token " << token @@ -67,15 +68,16 @@ void Vocab::insert_token(const std::string &token, const int64_t &index) { // need to offset all tokens greater than or equal index by 1 for (size_t i = index; i < itos_.size(); i++) { - stoi_[_find(c10::string_view{itos_[i].data(),itos_[i].size()})] = i + 1; + stoi_[_find(c10::string_view{itos_[i].data(), itos_[i].size()})] = i + 1; } itos_.insert(itos_.begin() + index, token); - stoi_[_find(c10::string_view{token.data(),token.size()})] = index; + stoi_[_find(c10::string_view{token.data(), token.size()})] = index; // need to update unk_index in case token equals unk_token or token // inserted before unk_token - unk_index_ = stoi_[_find(c10::string_view{unk_token_.data(),unk_token_.size()})]; + unk_index_ = + stoi_[_find(c10::string_view{unk_token_.data(), unk_token_.size()})]; } std::string Vocab::lookup_token(const int64_t &index) { @@ -96,7 +98,7 @@ std::string Vocab::lookup_token(const int64_t &index) { StringList Vocab::lookup_tokens(const std::vector &indices) { std::vector tokens(indices.size()); - for (uint32_t i = 0; i < static_cast(indices.size()); i++) { + for (int64_t i = 0; i < static_cast(indices.size()); i++) { tokens[i] = lookup_token(indices[i]); } return tokens; @@ -105,7 +107,7 @@ StringList Vocab::lookup_tokens(const std::vector &indices) { std::vector Vocab::lookup_indices(const std::vector &tokens) { std::vector indices(tokens.size()); - for (uint32_t i = 0; i < static_cast(tokens.size()); i++) { + for (int64_t i = 0; i < static_cast(tokens.size()); i++) { indices[i] = __getitem__(tokens[i]); } return indices; @@ -122,8 +124,8 @@ std::unordered_map Vocab::get_stoi() const { StringList Vocab::get_itos() const { return itos_; } -uint32_t _infer_lines(const std::string &file_path) { - uint32_t num_lines = 0; +int64_t _infer_lines(const std::string &file_path) { + int64_t num_lines = 0; std::ifstream fin; fin.open(file_path, std::ios::in); @@ -133,12 +135,9 @@ uint32_t _infer_lines(const std::string &file_path) { return num_lines; } -void parse_vocab_file_chunk( - const std::string &file_path, size_t offset, const uint32_t start_line, - const uint32_t end_line, - std::shared_ptr< - ska_ordered::order_preserving_flat_hash_map> - counter) { +void parse_vocab_file_chunk(const std::string &file_path, size_t offset, + const int64_t start_line, const int64_t end_line, + std::shared_ptr counter) { std::ifstream fin(file_path, std::ios::in); if (!fin.is_open()) { throw std::runtime_error("Cannot open input file " + file_path + "\n"); @@ -146,7 +145,7 @@ void parse_vocab_file_chunk( fin.seekg(offset); - for (uint32_t i = start_line; i < end_line; i++) { + for (int64_t i = start_line; i < end_line; i++) { std::string token; fin >> token; fin >> std::ws; @@ -159,13 +158,10 @@ void parse_vocab_file_chunk( } } -void parse_raw_text_file_chunk( - const std::string &file_path, size_t offset, const uint32_t start_line, - const uint32_t end_line, - std::shared_ptr< - ska_ordered::order_preserving_flat_hash_map> - counter, - torch::jit::script::Module &module) { +void parse_raw_text_file_chunk(const std::string &file_path, size_t offset, + const int64_t start_line, const int64_t end_line, + std::shared_ptr counter, + torch::jit::script::Module &module) { std::ifstream fin(file_path, std::ios::in); if (!fin.is_open()) { throw std::runtime_error("Cannot open input file " + file_path + "\n"); @@ -174,7 +170,7 @@ void parse_raw_text_file_chunk( fin.seekg(offset); std::string line; - for (uint32_t i = start_line; i < end_line; i++) { + for (int64_t i = start_line; i < end_line; i++) { std::getline(fin, line); auto token_list = module.forward(std::vector({c10::IValue(line)})).toList(); @@ -194,8 +190,8 @@ void parse_raw_text_file_chunk( // sorting using a custom object struct CompareTokens { - bool operator()(const std::pair &a, - const std::pair &b) { + bool operator()(const std::pair &a, + const std::pair &b) { if (a.second == b.second) { return a.first < b.first; } @@ -203,17 +199,14 @@ struct CompareTokens { } }; -StringList _concat_tokens( - std::vector>> - chunk_counters, - const std::string &unk_token, const uint32_t min_freq, - const uint32_t num_lines, const bool sort_tokens) { +StringList +_concat_tokens(std::vector> chunk_counters, + const std::string &unk_token, const int64_t min_freq, + const int64_t num_lines, const bool sort_tokens) { TORCH_CHECK(chunk_counters.size() > 0, "There must be at least 1 chunk to concatenate!"); - ska_ordered::order_preserving_flat_hash_map - tokens_freq; + IndexDict tokens_freq; StringList unique_tokens; unique_tokens.reserve(num_lines); @@ -221,7 +214,7 @@ StringList _concat_tokens( for (size_t i = 0; i < chunk_counters.size(); i++) { auto &cur_counter = *chunk_counters[i]; for (const auto &item : cur_counter) { - uint32_t cur_token_freq = item.second; + int64_t cur_token_freq = item.second; if (tokens_freq.find(item.first) != tokens_freq.end()) { tokens_freq[item.first] += cur_token_freq; } else { @@ -231,13 +224,13 @@ StringList _concat_tokens( // add to tokens list only if we exceed min_freq for the first time if (tokens_freq[item.first] - cur_token_freq < min_freq && tokens_freq[item.first] >= min_freq) { - unique_tokens.push_back(std::string{item.first}); + unique_tokens.push_back(item.first); } } } // create token freq pairs - std::vector> token_freq_pairs; + std::vector> token_freq_pairs; for (std::string token : unique_tokens) { token_freq_pairs.push_back(std::make_pair(token, tokens_freq[token])); @@ -268,14 +261,14 @@ StringList _concat_tokens( return unique_tokens; } -constexpr uint32_t GRAIN_SIZE = 13107; +constexpr int64_t GRAIN_SIZE = 13107; Vocab _load_vocab_from_file(const std::string &file_path, const std::string &unk_token, - const uint32_t min_freq, const uint32_t num_cpus) { + const int64_t min_freq, const int64_t num_cpus) { std::cerr << "[INFO] Reading file " << file_path << std::endl; - uint32_t num_lines = _infer_lines(file_path); - uint32_t chunk_size = impl::divup(num_lines, num_cpus); + int64_t num_lines = _infer_lines(file_path); + int64_t chunk_size = impl::divup(num_lines, num_cpus); // Launching a thread on less lines than this likely has too much overhead. // TODO: Add explicit test beyond grain size to cover multithreading chunk_size = std::max(chunk_size, GRAIN_SIZE); @@ -283,19 +276,16 @@ Vocab _load_vocab_from_file(const std::string &file_path, std::vector offsets; impl::infer_offsets(file_path, num_lines, chunk_size, offsets); - std::vector>> - chunk_counters; + std::vector> chunk_counters; std::mutex m; std::condition_variable cv; std::atomic thread_count(0); // create threads - uint32_t j = 0; - for (uint32_t i = 0; i < num_lines; i += chunk_size) { - auto counter_ptr = std::make_shared< - ska_ordered::order_preserving_flat_hash_map>(); + int64_t j = 0; + for (int64_t i = 0; i < num_lines; i += chunk_size) { + auto counter_ptr = std::make_shared(); thread_count++; at::launch([&, file_path, num_lines, chunk_size, j, i, counter_ptr]() { @@ -321,31 +311,28 @@ Vocab _load_vocab_from_file(const std::string &file_path, Vocab _build_vocab_from_text_file(const std::string &file_path, const std::string &unk_token, - const uint32_t min_freq, - const uint32_t num_cpus, + const int64_t min_freq, + const int64_t num_cpus, torch::jit::script::Module tokenizer) { std::cerr << "[INFO] Reading file " << file_path << std::endl; - uint32_t num_lines = _infer_lines(file_path); - uint32_t chunk_size = impl::divup(num_lines, num_cpus); + int64_t num_lines = _infer_lines(file_path); + int64_t chunk_size = impl::divup(num_lines, num_cpus); // Launching a thread on less lines than this likely has too much overhead. chunk_size = std::max(chunk_size, GRAIN_SIZE); std::vector offsets; impl::infer_offsets(file_path, num_lines, chunk_size, offsets); - std::vector>> - chunk_counters; + std::vector> chunk_counters; std::mutex m; std::condition_variable cv; std::atomic thread_count(0); // create threads - uint32_t j = 0; - for (uint32_t i = 0; i < num_lines; i += chunk_size) { - auto counter_ptr = std::make_shared< - ska_ordered::order_preserving_flat_hash_map>(); + int64_t j = 0; + for (int64_t i = 0; i < num_lines; i += chunk_size) { + auto counter_ptr = std::make_shared(); thread_count++; at::launch([&, file_path, num_lines, chunk_size, j, i, counter_ptr]() { parse_raw_text_file_chunk(file_path, offsets[j], i, diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index 18f6f4031b..c99a52378f 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -3,6 +3,8 @@ namespace torchtext { typedef std::vector StringList; +typedef ska_ordered::order_preserving_flat_hash_map + IndexDict; typedef std::tuple, std::vector, std::vector> VocabStates; @@ -62,11 +64,11 @@ c10::intrusive_ptr _deserialize_vocab(VocabStates states); Vocab _load_vocab_from_file(const std::string &file_path, const std::string &unk_token, - const uint32_t min_freq, const uint32_t num_cpus); + const int64_t min_freq, const int64_t num_cpus); Vocab _build_vocab_from_text_file(const std::string &file_path, const std::string &unk_token, - const uint32_t min_freq, - const uint32_t num_cpus, + const int64_t min_freq, + const int64_t num_cpus, torch::jit::script::Module tokenizer); } // namespace torchtext From 9d7a6145bd221ca7a515455c606958aee1a57aa8 Mon Sep 17 00:00:00 2001 From: Parmeet Bhatia Date: Tue, 23 Mar 2021 07:32:19 -0700 Subject: [PATCH 15/15] removing unnecesary c10::string_view construction for comparison with std::string --- torchtext/csrc/vocab.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index c99a52378f..660f6145d4 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -43,8 +43,7 @@ struct Vocab : torch::CustomClassHolder { uint32_t _find(const c10::string_view &w) const { uint32_t stoi_size = stoi_.size(); uint32_t id = _hash(w) % stoi_size; - while (stoi_[id] != -1 && c10::string_view{itos_[stoi_[id]].data(), - itos_[stoi_[id]].size()} != w) { + while (stoi_[id] != -1 && itos_[stoi_[id]]!= w) { id = (id + 1) % stoi_size; } return id;