Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 2c9a306

Browse files
committed
Clean up test
1 parent 513a537 commit 2c9a306

File tree

4 files changed

+106
-64
lines changed

4 files changed

+106
-64
lines changed

test/data/test_functional.py

Lines changed: 51 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -107,17 +107,24 @@ def test_BasicEnglishNormalize(self):
107107
self.assertEqual(eager_tokens, ref_results)
108108
self.assertEqual(experimental_eager_tokens, ref_results)
109109

110-
# test pybind load and save
111-
save_path = os.path.join(self.test_dir, 'basic_english_normalize_pybind.pt')
112-
torch.save(basic_eng_norm, save_path)
113-
loaded_basic_eng_norm = torch.load(save_path)
114-
self.assertEqual(loaded_basic_eng_norm(test_sample), ref_results)
115-
116-
# test torchscript load and save
117-
save_path = os.path.join(self.test_dir, 'basic_english_normalize_torchscrip.pt')
118-
torch.save(basic_eng_norm.to_ivalue(), save_path)
119-
loaded_basic_eng_norm = torch.load(save_path)
120-
self.assertEqual(loaded_basic_eng_norm(test_sample), ref_results)
110+
def test_basicEnglishNormalize_load_and_save(self):
111+
test_sample = '\'".<br />,()!?;: Basic English Normalization for a Line of Text \'".<br />,()!?;:'
112+
ref_results = ["'", '.', ',', '(', ')', '!', '?', 'basic', 'english', 'normalization',
113+
'for', 'a', 'line', 'of', 'text', "'", '.', ',', '(', ')', '!', '?']
114+
115+
with self.subTest('pybind'):
116+
save_path = os.path.join(self.test_dir, 'ben_pybind.pt')
117+
ben = basic_english_normalize()
118+
torch.save(ben, save_path)
119+
loaded_ben = torch.load(save_path)
120+
self.assertEqual(loaded_ben(test_sample), ref_results)
121+
122+
with self.subTest('torchscript'):
123+
save_path = os.path.join(self.test_dir, 'ben_torchscrip.pt')
124+
ben = basic_english_normalize().to_ivalue()
125+
torch.save(ben, save_path)
126+
loaded_ben = torch.load(save_path)
127+
self.assertEqual(loaded_ben(test_sample), ref_results)
121128

122129
# TODO(Nayef211): remove decorator once https://github.com/pytorch/pytorch/issues/38207 is closed
123130
@unittest.skipIf(platform.system() == "Windows", "Test is known to fail on Windows.")
@@ -151,19 +158,39 @@ def test_RegexTokenizer(self):
151158
self.assertEqual(eager_tokens, ref_results)
152159
self.assertEqual(jit_tokens, ref_results)
153160

154-
# test pybind load and save
155-
save_path = os.path.join(self.test_dir, 'regex_pybind.pt')
156-
torch.save(r_tokenizer, save_path)
157-
loaded_r_tokenizer = torch.load(save_path)
158-
loaded_eager_tokens = loaded_r_tokenizer(test_sample)
159-
self.assertEqual(loaded_eager_tokens, ref_results)
160-
161-
# test torchscript load and save
162-
save_path = os.path.join(self.test_dir, 'regex_torchscript.pt')
163-
torch.save(r_tokenizer.to_ivalue(), save_path)
164-
loaded_r_tokenizer = torch.load(save_path)
165-
loaded_eager_tokens = loaded_r_tokenizer(test_sample)
166-
self.assertEqual(loaded_eager_tokens, ref_results)
161+
def test_load_and_save(self):
162+
test_sample = '\'".<br />,()!?;: Basic Regex Tokenization for a Line of Text \'".<br />,()!?;:'
163+
ref_results = ["'", '.', ',', '(', ')', '!', '?', 'Basic', 'Regex', 'Tokenization',
164+
'for', 'a', 'Line', 'of', 'Text', "'", '.', ',', '(', ')', '!', '?']
165+
patterns_list = [
166+
(r'\'', ' \' '),
167+
(r'\"', ''),
168+
(r'\.', ' . '),
169+
(r'<br \/>', ' '),
170+
(r',', ' , '),
171+
(r'\(', ' ( '),
172+
(r'\)', ' ) '),
173+
(r'\!', ' ! '),
174+
(r'\?', ' ? '),
175+
(r'\;', ' '),
176+
(r'\:', ' '),
177+
(r'\s+', ' ')]
178+
179+
with self.subTest('pybind'):
180+
save_path = os.path.join(self.test_dir, 'regex_pybind.pt')
181+
tokenizer = regex_tokenizer(patterns_list)
182+
torch.save(tokenizer, save_path)
183+
loaded_tokenizer = torch.load(save_path)
184+
results = loaded_tokenizer(test_sample)
185+
self.assertEqual(results, ref_results)
186+
187+
with self.subTest('torchscript'):
188+
save_path = os.path.join(self.test_dir, 'regex_torchscript.pt')
189+
tokenizer = regex_tokenizer(patterns_list).to_ivalue()
190+
torch.save(tokenizer, save_path)
191+
loaded_tokenizer = torch.load(save_path)
192+
results = loaded_tokenizer(test_sample)
193+
self.assertEqual(results, ref_results)
167194

