Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 49 additions & 12 deletions test/data/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,24 @@ def test_BasicEnglishNormalize(self):
self.assertEqual(eager_tokens, ref_results)
self.assertEqual(experimental_eager_tokens, ref_results)

# test load and save
save_path = os.path.join(self.test_dir, 'basic_english_normalize.pt')
torch.save(basic_eng_norm.to_ivalue(), save_path)
loaded_basic_eng_norm = torch.load(save_path)
def test_basicEnglishNormalize_load_and_save(self):
test_sample = '\'".<br />,()!?;: Basic English Normalization for a Line of Text \'".<br />,()!?;:'
ref_results = ["'", '.', ',', '(', ')', '!', '?', 'basic', 'english', 'normalization',
'for', 'a', 'line', 'of', 'text', "'", '.', ',', '(', ')', '!', '?']

loaded_eager_tokens = loaded_basic_eng_norm(test_sample)
self.assertEqual(loaded_eager_tokens, ref_results)
with self.subTest('pybind'):
save_path = os.path.join(self.test_dir, 'ben_pybind.pt')
ben = basic_english_normalize()
torch.save(ben, save_path)
loaded_ben = torch.load(save_path)
self.assertEqual(loaded_ben(test_sample), ref_results)

with self.subTest('torchscript'):
save_path = os.path.join(self.test_dir, 'ben_torchscrip.pt')
ben = basic_english_normalize().to_ivalue()
torch.save(ben, save_path)
loaded_ben = torch.load(save_path)
self.assertEqual(loaded_ben(test_sample), ref_results)

# TODO(Nayef211): remove decorator once https://github.com/pytorch/pytorch/issues/38207 is closed
@unittest.skipIf(platform.system() == "Windows", "Test is known to fail on Windows.")
Expand Down Expand Up @@ -147,13 +158,39 @@ def test_RegexTokenizer(self):
self.assertEqual(eager_tokens, ref_results)
self.assertEqual(jit_tokens, ref_results)

# test load and save
save_path = os.path.join(self.test_dir, 'regex.pt')
torch.save(r_tokenizer.to_ivalue(), save_path)
loaded_r_tokenizer = torch.load(save_path)
def test_load_and_save(self):
test_sample = '\'".<br />,()!?;: Basic Regex Tokenization for a Line of Text \'".<br />,()!?;:'
ref_results = ["'", '.', ',', '(', ')', '!', '?', 'Basic', 'Regex', 'Tokenization',
'for', 'a', 'Line', 'of', 'Text', "'", '.', ',', '(', ')', '!', '?']
patterns_list = [
(r'\'', ' \' '),
(r'\"', ''),
(r'\.', ' . '),
(r'<br \/>', ' '),
(r',', ' , '),
(r'\(', ' ( '),
(r'\)', ' ) '),
(r'\!', ' ! '),
(r'\?', ' ? '),
(r'\;', ' '),
(r'\:', ' '),
(r'\s+', ' ')]

loaded_eager_tokens = loaded_r_tokenizer(test_sample)
self.assertEqual(loaded_eager_tokens, ref_results)
with self.subTest('pybind'):
save_path = os.path.join(self.test_dir, 'regex_pybind.pt')
tokenizer = regex_tokenizer(patterns_list)
torch.save(tokenizer, save_path)
loaded_tokenizer = torch.load(save_path)
results = loaded_tokenizer(test_sample)
self.assertEqual(results, ref_results)

with self.subTest('torchscript'):
save_path = os.path.join(self.test_dir, 'regex_torchscript.pt')
tokenizer = regex_tokenizer(patterns_list).to_ivalue()
torch.save(tokenizer, save_path)
loaded_tokenizer = torch.load(save_path)
results = loaded_tokenizer(test_sample)
self.assertEqual(results, ref_results)

