From 6f927d5f02b92fc900faf4ef6c81d3be13617fab Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Mon, 16 Nov 2020 08:24:11 -0800 Subject: [PATCH 01/13] sp constructor error --- torchtext/csrc/register_bindings.cpp | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index 6cc192c9b4..09e0ec0cf4 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -37,8 +37,17 @@ PYBIND11_MODULE(_torchtext, m) { .def("GetPieceSize", &SentencePiece::GetPieceSize) .def("unk_id", &SentencePiece::unk_id) .def("PieceToId", &SentencePiece::PieceToId) - .def("IdToPiece", &SentencePiece::IdToPiece); - + .def("IdToPiece", &SentencePiece::IdToPiece) + .def(py::pickle( + // __getstate__ + [](std::string state) { + SentencePiece p(state); + return p; + }, + // __setstate__ + [](const SentencePiece &self) { + return py::bytes(self.content_); + })); py::class_(m, "Vectors") .def(py::init, std::vector, torch::Tensor, torch::Tensor>()) From 0a1200059abd67d0146c7f19a12fa90d112b6f17 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Mon, 16 Nov 2020 12:01:33 -0800 Subject: [PATCH 02/13] checkpoint --- torchtext/csrc/register_bindings.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index 09e0ec0cf4..e0fa11797c 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -40,13 +40,14 @@ PYBIND11_MODULE(_torchtext, m) { .def("IdToPiece", &SentencePiece::IdToPiece) .def(py::pickle( // __getstate__ + [](const SentencePiece &self) { + return py::bytes(self.content_); + }, + // __setstate__ [](std::string state) { + // return c10::make_intrusive(std::move(state)); SentencePiece p(state); return p; - }, - // __setstate__ - [](const SentencePiece &self) { - return py::bytes(self.content_); })); py::class_(m, "Vectors") .def(py::init, std::vector, From c6272f3dc89d7c4bfdf3419934c81f43112ea041 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Sun, 22 Nov 2020 12:37:05 -0800 Subject: [PATCH 03/13] checkpoint --- torchtext/csrc/register_bindings.cpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index e0fa11797c..97de96a3c6 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -40,15 +40,14 @@ PYBIND11_MODULE(_torchtext, m) { .def("IdToPiece", &SentencePiece::IdToPiece) .def(py::pickle( // __getstate__ - [](const SentencePiece &self) { - return py::bytes(self.content_); + [](const c10::intrusive_ptr &self) { + return self->content_; }, // __setstate__ [](std::string state) { - // return c10::make_intrusive(std::move(state)); - SentencePiece p(state); - return p; + return c10::make_intrusive(std::move(state)); })); + py::class_(m, "Vectors") .def(py::init, std::vector, torch::Tensor, torch::Tensor>()) From 28156e3410b9a97c38e3cef6b82efe21bc8f25a4 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Sun, 22 Nov 2020 15:16:00 -0800 Subject: [PATCH 04/13] successful pybind11 pickle --- torchtext/csrc/register_bindings.cpp | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index 97de96a3c6..1d0263ec6b 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -16,7 +16,16 @@ PYBIND11_MODULE(_torchtext, m) { // Classes py::class_(m, "Regex") .def(py::init()) - .def("Sub", &Regex::Sub); + .def("Sub", &Regex::Sub) + .def(py::pickle( + // __getstate__ + [](const Regex &self) { + return self.re_str_; + }, + // __setstate__ + [](std::string state) { + return Regex(state); + })); py::class_(m, "RegexTokenizer") .def_readonly("patterns_", &RegexTokenizer::patterns_) @@ -37,16 +46,7 @@ PYBIND11_MODULE(_torchtext, m) { .def("GetPieceSize", &SentencePiece::GetPieceSize) .def("unk_id", &SentencePiece::unk_id) .def("PieceToId", &SentencePiece::PieceToId) - .def("IdToPiece", &SentencePiece::IdToPiece) - .def(py::pickle( - // __getstate__ - [](const c10::intrusive_ptr &self) { - return self->content_; - }, - // __setstate__ - [](std::string state) { - return c10::make_intrusive(std::move(state)); - })); + .def("IdToPiece", &SentencePiece::IdToPiece); py::class_(m, "Vectors") .def(py::init, std::vector, From eec2363dc9e95a5c6caa931a161341faff200d1f Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Sun, 22 Nov 2020 17:47:40 -0800 Subject: [PATCH 05/13] add RegexTokenizer pybind11 pickle --- torchtext/csrc/register_bindings.cpp | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index 1d0263ec6b..25c2ca2f89 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -32,7 +32,24 @@ PYBIND11_MODULE(_torchtext, m) { .def_readonly("replacements_", &RegexTokenizer::replacements_) .def_readonly("to_lower_", &RegexTokenizer::to_lower_) .def(py::init, std::vector, bool>()) - .def("forward", &RegexTokenizer::forward); + .def("forward", &RegexTokenizer::forward) + .def(py::pickle( + // __setstate__ + [](const RegexTokenizer &self) { + return std::make_tuple(self.patterns_, self.replacements_, + self.to_lower_); + }, + // __getstate__ + [](std::tuple, std::vector, + bool> + states) { + auto patterns = std::get<0>(states); + auto replacements = std::get<1>(states); + auto to_lower = std::get<2>(states); + + return RegexTokenizer( + std::move(patterns), std::move(replacements), to_lower); + })); py::class_(m, "SentencePiece") .def(py::init()) From ea27cbe4529689fd6ff2ac24038fc8e754104783 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Sun, 22 Nov 2020 18:15:03 -0800 Subject: [PATCH 06/13] Vocab pybind11 pickle --- torchtext/csrc/register_bindings.cpp | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index 25c2ca2f89..c56d5031a0 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -88,7 +88,21 @@ PYBIND11_MODULE(_torchtext, m) { .def("lookup_tokens", &Vocab::lookup_tokens) .def("lookup_indices", &Vocab::lookup_indices) .def("get_stoi", &Vocab::get_stoi) - .def("get_itos", &Vocab::get_itos); + .def("get_itos", &Vocab::get_itos) + .def(py::pickle( + // __setstate__ + [](const Vocab &self) { + StringList strings = self.itos_; + strings.push_back(self.unk_token_); + return std::make_tuple(strings); + }, + // __getstate__ + [](std::tuple states) { + auto strings = std::get<0>(states); + std::string unk_token = strings.back(); + strings.pop_back(); // remove last element which is unk_token + return Vocab(std::move(strings), std::move(unk_token)); + })); // Functions m.def("_load_token_and_vectors_from_file", From 0b84d550c056f3266c739c0633251634439ddfcc Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Sun, 22 Nov 2020 18:25:34 -0800 Subject: [PATCH 07/13] vectors pybind11 pickle --- torchtext/csrc/register_bindings.cpp | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index c56d5031a0..3be5045038 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -74,7 +74,33 @@ PYBIND11_MODULE(_torchtext, m) { .def("__getitem__", &Vectors::__getitem__) .def("lookup_vectors", &Vectors::lookup_vectors) .def("__setitem__", &Vectors::__setitem__) - .def("__len__", &Vectors::__len__); + .def("__len__", &Vectors::__len__) + .def(py::pickle( + // __setstate__ + [](const Vectors &self) { + std::vector tokens; + std::vector indices; + for (const auto &item : self.stoi_) { + tokens.push_back(item.first); + indices.push_back(item.second); + } + std::vector integers = std::move(indices); + std::vector strings = std::move(tokens); + std::vector tensors{self.vectors_, self.unk_tensor_}; + return std::make_tuple(std::move(integers), std::move(strings), std::move(tensors)); + }, + // __getstate__ + [](std::tuple, std::vector, std::vector> states) { + auto integers = std::get<0>(states); + auto strings = std::get<1>(states); + auto tensors = std::get<2>(states); + IndexMap stoi; + stoi.reserve(integers.size()); + for (size_t i = 0; i < integers.size(); i++) { + stoi[strings[i]] = integers[i]; + } + return Vectors(std::move(stoi), std::move(tensors[0]), std::move(tensors[1])); + })); py::class_(m, "Vocab") .def(py::init, std::string>()) From 0dfdd31dc34f8e50982817112d0dcfd40b139fed Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Tue, 24 Nov 2020 07:17:34 -0800 Subject: [PATCH 08/13] remove pybind pickle for regex and regextokenizer --- torchtext/csrc/register_bindings.cpp | 110 +++++++++++---------------- 1 file changed, 44 insertions(+), 66 deletions(-) diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index 3be5045038..b5df2dae9d 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -16,40 +16,14 @@ PYBIND11_MODULE(_torchtext, m) { // Classes py::class_(m, "Regex") .def(py::init()) - .def("Sub", &Regex::Sub) - .def(py::pickle( - // __getstate__ - [](const Regex &self) { - return self.re_str_; - }, - // __setstate__ - [](std::string state) { - return Regex(state); - })); + .def("Sub", &Regex::Sub); py::class_(m, "RegexTokenizer") .def_readonly("patterns_", &RegexTokenizer::patterns_) .def_readonly("replacements_", &RegexTokenizer::replacements_) .def_readonly("to_lower_", &RegexTokenizer::to_lower_) .def(py::init, std::vector, bool>()) - .def("forward", &RegexTokenizer::forward) - .def(py::pickle( - // __setstate__ - [](const RegexTokenizer &self) { - return std::make_tuple(self.patterns_, self.replacements_, - self.to_lower_); - }, - // __getstate__ - [](std::tuple, std::vector, - bool> - states) { - auto patterns = std::get<0>(states); - auto replacements = std::get<1>(states); - auto to_lower = std::get<2>(states); - - return RegexTokenizer( - std::move(patterns), std::move(replacements), to_lower); - })); + .def("forward", &RegexTokenizer::forward); py::class_(m, "SentencePiece") .def(py::init()) @@ -76,31 +50,35 @@ PYBIND11_MODULE(_torchtext, m) { .def("__setitem__", &Vectors::__setitem__) .def("__len__", &Vectors::__len__) .def(py::pickle( - // __setstate__ - [](const Vectors &self) { - std::vector tokens; - std::vector indices; - for (const auto &item : self.stoi_) { - tokens.push_back(item.first); - indices.push_back(item.second); - } - std::vector integers = std::move(indices); - std::vector strings = std::move(tokens); - std::vector tensors{self.vectors_, self.unk_tensor_}; - return std::make_tuple(std::move(integers), std::move(strings), std::move(tensors)); - }, - // __getstate__ - [](std::tuple, std::vector, std::vector> states) { - auto integers = std::get<0>(states); - auto strings = std::get<1>(states); - auto tensors = std::get<2>(states); - IndexMap stoi; - stoi.reserve(integers.size()); - for (size_t i = 0; i < integers.size(); i++) { - stoi[strings[i]] = integers[i]; - } - return Vectors(std::move(stoi), std::move(tensors[0]), std::move(tensors[1])); - })); + // __setstate__ + [](const Vectors &self) { + std::vector tokens; + std::vector indices; + for (const auto &item : self.stoi_) { + tokens.push_back(item.first); + indices.push_back(item.second); + } + std::vector integers = std::move(indices); + std::vector strings = std::move(tokens); + std::vector tensors{self.vectors_, self.unk_tensor_}; + return std::make_tuple(std::move(integers), std::move(strings), + std::move(tensors)); + }, + // __getstate__ + [](std::tuple, std::vector, + std::vector> + states) { + auto integers = std::get<0>(states); + auto strings = std::get<1>(states); + auto tensors = std::get<2>(states); + IndexMap stoi; + stoi.reserve(integers.size()); + for (size_t i = 0; i < integers.size(); i++) { + stoi[strings[i]] = integers[i]; + } + return Vectors(std::move(stoi), std::move(tensors[0]), + std::move(tensors[1])); + })); py::class_(m, "Vocab") .def(py::init, std::string>()) @@ -116,19 +94,19 @@ PYBIND11_MODULE(_torchtext, m) { .def("get_stoi", &Vocab::get_stoi) .def("get_itos", &Vocab::get_itos) .def(py::pickle( - // __setstate__ - [](const Vocab &self) { - StringList strings = self.itos_; - strings.push_back(self.unk_token_); - return std::make_tuple(strings); - }, - // __getstate__ - [](std::tuple states) { - auto strings = std::get<0>(states); - std::string unk_token = strings.back(); - strings.pop_back(); // remove last element which is unk_token - return Vocab(std::move(strings), std::move(unk_token)); - })); + // __setstate__ + [](const Vocab &self) { + StringList strings = self.itos_; + strings.push_back(self.unk_token_); + return std::make_tuple(strings); + }, + // __getstate__ + [](std::tuple states) { + auto strings = std::get<0>(states); + std::string unk_token = strings.back(); + strings.pop_back(); // remove last element which is unk_token + return Vocab(std::move(strings), std::move(unk_token)); + })); // Functions m.def("_load_token_and_vectors_from_file", From 7a530f0d231e83311bf23fdcb3d1dfbbb45a876d Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Tue, 24 Nov 2020 07:31:06 -0800 Subject: [PATCH 09/13] update tests --- test/experimental/test_vectors.py | 2 +- test/experimental/test_vocab.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/experimental/test_vectors.py b/test/experimental/test_vectors.py index 946c436e95..f88d0363e1 100644 --- a/test/experimental/test_vectors.py +++ b/test/experimental/test_vectors.py @@ -124,7 +124,7 @@ def test_vectors_load_and_save(self): vectors_obj['b'] = tensorC vector_path = os.path.join(self.test_dir, 'vectors.pt') - torch.save(vectors_obj.to_ivalue(), vector_path) + torch.save(vectors_obj, vector_path) loaded_vectors_obj = torch.load(vector_path) self.assertEqual(loaded_vectors_obj['a'], tensorA) diff --git a/test/experimental/test_vocab.py b/test/experimental/test_vocab.py index 626db2e726..f0629d9c78 100644 --- a/test/experimental/test_vocab.py +++ b/test/experimental/test_vocab.py @@ -200,7 +200,7 @@ def test_vocab_load_and_save(self): self.assertEqual(dict(v.get_stoi()), expected_stoi) vocab_path = os.path.join(self.test_dir, 'vocab.pt') - torch.save(v.to_ivalue(), vocab_path) + torch.save(v, vocab_path) loaded_v = torch.load(vocab_path) self.assertEqual(v.get_itos(), expected_itos) From 0382177d94e111f63bf223eb50af62271cd5d8e9 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Tue, 1 Dec 2020 12:38:17 -0800 Subject: [PATCH 10/13] vocab and vectors use same get_state and set_state func --- torchtext/csrc/register_bindings.cpp | 48 +++++++--------------------- torchtext/csrc/vectors.cpp | 12 +++---- torchtext/csrc/vectors.h | 2 +- torchtext/csrc/vocab.cpp | 8 ++--- torchtext/csrc/vocab.h | 5 ++- 5 files changed, 25 insertions(+), 50 deletions(-) diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index b5df2dae9d..9040c27e18 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -51,33 +51,13 @@ PYBIND11_MODULE(_torchtext, m) { .def("__len__", &Vectors::__len__) .def(py::pickle( // __setstate__ - [](const Vectors &self) { - std::vector tokens; - std::vector indices; - for (const auto &item : self.stoi_) { - tokens.push_back(item.first); - indices.push_back(item.second); - } - std::vector integers = std::move(indices); - std::vector strings = std::move(tokens); - std::vector tensors{self.vectors_, self.unk_tensor_}; - return std::make_tuple(std::move(integers), std::move(strings), - std::move(tensors)); + [](const Vectors &self) -> VectorsStates { + return _set_vectors_states(self); }, // __getstate__ - [](std::tuple, std::vector, - std::vector> - states) { - auto integers = std::get<0>(states); - auto strings = std::get<1>(states); - auto tensors = std::get<2>(states); - IndexMap stoi; - stoi.reserve(integers.size()); - for (size_t i = 0; i < integers.size(); i++) { - stoi[strings[i]] = integers[i]; - } - return Vectors(std::move(stoi), std::move(tensors[0]), - std::move(tensors[1])); + [](VectorsStates states) -> Vectors { + auto vectors = _get_vectors_from_states(states); + return *vectors; })); py::class_(m, "Vocab") @@ -95,17 +75,13 @@ PYBIND11_MODULE(_torchtext, m) { .def("get_itos", &Vocab::get_itos) .def(py::pickle( // __setstate__ - [](const Vocab &self) { - StringList strings = self.itos_; - strings.push_back(self.unk_token_); - return std::make_tuple(strings); + [](const Vocab &self) -> VocabStates { + return _set_vocab_states(self); }, // __getstate__ - [](std::tuple states) { - auto strings = std::get<0>(states); - std::string unk_token = strings.back(); - strings.pop_back(); // remove last element which is unk_token - return Vocab(std::move(strings), std::move(unk_token)); + [](VocabStates states) -> Vocab { + auto vocab = _get_vocab_from_states(states); + return *vocab; })); // Functions @@ -192,7 +168,7 @@ static auto vocab = .def_pickle( // __setstate__ [](const c10::intrusive_ptr &self) -> VocabStates { - return _set_vocab_states(self); + return _set_vocab_states(*self); }, // __getstate__ [](VocabStates states) -> c10::intrusive_ptr { @@ -210,7 +186,7 @@ static auto vectors = .def_pickle( // __setstate__ [](const c10::intrusive_ptr &self) -> VectorsStates { - return _set_vectors_states(self); + return _set_vectors_states(*self); }, // __getstate__ [](VectorsStates states) -> c10::intrusive_ptr { diff --git a/torchtext/csrc/vectors.cpp b/torchtext/csrc/vectors.cpp index 029b3abb0d..0521ac88f0 100644 --- a/torchtext/csrc/vectors.cpp +++ b/torchtext/csrc/vectors.cpp @@ -275,25 +275,25 @@ std::tuple> _load_token_and_vectors_from_file( return result; } -VectorsStates _set_vectors_states(const c10::intrusive_ptr &self) { +VectorsStates _set_vectors_states(const Vectors &self) { std::vector tokens; std::vector indices; - tokens.reserve(self->stoi_.size()); - indices.reserve(self->stoi_.size()); + tokens.reserve(self.stoi_.size()); + indices.reserve(self.stoi_.size()); // construct tokens and index list // we need to store indices because the `vectors_` tensor may have gaps - for (const auto &item : self->stoi_) { + for (const auto &item : self.stoi_) { tokens.push_back(item.first); indices.push_back(item.second); } std::vector integers = std::move(indices); std::vector strings = std::move(tokens); - std::vector tensors{self->vectors_, self->unk_tensor_}; + std::vector tensors{self.vectors_, self.unk_tensor_}; VectorsStates states = - std::make_tuple(self->version_str_, std::move(integers), + std::make_tuple(self.version_str_, std::move(integers), std::move(strings), std::move(tensors)); return states; diff --git a/torchtext/csrc/vectors.h b/torchtext/csrc/vectors.h index fde1f257f0..7b6077a61f 100644 --- a/torchtext/csrc/vectors.h +++ b/torchtext/csrc/vectors.h @@ -33,7 +33,7 @@ struct Vectors : torch::CustomClassHolder { }; c10::intrusive_ptr _get_vectors_from_states(VectorsStates states); -VectorsStates _set_vectors_states(const c10::intrusive_ptr &self); +VectorsStates _set_vectors_states(const Vectors &self); std::tuple> _load_token_and_vectors_from_file( const std::string &file_path, const std::string delimiter_str, diff --git a/torchtext/csrc/vocab.cpp b/torchtext/csrc/vocab.cpp index a9a5f0c844..56dafac73f 100644 --- a/torchtext/csrc/vocab.cpp +++ b/torchtext/csrc/vocab.cpp @@ -381,13 +381,13 @@ Vocab _build_vocab_from_text_file(const std::string &file_path, return Vocab(std::move(tokens), std::move(stoi), unk_token, unk_index); } -VocabStates _set_vocab_states(const c10::intrusive_ptr &self) { +VocabStates _set_vocab_states(const Vocab &self) { std::vector integers; - StringList strings = self->itos_; - strings.push_back(self->unk_token_); + StringList strings = self.itos_; + strings.push_back(self.unk_token_); std::vector tensors; - VocabStates states = std::make_tuple(self->version_str_, std::move(integers), + VocabStates states = std::make_tuple(self.version_str_, std::move(integers), std::move(strings), std::move(tensors)); return states; } diff --git a/torchtext/csrc/vocab.h b/torchtext/csrc/vocab.h index 594f44f2fb..a0d4e2ea7f 100644 --- a/torchtext/csrc/vocab.h +++ b/torchtext/csrc/vocab.h @@ -37,14 +37,13 @@ struct Vocab : torch::CustomClassHolder { }; c10::intrusive_ptr _get_vocab_from_states(VocabStates states); -VocabStates _set_vocab_states(const c10::intrusive_ptr &self); +VocabStates _set_vocab_states(const Vocab &self); Vocab _load_vocab_from_file(const std::string &file_path, const std::string &unk_token, const int64_t min_freq, const int64_t num_cpus); Vocab _build_vocab_from_text_file(const std::string &file_path, const std::string &unk_token, const int64_t min_freq, - const int64_t num_cpus, - py::object tokenizer); + const int64_t num_cpus, py::object tokenizer); } // namespace torchtext From 64b0a18860c96919c13f1a0a6f24c1f8e67840fc Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Wed, 2 Dec 2020 09:27:47 -0800 Subject: [PATCH 11/13] add pickle support for tokenizer --- torchtext/csrc/register_bindings.cpp | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index 9040c27e18..06a76fb2de 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -16,14 +16,38 @@ PYBIND11_MODULE(_torchtext, m) { // Classes py::class_(m, "Regex") .def(py::init()) - .def("Sub", &Regex::Sub); + .def("Sub", &Regex::Sub) + .def(py::pickle( + // __getstate__ + [](const Regex &self) -> std::string { return self.re_str_; }, + // __setstate__ + [](std::string state) -> Regex { return Regex(std::move(state)); })); py::class_(m, "RegexTokenizer") .def_readonly("patterns_", &RegexTokenizer::patterns_) .def_readonly("replacements_", &RegexTokenizer::replacements_) .def_readonly("to_lower_", &RegexTokenizer::to_lower_) .def(py::init, std::vector, bool>()) - .def("forward", &RegexTokenizer::forward); + .def("forward", &RegexTokenizer::forward) + .def(py::pickle( + // __setstate__ + [](const RegexTokenizer &self) + -> std::tuple, std::vector, + bool> { + return std::make_tuple(self.patterns_, self.replacements_, + self.to_lower_); + }, + // __getstate__ + [](std::tuple, std::vector, + bool> + states) -> RegexTokenizer { + auto patterns = std::get<0>(states); + auto replacements = std::get<1>(states); + auto to_lower = std::get<2>(states); + + return RegexTokenizer(std::move(patterns), std::move(replacements), + to_lower); + })); py::class_(m, "SentencePiece") .def(py::init()) From f0d07fa82bd60fc4b073702247a80377496a2fae Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Wed, 2 Dec 2020 13:41:48 -0800 Subject: [PATCH 12/13] move regextokenizer out of registration file --- torchtext/csrc/regex_tokenizer.cpp | 15 +++++++++++ torchtext/csrc/regex_tokenizer.h | 7 ++++++ torchtext/csrc/register_bindings.cpp | 37 ++++++++-------------------- 3 files changed, 32 insertions(+), 27 deletions(-) diff --git a/torchtext/csrc/regex_tokenizer.cpp b/torchtext/csrc/regex_tokenizer.cpp index c31ca36226..ee79e7c354 100644 --- a/torchtext/csrc/regex_tokenizer.cpp +++ b/torchtext/csrc/regex_tokenizer.cpp @@ -44,4 +44,19 @@ void RegexTokenizer::split_(std::string &str, std::vector &tokens, } } +c10::intrusive_ptr +_get_regex_tokenizer_from_states(RegexTokenizerStates states) { + auto &patterns = std::get<0>(states); + auto &replacements = std::get<1>(states); + auto &to_lower = std::get<2>(states); + return c10::make_intrusive(std::move(patterns), + std::move(replacements), to_lower); +} + +RegexTokenizerStates _set_regex_tokenizer_states(const RegexTokenizer &self) { + RegexTokenizerStates states = + std::make_tuple(self.patterns_, self.replacements_, self.to_lower_); + return states; +} + } // namespace torchtext diff --git a/torchtext/csrc/regex_tokenizer.h b/torchtext/csrc/regex_tokenizer.h index d0d9cfbb62..f1e0b369a3 100644 --- a/torchtext/csrc/regex_tokenizer.h +++ b/torchtext/csrc/regex_tokenizer.h @@ -3,6 +3,9 @@ namespace torchtext { +typedef std::tuple, std::vector, bool> + RegexTokenizerStates; + struct RegexTokenizer : torch::CustomClassHolder { private: std::vector compiled_patterns_; @@ -20,4 +23,8 @@ struct RegexTokenizer : torch::CustomClassHolder { std::vector forward(std::string str) const; }; +c10::intrusive_ptr +_get_regex_tokenizer_from_states(RegexTokenizerStates states); +RegexTokenizerStates _set_regex_tokenizer_states(const RegexTokenizer &self); + } // namespace torchtext diff --git a/torchtext/csrc/register_bindings.cpp b/torchtext/csrc/register_bindings.cpp index 06a76fb2de..46262d5680 100644 --- a/torchtext/csrc/register_bindings.cpp +++ b/torchtext/csrc/register_bindings.cpp @@ -31,22 +31,13 @@ PYBIND11_MODULE(_torchtext, m) { .def("forward", &RegexTokenizer::forward) .def(py::pickle( // __setstate__ - [](const RegexTokenizer &self) - -> std::tuple, std::vector, - bool> { - return std::make_tuple(self.patterns_, self.replacements_, - self.to_lower_); + [](const RegexTokenizer &self) -> RegexTokenizerStates { + return _set_regex_tokenizer_states(self); }, // __getstate__ - [](std::tuple, std::vector, - bool> - states) -> RegexTokenizer { - auto patterns = std::get<0>(states); - auto replacements = std::get<1>(states); - auto to_lower = std::get<2>(states); - - return RegexTokenizer(std::move(patterns), std::move(replacements), - to_lower); + [](RegexTokenizerStates states) -> RegexTokenizer { + auto regex_tokenizer = _get_regex_tokenizer_from_states(states); + return *regex_tokenizer; })); py::class_(m, "SentencePiece") @@ -138,21 +129,13 @@ static auto regex_tokenizer = .def_pickle( // __setstate__ [](const c10::intrusive_ptr &self) - -> std::tuple, - std::vector, bool> { - return std::make_tuple(self->patterns_, self->replacements_, - self->to_lower_); + -> RegexTokenizerStates { + return _set_regex_tokenizer_states(*self); }, // __getstate__ - [](std::tuple, std::vector, - bool> - states) -> c10::intrusive_ptr { - auto patterns = std::get<0>(states); - auto replacements = std::get<1>(states); - auto to_lower = std::get<2>(states); - - return c10::make_intrusive( - std::move(patterns), std::move(replacements), to_lower); + [](RegexTokenizerStates states) + -> c10::intrusive_ptr { + return _get_regex_tokenizer_from_states(states); }); static auto sentencepiece = From 54deffc9742f0b08fd3aab30de79d82af7acb515 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Fri, 4 Dec 2020 07:25:04 -0800 Subject: [PATCH 13/13] add test for pybind tokenizer --- test/data/test_functional.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/test/data/test_functional.py b/test/data/test_functional.py index 66fda21154..45a7b69494 100644 --- a/test/data/test_functional.py +++ b/test/data/test_functional.py @@ -107,13 +107,21 @@ def test_BasicEnglishNormalize(self): self.assertEqual(eager_tokens, ref_results) self.assertEqual(experimental_eager_tokens, ref_results) - # test load and save + # test pybind load and save + save_path = os.path.join(self.test_dir, 'pybind_basic_english_normalize.pt') + torch.save(basic_eng_norm, save_path) + pybind_loaded_basic_eng_norm = torch.load(save_path) + + pybind_loaded_eager_tokens = pybind_loaded_basic_eng_norm(test_sample) + self.assertEqual(pybind_loaded_eager_tokens, ref_results) + + # test torchbind load and save save_path = os.path.join(self.test_dir, 'basic_english_normalize.pt') torch.save(basic_eng_norm.to_ivalue(), save_path) - loaded_basic_eng_norm = torch.load(save_path) + torchbind_loaded_basic_eng_norm = torch.load(save_path) - loaded_eager_tokens = loaded_basic_eng_norm(test_sample) - self.assertEqual(loaded_eager_tokens, ref_results) + torchbind_loaded_eager_tokens = torchbind_loaded_basic_eng_norm(test_sample) + self.assertEqual(torchbind_loaded_eager_tokens, ref_results) # TODO(Nayef211): remove decorator once https://github.com/pytorch/pytorch/issues/38207 is closed @unittest.skipIf(platform.system() == "Windows", "Test is known to fail on Windows.")