Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 3a0d0a3

Browse files
authored
Graduate MaskTransform from prototype (#1882)
1 parent 2fd12f3 commit 3a0d0a3

File tree

4 files changed

+261
-264
lines changed

4 files changed

+261
-264
lines changed

test/prototype/test_transforms.py

Lines changed: 0 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
import shutil
33
import tempfile
4-
from unittest.mock import patch
54

65
import torch
76
from test.common.assets import get_asset_path
@@ -10,12 +9,9 @@
109
sentencepiece_processor,
1110
sentencepiece_tokenizer,
1211
VectorTransform,
13-
MaskTransform,
1412
)
1513
from torchtext.prototype.vectors import FastText
1614

17-
from ..common.parameterized_utils import nested_params
18-
1915

2016
class TestTransforms(TorchtextTestCase):
2117
def test_sentencepiece_processor(self) -> None:
@@ -140,119 +136,3 @@ def test_sentencepiece_load_and_save(self) -> None:
140136
torch.save(spm, save_path)
141137
loaded_spm = torch.load(save_path)
142138
self.assertEqual(expected, loaded_spm(input))
143-
144-
145-
class TestMaskTransform(TorchtextTestCase):
146-
147-
"""
148-
Testing under these assumed conditions:
149-
150-
Vocab maps the following tokens to the following ids:
151-
['a', 'b', 'c', 'd', '[PAD]', '[MASK]', '[BOS]'] -> [0, 1, 2, 3, 4, 5, 6]
152-
153-
The sample token sequences are:
154-
[["[BOS]", "a", "b", "c", "d"],
155-
["[BOS]", "a", "b", "[PAD]", "[PAD]"]]
156-
"""
157-
158-
sample_token_ids = torch.tensor([[6, 0, 1, 2, 3], [6, 0, 1, 4, 4]])
159-
160-
vocab_len = 7
161-
pad_idx = 4
162-
mask_idx = 5
163-
bos_idx = 6
164-
165-
@nested_params([0.0, 1.0])
166-
def test_mask_transform_probs(self, test_mask_prob):
167-
168-
# We pass (vocab_len - 1) into MaskTransform to test masking with a random token.
169-
# This modifies the distribution from which token ids are randomly selected such that the
170-
# largest token id availible for selection is 1 less than the actual largest token id in our
171-
# vocab, which we've assigned to the [BOS] token. This allows us to test random replacement
172-
# by ensuring that when the first token ([BOS]) in the first sample sequence is selected for random replacement,
173-
# we know with certainty the token it is replaced with is different from the [BOS] token.
174-
# In practice, however, the actual vocab length should be provided as the input parameter so that random
175-
# replacement selects from all possible tokens in the vocab.
176-
mask_transform = MaskTransform(
177-
self.vocab_len - 1, self.mask_idx, self.bos_idx, self.pad_idx, mask_bos=False, mask_prob=test_mask_prob
178-
)
179-
180-
# when mask_prob = 0, we expect the first token of the first sample sequence to be chosen for replacement
181-
if test_mask_prob == 0.0:
182-
183-
# when mask_mask_prob, rand_mask_prob = 0,0 no tokens should change
184-
with patch("torchtext.prototype.transforms.MaskTransform.mask_mask_prob", 0.0), patch(
185-
"torchtext.prototype.transforms.MaskTransform.rand_mask_prob", 0.0
186-
):
187-
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
188-
self.assertEqual(self.sample_token_ids, masked_tokens)
189-
190-
# when mask_mask_prob, rand_mask_prob = 0,1 we expect all tokens selected for replacement to be
191-
# changed to a random token_id
192-
with patch("torchtext.prototype.transforms.MaskTransform.mask_mask_prob", 0.0), patch(
193-
"torchtext.prototype.transforms.MaskTransform.rand_mask_prob", 1.0
194-
):
195-
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
196-
197-
# first token in first sequence should be different
198-
self.assertNotEqual(masked_tokens[0, 0], self.sample_token_ids[0, 0])
199-
# replaced token id should still be in vocab, not including [BOS]
200-
assert masked_tokens[0, 0] in range(self.vocab_len - 1)
201-
202-
# all other tokens except for first token of first sequence should remain the same
203-
self.assertEqual(self.sample_token_ids[0, 1:], masked_tokens[0, 1:])
204-
self.assertEqual(self.sample_token_ids[1], masked_tokens[1])
205-
206-
# when mask_mask_prob, rand_mask_prob = 1,0 we expect all tokens selected for replacement to be changed to [MASK]
207-
with patch("torchtext.prototype.transforms.MaskTransform.mask_mask_prob", 1.0), patch(
208-
"torchtext.prototype.transforms.MaskTransform.rand_mask_prob", 0.0
209-
):
210-
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
211-
exp_tokens = torch.tensor([[5, 0, 1, 2, 3], [6, 0, 1, 4, 4]])
212-
self.assertEqual(exp_tokens, masked_tokens)
213-
214-
# when mask_prob = 1, we expect all tokens that are not [BOS] or [PAD] to be chosen for replacement
215-
# (under the default condition that mask_transform.mask_bos=False)
216-
if test_mask_prob == 1.0:
217-
218-
# when mask_mask_prob, rand_mask_prob = 0,0 no tokens should change
219-
with patch("torchtext.prototype.transforms.MaskTransform.mask_mask_prob", 0.0), patch(
220-
"torchtext.prototype.transforms.MaskTransform.rand_mask_prob", 0.0
221-
):
222-
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
223-
self.assertEqual(self.sample_token_ids, masked_tokens)
224-
225-
# when mask_mask_prob, rand_mask_prob = 0,1 we expect all tokens selected for replacement
226-
# to be changed to random token_ids. It is possible that the randomly selected token id is the same
227-
# as the original token id, however we know deterministically that [BOS] and [PAD] tokens
228-
# in the sequences will remain unchanged.
229-
with patch("torchtext.prototype.transforms.MaskTransform.mask_mask_prob", 0.0), patch(
230-
"torchtext.prototype.transforms.MaskTransform.rand_mask_prob", 1.0
231-
):
232-
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
233-
self.assertEqual(masked_tokens[:, 0], 6 * torch.ones_like(masked_tokens[:, 0]))
234-
self.assertEqual(masked_tokens[1, 3:], 4 * torch.ones_like(masked_tokens[1, 3:]))
235-
236-
# when mask_mask_prob, rand_mask_prob = 1,0 we expect all tokens selected for replacement to be changed to [MASK]
237-
with patch("torchtext.prototype.transforms.MaskTransform.mask_mask_prob", 1.0), patch(
238-
"torchtext.prototype.transforms.MaskTransform.rand_mask_prob", 0.0
239-
):
240-
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
241-
exp_tokens = torch.tensor([[6, 5, 5, 5, 5], [6, 5, 5, 4, 4]])
242-
self.assertEqual(exp_tokens, masked_tokens)
243-
244-
def test_mask_transform_mask_bos(self) -> None:
245-
# MaskTransform has boolean parameter mask_bos to indicate whether or not [BOS] tokens
246-
# should be eligible for replacement. The above tests of MaskTransform are under default value
247-
# mask_bos = False. Here we test the case where mask_bos = True
248-
mask_transform = MaskTransform(
249-
self.vocab_len - 1, self.mask_idx, self.bos_idx, self.pad_idx, mask_bos=True, mask_prob=1.0
250-
)
251-
252-
# when mask_mask_prob, rand_mask_prob = 1,0 we expect all tokens selected for replacement to be changed to [MASK]
253-
with patch("torchtext.prototype.transforms.MaskTransform.mask_mask_prob", 1.0), patch(
254-
"torchtext.prototype.transforms.MaskTransform.rand_mask_prob", 0.0
255-
):
256-
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
257-
exp_tokens = torch.tensor([[5, 5, 5, 5, 5], [5, 5, 5, 4, 4]])
258-
self.assertEqual(exp_tokens, masked_tokens)

