Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit b9ca901

Browse files
committed
Graduate MaskTransform from prototype
1 parent 2fd12f3 commit b9ca901

File tree

4 files changed

+259
-260
lines changed

4 files changed

+259
-260
lines changed

test/prototype/test_transforms.py

Lines changed: 0 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
sentencepiece_processor,
1111
sentencepiece_tokenizer,
1212
VectorTransform,
13-
MaskTransform,
1413
)
1514
from torchtext.prototype.vectors import FastText
1615

@@ -140,119 +139,3 @@ def test_sentencepiece_load_and_save(self) -> None:
140139
torch.save(spm, save_path)
141140
loaded_spm = torch.load(save_path)
142141
self.assertEqual(expected, loaded_spm(input))
143-
144-
145-
class TestMaskTransform(TorchtextTestCase):
146-
147-
"""
148-
Testing under these assumed conditions:
149-
150-
Vocab maps the following tokens to the following ids:
151-
['a', 'b', 'c', 'd', '[PAD]', '[MASK]', '[BOS]'] -> [0, 1, 2, 3, 4, 5, 6]
152-
153-
The sample token sequences are:
154-
[["[BOS]", "a", "b", "c", "d"],
155-
["[BOS]", "a", "b", "[PAD]", "[PAD]"]]
156-
"""
157-
158-
sample_token_ids = torch.tensor([[6, 0, 1, 2, 3], [6, 0, 1, 4, 4]])
159-
160-
vocab_len = 7
161-
pad_idx = 4
162-
mask_idx = 5
163-
bos_idx = 6
164-
165-
@nested_params([0.0, 1.0])
166-
def test_mask_transform_probs(self, test_mask_prob):
167-
168-
# We pass (vocab_len - 1) into MaskTransform to test masking with a random token.
169-
# This modifies the distribution from which token ids are randomly selected such that the
170-
# largest token id availible for selection is 1 less than the actual largest token id in our
171-
# vocab, which we've assigned to the [BOS] token. This allows us to test random replacement
172-
# by ensuring that when the first token ([BOS]) in the first sample sequence is selected for random replacement,
173-
# we know with certainty the token it is replaced with is different from the [BOS] token.
174-
# In practice, however, the actual vocab length should be provided as the input parameter so that random
175-
# replacement selects from all possible tokens in the vocab.
176-
mask_transform = MaskTransform(
177-
self.vocab_len - 1, self.mask_idx, self.bos_idx, self.pad_idx, mask_bos=False, mask_prob=test_mask_prob
178-
)
179-
180-
# when mask_prob = 0, we expect the first token of the first sample sequence to be chosen for replacement
181-
if test_mask_prob == 0.0:
182-
183-
# when mask_mask_prob, rand_mask_prob = 0,0 no tokens should change
184-
with patch("torchtext.prototype.transforms.MaskTransform.mask_mask_prob", 0.0), patch(
185-
"torchtext.prototype.transforms.MaskTransform.rand_mask_prob", 0.0
186-
):
187-
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
188-
self.assertEqual(self.sample_token_ids, masked_tokens)
189-
190-
# when mask_mask_prob, rand_mask_prob = 0,1 we expect all tokens selected for replacement to be
191-
# changed to a random token_id
192-
with patch("torchtext.prototype.transforms.MaskTransform.mask_mask_prob", 0.0), patch(
193-
"torchtext.prototype.transforms.MaskTransform.rand_mask_prob", 1.0
194-
):
195-
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
196-
197-
# first token in first sequence should be different
198-
self.assertNotEqual(masked_tokens[0, 0], self.sample_token_ids[0, 0])
199-
# replaced token id should still be in vocab, not including [BOS]
200-
assert masked_tokens[0, 0] in range(self.vocab_len - 1)
201-
202-
# all other tokens except for first token of first sequence should remain the same
203-
self.assertEqual(self.sample_token_ids[0, 1:], masked_tokens[0, 1:])
204-
self.assertEqual(self.sample_token_ids[1], masked_tokens[1])
205-
206-
# when mask_mask_prob, rand_mask_prob = 1,0 we expect all tokens selected for replacement to be changed to [MASK]
207-
with patch("torchtext.prototype.transforms.MaskTransform.mask_mask_prob", 1.0), patch(
208-
"torchtext.prototype.transforms.MaskTransform.rand_mask_prob", 0.0
209-
):
210-
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
211-
exp_tokens = torch.tensor([[5, 0, 1, 2, 3], [6, 0, 1, 4, 4]])
212-
self.assertEqual(exp_tokens, masked_tokens)
213-
214-
# when mask_prob = 1, we expect all tokens that are not [BOS] or [PAD] to be chosen for replacement
215-
# (under the default condition that mask_transform.mask_bos=False)
216-
if test_mask_prob == 1.0:
217-
218-
# when mask_mask_prob, rand_mask_prob = 0,0 no tokens should change
219-
with patch("torchtext.prototype.transforms.MaskTransform.mask_mask_prob", 0.0), patch(
220-
"torchtext.prototype.transforms.MaskTransform.rand_mask_prob", 0.0
221-
):
222-
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
223-
self.assertEqual(self.sample_token_ids, masked_tokens)
224-
225-
# when mask_mask_prob, rand_mask_prob = 0,1 we expect all tokens selected for replacement
226-
# to be changed to random token_ids. It is possible that the randomly selected token id is the same
227-
# as the original token id, however we know deterministically that [BOS] and [PAD] tokens
228-
# in the sequences will remain unchanged.
229-
with patch("torchtext.prototype.transforms.MaskTransform.mask_mask_prob", 0.0), patch(
230-
"torchtext.prototype.transforms.MaskTransform.rand_mask_prob", 1.0
231-
):
232-
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
233-
self.assertEqual(masked_tokens[:, 0], 6 * torch.ones_like(masked_tokens[:, 0]))
234-
self.assertEqual(masked_tokens[1, 3:], 4 * torch.ones_like(masked_tokens[1, 3:]))
235-
236-
# when mask_mask_prob, rand_mask_prob = 1,0 we expect all tokens selected for replacement to be changed to [MASK]
237-
with patch("torchtext.prototype.transforms.MaskTransform.mask_mask_prob", 1.0), patch(
238-
"torchtext.prototype.transforms.MaskTransform.rand_mask_prob", 0.0
239-
):
240-
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
241-
exp_tokens = torch.tensor([[6, 5, 5, 5, 5], [6, 5, 5, 4, 4]])
242-
self.assertEqual(exp_tokens, masked_tokens)
243-
244-
def test_mask_transform_mask_bos(self) -> None:
245-
# MaskTransform has boolean parameter mask_bos to indicate whether or not [BOS] tokens
246-
# should be eligible for replacement. The above tests of MaskTransform are under default value
247-
# mask_bos = False. Here we test the case where mask_bos = True
248-
mask_transform = MaskTransform(
249-
self.vocab_len - 1, self.mask_idx, self.bos_idx, self.pad_idx, mask_bos=True, mask_prob=1.0
250-
)
251-
252-
# when mask_mask_prob, rand_mask_prob = 1,0 we expect all tokens selected for replacement to be changed to [MASK]
253-
with patch("torchtext.prototype.transforms.MaskTransform.mask_mask_prob", 1.0), patch(
254-
"torchtext.prototype.transforms.MaskTransform.rand_mask_prob", 0.0
255-
):
256-
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
257-
exp_tokens = torch.tensor([[5, 5, 5, 5, 5], [5, 5, 5, 4, 4]])
258-
self.assertEqual(exp_tokens, masked_tokens)