def test_custom_replace(self):
custom_replace_transform = custom_replace([(r'S', 's'), (r'\s+', ' ')])
Expand Down
24 changes: 24 additions & 0 deletions test/experimental/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,27 @@ def test_vector_transform(self):
[-0.32423, -0.098845, -0.0073467]])
self.assertEqual(vector_transform(['the', 'world'])[:, 0:3], expected_fasttext_simple_en)
self.assertEqual(jit_vector_transform(['the', 'world'])[:, 0:3], expected_fasttext_simple_en)

def test_sentencepiece_load_and_save(self):
model_path = get_asset_path('spm_example.model')
input = 'SentencePiece is an unsupervised text tokenizer and detokenizer'
expected = [
'▁Sent', 'ence', 'P', 'ie', 'ce', '▁is',
'▁an', '▁un', 'super', 'vis', 'ed', '▁text',
'▁to', 'ken', 'izer', '▁and',
'▁de', 'to', 'ken', 'izer',
]

with self.subTest('pybind'):
save_path = os.path.join(self.test_dir, 'spm_pybind.pt')
spm = sentencepiece_tokenizer((model_path))
torch.save(spm, save_path)
loaded_spm = torch.load(save_path)
self.assertEqual(expected, loaded_spm(input))

with self.subTest('torchscript'):
save_path = os.path.join(self.test_dir, 'spm_torchscript.pt')
spm = sentencepiece_tokenizer((model_path)).to_ivalue()
torch.save(spm, save_path)
loaded_spm = torch.load(save_path)
self.assertEqual(expected, loaded_spm(input))
40 changes: 32 additions & 8 deletions test/experimental/test_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,25 +111,49 @@ def test_vectors_add_item(self):
self.assertEqual(vectors_obj['b'], tensorB)
self.assertEqual(vectors_obj['not_in_it'], unk_tensor)

def test_vectors_load_and_save(self):
def test_vectors_update(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Splitting the test for updating as I think it should be tested separately from serialization.

tensorA = torch.tensor([1, 0], dtype=torch.float)
tensorB = torch.tensor([0, 1], dtype=torch.float)
tensorC = torch.tensor([1, 1], dtype=torch.float)

expected_unk_tensor = torch.tensor([0, 0], dtype=torch.float)

tokens = ['a', 'b']
vecs = torch.stack((tensorA, tensorB), 0)
vectors_obj = build_vectors(tokens, vecs)

tensorC = torch.tensor([1, 1], dtype=torch.float)
vectors_obj['b'] = tensorC

vector_path = os.path.join(self.test_dir, 'vectors.pt')
torch.save(vectors_obj.to_ivalue(), vector_path)
loaded_vectors_obj = torch.load(vector_path)
self.assertEqual(vectors_obj['a'], tensorA)
self.assertEqual(vectors_obj['b'], tensorC)
self.assertEqual(vectors_obj['not_in_it'], expected_unk_tensor)

def test_vectors_load_and_save(self):
tensorA = torch.tensor([1, 0], dtype=torch.float)
tensorB = torch.tensor([0, 1], dtype=torch.float)
expected_unk_tensor = torch.tensor([0, 0], dtype=torch.float)

tokens = ['a', 'b']
vecs = torch.stack((tensorA, tensorB), 0)
vectors_obj = build_vectors(tokens, vecs)

with self.subTest('pybind'):
vector_path = os.path.join(self.test_dir, 'vectors_pybind.pt')
torch.save(vectors_obj, vector_path)
loaded_vectors_obj = torch.load(vector_path)

self.assertEqual(loaded_vectors_obj['a'], tensorA)
self.assertEqual(loaded_vectors_obj['b'], tensorB)
self.assertEqual(loaded_vectors_obj['not_in_it'], expected_unk_tensor)

with self.subTest('torchscript'):
vector_path = os.path.join(self.test_dir, 'vectors_torchscript.pt')
torch.save(vectors_obj.to_ivalue(), vector_path)
loaded_vectors_obj = torch.load(vector_path)

self.assertEqual(loaded_vectors_obj['a'], tensorA)
self.assertEqual(loaded_vectors_obj['b'], tensorC)
self.assertEqual(loaded_vectors_obj['not_in_it'], expected_unk_tensor)
self.assertEqual(loaded_vectors_obj['a'], tensorA)
self.assertEqual(loaded_vectors_obj['b'], tensorB)
self.assertEqual(loaded_vectors_obj['not_in_it'], expected_unk_tensor)

# we separate out these errors because Windows runs into seg faults when propagating
# exceptions from C++ using pybind11
Expand Down
19 changes: 13 additions & 6 deletions test/experimental/test_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,19 @@ def test_vocab_load_and_save(self):
self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(v.get_stoi()), expected_stoi)

