Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
16 changes: 12 additions & 4 deletions test/data/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,21 @@ def test_BasicEnglishNormalize(self):
self.assertEqual(eager_tokens, ref_results)
self.assertEqual(experimental_eager_tokens, ref_results)

# test load and save
# test pybind load and save
save_path = os.path.join(self.test_dir, 'pybind_basic_english_normalize.pt')
torch.save(basic_eng_norm, save_path)
pybind_loaded_basic_eng_norm = torch.load(save_path)

pybind_loaded_eager_tokens = pybind_loaded_basic_eng_norm(test_sample)
self.assertEqual(pybind_loaded_eager_tokens, ref_results)

# test torchbind load and save
save_path = os.path.join(self.test_dir, 'basic_english_normalize.pt')
torch.save(basic_eng_norm.to_ivalue(), save_path)
loaded_basic_eng_norm = torch.load(save_path)
torchbind_loaded_basic_eng_norm = torch.load(save_path)

loaded_eager_tokens = loaded_basic_eng_norm(test_sample)
self.assertEqual(loaded_eager_tokens, ref_results)
torchbind_loaded_eager_tokens = torchbind_loaded_basic_eng_norm(test_sample)
self.assertEqual(torchbind_loaded_eager_tokens, ref_results)

# TODO(Nayef211): remove decorator once https://github.com/pytorch/pytorch/issues/38207 is closed
@unittest.skipIf(platform.system() == "Windows", "Test is known to fail on Windows.")
Expand Down
2 changes: 1 addition & 1 deletion test/experimental/test_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_vectors_load_and_save(self):
vectors_obj['b'] = tensorC

vector_path = os.path.join(self.test_dir, 'vectors.pt')
torch.save(vectors_obj.to_ivalue(), vector_path)
torch.save(vectors_obj, vector_path)
loaded_vectors_obj = torch.load(vector_path)

self.assertEqual(loaded_vectors_obj['a'], tensorA)
Expand Down
2 changes: 1 addition & 1 deletion test/experimental/test_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def test_vocab_load_and_save(self):
self.assertEqual(dict(v.get_stoi()), expected_stoi)

vocab_path = os.path.join(self.test_dir, 'vocab.pt')
torch.save(v.to_ivalue(), vocab_path)
torch.save(v, vocab_path)
loaded_v = torch.load(vocab_path)

self.assertEqual(v.get_itos(), expected_itos)
Expand Down
15 changes: 15 additions & 0 deletions torchtext/csrc/regex_tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,19 @@ void RegexTokenizer::split_(std::string &str, std::vector<std::string> &tokens,
}
}

c10::intrusive_ptr<RegexTokenizer>
_get_regex_tokenizer_from_states(RegexTokenizerStates states) {
auto &patterns = std::get<0>(states);
auto &replacements = std::get<1>(states);
auto &to_lower = std::get<2>(states);
return c10::make_intrusive<RegexTokenizer>(std::move(patterns),
std::move(replacements), to_lower);
}

RegexTokenizerStates _set_regex_tokenizer_states(const RegexTokenizer &self) {
RegexTokenizerStates states =
std::make_tuple(self.patterns_, self.replacements_, self.to_lower_);
return states;
}

} // namespace torchtext
7 changes: 7 additions & 0 deletions torchtext/csrc/regex_tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

