Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ torchtext.transforms
.. automodule:: torchtext.transforms
.. currentmodule:: torchtext.transforms

Transforms are common text transforms. They can be chained together using :class:`torch.nn.Sequential`
Transforms are common text transforms. They can be chained together using :class:`torch.nn.Sequential` or using :class:`torchtext.transforms.Sequential` to support torch-scriptability.

SentencePieceTokenizer
----------------------
Expand Down Expand Up @@ -51,3 +51,10 @@ AddToken
.. autoclass:: AddToken

.. automethod:: forward

Sequential
----------

.. autoclass:: Sequential

.. automethod:: forward
9 changes: 0 additions & 9 deletions test/experimental/test_with_asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
VocabTransform,
PRETRAINED_SP_MODEL,
sentencepiece_processor,
TextSequentialTransforms,
)
from torch.utils.data import DataLoader
from torchtext.experimental.vocab_factory import (
Expand Down Expand Up @@ -213,14 +212,6 @@ def batch_func(data):
for item in dataloader:
self.assertEqual(item, ref_results)

def test_text_sequential_transform(self):
asset_name = 'vocab_test2.txt'
asset_path = get_asset_path(asset_name)
pipeline = TextSequentialTransforms(basic_english_normalize(), load_vocab_from_file(asset_path))
jit_pipeline = torch.jit.script(pipeline)
self.assertEqual(pipeline('of that new'), [7, 18, 24])
self.assertEqual(jit_pipeline('of that new'), [7, 18, 24])

def test_vectors_from_file(self):
asset_name = 'vectors_test.csv'
asset_path = get_asset_path(asset_name)
Expand Down
27 changes: 27 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,33 @@ def test_add_token_jit(self):
self._add_token(test_scripting=True)


class TestSequential(TorchtextTestCase):
def _sequential(self, test_scripting):
max_seq_len = 3
padding_val = 0
transform = transforms.Sequential(
transforms.Truncate(max_seq_len=max_seq_len),
transforms.ToTensor(padding_value=padding_val, dtype=torch.long)
)

if test_scripting:
transform = torch.jit.script(transform)

input = [[1, 2, 3], [1, 2, 3]]

actual = transform(input)
expected = torch.tensor(input)
torch.testing.assert_close(actual, expected)

def test_sequential(self):
"""test pipelining transforms using Sequential transform"""
self._sequential(test_scripting=False)

def test_sequential_jit(self):
"""test pipelining transforms using Sequential transform, ensuring the composite transform is scriptable"""
self._sequential(test_scripting=True)


class TestGPT2BPETokenizer(TorchtextTestCase):
def _load_tokenizer(self, test_scripting):
encoder_json = "gpt2_bpe_encoder.json"
Expand Down
20 changes: 0 additions & 20 deletions torchtext/experimental/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
'regex_tokenizer',
'BasicEnglishNormalize',
'RegexTokenizer',
'TextSequentialTransforms',
'PRETRAINED_SP_MODEL',
'load_sp_model',
'sentencepiece_tokenizer',
Expand Down Expand Up @@ -164,25 +163,6 @@ def __prepare_scriptable__(self):
return RegexTokenizer(regex_tokenizer)


class TextSequentialTransforms(nn.Sequential):
r"""A container to host a sequential text transforms.

Example:
>>> import torch
>>> from torchtext.experimental.transforms import basic_english_normalize, TextSequentialTransforms
>>> tokenizer = basic_english_normalize()
>>> txt_pipeline = TextSequentialTransforms(tokenizer)
>>> txt_pipeline('here is an example')
['here', 'is', 'an', 'example']
>>> jit_txt_pipeline = torch.jit.script(txt_pipeline)
"""

def forward(self, input: str):
for module in self:
input = module(input)
return input


PRETRAINED_SP_MODEL = {
'text_unigram_15000': 'https://pytorch.s3.amazonaws.com/models/text/pretrained_spm/text_unigram_15000.model',
'text_unigram_25000': 'https://pytorch.s3.amazonaws.com/models/text/pretrained_spm/text_unigram_25000.model',
Expand Down
15 changes: 15 additions & 0 deletions torchtext/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
'Truncate',
'AddToken',
'GPT2BPETokenizer',
'Sequential',
]


Expand Down Expand Up @@ -335,3 +336,17 @@ def bytes_to_unicode():
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))


class Sequential(torch.nn.Sequential):
r"""A container to host a sequence of text transforms.
"""

def forward(self, input: Any) -> Any:
"""
:param input: Input sequence or batch. The input type must be supported by the first transform in the sequence.
:type input: `Any`
"""
for module in self:
input = module(input)
return input