168195
def test_custom_replace(self):
169196
custom_replace_transform = custom_replace([(r'S', 's'), (r'\s+', ' ')])

test/experimental/test_transforms.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ def test_vector_transform(self):
5757

5858
def test_sentencepiece_load_and_save(self):
5959
model_path = get_asset_path('spm_example.model')
60-
spm = sentencepiece_tokenizer((model_path))
6160
input = 'SentencePiece is an unsupervised text tokenizer and detokenizer'
6261
expected = [
6362
'▁Sent', 'ence', 'P', 'ie', 'ce', '▁is',
@@ -66,14 +65,16 @@ def test_sentencepiece_load_and_save(self):
6665
'▁de', 'to', 'ken', 'izer',
6766
]
6867

69-
# test pybind load and save
70-
save_path = os.path.join(self.test_dir, 'spm_pybind.pt')
71-
torch.save(spm, save_path)
72-
loaded_spm = torch.load(save_path)
73-
self.assertEqual(expected, loaded_spm(input))
68+
with self.subTest('pybind'):
69+
save_path = os.path.join(self.test_dir, 'spm_pybind.pt')
70+
spm = sentencepiece_tokenizer((model_path))
71+
torch.save(spm, save_path)
72+
loaded_spm = torch.load(save_path)
73+
self.assertEqual(expected, loaded_spm(input))
7474

75-
# test torchscript load and save
76-
save_path = os.path.join(self.test_dir, 'spm_torchscript.pt')
77-
torch.save(spm.to_ivalue(), save_path)
78-
loaded_spm = torch.load(save_path)
79-
self.assertEqual(expected, loaded_spm(input))
75+
with self.subTest('torchscript'):
76+
save_path = os.path.join(self.test_dir, 'spm_torchscript.pt')
77+
spm = sentencepiece_tokenizer((model_path)).to_ivalue()
78+
torch.save(spm, save_path)
79+
loaded_spm = torch.load(save_path)
80+
self.assertEqual(expected, loaded_spm(input))

test/experimental/test_vectors.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -111,35 +111,49 @@ def test_vectors_add_item(self):
111111
self.assertEqual(vectors_obj['b'], tensorB)
112112
self.assertEqual(vectors_obj['not_in_it'], unk_tensor)
113113

114-
def test_vectors_load_and_save(self):
114+
def test_vectors_update(self):
115115
tensorA = torch.tensor([1, 0], dtype=torch.float)
116116
tensorB = torch.tensor([0, 1], dtype=torch.float)
117+
tensorC = torch.tensor([1, 1], dtype=torch.float)
118+
117119
expected_unk_tensor = torch.tensor([0, 0], dtype=torch.float)
118120

119121
tokens = ['a', 'b']
120122
vecs = torch.stack((tensorA, tensorB), 0)
121123
vectors_obj = build_vectors(tokens, vecs)
122124

123-
tensorC = torch.tensor([1, 1], dtype=torch.float)
124125
vectors_obj['b'] = tensorC
125126

