diff --git a/test/prototype/test_transforms.py b/test/prototype/test_transforms.py index ee1a291cd2..71e9c02f74 100644 --- a/test/prototype/test_transforms.py +++ b/test/prototype/test_transforms.py @@ -1,7 +1,6 @@ import os import shutil import tempfile -from unittest.mock import patch import torch from test.common.assets import get_asset_path @@ -10,12 +9,9 @@ sentencepiece_processor, sentencepiece_tokenizer, VectorTransform, - MaskTransform, ) from torchtext.prototype.vectors import FastText -from ..common.parameterized_utils import nested_params - class TestTransforms(TorchtextTestCase): def test_sentencepiece_processor(self) -> None: @@ -140,119 +136,3 @@ def test_sentencepiece_load_and_save(self) -> None: torch.save(spm, save_path) loaded_spm = torch.load(save_path) self.assertEqual(expected, loaded_spm(input)) - - -class TestMaskTransform(TorchtextTestCase): - - """ - Testing under these assumed conditions: - - Vocab maps the following tokens to the following ids: - ['a', 'b', 'c', 'd', '[PAD]', '[MASK]', '[BOS]'] -> [0, 1, 2, 3, 4, 5, 6] - - The sample token sequences are: - [["[BOS]", "a", "b", "c", "d"], - ["[BOS]", "a", "b", "[PAD]", "[PAD]"]] - """ - - sample_token_ids = torch.tensor([[6, 0, 1, 2, 3], [6, 0, 1, 4, 4]]) - - vocab_len = 7 - pad_idx = 4 - mask_idx = 5 - bos_idx = 6 - - @nested_params([0.0, 1.0]) - def test_mask_transform_probs(self, test_mask_prob): - - # We pass (vocab_len - 1) into MaskTransform to test masking with a random token. - # This modifies the distribution from which token ids are randomly selected such that the - # largest token id availible for selection is 1 less than the actual largest token id in our - # vocab, which we've assigned to the [BOS] token. This allows us to test random replacement - # by ensuring that when the first token ([BOS]) in the first sample sequence is selected for random replacement, - # we know with certainty the token it is replaced with is different from the [BOS] token. - # In practice, however, the actual vocab length should be provided as the input parameter so that random - # replacement selects from all possible tokens in the vocab. - mask_transform = MaskTransform( - self.vocab_len - 1, self.mask_idx, self.bos_idx, self.pad_idx, mask_bos=False, mask_prob=test_mask_prob - ) - - # when mask_prob = 0, we expect the first token of the first sample sequence to be chosen for replacement - if test_mask_prob == 0.0: - - # when mask_mask_prob, rand_mask_prob = 0,0 no tokens should change - with patch("torchtext.prototype.transforms.MaskTransform.mask_mask_prob", 0.0), patch( - "torchtext.prototype.transforms.MaskTransform.rand_mask_prob", 0.0 - ): - masked_tokens, _, _ = mask_transform(self.sample_token_ids) - self.assertEqual(self.sample_token_ids, masked_tokens) - - # when mask_mask_prob, rand_mask_prob = 0,1 we expect all tokens selected for replacement to be - # changed to a random token_id - with patch("torchtext.prototype.transforms.MaskTransform.mask_mask_prob", 0.0), patch( - "torchtext.prototype.transforms.MaskTransform.rand_mask_prob", 1.0 - ): - masked_tokens, _, _ = mask_transform(self.sample_token_ids) - - # first token in first sequence should be different - self.assertNotEqual(masked_tokens[0, 0], self.sample_token_ids[0, 0]) - # replaced token id should still be in vocab, not including [BOS] - assert masked_tokens[0, 0] in range(self.vocab_len - 1) - - # all other tokens except for first token of first sequence should remain the same - self.assertEqual(self.sample_token_ids[0, 1:], masked_tokens[0, 1:]) - self.assertEqual(self.sample_token_ids[1], masked_tokens[1]) - - # when mask_mask_prob, rand_mask_prob = 1,0 we expect all tokens selected for replacement to be changed to [MASK] - with patch("torchtext.prototype.transforms.MaskTransform.mask_mask_prob", 1.0), patch( - "torchtext.prototype.transforms.MaskTransform.rand_mask_prob", 0.0 - ): - masked_tokens, _, _ = mask_transform(self.sample_token_ids) - exp_tokens = torch.tensor([[5, 0, 1, 2, 3], [6, 0, 1, 4, 4]]) - self.assertEqual(exp_tokens, masked_tokens) - - # when mask_prob = 1, we expect all tokens that are not [BOS] or [PAD] to be chosen for replacement - # (under the default condition that mask_transform.mask_bos=False) - if test_mask_prob == 1.0: - - # when mask_mask_prob, rand_mask_prob = 0,0 no tokens should change - with patch("torchtext.prototype.transforms.MaskTransform.mask_mask_prob", 0.0), patch( - "torchtext.prototype.transforms.MaskTransform.rand_mask_prob", 0.0 - ): - masked_tokens, _, _ = mask_transform(self.sample_token_ids) - self.assertEqual(self.sample_token_ids, masked_tokens) - - # when mask_mask_prob, rand_mask_prob = 0,1 we expect all tokens selected for replacement - # to be changed to random token_ids. It is possible that the randomly selected token id is the same - # as the original token id, however we know deterministically that [BOS] and [PAD] tokens - # in the sequences will remain unchanged. - with patch("torchtext.prototype.transforms.MaskTransform.mask_mask_prob", 0.0), patch( - "torchtext.prototype.transforms.MaskTransform.rand_mask_prob", 1.0 - ): - masked_tokens, _, _ = mask_transform(self.sample_token_ids) - self.assertEqual(masked_tokens[:, 0], 6 * torch.ones_like(masked_tokens[:, 0])) - self.assertEqual(masked_tokens[1, 3:], 4 * torch.ones_like(masked_tokens[1, 3:])) - - # when mask_mask_prob, rand_mask_prob = 1,0 we expect all tokens selected for replacement to be changed to [MASK] - with patch("torchtext.prototype.transforms.MaskTransform.mask_mask_prob", 1.0), patch( - "torchtext.prototype.transforms.MaskTransform.rand_mask_prob", 0.0 - ): - masked_tokens, _, _ = mask_transform(self.sample_token_ids) - exp_tokens = torch.tensor([[6, 5, 5, 5, 5], [6, 5, 5, 4, 4]]) - self.assertEqual(exp_tokens, masked_tokens) - - def test_mask_transform_mask_bos(self) -> None: - # MaskTransform has boolean parameter mask_bos to indicate whether or not [BOS] tokens - # should be eligible for replacement. The above tests of MaskTransform are under default value - # mask_bos = False. Here we test the case where mask_bos = True - mask_transform = MaskTransform( - self.vocab_len - 1, self.mask_idx, self.bos_idx, self.pad_idx, mask_bos=True, mask_prob=1.0 - ) - - # when mask_mask_prob, rand_mask_prob = 1,0 we expect all tokens selected for replacement to be changed to [MASK] - with patch("torchtext.prototype.transforms.MaskTransform.mask_mask_prob", 1.0), patch( - "torchtext.prototype.transforms.MaskTransform.rand_mask_prob", 0.0 - ): - masked_tokens, _, _ = mask_transform(self.sample_token_ids) - exp_tokens = torch.tensor([[5, 5, 5, 5, 5], [5, 5, 5, 4, 4]]) - self.assertEqual(exp_tokens, masked_tokens) diff --git a/test/test_transforms.py b/test/test_transforms.py index 2dc11cb6d0..76f84b66aa 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1,9 +1,10 @@ import os from collections import OrderedDict +from unittest.mock import patch import torch from torchtext import transforms -from torchtext.transforms import RegexTokenizer +from torchtext.transforms import MaskTransform, RegexTokenizer from torchtext.vocab import vocab from .common.assets import get_asset_path @@ -750,3 +751,119 @@ def test_regex_tokenizer_save_load(self) -> None: loaded_tokenizer = torch.jit.load(save_path) results = loaded_tokenizer(self.test_sample) self.assertEqual(results, self.ref_results) + + +class TestMaskTransform(TorchtextTestCase): + + """ + Testing under these assumed conditions: + + Vocab maps the following tokens to the following ids: + ['a', 'b', 'c', 'd', '[PAD]', '[MASK]', '[BOS]'] -> [0, 1, 2, 3, 4, 5, 6] + + The sample token sequences are: + [["[BOS]", "a", "b", "c", "d"], + ["[BOS]", "a", "b", "[PAD]", "[PAD]"]] + """ + + sample_token_ids = torch.tensor([[6, 0, 1, 2, 3], [6, 0, 1, 4, 4]]) + + vocab_len = 7 + pad_idx = 4 + mask_idx = 5 + bos_idx = 6 + + @nested_params([0.0, 1.0]) + def test_mask_transform_probs(self, test_mask_prob): + + # We pass (vocab_len - 1) into MaskTransform to test masking with a random token. + # This modifies the distribution from which token ids are randomly selected such that the + # largest token id availible for selection is 1 less than the actual largest token id in our + # vocab, which we've assigned to the [BOS] token. This allows us to test random replacement + # by ensuring that when the first token ([BOS]) in the first sample sequence is selected for random replacement, + # we know with certainty the token it is replaced with is different from the [BOS] token. + # In practice, however, the actual vocab length should be provided as the input parameter so that random + # replacement selects from all possible tokens in the vocab. + mask_transform = MaskTransform( + self.vocab_len - 1, self.mask_idx, self.bos_idx, self.pad_idx, mask_bos=False, mask_prob=test_mask_prob + ) + + # when mask_prob = 0, we expect the first token of the first sample sequence to be chosen for replacement + if test_mask_prob == 0.0: + + # when mask_mask_prob, rand_mask_prob = 0,0 no tokens should change + with patch("torchtext.transforms.MaskTransform.mask_mask_prob", 0.0), patch( + "torchtext.transforms.MaskTransform.rand_mask_prob", 0.0 + ): + masked_tokens, _, _ = mask_transform(self.sample_token_ids) + self.assertEqual(self.sample_token_ids, masked_tokens) + + # when mask_mask_prob, rand_mask_prob = 0,1 we expect all tokens selected for replacement to be + # changed to a random token_id + with patch("torchtext.transforms.MaskTransform.mask_mask_prob", 0.0), patch( + "torchtext.transforms.MaskTransform.rand_mask_prob", 1.0 + ): + masked_tokens, _, _ = mask_transform(self.sample_token_ids) + + # first token in first sequence should be different + self.assertNotEqual(masked_tokens[0, 0], self.sample_token_ids[0, 0]) + # replaced token id should still be in vocab, not including [BOS] + assert masked_tokens[0, 0] in range(self.vocab_len - 1) + + # all other tokens except for first token of first sequence should remain the same + self.assertEqual(self.sample_token_ids[0, 1:], masked_tokens[0, 1:]) + self.assertEqual(self.sample_token_ids[1], masked_tokens[1]) + + # when mask_mask_prob, rand_mask_prob = 1,0 we expect all tokens selected for replacement to be changed to [MASK] + with patch("torchtext.transforms.MaskTransform.mask_mask_prob", 1.0), patch( + "torchtext.transforms.MaskTransform.rand_mask_prob", 0.0 + ): + masked_tokens, _, _ = mask_transform(self.sample_token_ids) + exp_tokens = torch.tensor([[5, 0, 1, 2, 3], [6, 0, 1, 4, 4]]) + self.assertEqual(exp_tokens, masked_tokens) + + # when mask_prob = 1, we expect all tokens that are not [BOS] or [PAD] to be chosen for replacement + # (under the default condition that mask_transform.mask_bos=False) + if test_mask_prob == 1.0: + + # when mask_mask_prob, rand_mask_prob = 0,0 no tokens should change + with patch("torchtext.transforms.MaskTransform.mask_mask_prob", 0.0), patch( + "torchtext.transforms.MaskTransform.rand_mask_prob", 0.0 + ): + masked_tokens, _, _ = mask_transform(self.sample_token_ids) + self.assertEqual(self.sample_token_ids, masked_tokens) + + # when mask_mask_prob, rand_mask_prob = 0,1 we expect all tokens selected for replacement + # to be changed to random token_ids. It is possible that the randomly selected token id is the same + # as the original token id, however we know deterministically that [BOS] and [PAD] tokens + # in the sequences will remain unchanged. + with patch("torchtext.transforms.MaskTransform.mask_mask_prob", 0.0), patch( + "torchtext.transforms.MaskTransform.rand_mask_prob", 1.0 + ): + masked_tokens, _, _ = mask_transform(self.sample_token_ids) + self.assertEqual(masked_tokens[:, 0], 6 * torch.ones_like(masked_tokens[:, 0])) + self.assertEqual(masked_tokens[1, 3:], 4 * torch.ones_like(masked_tokens[1, 3:])) + + # when mask_mask_prob, rand_mask_prob = 1,0 we expect all tokens selected for replacement to be changed to [MASK] + with patch("torchtext.transforms.MaskTransform.mask_mask_prob", 1.0), patch( + "torchtext.transforms.MaskTransform.rand_mask_prob", 0.0 + ): + masked_tokens, _, _ = mask_transform(self.sample_token_ids) + exp_tokens = torch.tensor([[6, 5, 5, 5, 5], [6, 5, 5, 4, 4]]) + self.assertEqual(exp_tokens, masked_tokens) + + def test_mask_transform_mask_bos(self) -> None: + # MaskTransform has boolean parameter mask_bos to indicate whether or not [BOS] tokens + # should be eligible for replacement. The above tests of MaskTransform are under default value + # mask_bos = False. Here we test the case where mask_bos = True + mask_transform = MaskTransform( + self.vocab_len - 1, self.mask_idx, self.bos_idx, self.pad_idx, mask_bos=True, mask_prob=1.0 + ) + + # when mask_mask_prob, rand_mask_prob = 1,0 we expect all tokens selected for replacement to be changed to [MASK] + with patch("torchtext.transforms.MaskTransform.mask_mask_prob", 1.0), patch( + "torchtext.transforms.MaskTransform.rand_mask_prob", 0.0 + ): + masked_tokens, _, _ = mask_transform(self.sample_token_ids) + exp_tokens = torch.tensor([[5, 5, 5, 5, 5], [5, 5, 5, 4, 4]]) + self.assertEqual(exp_tokens, masked_tokens) diff --git a/torchtext/prototype/transforms.py b/torchtext/prototype/transforms.py index e11f88fa77..f837894e92 100644 --- a/torchtext/prototype/transforms.py +++ b/torchtext/prototype/transforms.py @@ -1,5 +1,5 @@ import io -from typing import List, Tuple +from typing import List import torch import torch.nn as nn @@ -340,144 +340,3 @@ def forward(self, tokens: List[str]) -> Tensor: """ return self.vector.lookup_vectors(tokens) - - -class MaskTransform(nn.Module): - """ - The transform chooses mask_prob% (example 15%) of the token positions at random for - prediction. - - If the i-th token is chosen, we replace the i-th token with - (1) the [MASK] token 80% of the time - (2) a random token 10% of the time - (3) the unchanged i-th token 10% of the time. - - Args: - vocab_len (int): the length of the vocabulary, including special tokens such as [BOS], [PAD], [MASK] - mask_idx (int): index assigned to mask token in vocabulary - bos_idx (int): index assigned to beginning-of-sequence token in vocabulary - pad_idx (int): index assigned to padding token in vocabulary - mask_bos (bool): indicate whether beginning-of-sequence tokens are eligible for masking (default: False) - mask_prob (float): probability that a token is chosen for replacement (default: 0.15) - - Example: - >>> import torch - >>> from torchtext.experimental.transforms import MaskTransform - >>> sample_tokens = [ - ["[BOS]", "a", "b", "c", "d"], - ["[BOS]", "a", "b", "[PAD]", "[PAD]"] - ] - >>> sample_token_ids = torch.tensor([ - [6, 0, 1, 2, 3], [6, 0, 1, 4, 4] - ]) - >>> mask_transform = MaskTransform( - vocab_len = 7, - mask_idx = 4, - bos_idx = 6, - pad_idx = 5, - mask_bos = False, - mask_prob = 0.15 - ) - >>> masked_tokens, target_tokens, mask = mask_transform(sample_token_ids) - """ - - # maks_mask_prob is prob. of replacing a token with [MASK] (ex. 80%) - mask_mask_prob = 0.8 - - # rand_mask_thresh is prob. of replacing a token with a random token. (ex.10%) - rand_mask_prob = 0.1 - - def __init__( - self, - vocab_len: int, - mask_idx: int, - bos_idx: int, - pad_idx: int, - mask_bos: bool = False, - mask_prob: float = 0.15, - ): - super().__init__() - self.vocab_len = vocab_len - self.mask_idx = mask_idx - self.bos_idx = bos_idx - self.pad_idx = pad_idx - self.mask_prob = mask_prob - self.mask_bos = mask_bos - - def forward(self, tokens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Applies mask to input tokens. - - Inputs: - tokens: Tensor with token ids of shape (batch_size x seq_len). Includes token ids for special tokens such as [BOS] and [PAD] - - Outputs: - masked_tokens: Tensor of tokens after masking has been applied - target_tokens: Tensor of token values selected for masking - mask: Tensor with same shape as input tokens (batch_size x seq_len) - with masked tokens represented by a 1 and everything else as 0. - """ - # tokens, mask, mask_mask, rand_mask: (T, C) - mask, mask_mask, rand_mask = self._generate_mask(tokens) - - # a. generate the masked input tokens - # (1) the [MASK] token 80% of the time - masked_tokens = self._mask_input(tokens, mask_mask, self.mask_idx) - # (2) a random token 10% of the time - masked_tokens = self._mask_input( - masked_tokens, - rand_mask, - torch.randint_like(tokens, high=self.vocab_len), - ) - - # b. generate the target prediction - target_tokens = torch.masked_select(tokens, mask.bool()) - - # masked_tokens: (T, C), target_tokens: (T x C x mask_prob, ), mask - return masked_tokens, target_tokens, mask - - def _random_masking(self, tokens: torch.tensor, mask_prob: float) -> torch.Tensor: - """ - Function to mask tokens randomly. - - Inputs: - 1) tokens: Tensor with token ids of shape (batch_size x seq_len). Includes token ids for special tokens such as [BOS] and [PAD] - 2) mask_prob: Probability of masking a particular token - - Outputs: - mask: Tensor with same shape as input tokens (batch_size x seq_len) - with masked tokens represented by a 1 and everything else as 0. - """ - batch_size, seq_len = tokens.size() - num_masked_per_seq = int(seq_len * mask_prob) - - mask = torch.zeros((batch_size, seq_len), dtype=torch.int).to(tokens.device) - mask[:, :num_masked_per_seq] = 1 - for i in range(batch_size): - mask[i] = mask[i, torch.randperm(seq_len)] - - return mask - - def _select_tokens_to_mask(self, tokens: torch.Tensor, mask_prob: float) -> torch.Tensor: - mask = self._random_masking(tokens, mask_prob) - if not self.mask_bos: - mask *= (tokens != self.bos_idx).long() - return mask - - def _generate_mask(self, tokens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - # chooses mask_prob% of the token positions at random - mask = self._select_tokens_to_mask(tokens, self.mask_prob) - # not mask the pad token - mask *= (tokens != self.pad_idx).long() - # keep one masked token to avoid failure in the loss calculation. - mask[0, 0] = 1 if not mask.byte().any() else mask[0, 0] - - probs = torch.rand_like(tokens, dtype=torch.float) - # (1) the [MASK] token 80% of the time - mask_mask = (probs >= (1 - self.mask_mask_prob)).long() * mask - # (2) a random token 10% of the time - rand_mask = (probs < self.rand_mask_prob).long() * mask - return mask, mask_mask, rand_mask - - def _mask_input(self, tokens: torch.Tensor, mask: torch.Tensor, replacement) -> torch.Tensor: - return tokens * (1 - mask) + replacement * mask diff --git a/torchtext/transforms.py b/torchtext/transforms.py index 60ab52df34..84e93aa3cc 100644 --- a/torchtext/transforms.py +++ b/torchtext/transforms.py @@ -1,7 +1,7 @@ import json from copy import deepcopy from functools import lru_cache -from typing import Any, List, Optional, Union +from typing import Any, List, Optional, Tuple, Union import torch import torchtext # noqa: F401 @@ -767,3 +767,144 @@ def forward(self, input: Any) -> Any: for module in self: input = module(input) return input + + +class MaskTransform(torch.nn.Module): + """ + The transform chooses mask_prob% (example 15%) of the token positions at random for + prediction. + + If the i-th token is chosen, we replace the i-th token with + (1) the [MASK] token 80% of the time + (2) a random token 10% of the time + (3) the unchanged i-th token 10% of the time. + + Args: + vocab_len (int): the length of the vocabulary, including special tokens such as [BOS], [PAD], [MASK] + mask_idx (int): index assigned to mask token in vocabulary + bos_idx (int): index assigned to beginning-of-sequence token in vocabulary + pad_idx (int): index assigned to padding token in vocabulary + mask_bos (bool): indicate whether beginning-of-sequence tokens are eligible for masking (default: False) + mask_prob (float): probability that a token is chosen for replacement (default: 0.15) + + Example: + >>> import torch + >>> from torchtext.transforms import MaskTransform + >>> sample_tokens = [ + ["[BOS]", "a", "b", "c", "d"], + ["[BOS]", "a", "b", "[PAD]", "[PAD]"] + ] + >>> sample_token_ids = torch.tensor([ + [6, 0, 1, 2, 3], [6, 0, 1, 4, 4] + ]) + >>> mask_transform = MaskTransform( + vocab_len = 7, + mask_idx = 4, + bos_idx = 6, + pad_idx = 5, + mask_bos = False, + mask_prob = 0.15 + ) + >>> masked_tokens, target_tokens, mask = mask_transform(sample_token_ids) + """ + + # maks_mask_prob is prob. of replacing a token with [MASK] (ex. 80%) + mask_mask_prob = 0.8 + + # rand_mask_thresh is prob. of replacing a token with a random token. (ex.10%) + rand_mask_prob = 0.1 + + def __init__( + self, + vocab_len: int, + mask_idx: int, + bos_idx: int, + pad_idx: int, + mask_bos: bool = False, + mask_prob: float = 0.15, + ): + super().__init__() + self.vocab_len = vocab_len + self.mask_idx = mask_idx + self.bos_idx = bos_idx + self.pad_idx = pad_idx + self.mask_prob = mask_prob + self.mask_bos = mask_bos + + def forward(self, tokens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Applies mask to input tokens. + + Inputs: + tokens: Tensor with token ids of shape (batch_size x seq_len). Includes token ids for special tokens such as [BOS] and [PAD] + + Outputs: + masked_tokens: Tensor of tokens after masking has been applied + target_tokens: Tensor of token values selected for masking + mask: Tensor with same shape as input tokens (batch_size x seq_len) + with masked tokens represented by a 1 and everything else as 0. + """ + # tokens, mask, mask_mask, rand_mask: (T, C) + mask, mask_mask, rand_mask = self._generate_mask(tokens) + + # a. generate the masked input tokens + # (1) the [MASK] token 80% of the time + masked_tokens = self._mask_input(tokens, mask_mask, self.mask_idx) + # (2) a random token 10% of the time + masked_tokens = self._mask_input( + masked_tokens, + rand_mask, + torch.randint_like(tokens, high=self.vocab_len), + ) + + # b. generate the target prediction + target_tokens = torch.masked_select(tokens, mask.bool()) + + # masked_tokens: (T, C), target_tokens: (T x C x mask_prob, ), mask + return masked_tokens, target_tokens, mask + + def _random_masking(self, tokens: torch.tensor, mask_prob: float) -> torch.Tensor: + """ + Function to mask tokens randomly. + + Inputs: + 1) tokens: Tensor with token ids of shape (batch_size x seq_len). Includes token ids for special tokens such as [BOS] and [PAD] + 2) mask_prob: Probability of masking a particular token + + Outputs: + mask: Tensor with same shape as input tokens (batch_size x seq_len) + with masked tokens represented by a 1 and everything else as 0. + """ + batch_size, seq_len = tokens.size() + num_masked_per_seq = int(seq_len * mask_prob) + + mask = torch.zeros((batch_size, seq_len), dtype=torch.int).to(tokens.device) + mask[:, :num_masked_per_seq] = 1 + for i in range(batch_size): + mask[i] = mask[i, torch.randperm(seq_len)] + + return mask + + def _select_tokens_to_mask(self, tokens: torch.Tensor, mask_prob: float) -> torch.Tensor: + mask = self._random_masking(tokens, mask_prob) + if not self.mask_bos: + mask *= (tokens != self.bos_idx).long() + return mask + + def _generate_mask(self, tokens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # chooses mask_prob% of the token positions at random + mask = self._select_tokens_to_mask(tokens, self.mask_prob) + # not mask the pad token + mask *= (tokens != self.pad_idx).long() + # keep one masked token to avoid failure in the loss calculation. + mask[0, 0] = 1 if not mask.byte().any() else mask[0, 0] + + probs = torch.rand_like(tokens, dtype=torch.float) + # (1) the [MASK] token 80% of the time + mask_mask = (probs >= (1 - self.mask_mask_prob)).long() * mask + # (2) a random token 10% of the time + rand_mask = (probs < self.rand_mask_prob).long() * mask + return mask, mask_mask, rand_mask + + def _mask_input(self, tokens: torch.Tensor, mask: torch.Tensor, replacement) -> torch.Tensor: + return tokens * (1 - mask) + replacement * mask