test/test_transforms.py

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import os
22
from collections import OrderedDict
3+
from unittest.mock import patch
34

45
import torch
56
from torchtext import transforms
6-
from torchtext.transforms import RegexTokenizer
7+
from torchtext.transforms import MaskTransform, RegexTokenizer
78
from torchtext.vocab import vocab
89

910
from .common.assets import get_asset_path
@@ -750,3 +751,118 @@ def test_regex_tokenizer_save_load(self) -> None:
750751
loaded_tokenizer = torch.jit.load(save_path)
751752
results = loaded_tokenizer(self.test_sample)
752753
self.assertEqual(results, self.ref_results)
754+
755+
class TestMaskTransform(TorchtextTestCase):
756+
757+
"""
758+
Testing under these assumed conditions:
759+
760+
Vocab maps the following tokens to the following ids:
761+
['a', 'b', 'c', 'd', '[PAD]', '[MASK]', '[BOS]'] -> [0, 1, 2, 3, 4, 5, 6]
762+
763+
The sample token sequences are:
764+
[["[BOS]", "a", "b", "c", "d"],
765+
["[BOS]", "a", "b", "[PAD]", "[PAD]"]]
766+
"""
767+
768+
sample_token_ids = torch.tensor([[6, 0, 1, 2, 3], [6, 0, 1, 4, 4]])
769+
770+
vocab_len = 7
771+
pad_idx = 4
772+
mask_idx = 5
773+
bos_idx = 6
774+
775+
@nested_params([0.0, 1.0])
776+
def test_mask_transform_probs(self, test_mask_prob):
777+
778+
# We pass (vocab_len - 1) into MaskTransform to test masking with a random token.
779+
# This modifies the distribution from which token ids are randomly selected such that the
780+
# largest token id availible for selection is 1 less than the actual largest token id in our
781+
# vocab, which we've assigned to the [BOS] token. This allows us to test random replacement
782+
# by ensuring that when the first token ([BOS]) in the first sample sequence is selected for random replacement,
783+
# we know with certainty the token it is replaced with is different from the [BOS] token.
784+
# In practice, however, the actual vocab length should be provided as the input parameter so that random
785+
# replacement selects from all possible tokens in the vocab.
786+
mask_transform = MaskTransform(
787+
self.vocab_len - 1, self.mask_idx, self.bos_idx, self.pad_idx, mask_bos=False, mask_prob=test_mask_prob
788+
)
789+
790+
# when mask_prob = 0, we expect the first token of the first sample sequence to be chosen for replacement
791+
if test_mask_prob == 0.0:
792+
793+
# when mask_mask_prob, rand_mask_prob = 0,0 no tokens should change
794+
with patch("torchtext.transforms.MaskTransform.mask_mask_prob", 0.0), patch(
795+
"torchtext.transforms.MaskTransform.rand_mask_prob", 0.0
796+
):
797+
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
798+
self.assertEqual(self.sample_token_ids, masked_tokens)
799+
800+
# when mask_mask_prob, rand_mask_prob = 0,1 we expect all tokens selected for replacement to be
801+
# changed to a random token_id
802+
with patch("torchtext.transforms.MaskTransform.mask_mask_prob", 0.0), patch(
803+
"torchtext.transforms.MaskTransform.rand_mask_prob", 1.0
804+
):
805+
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
806+
807+
# first token in first sequence should be different
808+
self.assertNotEqual(masked_tokens[0, 0], self.sample_token_ids[0, 0])
809+
# replaced token id should still be in vocab, not including [BOS]
810+
assert masked_tokens[0, 0] in range(self.vocab_len - 1)
811+
812+
# all other tokens except for first token of first sequence should remain the same
813+
self.assertEqual(self.sample_token_ids[0, 1:], masked_tokens[0, 1:])
814+
self.assertEqual(self.sample_token_ids[1], masked_tokens[1])
815+
816+
# when mask_mask_prob, rand_mask_prob = 1,0 we expect all tokens selected for replacement to be changed to [MASK]
817+
with patch("torchtext.transforms.MaskTransform.mask_mask_prob", 1.0), patch(
818+
"torchtext.transforms.MaskTransform.rand_mask_prob", 0.0
819+
):
820+
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
821+
exp_tokens = torch.tensor([[5, 0, 1, 2, 3], [6, 0, 1, 4, 4]])
822+
self.assertEqual(exp_tokens, masked_tokens)
823+
824+
# when mask_prob = 1, we expect all tokens that are not [BOS] or [PAD] to be chosen for replacement
825+
# (under the default condition that mask_transform.mask_bos=False)
826+
if test_mask_prob == 1.0:
827+
828+
# when mask_mask_prob, rand_mask_prob = 0,0 no tokens should change
829+
with patch("torchtext.transforms.MaskTransform.mask_mask_prob", 0.0), patch(
830+
"torchtext.transforms.MaskTransform.rand_mask_prob", 0.0
831+
):
832+
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
833+
self.assertEqual(self.sample_token_ids, masked_tokens)
834+
835+
# when mask_mask_prob, rand_mask_prob = 0,1 we expect all tokens selected for replacement
836+
# to be changed to random token_ids. It is possible that the randomly selected token id is the same
837+
# as the original token id, however we know deterministically that [BOS] and [PAD] tokens
838+
# in the sequences will remain unchanged.
839+
with patch("torchtext.transforms.MaskTransform.mask_mask_prob", 0.0), patch(
840+
"torchtext.transforms.MaskTransform.rand_mask_prob", 1.0
841+
):
842+
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
843+
self.assertEqual(masked_tokens[:, 0], 6 * torch.ones_like(masked_tokens[:, 0]))
844+
self.assertEqual(masked_tokens[1, 3:], 4 * torch.ones_like(masked_tokens[1, 3:]))
845+
846+
# when mask_mask_prob, rand_mask_prob = 1,0 we expect all tokens selected for replacement to be changed to [MASK]
847+
with patch("torchtext.transforms.MaskTransform.mask_mask_prob", 1.0), patch(
848+
"torchtext.transforms.MaskTransform.rand_mask_prob", 0.0
849+
):
850+
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
851+
exp_tokens = torch.tensor([[6, 5, 5, 5, 5], [6, 5, 5, 4, 4]])
852+
self.assertEqual(exp_tokens, masked_tokens)
853+
854+
def test_mask_transform_mask_bos(self) -> None:
855+
# MaskTransform has boolean parameter mask_bos to indicate whether or not [BOS] tokens
856+
# should be eligible for replacement. The above tests of MaskTransform are under default value
857+
# mask_bos = False. Here we test the case where mask_bos = True
858+
mask_transform = MaskTransform(
859+
self.vocab_len - 1, self.mask_idx, self.bos_idx, self.pad_idx, mask_bos=True, mask_prob=1.0
860+
)
861+
862+
# when mask_mask_prob, rand_mask_prob = 1,0 we expect all tokens selected for replacement to be changed to [MASK]
863+
with patch("torchtext.transforms.MaskTransform.mask_mask_prob", 1.0), patch(
864+
"torchtext.transforms.MaskTransform.rand_mask_prob", 0.0
865+
):
866+
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
867+
exp_tokens = torch.tensor([[5, 5, 5, 5, 5], [5, 5, 5, 4, 4]])
868+
self.assertEqual(exp_tokens, masked_tokens)

0 commit comments

Comments
 (0)