126-
# test pybind load and save
127-
vector_path = os.path.join(self.test_dir, 'vectors_pybind.pt')
128-
torch.save(vectors_obj, vector_path)
129-
loaded_vectors_obj = torch.load(vector_path)
127+
self.assertEqual(vectors_obj['a'], tensorA)
128+
self.assertEqual(vectors_obj['b'], tensorC)
129+
self.assertEqual(vectors_obj['not_in_it'], expected_unk_tensor)
130+
131+
def test_vectors_load_and_save(self):
132+
tensorA = torch.tensor([1, 0], dtype=torch.float)
133+
tensorB = torch.tensor([0, 1], dtype=torch.float)
134+
expected_unk_tensor = torch.tensor([0, 0], dtype=torch.float)
135+
136+
tokens = ['a', 'b']
137+
vecs = torch.stack((tensorA, tensorB), 0)
138+
vectors_obj = build_vectors(tokens, vecs)
139+
140+
with self.subTest('pybind'):
141+
vector_path = os.path.join(self.test_dir, 'vectors_pybind.pt')
142+
torch.save(vectors_obj, vector_path)
143+
loaded_vectors_obj = torch.load(vector_path)
130144

131-
self.assertEqual(loaded_vectors_obj['a'], tensorA)
132-
self.assertEqual(loaded_vectors_obj['b'], tensorC)
133-
self.assertEqual(loaded_vectors_obj['not_in_it'], expected_unk_tensor)
145+
self.assertEqual(loaded_vectors_obj['a'], tensorA)
146+
self.assertEqual(loaded_vectors_obj['b'], tensorB)
147+
self.assertEqual(loaded_vectors_obj['not_in_it'], expected_unk_tensor)
134148

135-
# test torchscript load and save
136-
vector_path = os.path.join(self.test_dir, 'vectors_torchscript.pt')
137-
torch.save(vectors_obj.to_ivalue(), vector_path)
138-
loaded_vectors_obj = torch.load(vector_path)
149+
with self.subTest('torchscript'):
150+
vector_path = os.path.join(self.test_dir, 'vectors_torchscript.pt')
151+
torch.save(vectors_obj.to_ivalue(), vector_path)
152+
loaded_vectors_obj = torch.load(vector_path)
139153

140-
self.assertEqual(loaded_vectors_obj['a'], tensorA)
141-
self.assertEqual(loaded_vectors_obj['b'], tensorC)
142-
self.assertEqual(loaded_vectors_obj['not_in_it'], expected_unk_tensor)
154+
self.assertEqual(loaded_vectors_obj['a'], tensorA)
155+
self.assertEqual(loaded_vectors_obj['b'], tensorB)
156+
self.assertEqual(loaded_vectors_obj['not_in_it'], expected_unk_tensor)
143157

144158
# we separate out these errors because Windows runs into seg faults when propagating
145159
# exceptions from C++ using pybind11

test/experimental/test_vocab.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -199,19 +199,19 @@ def test_vocab_load_and_save(self):
199199
self.assertEqual(v.get_itos(), expected_itos)
200200
self.assertEqual(dict(v.get_stoi()), expected_stoi)
201201

202-
# test pybind load and save
203-
vocab_path = os.path.join(self.test_dir, 'vocab_pybind.pt')
204-
torch.save(v, vocab_path)
205-
loaded_v = torch.load(vocab_path)
206-
self.assertEqual(v.get_itos(), expected_itos)
207-
self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi)
208-
209-
# test torchscript load and save
210-
vocab_path = os.path.join(self.test_dir, 'vocab_torchscript.pt')
211-
torch.save(v.to_ivalue(), vocab_path)
212-
loaded_v = torch.load(vocab_path)
213-
self.assertEqual(v.get_itos(), expected_itos)
214-
self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi)
202+
with self.subTest('pybind'):
203+
vocab_path = os.path.join(self.test_dir, 'vocab_pybind.pt')
204+
torch.save(v, vocab_path)
205+
loaded_v = torch.load(vocab_path)
206+
self.assertEqual(v.get_itos(), expected_itos)
207+
self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi)
208+
209+
with self.subTest('torchscript'):
210+
vocab_path = os.path.join(self.test_dir, 'vocab_torchscript.pt')
211+
torch.save(v.to_ivalue(), vocab_path)
212+
loaded_v = torch.load(vocab_path)
213+
self.assertEqual(v.get_itos(), expected_itos)
214+
self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi)
215215

216216
def test_build_vocab_iterator(self):
217217
iterator = [['hello', 'hello', 'hello', 'freq_low', 'hello', 'world', 'world', 'world', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T',

0 commit comments

Comments
 (0)