vocab_path = os.path.join(self.test_dir, 'vocab.pt')
torch.save(v.to_ivalue(), vocab_path)
loaded_v = torch.load(vocab_path)

self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi)
with self.subTest('pybind'):
vocab_path = os.path.join(self.test_dir, 'vocab_pybind.pt')
torch.save(v, vocab_path)
loaded_v = torch.load(vocab_path)
self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi)

with self.subTest('torchscript'):
vocab_path = os.path.join(self.test_dir, 'vocab_torchscript.pt')
torch.save(v.to_ivalue(), vocab_path)
loaded_v = torch.load(vocab_path)
self.assertEqual(v.get_itos(), expected_itos)
self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi)

def test_build_vocab_iterator(self):
iterator = [['hello', 'hello', 'hello', 'freq_low', 'hello', 'world', 'world', 'world', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T',
Expand Down
8 changes: 8 additions & 0 deletions torchtext/csrc/regex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,12 @@ std::string Regex::Sub(std::string str, const std::string &repl) const {
return str;
}

std::string _serialize_regex(const c10::intrusive_ptr<Regex> &self) {
return self->re_str_;
}

c10::intrusive_ptr<Regex> _deserialize_regex(std::string &&state) {
return c10::make_intrusive<Regex>(std::move(state));
}

} // namespace torchtext
4 changes: 4 additions & 0 deletions torchtext/csrc/regex.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,8 @@ struct Regex : torch::CustomClassHolder {
Regex(const std::string &re_str);
std::string Sub(std::string str, const std::string &repl) const;
};

std::string _serialize_regex(const c10::intrusive_ptr<Regex> &self);
c10::intrusive_ptr<Regex> _deserialize_regex(std::string &&state);

} // namespace torchtext
11 changes: 11 additions & 0 deletions torchtext/csrc/regex_tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,15 @@ void RegexTokenizer::split_(std::string &str, std::vector<std::string> &tokens,
}
}

RegexTokenizerStates _serialize_regex_tokenizer(const c10::intrusive_ptr<RegexTokenizer> &self) {
return std::make_tuple(self->patterns_, self->replacements_, self->to_lower_);
}

c10::intrusive_ptr<RegexTokenizer> _deserialize_regex_tokenizer(RegexTokenizerStates &&states) {
return c10::make_intrusive<RegexTokenizer>(
std::move(std::get<0>(states)),
std::move(std::get<1>(states)),
std::get<2>(states));
}

} // namespace torchtext
6 changes: 6 additions & 0 deletions torchtext/csrc/regex_tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

namespace torchtext {

typedef std::tuple<std::vector<std::string>, std::vector<std::string>, bool>
RegexTokenizerStates;

struct RegexTokenizer : torch::CustomClassHolder {
private:
std::vector<RE2 *> compiled_patterns_;
Expand All @@ -20,4 +23,7 @@ struct RegexTokenizer : torch::CustomClassHolder {
std::vector<std::string> forward(std::string str) const;
};

RegexTokenizerStates _serialize_regex_tokenizer(const c10::intrusive_ptr<RegexTokenizer> &self);
c10::intrusive_ptr<RegexTokenizer> _deserialize_regex_tokenizer(RegexTokenizerStates &&states);

} // namespace torchtext
Loading