11import os
22from collections import OrderedDict
3+ from typing import List , Optional
34from unittest .mock import patch
45
56import torch
@@ -586,7 +587,9 @@ def test_clip_tokenizer_save_load_torchscript(self) -> None:
586587
587588
588589class TestBERTTokenizer (TorchtextTestCase ):
589- def _load_tokenizer (self , test_scripting : bool , do_lower_case : bool , return_tokens : bool ):
590+ def _load_tokenizer (
591+ self , test_scripting : bool , do_lower_case : bool , return_tokens : bool , never_split : Optional [List [str ]] = None
592+ ):
590593 if do_lower_case :
591594 vocab_file = "bert_base_uncased_vocab.txt"
592595 else :
@@ -596,46 +599,117 @@ def _load_tokenizer(self, test_scripting: bool, do_lower_case: bool, return_toke
596599 vocab_path = get_asset_path (vocab_file ),
597600 do_lower_case = do_lower_case ,
598601 return_tokens = return_tokens ,
602+ never_split = never_split ,
599603 )
600604 if test_scripting :
601605 tokenizer = torch .jit .script (tokenizer )
602606 return tokenizer
603607
604- def _bert_tokenizer (self , tokenizer , do_lower_case ):
608+ def _bert_tokenizer (self , tokenizer , do_lower_case , never_split : Optional [ List [ str ]] = None ):
605609 sample_texts = [
606610 "Hello World!, how are you?" ,
607611 "Hélló WoŕlḊ¿" ,
608612 "Respublica superiorem" ,
609613 "Avdija Vršajević în" ,
614+ " \t HeLLo!how \n Are yoU? [UNK]" ,
615+ "hi world [UNK] [CLS]" ,
616+ "testing, [UNK] words! [SEP]" ,
610617 ]
611618
612- if do_lower_case :
613- expected_tokens = [
614- ["hello" , "world" , "!" , "," , "how" , "are" , "you" , "?" ],
615- ["hello" , "world" , "¿" ],
616- ["res" , "##pu" , "##bl" , "##ica" , "superior" , "##em" ],
617- ["av" , "##di" , "##ja" , "vr" , "##sa" , "##jevic" , "in" ],
618- ]
619- expected_token_ids = [
620- ["7592" , "2088" , "999" , "1010" , "2129" , "2024" , "2017" , "1029" ],
621- ["7592" , "2088" , "1094" ],
622- ["24501" , "14289" , "16558" , "5555" , "6020" , "6633" ],
623- ["20704" , "4305" , "3900" , "27830" , "3736" , "26782" , "1999" ],
624- ]
619+ if not never_split :
620+ if do_lower_case :
621+ expected_tokens = [
622+ ["hello" , "world" , "!" , "," , "how" , "are" , "you" , "?" ],
623+ ["hello" , "world" , "¿" ],
624+ ["res" , "##pu" , "##bl" , "##ica" , "superior" , "##em" ],
625+ ["av" , "##di" , "##ja" , "vr" , "##sa" , "##jevic" , "in" ],
626+ ["hello" , "!" , "how" , "are" , "you" , "?" , "[" , "un" , "##k" , "]" ],
627+ ["hi" , "world" , "[" , "un" , "##k" , "]" , "[" , "cl" , "##s" , "]" ],
628+ ["testing" , "," , "[" , "un" , "##k" , "]" , "words" , "!" , "[" , "sep" , "]" ],
629+ ]
630+ expected_token_ids = [
631+ ["7592" , "2088" , "999" , "1010" , "2129" , "2024" , "2017" , "1029" ],
632+ ["7592" , "2088" , "1094" ],
633+ ["24501" , "14289" , "16558" , "5555" , "6020" , "6633" ],
634+ ["20704" , "4305" , "3900" , "27830" , "3736" , "26782" , "1999" ],
635+ ["7592" , "999" , "2129" , "2024" , "2017" , "1029" , "1031" , "4895" , "2243" , "1033" ],
636+ ["7632" , "2088" , "1031" , "4895" , "2243" , "1033" , "1031" , "18856" , "2015" , "1033" ],
637+ ["5604" , "1010" , "1031" , "4895" , "2243" , "1033" , "2616" , "999" , "1031" , "19802" , "1033" ],
638+ ]
625639
640+ else :
641+ expected_tokens = [
642+ ["Hello" , "World" , "!" , "," , "how" , "are" , "you" , "?" ],
643+ ["H" , "##é" , "##ll" , "##ó" , "[UNK]" , "¿" ],
644+ ["Re" , "##sp" , "##ub" , "##lica" , "superior" , "##em" ],
645+ ["A" , "##v" , "##di" , "##ja" , "V" , "##r" , "##ša" , "##je" , "##vić" , "î" , "##n" ],
646+ ["He" , "##LL" , "##o" , "!" , "how" , "Are" , "yo" , "##U" , "?" , "[" , "UN" , "##K" , "]" ],
647+ ["hi" , "world" , "[" , "UN" , "##K" , "]" , "[" , "C" , "##LS" , "]" ],
648+ ["testing" , "," , "[" , "UN" , "##K" , "]" , "words" , "!" , "[" , "SE" , "##P" , "]" ],
649+ ]
650+ expected_token_ids = [
651+ ["8667" , "1291" , "106" , "117" , "1293" , "1132" , "1128" , "136" ],
652+ ["145" , "2744" , "2339" , "7774" , "100" , "225" ],
653+ ["11336" , "20080" , "10354" , "9538" , "7298" , "5521" ],
654+ ["138" , "1964" , "3309" , "3174" , "159" , "1197" , "23834" , "5561" , "10225" , "260" , "1179" ],
655+ [
656+ "1124" ,
657+ "23955" ,
658+ "1186" ,
659+ "106" ,
660+ "1293" ,
661+ "2372" ,
662+ "26063" ,
663+ "2591" ,
664+ "136" ,
665+ "164" ,
666+ "7414" ,
667+ "2428" ,
668+ "166" ,
669+ ],
670+ ["20844" , "1362" , "164" , "7414" , "2428" , "166" , "164" , "140" , "15928" , "166" ],
671+ ["5193" , "117" , "164" , "7414" , "2428" , "166" , "1734" , "106" , "164" , "12342" , "2101" , "166" ],
672+ ]
626673 else :
627- expected_tokens = [
628- ["Hello" , "World" , "!" , "," , "how" , "are" , "you" , "?" ],
629- ["H" , "##é" , "##ll" , "##ó" , "[UNK]" , "¿" ],
630- ["Re" , "##sp" , "##ub" , "##lica" , "superior" , "##em" ],
631- ["A" , "##v" , "##di" , "##ja" , "V" , "##r" , "##ša" , "##je" , "##vić" , "î" , "##n" ],
632- ]
633- expected_token_ids = [
634- ["8667" , "1291" , "106" , "117" , "1293" , "1132" , "1128" , "136" ],
635- ["145" , "2744" , "2339" , "7774" , "100" , "225" ],
636- ["11336" , "20080" , "10354" , "9538" , "7298" , "5521" ],
637- ["138" , "1964" , "3309" , "3174" , "159" , "1197" , "23834" , "5561" , "10225" , "260" , "1179" ],
638- ]
674+ if do_lower_case :
675+ expected_tokens = [
676+ ["hello" , "world" , "!" , "," , "how" , "are" , "you" , "?" ],
677+ ["hello" , "world" , "¿" ],
678+ ["res" , "##pu" , "##bl" , "##ica" , "superior" , "##em" ],
679+ ["av" , "##di" , "##ja" , "vr" , "##sa" , "##jevic" , "in" ],
680+ ["hello" , "!" , "how" , "are" , "you" , "?" , "[UNK]" ],
681+ ["hi" , "world" , "[UNK]" , "[CLS]" ],
682+ ["testing" , "," , "[UNK]" , "words" , "!" , "[" , "sep" , "]" ],
683+ ]
684+ expected_token_ids = [
685+ ["7592" , "2088" , "999" , "1010" , "2129" , "2024" , "2017" , "1029" ],
686+ ["7592" , "2088" , "1094" ],
687+ ["24501" , "14289" , "16558" , "5555" , "6020" , "6633" ],
688+ ["20704" , "4305" , "3900" , "27830" , "3736" , "26782" , "1999" ],
689+ ["7592" , "999" , "2129" , "2024" , "2017" , "1029" , "100" ],
690+ ["7632" , "2088" , "100" , "101" ],
691+ ["5604" , "1010" , "100" , "2616" , "999" , "1031" , "19802" , "1033" ],
692+ ]
693+
694+ else :
695+ expected_tokens = [
696+ ["Hello" , "World" , "!" , "," , "how" , "are" , "you" , "?" ],
697+ ["H" , "##é" , "##ll" , "##ó" , "[UNK]" , "¿" ],
698+ ["Re" , "##sp" , "##ub" , "##lica" , "superior" , "##em" ],
699+ ["A" , "##v" , "##di" , "##ja" , "V" , "##r" , "##ša" , "##je" , "##vić" , "î" , "##n" ],
700+ ["He" , "##LL" , "##o" , "!" , "how" , "Are" , "yo" , "##U" , "?" , "[UNK]" ],
701+ ["hi" , "world" , "[UNK]" , "[CLS]" ],
702+ ["testing" , "," , "[UNK]" , "words" , "!" , "[" , "SE" , "##P" , "]" ],
703+ ]
704+ expected_token_ids = [
705+ ["8667" , "1291" , "106" , "117" , "1293" , "1132" , "1128" , "136" ],
706+ ["145" , "2744" , "2339" , "7774" , "100" , "225" ],
707+ ["11336" , "20080" , "10354" , "9538" , "7298" , "5521" ],
708+ ["138" , "1964" , "3309" , "3174" , "159" , "1197" , "23834" , "5561" , "10225" , "260" , "1179" ],
709+ ["1124" , "23955" , "1186" , "106" , "1293" , "2372" , "26063" , "2591" , "136" , "100" ],
710+ ["20844" , "1362" , "100" , "101" ],
711+ ["5193" , "117" , "100" , "1734" , "106" , "164" , "12342" , "2101" , "166" ],
712+ ]
639713
640714 # test batch of sentences
641715 if tokenizer ._return_tokens :
@@ -650,14 +724,18 @@ def _bert_tokenizer(self, tokenizer, do_lower_case):
650724 else :
651725 self .assertEqual (tokenizer (txt ), expected_token_ids [idx ])
652726
653- @nested_params ([True , False ], [True , False ], [True , False ])
654- def test_bert_tokenizer (self , test_scripting , do_lower_case , return_tokens ):
727+ @nested_params ([True , False ], [True , False ], [True , False ], [[], None , [ "[UNK]" , "[CLS]" ]] )
728+ def test_bert_tokenizer (self , test_scripting , do_lower_case , return_tokens , never_split ):
655729 """test tokenization on single sentence input as well as batch on sentences"""
656730 self ._bert_tokenizer (
657731 self ._load_tokenizer (
658- test_scripting = test_scripting , do_lower_case = do_lower_case , return_tokens = return_tokens
732+ test_scripting = test_scripting ,
733+ do_lower_case = do_lower_case ,
734+ return_tokens = return_tokens ,
735+ never_split = never_split ,
659736 ),
660737 do_lower_case = do_lower_case ,
738+ never_split = never_split ,
661739 )
662740
663741 @nested_params ([True , False ], [True , False ], [True , False ])
0 commit comments