test/test_transforms.py

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import os
22
from collections import OrderedDict
3+
from unittest.mock import patch
34

45
import torch
56
from torchtext import transforms
6-
from torchtext.transforms import RegexTokenizer
7+
from torchtext.transforms import MaskTransform, RegexTokenizer
78
from torchtext.vocab import vocab
89

910
from .common.assets import get_asset_path
@@ -750,3 +751,119 @@ def test_regex_tokenizer_save_load(self) -> None:
750751
loaded_tokenizer = torch.jit.load(save_path)
751752
results = loaded_tokenizer(self.test_sample)
752753
self.assertEqual(results, self.ref_results)
754+
755+
756+
class TestMaskTransform(TorchtextTestCase):
757+
758+
"""
759+
Testing under these assumed conditions:
760+
761+
Vocab maps the following tokens to the following ids:
762+
['a', 'b', 'c', 'd', '[PAD]', '[MASK]', '[BOS]'] -> [0, 1, 2, 3, 4, 5, 6]
763+
764+
The sample token sequences are:
765+
[["[BOS]", "a", "b", "c", "d"],
766+
["[BOS]", "a", "b", "[PAD]", "[PAD]"]]
767+
"""
768+
769+
sample_token_ids = torch.tensor([[6, 0, 1, 2, 3], [6, 0, 1, 4, 4]])
770+
771+
vocab_len = 7
772+
pad_idx = 4
773+
mask_idx = 5
774+
bos_idx = 6
775+
776+
@nested_params([0.0, 1.0])
777+
def test_mask_transform_probs(self, test_mask_prob):
778+
779+
# We pass (vocab_len - 1) into MaskTransform to test masking with a random token.
780+
# This modifies the distribution from which token ids are randomly selected such that the
781+
# largest token id availible for selection is 1 less than the actual largest token id in our
782+
# vocab, which we've assigned to the [BOS] token. This allows us to test random replacement
783+
# by ensuring that when the first token ([BOS]) in the first sample sequence is selected for random replacement,
784+
# we know with certainty the token it is replaced with is different from the [BOS] token.
785+
# In practice, however, the actual vocab length should be provided as the input parameter so that random
786+
# replacement selects from all possible tokens in the vocab.
787+
mask_transform = MaskTransform(
788+
self.vocab_len - 1, self.mask_idx, self.bos_idx, self.pad_idx, mask_bos=False, mask_prob=test_mask_prob
789+
)
790+
791+
# when mask_prob = 0, we expect the first token of the first sample sequence to be chosen for replacement
792+
if test_mask_prob == 0.0:
793+
794+
# when mask_mask_prob, rand_mask_prob = 0,0 no tokens should change
795+
with patch("torchtext.transforms.MaskTransform.mask_mask_prob", 0.0), patch(
796+
"torchtext.transforms.MaskTransform.rand_mask_prob", 0.0
797+
):
798+
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
799+
self.assertEqual(self.sample_token_ids, masked_tokens)
800+
801+
# when mask_mask_prob, rand_mask_prob = 0,1 we expect all tokens selected for replacement to be
802+
# changed to a random token_id
803+
with patch("torchtext.transforms.MaskTransform.mask_mask_prob", 0.0), patch(
804+
"torchtext.transforms.MaskTransform.rand_mask_prob", 1.0
805+
):
806+
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
807+
808+
# first token in first sequence should be different
809+
self.assertNotEqual(masked_tokens[0, 0], self.sample_token_ids[0, 0])
810+
# replaced token id should still be in vocab, not including [BOS]
811+
assert masked_tokens[0, 0] in range(self.vocab_len - 1)
812+
813+
# all other tokens except for first token of first sequence should remain the same
814+
self.assertEqual(self.sample_token_ids[0, 1:], masked_tokens[0, 1:])
815+
self.assertEqual(self.sample_token_ids[1], masked_tokens[1])
816+
817+
# when mask_mask_prob, rand_mask_prob = 1,0 we expect all tokens selected for replacement to be changed to [MASK]
818+
with patch("torchtext.transforms.MaskTransform.mask_mask_prob", 1.0), patch(
819+
"torchtext.transforms.MaskTransform.rand_mask_prob", 0.0
820+
):
821+
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
822+
exp_tokens = torch.tensor([[5, 0, 1, 2, 3], [6, 0, 1, 4, 4]])
823+
self.assertEqual(exp_tokens, masked_tokens)
824+
825+
# when mask_prob = 1, we expect all tokens that are not [BOS] or [PAD] to be chosen for replacement
826+
# (under the default condition that mask_transform.mask_bos=False)
827+
if test_mask_prob == 1.0:
828+
829+
# when mask_mask_prob, rand_mask_prob = 0,0 no tokens should change
830+
with patch("torchtext.transforms.MaskTransform.mask_mask_prob", 0.0), patch(
831+
"torchtext.transforms.MaskTransform.rand_mask_prob", 0.0
832+
):
833+
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
834+
self.assertEqual(self.sample_token_ids, masked_tokens)
835+
836+
# when mask_mask_prob, rand_mask_prob = 0,1 we expect all tokens selected for replacement
837+
# to be changed to random token_ids. It is possible that the randomly selected token id is the same
838+
# as the original token id, however we know deterministically that [BOS] and [PAD] tokens
839+
# in the sequences will remain unchanged.
840+
with patch("torchtext.transforms.MaskTransform.mask_mask_prob", 0.0), patch(
841+
"torchtext.transforms.MaskTransform.rand_mask_prob", 1.0
842+
):
843+
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
844+
self.assertEqual(masked_tokens[:, 0], 6 * torch.ones_like(masked_tokens[:, 0]))
845+
self.assertEqual(masked_tokens[1, 3:], 4 * torch.ones_like(masked_tokens[1, 3:]))
846+
847+
# when mask_mask_prob, rand_mask_prob = 1,0 we expect all tokens selected for replacement to be changed to [MASK]
848+
with patch("torchtext.transforms.MaskTransform.mask_mask_prob", 1.0), patch(
849+
"torchtext.transforms.MaskTransform.rand_mask_prob", 0.0
850+
):
851+
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
852+
exp_tokens = torch.tensor([[6, 5, 5, 5, 5], [6, 5, 5, 4, 4]])
853+
self.assertEqual(exp_tokens, masked_tokens)
854+
855+
def test_mask_transform_mask_bos(self) -> None:
856+
# MaskTransform has boolean parameter mask_bos to indicate whether or not [BOS] tokens
857+
# should be eligible for replacement. The above tests of MaskTransform are under default value
858+
# mask_bos = False. Here we test the case where mask_bos = True
859+
mask_transform = MaskTransform(
860+
self.vocab_len - 1, self.mask_idx, self.bos_idx, self.pad_idx, mask_bos=True, mask_prob=1.0
861+
)
862+
863+
# when mask_mask_prob, rand_mask_prob = 1,0 we expect all tokens selected for replacement to be changed to [MASK]
864+
with patch("torchtext.transforms.MaskTransform.mask_mask_prob", 1.0), patch(
865+
"torchtext.transforms.MaskTransform.rand_mask_prob", 0.0
866+
):
867+
masked_tokens, _, _ = mask_transform(self.sample_token_ids)
868+
exp_tokens = torch.tensor([[5, 5, 5, 5, 5], [5, 5, 5, 4, 4]])
869+
self.assertEqual(exp_tokens, masked_tokens)

0 commit comments

Comments
 (0)