From 6daa7bdebe28a002ae45a28dca31935200b2e2aa Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Mon, 3 Jan 2022 19:34:40 -0500 Subject: [PATCH 1/6] add scriptable sequential transform --- docs/source/transforms.rst | 9 ++++++++- test/test_transforms.py | 25 +++++++++++++++++++++++++ torchtext/experimental/transforms.py | 18 ------------------ torchtext/transforms.py | 14 ++++++++++++++ 4 files changed, 47 insertions(+), 19 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 6393dd11d2..41f7d0e92c 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -7,7 +7,7 @@ torchtext.transforms .. automodule:: torchtext.transforms .. currentmodule:: torchtext.transforms -Transforms are common text transforms. They can be chained together using :class:`torch.nn.Sequential` +Transforms are common text transforms. They can be chained together using :class:`torch.nn.Sequential` or using :class:`torchtext.transforms.Sequential` to support torch-scriptability. SentencePieceTokenizer ---------------------- @@ -51,3 +51,10 @@ AddToken .. autoclass:: AddToken .. automethod:: forward + +Sequential +---------- + +.. autoclass:: Sequential + + .. automethod:: forward diff --git a/test/test_transforms.py b/test/test_transforms.py index 3aebb90b47..aad46f1317 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -203,6 +203,31 @@ def test_add_token(self): def test_add_token_jit(self): self._add_token(test_scripting=True) + def _sequential(self, test_scripting): + max_seq_len = 3 + padding_val = 0 + transform = transforms.Sequential( + transforms.Truncate(max_seq_len=max_seq_len), + transforms.ToTensor(padding_value=padding_val, dtype=torch.long) + ) + + if test_scripting: + transform = torch.jit.script(transform) + + input = [[1, 2, 3], [1, 2, 3]] + + actual = transform(input) + expected = torch.tensor(input) + torch.testing.assert_close(actual, expected) + + def test_sequential(self): + """test pipelining transforms using Sequential transform""" + self._sequential(test_scripting=False) + + def test_sequential_jit(self): + """test pipelining transforms using Sequential transform, ensuring the composite transform is scriptable""" + self._sequential(test_scripting=True) + class TestGPT2BPETokenizer(TorchtextTestCase): def _gpt2_bpe_tokenizer(self, test_scripting): diff --git a/torchtext/experimental/transforms.py b/torchtext/experimental/transforms.py index aa37044d22..e29c50bbdd 100644 --- a/torchtext/experimental/transforms.py +++ b/torchtext/experimental/transforms.py @@ -164,24 +164,6 @@ def __prepare_scriptable__(self): return RegexTokenizer(regex_tokenizer) -class TextSequentialTransforms(nn.Sequential): - r"""A container to host a sequential text transforms. - - Example: - >>> import torch - >>> from torchtext.experimental.transforms import basic_english_normalize, TextSequentialTransforms - >>> tokenizer = basic_english_normalize() - >>> txt_pipeline = TextSequentialTransforms(tokenizer) - >>> txt_pipeline('here is an example') - ['here', 'is', 'an', 'example'] - >>> jit_txt_pipeline = torch.jit.script(txt_pipeline) - """ - - def forward(self, input: str): - for module in self: - input = module(input) - return input - PRETRAINED_SP_MODEL = { 'text_unigram_15000': 'https://pytorch.s3.amazonaws.com/models/text/pretrained_spm/text_unigram_15000.model', diff --git a/torchtext/transforms.py b/torchtext/transforms.py index 3d8fe5bc22..81939d428d 100644 --- a/torchtext/transforms.py +++ b/torchtext/transforms.py @@ -485,3 +485,17 @@ def bytes_to_unicode(): n += 1 cs = [chr(n) for n in cs] return dict(zip(bs, cs)) + + +class Sequential(torch.nn.Sequential): + r"""A container to host a sequence of text transforms. + """ + + def forward(self, input: Any) -> Any: + """ + :param input: Input sequence or batch. The input type must be supported by the first transform in the sequence. + :type input: `Any` + """ + for module in self: + input = module(input) + return input From dd0fcca52deec9500bf48d95b0f896a8d961430b Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Mon, 3 Jan 2022 19:40:41 -0500 Subject: [PATCH 2/6] fix flake --- torchtext/experimental/transforms.py | 2 -- torchtext/transforms.py | 1 + 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/torchtext/experimental/transforms.py b/torchtext/experimental/transforms.py index e29c50bbdd..5b0b2e8aed 100644 --- a/torchtext/experimental/transforms.py +++ b/torchtext/experimental/transforms.py @@ -12,7 +12,6 @@ 'regex_tokenizer', 'BasicEnglishNormalize', 'RegexTokenizer', - 'TextSequentialTransforms', 'PRETRAINED_SP_MODEL', 'load_sp_model', 'sentencepiece_tokenizer', @@ -164,7 +163,6 @@ def __prepare_scriptable__(self): return RegexTokenizer(regex_tokenizer) - PRETRAINED_SP_MODEL = { 'text_unigram_15000': 'https://pytorch.s3.amazonaws.com/models/text/pretrained_spm/text_unigram_15000.model', 'text_unigram_25000': 'https://pytorch.s3.amazonaws.com/models/text/pretrained_spm/text_unigram_25000.model', diff --git a/torchtext/transforms.py b/torchtext/transforms.py index 81939d428d..eaeaf4c372 100644 --- a/torchtext/transforms.py +++ b/torchtext/transforms.py @@ -18,6 +18,7 @@ 'Truncate', 'AddToken', 'GPT2BPETokenizer', + 'Sequential', ] From 9037b995ca2ad39b00bf3acb9eb46fd8bf52e26d Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Wed, 5 Jan 2022 12:38:11 -0500 Subject: [PATCH 3/6] fix experimental test --- test/experimental/test_with_asset.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/test/experimental/test_with_asset.py b/test/experimental/test_with_asset.py index bfbf82ba3b..deb40211ae 100644 --- a/test/experimental/test_with_asset.py +++ b/test/experimental/test_with_asset.py @@ -7,7 +7,6 @@ VocabTransform, PRETRAINED_SP_MODEL, sentencepiece_processor, - TextSequentialTransforms, ) from torch.utils.data import DataLoader from torchtext.experimental.vocab_factory import ( @@ -213,14 +212,6 @@ def batch_func(data): for item in dataloader: self.assertEqual(item, ref_results) - def test_text_sequential_transform(self): - asset_name = 'vocab_test2.txt' - asset_path = get_asset_path(asset_name) - pipeline = TextSequentialTransforms(basic_english_normalize(), load_vocab_from_file(asset_path)) - jit_pipeline = torch.jit.script(pipeline) - self.assertEqual(pipeline('of that new'), [7, 18, 24]) - self.assertEqual(jit_pipeline('of that new'), [7, 18, 24]) - def test_vectors_from_file(self): asset_name = 'vectors_test.csv' asset_path = get_asset_path(asset_name) From 7ab547ad575947fe568263915dcf1a903afc9ff7 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Wed, 5 Jan 2022 15:25:32 -0500 Subject: [PATCH 4/6] separate out test class --- test/test_transforms.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_transforms.py b/test/test_transforms.py index 2157296821..56dc4ed2e1 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -204,6 +204,8 @@ def test_add_token(self): def test_add_token_jit(self): self._add_token(test_scripting=True) + +class TestSequential(TorchtextTestCase): def _sequential(self, test_scripting): max_seq_len = 3 padding_val = 0 From ac6701f819d7037591cddbd352e86daa4cf71738 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Wed, 5 Jan 2022 15:25:57 -0500 Subject: [PATCH 5/6] minor fix --- test/test_transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 56dc4ed2e1..81831b1467 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -195,7 +195,7 @@ def _add_token(self, test_scripting): input = ['1', '2'] actual = transform(input) - expected = ['1', '2', '0'] + expected = ['1', '2', '0']˝ self.assertEqual(actual, expected) def test_add_token(self): From c83a40987b5b5b95c7c6ef877b24bbf4928f3b34 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Wed, 5 Jan 2022 17:20:05 -0500 Subject: [PATCH 6/6] fix typo --- test/test_transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 81831b1467..56dc4ed2e1 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -195,7 +195,7 @@ def _add_token(self, test_scripting): input = ['1', '2'] actual = transform(input) - expected = ['1', '2', '0']˝ + expected = ['1', '2', '0'] self.assertEqual(actual, expected) def test_add_token(self):