namespace torchtext {

typedef std::tuple<std::vector<std::string>, std::vector<std::string>, bool>
RegexTokenizerStates;

struct RegexTokenizer : torch::CustomClassHolder {
private:
std::vector<RE2 *> compiled_patterns_;
Expand All @@ -20,4 +23,8 @@ struct RegexTokenizer : torch::CustomClassHolder {
std::vector<std::string> forward(std::string str) const;
};

c10::intrusive_ptr<RegexTokenizer>
_get_regex_tokenizer_from_states(RegexTokenizerStates states);
RegexTokenizerStates _set_regex_tokenizer_states(const RegexTokenizer &self);

} // namespace torchtext
65 changes: 46 additions & 19 deletions torchtext/csrc/register_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,29 @@ PYBIND11_MODULE(_torchtext, m) {
// Classes
py::class_<Regex>(m, "Regex")
.def(py::init<std::string>())
.def("Sub", &Regex::Sub);
.def("Sub", &Regex::Sub)
.def(py::pickle(
// __getstate__
[](const Regex &self) -> std::string { return self.re_str_; },
// __setstate__
[](std::string state) -> Regex { return Regex(std::move(state)); }));

py::class_<RegexTokenizer>(m, "RegexTokenizer")
.def_readonly("patterns_", &RegexTokenizer::patterns_)
.def_readonly("replacements_", &RegexTokenizer::replacements_)
.def_readonly("to_lower_", &RegexTokenizer::to_lower_)
.def(py::init<std::vector<std::string>, std::vector<std::string>, bool>())
.def("forward", &RegexTokenizer::forward);
.def("forward", &RegexTokenizer::forward)
.def(py::pickle(
// __setstate__
[](const RegexTokenizer &self) -> RegexTokenizerStates {
return _set_regex_tokenizer_states(self);
},
// __getstate__
[](RegexTokenizerStates states) -> RegexTokenizer {
auto regex_tokenizer = _get_regex_tokenizer_from_states(states);
return *regex_tokenizer;
}));

py::class_<SentencePiece>(m, "SentencePiece")
.def(py::init<std::string>())
Expand All @@ -48,7 +63,17 @@ PYBIND11_MODULE(_torchtext, m) {
.def("__getitem__", &Vectors::__getitem__)
.def("lookup_vectors", &Vectors::lookup_vectors)
.def("__setitem__", &Vectors::__setitem__)
.def("__len__", &Vectors::__len__);
.def("__len__", &Vectors::__len__)
.def(py::pickle(
// __setstate__
[](const Vectors &self) -> VectorsStates {
return _set_vectors_states(self);
},
// __getstate__
[](VectorsStates states) -> Vectors {
auto vectors = _get_vectors_from_states(states);
return *vectors;
}));

py::class_<Vocab>(m, "Vocab")
.def(py::init<std::vector<std::string>, std::string>())
Expand All @@ -62,7 +87,17 @@ PYBIND11_MODULE(_torchtext, m) {
.def("lookup_tokens", &Vocab::lookup_tokens)
.def("lookup_indices", &Vocab::lookup_indices)
.def("get_stoi", &Vocab::get_stoi)
.def("get_itos", &Vocab::get_itos);
.def("get_itos", &Vocab::get_itos)
.def(py::pickle(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can use the existing _set_vocab_states etc. functionality

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to def_pickle func of torchbind, I tried _set_vectors_states and _get_vectors_from_states in py::pickle func of pybind11. However, the intrusive_ptr holder used by _set_vectors_states and _get_vectors_from_states are not supported in pybind11 pickle mechanism.

 error: static assertion failed: pybind11::init(): init function must return a compatible pointer, holder, or value

Copy link
Contributor

@cpuhrsch cpuhrsch Dec 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you could change

_set_vectors_states(const c10::intrusive_ptr<Vectors> &self)

to

_set_vectors_states(const Vectors &self)

and then change the callsite of _set_vectors_states(self) to _set_vectors_states(*self).

You could even use move semantics

// __setstate__
[](const Vocab &self) -> VocabStates {
return _set_vocab_states(self);
},
// __getstate__
[](VocabStates states) -> Vocab {
auto vocab = _get_vocab_from_states(states);
return *vocab;
}));

// Functions
m.def("_load_token_and_vectors_from_file",
Expand Down Expand Up @@ -94,21 +129,13 @@ static auto regex_tokenizer =
.def_pickle(
// __setstate__
[](const c10::intrusive_ptr<RegexTokenizer> &self)
-> std::tuple<std::vector<std::string>,
std::vector<std::string>, bool> {
return std::make_tuple(self->patterns_, self->replacements_,
self->to_lower_);
-> RegexTokenizerStates {
return _set_regex_tokenizer_states(*self);
},
// __getstate__
[](std::tuple<std::vector<std::string>, std::vector<std::string>,
bool>
states) -> c10::intrusive_ptr<RegexTokenizer> {
auto patterns = std::get<0>(states);
auto replacements = std::get<1>(states);
auto to_lower = std::get<2>(states);

return c10::make_intrusive<RegexTokenizer>(
std::move(patterns), std::move(replacements), to_lower);
[](RegexTokenizerStates states)
-> c10::intrusive_ptr<RegexTokenizer> {
return _get_regex_tokenizer_from_states(states);
});

static auto sentencepiece =
Expand Down Expand Up @@ -148,7 +175,7 @@ static auto vocab =
.def_pickle(
// __setstate__
[](const c10::intrusive_ptr<Vocab> &self) -> VocabStates {
return _set_vocab_states(self);
return _set_vocab_states(*self);
},
// __getstate__
[](VocabStates states) -> c10::intrusive_ptr<Vocab> {
Expand All @@ -166,7 +193,7 @@ static auto vectors =
.def_pickle(
// __setstate__
[](const c10::intrusive_ptr<Vectors> &self) -> VectorsStates {
return _set_vectors_states(self);
return _set_vectors_states(*self);
},
// __getstate__
[](VectorsStates states) -> c10::intrusive_ptr<Vectors> {
Expand Down
12 changes: 6 additions & 6 deletions torchtext/csrc/vectors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,25 +275,25 @@ std::tuple<Vectors, std::vector<std::string>> _load_token_and_vectors_from_file(
return result;
}

VectorsStates _set_vectors_states(const c10::intrusive_ptr<Vectors> &self) {
VectorsStates _set_vectors_states(const Vectors &self) {
std::vector<std::string> tokens;
std::vector<int64_t> indices;
tokens.reserve(self->stoi_.size());
indices.reserve(self->stoi_.size());
tokens.reserve(self.stoi_.size());
indices.reserve(self.stoi_.size());

// construct tokens and index list
// we need to store indices because the `vectors_` tensor may have gaps
for (const auto &item : self->stoi_) {
for (const auto &item : self.stoi_) {
tokens.push_back(item.first);
indices.push_back(item.second);
}

std::vector<int64_t> integers = std::move(indices);
std::vector<std::string> strings = std::move(tokens);
std::vector<torch::Tensor> tensors{self->vectors_, self->unk_tensor_};
std::vector<torch::Tensor> tensors{self.vectors_, self.unk_tensor_};

VectorsStates states =
std::make_tuple(self->version_str_, std::move(integers),
std::make_tuple(self.version_str_, std::move(integers),
std::move(strings), std::move(tensors));

return states;
Expand Down
2 changes: 1 addition & 1 deletion torchtext/csrc/vectors.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ struct Vectors : torch::CustomClassHolder {
};

c10::intrusive_ptr<Vectors> _get_vectors_from_states(VectorsStates states);
VectorsStates _set_vectors_states(const c10::intrusive_ptr<Vectors> &self);
VectorsStates _set_vectors_states(const Vectors &self);

std::tuple<Vectors, std::vector<std::string>> _load_token_and_vectors_from_file(
const std::string &file_path, const std::string delimiter_str,
Expand Down
8 changes: 4 additions & 4 deletions torchtext/csrc/vocab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,13 +381,13 @@ Vocab _build_vocab_from_text_file(const std::string &file_path,
return Vocab(std::move(tokens), std::move(stoi), unk_token, unk_index);
}

VocabStates _set_vocab_states(const c10::intrusive_ptr<Vocab> &self) {
VocabStates _set_vocab_states(const Vocab &self) {
std::vector<int64_t> integers;
StringList strings = self->itos_;
strings.push_back(self->unk_token_);
StringList strings = self.itos_;
strings.push_back(self.unk_token_);
std::vector<torch::Tensor> tensors;

VocabStates states = std::make_tuple(self->version_str_, std::move(integers),
VocabStates states = std::make_tuple(self.version_str_, std::move(integers),
std::move(strings), std::move(tensors));
return states;
}
Expand Down
5 changes: 2 additions & 3 deletions torchtext/csrc/vocab.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,13 @@ struct Vocab : torch::CustomClassHolder {
};

c10::intrusive_ptr<Vocab> _get_vocab_from_states(VocabStates states);
VocabStates _set_vocab_states(const c10::intrusive_ptr<Vocab> &self);
VocabStates _set_vocab_states(const Vocab &self);
Vocab _load_vocab_from_file(const std::string &file_path,
const std::string &unk_token,
const int64_t min_freq, const int64_t num_cpus);
Vocab _build_vocab_from_text_file(const std::string &file_path,
const std::string &unk_token,
const int64_t min_freq,
const int64_t num_cpus,
py::object tokenizer);
const int64_t num_cpus, py::object tokenizer);

} // namespace torchtext