diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 6393dd11d2..41f7d0e92c 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -7,7 +7,7 @@ torchtext.transforms .. automodule:: torchtext.transforms .. currentmodule:: torchtext.transforms -Transforms are common text transforms. They can be chained together using :class:`torch.nn.Sequential` +Transforms are common text transforms. They can be chained together using :class:`torch.nn.Sequential` or using :class:`torchtext.transforms.Sequential` to support torch-scriptability. SentencePieceTokenizer ---------------------- @@ -51,3 +51,10 @@ AddToken .. autoclass:: AddToken .. automethod:: forward + +Sequential +---------- + +.. autoclass:: Sequential + + .. automethod:: forward diff --git a/test/experimental/test_with_asset.py b/test/experimental/test_with_asset.py index bfbf82ba3b..deb40211ae 100644 --- a/test/experimental/test_with_asset.py +++ b/test/experimental/test_with_asset.py @@ -7,7 +7,6 @@ VocabTransform, PRETRAINED_SP_MODEL, sentencepiece_processor, - TextSequentialTransforms, ) from torch.utils.data import DataLoader from torchtext.experimental.vocab_factory import ( @@ -213,14 +212,6 @@ def batch_func(data): for item in dataloader: self.assertEqual(item, ref_results) - def test_text_sequential_transform(self): - asset_name = 'vocab_test2.txt' - asset_path = get_asset_path(asset_name) - pipeline = TextSequentialTransforms(basic_english_normalize(), load_vocab_from_file(asset_path)) - jit_pipeline = torch.jit.script(pipeline) - self.assertEqual(pipeline('of that new'), [7, 18, 24]) - self.assertEqual(jit_pipeline('of that new'), [7, 18, 24]) - def test_vectors_from_file(self): asset_name = 'vectors_test.csv' asset_path = get_asset_path(asset_name) diff --git a/test/test_transforms.py b/test/test_transforms.py index 8771d93436..56dc4ed2e1 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -205,6 +205,33 @@ def test_add_token_jit(self): self._add_token(test_scripting=True) +class TestSequential(TorchtextTestCase): + def _sequential(self, test_scripting): + max_seq_len = 3 + padding_val = 0 + transform = transforms.Sequential( + transforms.Truncate(max_seq_len=max_seq_len), + transforms.ToTensor(padding_value=padding_val, dtype=torch.long) + ) + + if test_scripting: + transform = torch.jit.script(transform) + + input = [[1, 2, 3], [1, 2, 3]] + + actual = transform(input) + expected = torch.tensor(input) + torch.testing.assert_close(actual, expected) + + def test_sequential(self): + """test pipelining transforms using Sequential transform""" + self._sequential(test_scripting=False) + + def test_sequential_jit(self): + """test pipelining transforms using Sequential transform, ensuring the composite transform is scriptable""" + self._sequential(test_scripting=True) + + class TestGPT2BPETokenizer(TorchtextTestCase): def _load_tokenizer(self, test_scripting): encoder_json = "gpt2_bpe_encoder.json" diff --git a/torchtext/experimental/transforms.py b/torchtext/experimental/transforms.py index aa37044d22..5b0b2e8aed 100644 --- a/torchtext/experimental/transforms.py +++ b/torchtext/experimental/transforms.py @@ -12,7 +12,6 @@ 'regex_tokenizer', 'BasicEnglishNormalize', 'RegexTokenizer', - 'TextSequentialTransforms', 'PRETRAINED_SP_MODEL', 'load_sp_model', 'sentencepiece_tokenizer', @@ -164,25 +163,6 @@ def __prepare_scriptable__(self): return RegexTokenizer(regex_tokenizer) -class TextSequentialTransforms(nn.Sequential): - r"""A container to host a sequential text transforms. - - Example: - >>> import torch - >>> from torchtext.experimental.transforms import basic_english_normalize, TextSequentialTransforms - >>> tokenizer = basic_english_normalize() - >>> txt_pipeline = TextSequentialTransforms(tokenizer) - >>> txt_pipeline('here is an example') - ['here', 'is', 'an', 'example'] - >>> jit_txt_pipeline = torch.jit.script(txt_pipeline) - """ - - def forward(self, input: str): - for module in self: - input = module(input) - return input - - PRETRAINED_SP_MODEL = { 'text_unigram_15000': 'https://pytorch.s3.amazonaws.com/models/text/pretrained_spm/text_unigram_15000.model', 'text_unigram_25000': 'https://pytorch.s3.amazonaws.com/models/text/pretrained_spm/text_unigram_25000.model', diff --git a/torchtext/transforms.py b/torchtext/transforms.py index 4c20be7f6d..99cff3fe12 100644 --- a/torchtext/transforms.py +++ b/torchtext/transforms.py @@ -20,6 +20,7 @@ 'Truncate', 'AddToken', 'GPT2BPETokenizer', + 'Sequential', ] @@ -335,3 +336,17 @@ def bytes_to_unicode(): n += 1 cs = [chr(n) for n in cs] return dict(zip(bs, cs)) + + +class Sequential(torch.nn.Sequential): + r"""A container to host a sequence of text transforms. + """ + + def forward(self, input: Any) -> Any: + """ + :param input: Input sequence or batch. The input type must be supported by the first transform in the sequence. + :type input: `Any` + """ + for module in self: + input = module(input) + return input