11import os
22from collections import OrderedDict
3+ from unittest .mock import patch
34
45import torch
56from torchtext import transforms
6- from torchtext .transforms import RegexTokenizer
7+ from torchtext .transforms import MaskTransform , RegexTokenizer
78from torchtext .vocab import vocab
89
910from .common .assets import get_asset_path
@@ -750,3 +751,118 @@ def test_regex_tokenizer_save_load(self) -> None:
750751 loaded_tokenizer = torch .jit .load (save_path )
751752 results = loaded_tokenizer (self .test_sample )
752753 self .assertEqual (results , self .ref_results )
754+
755+ class TestMaskTransform (TorchtextTestCase ):
756+
757+ """
758+ Testing under these assumed conditions:
759+
760+ Vocab maps the following tokens to the following ids:
761+ ['a', 'b', 'c', 'd', '[PAD]', '[MASK]', '[BOS]'] -> [0, 1, 2, 3, 4, 5, 6]
762+
763+ The sample token sequences are:
764+ [["[BOS]", "a", "b", "c", "d"],
765+ ["[BOS]", "a", "b", "[PAD]", "[PAD]"]]
766+ """
767+
768+ sample_token_ids = torch .tensor ([[6 , 0 , 1 , 2 , 3 ], [6 , 0 , 1 , 4 , 4 ]])
769+
770+ vocab_len = 7
771+ pad_idx = 4
772+ mask_idx = 5
773+ bos_idx = 6
774+
775+ @nested_params ([0.0 , 1.0 ])
776+ def test_mask_transform_probs (self , test_mask_prob ):
777+
778+ # We pass (vocab_len - 1) into MaskTransform to test masking with a random token.
779+ # This modifies the distribution from which token ids are randomly selected such that the
780+ # largest token id availible for selection is 1 less than the actual largest token id in our
781+ # vocab, which we've assigned to the [BOS] token. This allows us to test random replacement
782+ # by ensuring that when the first token ([BOS]) in the first sample sequence is selected for random replacement,
783+ # we know with certainty the token it is replaced with is different from the [BOS] token.
784+ # In practice, however, the actual vocab length should be provided as the input parameter so that random
785+ # replacement selects from all possible tokens in the vocab.
786+ mask_transform = MaskTransform (
787+ self .vocab_len - 1 , self .mask_idx , self .bos_idx , self .pad_idx , mask_bos = False , mask_prob = test_mask_prob
788+ )
789+
790+ # when mask_prob = 0, we expect the first token of the first sample sequence to be chosen for replacement
791+ if test_mask_prob == 0.0 :
792+
793+ # when mask_mask_prob, rand_mask_prob = 0,0 no tokens should change
794+ with patch ("torchtext.transforms.MaskTransform.mask_mask_prob" , 0.0 ), patch (
795+ "torchtext.transforms.MaskTransform.rand_mask_prob" , 0.0
796+ ):
797+ masked_tokens , _ , _ = mask_transform (self .sample_token_ids )
798+ self .assertEqual (self .sample_token_ids , masked_tokens )
799+
800+ # when mask_mask_prob, rand_mask_prob = 0,1 we expect all tokens selected for replacement to be
801+ # changed to a random token_id
802+ with patch ("torchtext.transforms.MaskTransform.mask_mask_prob" , 0.0 ), patch (
803+ "torchtext.transforms.MaskTransform.rand_mask_prob" , 1.0
804+ ):
805+ masked_tokens , _ , _ = mask_transform (self .sample_token_ids )
806+
807+ # first token in first sequence should be different
808+ self .assertNotEqual (masked_tokens [0 , 0 ], self .sample_token_ids [0 , 0 ])
809+ # replaced token id should still be in vocab, not including [BOS]
810+ assert masked_tokens [0 , 0 ] in range (self .vocab_len - 1 )
811+
812+ # all other tokens except for first token of first sequence should remain the same
813+ self .assertEqual (self .sample_token_ids [0 , 1 :], masked_tokens [0 , 1 :])
814+ self .assertEqual (self .sample_token_ids [1 ], masked_tokens [1 ])
815+
816+ # when mask_mask_prob, rand_mask_prob = 1,0 we expect all tokens selected for replacement to be changed to [MASK]
817+ with patch ("torchtext.transforms.MaskTransform.mask_mask_prob" , 1.0 ), patch (
818+ "torchtext.transforms.MaskTransform.rand_mask_prob" , 0.0
819+ ):
820+ masked_tokens , _ , _ = mask_transform (self .sample_token_ids )
821+ exp_tokens = torch .tensor ([[5 , 0 , 1 , 2 , 3 ], [6 , 0 , 1 , 4 , 4 ]])
822+ self .assertEqual (exp_tokens , masked_tokens )
823+
824+ # when mask_prob = 1, we expect all tokens that are not [BOS] or [PAD] to be chosen for replacement
825+ # (under the default condition that mask_transform.mask_bos=False)
826+ if test_mask_prob == 1.0 :
827+
828+ # when mask_mask_prob, rand_mask_prob = 0,0 no tokens should change
829+ with patch ("torchtext.transforms.MaskTransform.mask_mask_prob" , 0.0 ), patch (
830+ "torchtext.transforms.MaskTransform.rand_mask_prob" , 0.0
831+ ):
832+ masked_tokens , _ , _ = mask_transform (self .sample_token_ids )
833+ self .assertEqual (self .sample_token_ids , masked_tokens )
834+
835+ # when mask_mask_prob, rand_mask_prob = 0,1 we expect all tokens selected for replacement
836+ # to be changed to random token_ids. It is possible that the randomly selected token id is the same
837+ # as the original token id, however we know deterministically that [BOS] and [PAD] tokens
838+ # in the sequences will remain unchanged.
839+ with patch ("torchtext.transforms.MaskTransform.mask_mask_prob" , 0.0 ), patch (
840+ "torchtext.transforms.MaskTransform.rand_mask_prob" , 1.0
841+ ):
842+ masked_tokens , _ , _ = mask_transform (self .sample_token_ids )
843+ self .assertEqual (masked_tokens [:, 0 ], 6 * torch .ones_like (masked_tokens [:, 0 ]))
844+ self .assertEqual (masked_tokens [1 , 3 :], 4 * torch .ones_like (masked_tokens [1 , 3 :]))
845+
846+ # when mask_mask_prob, rand_mask_prob = 1,0 we expect all tokens selected for replacement to be changed to [MASK]
847+ with patch ("torchtext.transforms.MaskTransform.mask_mask_prob" , 1.0 ), patch (
848+ "torchtext.transforms.MaskTransform.rand_mask_prob" , 0.0
849+ ):
850+ masked_tokens , _ , _ = mask_transform (self .sample_token_ids )
851+ exp_tokens = torch .tensor ([[6 , 5 , 5 , 5 , 5 ], [6 , 5 , 5 , 4 , 4 ]])
852+ self .assertEqual (exp_tokens , masked_tokens )
853+
854+ def test_mask_transform_mask_bos (self ) -> None :
855+ # MaskTransform has boolean parameter mask_bos to indicate whether or not [BOS] tokens
856+ # should be eligible for replacement. The above tests of MaskTransform are under default value
857+ # mask_bos = False. Here we test the case where mask_bos = True
858+ mask_transform = MaskTransform (
859+ self .vocab_len - 1 , self .mask_idx , self .bos_idx , self .pad_idx , mask_bos = True , mask_prob = 1.0
860+ )
861+
862+ # when mask_mask_prob, rand_mask_prob = 1,0 we expect all tokens selected for replacement to be changed to [MASK]
863+ with patch ("torchtext.transforms.MaskTransform.mask_mask_prob" , 1.0 ), patch (
864+ "torchtext.transforms.MaskTransform.rand_mask_prob" , 0.0
865+ ):
866+ masked_tokens , _ , _ = mask_transform (self .sample_token_ids )
867+ exp_tokens = torch .tensor ([[5 , 5 , 5 , 5 , 5 ], [5 , 5 , 5 , 4 , 4 ]])
868+ self .assertEqual (exp_tokens , masked_tokens )
0 commit comments