@@ -107,17 +107,24 @@ def test_BasicEnglishNormalize(self):
107107 self .assertEqual (eager_tokens , ref_results )
108108 self .assertEqual (experimental_eager_tokens , ref_results )
109109
110- # test pybind load and save
111- save_path = os .path .join (self .test_dir , 'basic_english_normalize_pybind.pt' )
112- torch .save (basic_eng_norm , save_path )
113- loaded_basic_eng_norm = torch .load (save_path )
114- self .assertEqual (loaded_basic_eng_norm (test_sample ), ref_results )
115-
116- # test torchscript load and save
117- save_path = os .path .join (self .test_dir , 'basic_english_normalize_torchscrip.pt' )
118- torch .save (basic_eng_norm .to_ivalue (), save_path )
119- loaded_basic_eng_norm = torch .load (save_path )
120- self .assertEqual (loaded_basic_eng_norm (test_sample ), ref_results )
110+ def test_basicEnglishNormalize_load_and_save (self ):
111+ test_sample = '\' ".<br />,()!?;: Basic English Normalization for a Line of Text \' ".<br />,()!?;:'
112+ ref_results = ["'" , '.' , ',' , '(' , ')' , '!' , '?' , 'basic' , 'english' , 'normalization' ,
113+ 'for' , 'a' , 'line' , 'of' , 'text' , "'" , '.' , ',' , '(' , ')' , '!' , '?' ]
114+
115+ with self .subTest ('pybind' ):
116+ save_path = os .path .join (self .test_dir , 'ben_pybind.pt' )
117+ ben = basic_english_normalize ()
118+ torch .save (ben , save_path )
119+ loaded_ben = torch .load (save_path )
120+ self .assertEqual (loaded_ben (test_sample ), ref_results )
121+
122+ with self .subTest ('torchscript' ):
123+ save_path = os .path .join (self .test_dir , 'ben_torchscrip.pt' )
124+ ben = basic_english_normalize ().to_ivalue ()
125+ torch .save (ben , save_path )
126+ loaded_ben = torch .load (save_path )
127+ self .assertEqual (loaded_ben (test_sample ), ref_results )
121128
122129 # TODO(Nayef211): remove decorator once https://github.com/pytorch/pytorch/issues/38207 is closed
123130 @unittest .skipIf (platform .system () == "Windows" , "Test is known to fail on Windows." )
@@ -151,19 +158,39 @@ def test_RegexTokenizer(self):
151158 self .assertEqual (eager_tokens , ref_results )
152159 self .assertEqual (jit_tokens , ref_results )
153160
154- # test pybind load and save
155- save_path = os .path .join (self .test_dir , 'regex_pybind.pt' )
156- torch .save (r_tokenizer , save_path )
157- loaded_r_tokenizer = torch .load (save_path )
158- loaded_eager_tokens = loaded_r_tokenizer (test_sample )
159- self .assertEqual (loaded_eager_tokens , ref_results )
160-
161- # test torchscript load and save
162- save_path = os .path .join (self .test_dir , 'regex_torchscript.pt' )
163- torch .save (r_tokenizer .to_ivalue (), save_path )
164- loaded_r_tokenizer = torch .load (save_path )
165- loaded_eager_tokens = loaded_r_tokenizer (test_sample )
166- self .assertEqual (loaded_eager_tokens , ref_results )
161+ def test_load_and_save (self ):
162+ test_sample = '\' ".<br />,()!?;: Basic Regex Tokenization for a Line of Text \' ".<br />,()!?;:'
163+ ref_results = ["'" , '.' , ',' , '(' , ')' , '!' , '?' , 'Basic' , 'Regex' , 'Tokenization' ,
164+ 'for' , 'a' , 'Line' , 'of' , 'Text' , "'" , '.' , ',' , '(' , ')' , '!' , '?' ]
165+ patterns_list = [
166+ (r'\'' , ' \' ' ),
167+ (r'\"' , '' ),
168+ (r'\.' , ' . ' ),
169+ (r'<br \/>' , ' ' ),
170+ (r',' , ' , ' ),
171+ (r'\(' , ' ( ' ),
172+ (r'\)' , ' ) ' ),
173+ (r'\!' , ' ! ' ),
174+ (r'\?' , ' ? ' ),
175+ (r'\;' , ' ' ),
176+ (r'\:' , ' ' ),
177+ (r'\s+' , ' ' )]
178+
179+ with self .subTest ('pybind' ):
180+ save_path = os .path .join (self .test_dir , 'regex_pybind.pt' )
181+ tokenizer = regex_tokenizer (patterns_list )
182+ torch .save (tokenizer , save_path )
183+ loaded_tokenizer = torch .load (save_path )
184+ results = loaded_tokenizer (test_sample )
185+ self .assertEqual (results , ref_results )
186+
187+ with self .subTest ('torchscript' ):
188+ save_path = os .path .join (self .test_dir , 'regex_torchscript.pt' )
189+ tokenizer = regex_tokenizer (patterns_list ).to_ivalue ()
190+ torch .save (tokenizer , save_path )
191+ loaded_tokenizer = torch .load (save_path )
192+ results = loaded_tokenizer (test_sample )
193+ self .assertEqual (results , ref_results )
167194
168195 def test_custom_replace (self ):
169196 custom_replace_transform = custom_replace ([(r'S' , 's' ), (r'\s+' , ' ' )])
0 commit comments