From ed7bb34774d0bc3262d5feb3b8620b399547506c Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Wed, 1 Dec 2021 05:48:58 -0500 Subject: [PATCH 1/7] add truncate transform --- test/test_transforms.py | 29 +++++++++++++++++++++++++++++ torchtext/transforms.py | 9 +++++++++ 2 files changed, 38 insertions(+) diff --git a/test/test_transforms.py b/test/test_transforms.py index d7df6d3876..ff176c073d 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1,5 +1,6 @@ import torch from torchtext import transforms +from torchtext.functional import truncate from torchtext.vocab import vocab from collections import OrderedDict @@ -121,3 +122,31 @@ def test_labeltoindex_jit(self): actual = transform_jit("test") expected = 0 self.assertEqual(actual, expected) + + def test_truncate(self): + input = [[1, 2], [1, 2, 3]] + max_seq_len = 2 + + transform = transforms.Truncate(max_seq_len) + actual = transform(input) + expected = [[1, 2], [1, 2]] + self.assertEqual(actual, expected) + + input = [1, 2, 3] + actual = transform(input) + expected = [1, 2] + self.assertEqual(actual, expected) + + def test_truncate_jit(self): + input = [[1, 2], [1, 2, 3]] + max_seq_len = 2 + transform = transforms.Truncate(max_seq_len) + transform_jit = torch.jit.script(transform) + actual = transform_jit(input) + expected = [[1, 2], [1, 2]] + self.assertEqual(actual, expected) + + input = [1, 2, 3] + actual = transform_jit(input) + expected = [1, 2] + self.assertEqual(actual, expected) diff --git a/torchtext/transforms.py b/torchtext/transforms.py index cf43e40f4d..b431ddae46 100644 --- a/torchtext/transforms.py +++ b/torchtext/transforms.py @@ -142,3 +142,12 @@ def forward(self, labels: Union[str, List[str]]) -> Union[int, List[int]]: @property def label_names(self) -> List[str]: return self._label_names + + +class Truncate(Module): + def __init__(self, max_seq_len) -> None: + super().__init__() + self.max_seq_len = max_seq_len + + def forward(self, input: Union[List[int], List[List[int]]]) -> Union[List[int], List[List[int]]]: + return F.truncate(input, self.max_seq_len) From 8b9bccbb2432de13c9d564a0a84c13caa912d184 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Sat, 4 Dec 2021 18:26:49 -0500 Subject: [PATCH 2/7] update transform --- torchtext/functional.py | 15 +++++++++++---- torchtext/models/roberta/transforms.py | 6 ++++-- torchtext/transforms.py | 11 +++++++++-- 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/torchtext/functional.py b/torchtext/functional.py index fe81920ed5..91bc35afbc 100644 --- a/torchtext/functional.py +++ b/torchtext/functional.py @@ -26,15 +26,22 @@ def to_tensor(input: Union[List[int], List[List[int]]], padding_value: Optional[ return output -def truncate(input: Union[List[int], List[List[int]]], max_seq_len: int) -> Union[List[int], List[List[int]]]: +def truncate(input: Union[List[int], List[str], List[List[int]], List[List[str]]], max_seq_len: int) -> Union[List[int], List[str], List[List[int]], List[List[str]]]: if torch.jit.isinstance(input, List[int]): return input[:max_seq_len] - else: + elif torch.jit.isinstance(input, List[str]): + return input[:max_seq_len] + elif torch.jit.isinstance(input, List[str]): + return input[:max_seq_len] + elif torch.jit.isinstance(input, List[List[int]]): output: List[List[int]] = [] - for ids in input: output.append(ids[:max_seq_len]) - + return output + else: + output: List[List[str]] = [] + for ids in input: + output.append(ids[:max_seq_len]) return output diff --git a/torchtext/models/roberta/transforms.py b/torchtext/models/roberta/transforms.py index 5cb2279649..d8f5ec8fa5 100644 --- a/torchtext/models/roberta/transforms.py +++ b/torchtext/models/roberta/transforms.py @@ -32,7 +32,7 @@ def __init__( self.sep_token = sep_token self.max_seq_len = max_seq_len - self.token_transform = transforms.SentencePieceTokenizer(spm_model_path) + self.tokenizer = transforms.SentencePieceTokenizer(spm_model_path) if os.path.exists(vocab_path): self.vocab = torch.load(vocab_path) @@ -49,11 +49,13 @@ def forward(self, input: Union[str, List[str]], add_eos: bool = True, truncate: bool = True) -> Union[List[int], List[List[int]]]: - tokens = self.vocab_transform(self.token_transform(input)) + tokens = self.tokenizer(input) if truncate: tokens = functional.truncate(tokens, self.max_seq_len - 2) + tokens = self.vocab_transform((tokens)) + if add_bos: tokens = functional.add_token(tokens, self.bos_idx) diff --git a/torchtext/transforms.py b/torchtext/transforms.py index b431ddae46..e5348a1a53 100644 --- a/torchtext/transforms.py +++ b/torchtext/transforms.py @@ -15,6 +15,7 @@ 'VocabTransform', 'ToTensor', 'LabelToIndex', + 'Truncate', ] @@ -145,9 +146,15 @@ def label_names(self) -> List[str]: class Truncate(Module): - def __init__(self, max_seq_len) -> None: + r"""Truncate input sequence + + Args: + max_seq_len (int): The maximum allowable length for input sequence + """ + + def __init__(self, max_seq_len: int) -> None: super().__init__() self.max_seq_len = max_seq_len - def forward(self, input: Union[List[int], List[List[int]]]) -> Union[List[int], List[List[int]]]: + def forward(self, input: Union[List[int], List[str], List[List[int]], List[List[str]]]) -> Union[List[int], List[str], List[List[int]], List[List[str]]]: return F.truncate(input, self.max_seq_len) From 92904a58b76438c808fa44e90380b1c565d91b20 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Sat, 4 Dec 2021 18:43:30 -0500 Subject: [PATCH 3/7] update doc --- docs/source/transforms.rst | 7 +++++++ test/test_transforms.py | 1 - torchtext/transforms.py | 4 ++++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 220f18bf34..f1144c68fa 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -37,3 +37,10 @@ LabelToIndex .. autoclass:: LabelToIndex .. automethod:: forward + +Truncate +------------ + +.. autoclass:: Truncate + + .. automethod:: forward diff --git a/test/test_transforms.py b/test/test_transforms.py index ff176c073d..ac48146ddc 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1,6 +1,5 @@ import torch from torchtext import transforms -from torchtext.functional import truncate from torchtext.vocab import vocab from collections import OrderedDict diff --git a/torchtext/transforms.py b/torchtext/transforms.py index e5348a1a53..a2c8773e86 100644 --- a/torchtext/transforms.py +++ b/torchtext/transforms.py @@ -157,4 +157,8 @@ def __init__(self, max_seq_len: int) -> None: self.max_seq_len = max_seq_len def forward(self, input: Union[List[int], List[str], List[List[int]], List[List[str]]]) -> Union[List[int], List[str], List[List[int]], List[List[str]]]: + """ + Args: + input: Input sequence to truncate. The input can either be a ``List`` or ``List[List]`` for batched operation + """ return F.truncate(input, self.max_seq_len) From 7e1725d097487c3f19ce3b6a2770999c58d8b88b Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Mon, 6 Dec 2021 19:28:27 -0500 Subject: [PATCH 4/7] change annotation types --- torchtext/functional.py | 22 +++++++++++--------- torchtext/models/roberta/transforms.py | 6 +++--- torchtext/transforms.py | 28 ++++++++++++++++---------- 3 files changed, 33 insertions(+), 23 deletions(-) diff --git a/torchtext/functional.py b/torchtext/functional.py index 91bc35afbc..89a9310424 100644 --- a/torchtext/functional.py +++ b/torchtext/functional.py @@ -1,7 +1,7 @@ import torch from torch import Tensor from torch.nn.utils.rnn import pad_sequence -from typing import List, Optional, Union +from typing import List, Optional, Any __all__ = [ 'to_tensor', @@ -10,10 +10,10 @@ ] -def to_tensor(input: Union[List[int], List[List[int]]], padding_value: Optional[int] = None, dtype: Optional[torch.dtype] = torch.long) -> Tensor: +def to_tensor(input: Any, padding_value: Optional[int] = None, dtype: Optional[torch.dtype] = torch.long) -> Tensor: if torch.jit.isinstance(input, List[int]): return torch.tensor(input, dtype=torch.long) - else: + elif torch.jit.isinstance(input, List[List[int]]): if padding_value is None: output = torch.tensor(input, dtype=dtype) return output @@ -24,34 +24,36 @@ def to_tensor(input: Union[List[int], List[List[int]]], padding_value: Optional[ padding_value=float(padding_value) ) return output + else: + raise TypeError("Input type not supported") -def truncate(input: Union[List[int], List[str], List[List[int]], List[List[str]]], max_seq_len: int) -> Union[List[int], List[str], List[List[int]], List[List[str]]]: +def truncate(input: Any, max_seq_len: int) -> Any: if torch.jit.isinstance(input, List[int]): return input[:max_seq_len] elif torch.jit.isinstance(input, List[str]): return input[:max_seq_len] - elif torch.jit.isinstance(input, List[str]): - return input[:max_seq_len] elif torch.jit.isinstance(input, List[List[int]]): output: List[List[int]] = [] for ids in input: output.append(ids[:max_seq_len]) return output - else: + elif torch.jit.isinstance(input, List[List[str]]): output: List[List[str]] = [] for ids in input: output.append(ids[:max_seq_len]) return output + else: + raise TypeError("Input type not supported") -def add_token(input: Union[List[int], List[List[int]]], token_id: int, begin: bool = True) -> Union[List[int], List[List[int]]]: +def add_token(input: Any, token_id: int, begin: bool = True) -> Any: if torch.jit.isinstance(input, List[int]): if begin: return [token_id] + input else: return input + [token_id] - else: + elif torch.jit.isinstance(input, List[List[int]]): output: List[List[int]] = [] if begin: @@ -62,3 +64,5 @@ def add_token(input: Union[List[int], List[List[int]]], token_id: int, begin: bo output.append(ids + [token_id]) return output + else: + raise TypeError("Input type not supported") diff --git a/torchtext/models/roberta/transforms.py b/torchtext/models/roberta/transforms.py index d8f5ec8fa5..bfb7463d45 100644 --- a/torchtext/models/roberta/transforms.py +++ b/torchtext/models/roberta/transforms.py @@ -5,7 +5,7 @@ from torchtext import transforms from torchtext import functional -from typing import List, Union +from typing import Any class XLMRobertaModelTransform(Module): @@ -44,10 +44,10 @@ def __init__( self.bos_idx = self.vocab[self.bos_token] self.eos_idx = self.vocab[self.eos_token] - def forward(self, input: Union[str, List[str]], + def forward(self, input: Any, add_bos: bool = True, add_eos: bool = True, - truncate: bool = True) -> Union[List[int], List[List[int]]]: + truncate: bool = True) -> Any: tokens = self.tokenizer(input) diff --git a/torchtext/transforms.py b/torchtext/transforms.py index a2c8773e86..f94bf1501b 100644 --- a/torchtext/transforms.py +++ b/torchtext/transforms.py @@ -5,7 +5,7 @@ from torchtext.data.functional import load_sp_model from torchtext.utils import download_from_url from torchtext.vocab import Vocab -from typing import List, Optional, Union +from typing import List, Optional, Any import os from torchtext import _CACHE_DIR @@ -37,14 +37,16 @@ def __init__(self, sp_model_path: str): local_path = download_from_url(url=sp_model_path, root=_CACHE_DIR) self.sp_model = load_sp_model(local_path) - def forward(self, input: Union[str, List[str]]) -> Union[List[str], List[List[str]]]: + def forward(self, input: Any) -> Any: if torch.jit.isinstance(input, List[str]): tokens: List[List[str]] = [] for text in input: tokens.append(self.sp_model.EncodeAsPieces(text)) return tokens - else: + elif torch.jit.isinstance(input, str): return self.sp_model.EncodeAsPieces(input) + else: + raise TypeError("Input type not supported") class VocabTransform(Module): @@ -69,7 +71,7 @@ def __init__(self, vocab: Vocab): assert isinstance(vocab, Vocab) self.vocab = vocab - def forward(self, input: Union[List[str], List[List[str]]]) -> Union[List[int], List[List[int]]]: + def forward(self, input: Any) -> Any: r""" Args: @@ -78,12 +80,14 @@ def forward(self, input: Union[List[str], List[List[str]]]) -> Union[List[int], if torch.jit.isinstance(input, List[str]): return self.vocab.lookup_indices(input) - else: + elif torch.jit.isinstance(input, List[List[str]]): output: List[List[int]] = [] for tokens in input: output.append(self.vocab.lookup_indices(tokens)) return output + else: + raise TypeError("Input type not supported") class ToTensor(Module): @@ -98,7 +102,7 @@ def __init__(self, padding_value: Optional[int] = None, dtype: Optional[torch.dt self.padding_value = padding_value self.dtype = dtype - def forward(self, input: Union[List[int], List[List[int]]]) -> Tensor: + def forward(self, input: Any) -> Tensor: r""" Args: @@ -134,11 +138,13 @@ def __init__( self._label_vocab = Vocab(torch.classes.torchtext.Vocab(label_names, 0)) self._label_names = self._label_vocab.get_itos() - def forward(self, labels: Union[str, List[str]]) -> Union[int, List[int]]: - if torch.jit.isinstance(labels, List[str]): - return self._label_vocab.lookup_indices(labels) + def forward(self, input: Any) -> Any: + if torch.jit.isinstance(input, List[str]): + return self._label_vocab.lookup_indices(input) + elif torch.jit.isinstance(input, str): + return self._label_vocab.__getitem__(input) else: - return self._label_vocab.__getitem__(labels) + raise TypeError("Input type not supported") @property def label_names(self) -> List[str]: @@ -156,7 +162,7 @@ def __init__(self, max_seq_len: int) -> None: super().__init__() self.max_seq_len = max_seq_len - def forward(self, input: Union[List[int], List[str], List[List[int]], List[List[str]]]) -> Union[List[int], List[str], List[List[int]], List[List[str]]]: + def forward(self, input: Any) -> Any: """ Args: input: Input sequence to truncate. The input can either be a ``List`` or ``List[List]`` for batched operation From 97cafc5a6e79ff2abaec758164725a933fd39bcd Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Wed, 8 Dec 2021 00:52:11 -0500 Subject: [PATCH 5/7] update tests and docs --- docs/source/transforms.rst | 7 -- test/test_functional.py | 78 +++++++---------- test/test_transforms.py | 111 ++++++++----------------- torchtext/functional.py | 51 +++++++++++- torchtext/models/roberta/transforms.py | 35 +++++--- torchtext/transforms.py | 76 ++++++++--------- 6 files changed, 174 insertions(+), 184 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index f1144c68fa..220f18bf34 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -37,10 +37,3 @@ LabelToIndex .. autoclass:: LabelToIndex .. automethod:: forward - -Truncate ------------- - -.. autoclass:: Truncate - - .. automethod:: forward diff --git a/test/test_functional.py b/test/test_functional.py index f8dde30e06..497d757664 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -4,87 +4,71 @@ class TestFunctional(TorchtextTestCase): - def test_to_tensor(self): + def _to_tensor(self, test_scripting): input = [[1, 2], [1, 2, 3]] padding_value = 0 - actual = functional.to_tensor(input, padding_value=padding_value) + func = functional.to_tensor + if test_scripting: + func = torch.jit.script(func) + actual = func(input, padding_value=padding_value) expected = torch.tensor([[1, 2, 0], [1, 2, 3]], dtype=torch.long) torch.testing.assert_close(actual, expected) input = [1, 2] - actual = functional.to_tensor(input, padding_value=padding_value) + actual = func(input, padding_value=padding_value) expected = torch.tensor([1, 2], dtype=torch.long) torch.testing.assert_close(actual, expected) - def test_to_tensor_jit(self): - input = [[1, 2], [1, 2, 3]] - padding_value = 0 - to_tensor_jit = torch.jit.script(functional.to_tensor) - actual = to_tensor_jit(input, padding_value=padding_value) - expected = torch.tensor([[1, 2, 0], [1, 2, 3]], dtype=torch.long) - torch.testing.assert_close(actual, expected) + def test_to_tensor(self): + self._to_tensor(test_scripting=False) - input = [1, 2] - actual = to_tensor_jit(input, padding_value=padding_value) - expected = torch.tensor([1, 2], dtype=torch.long) - torch.testing.assert_close(actual, expected) + def test_to_tensor_jit(self): + self._to_tensor(test_scripting=True) - def test_truncate(self): + def _truncate(self, test_scripting): input = [[1, 2], [1, 2, 3]] max_seq_len = 2 - actual = functional.truncate(input, max_seq_len=max_seq_len) + func = functional.truncate + if test_scripting: + func = torch.jit.script(func) + actual = func(input, max_seq_len=max_seq_len) expected = [[1, 2], [1, 2]] self.assertEqual(actual, expected) input = [1, 2, 3] - actual = functional.truncate(input, max_seq_len=max_seq_len) + actual = func(input, max_seq_len=max_seq_len) expected = [1, 2] self.assertEqual(actual, expected) - def test_truncate_jit(self): - input = [[1, 2], [1, 2, 3]] - max_seq_len = 2 - truncate_jit = torch.jit.script(functional.truncate) - actual = truncate_jit(input, max_seq_len=max_seq_len) - expected = [[1, 2], [1, 2]] - self.assertEqual(actual, expected) + def test_truncate(self): + self._truncate(test_scripting=False) - input = [1, 2, 3] - actual = truncate_jit(input, max_seq_len=max_seq_len) - expected = [1, 2] - self.assertEqual(actual, expected) + def test_truncate_jit(self): + self._truncate(test_scripting=True) - def test_add_token(self): + def _add_token(self, test_scripting): + func = functional.add_token + if test_scripting: + func = torch.jit.script(func) input = [[1, 2], [1, 2, 3]] token_id = 0 - actual = functional.add_token(input, token_id=token_id) + actual = func(input, token_id=token_id) expected = [[0, 1, 2], [0, 1, 2, 3]] self.assertEqual(actual, expected) - actual = functional.add_token(input, token_id=token_id, begin=False) + actual = func(input, token_id=token_id, begin=False) expected = [[1, 2, 0], [1, 2, 3, 0]] self.assertEqual(actual, expected) input = [1, 2] - actual = functional.add_token(input, token_id=token_id, begin=False) + actual = func(input, token_id=token_id, begin=False) expected = [1, 2, 0] self.assertEqual(actual, expected) - def test_add_token_jit(self): - input = [[1, 2], [1, 2, 3]] - token_id = 0 - add_token_jit = torch.jit.script(functional.add_token) - actual = add_token_jit(input, token_id=token_id) - expected = [[0, 1, 2], [0, 1, 2, 3]] - self.assertEqual(actual, expected) - - actual = add_token_jit(input, token_id=token_id, begin=False) - expected = [[1, 2, 0], [1, 2, 3, 0]] - self.assertEqual(actual, expected) + def test_add_token(self): + self._add_token(test_scripting=False) - input = [1, 2] - actual = add_token_jit(input, token_id=token_id, begin=False) - expected = [1, 2, 0] - self.assertEqual(actual, expected) + def test_add_token_jit(self): + self._add_token(test_scripting=True) diff --git a/test/test_transforms.py b/test/test_transforms.py index ac48146ddc..41badad97c 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -8,11 +8,12 @@ class TestTransforms(TorchtextTestCase): - def test_spmtokenizer(self): + def _spmtokenizer(self, test_scripting): asset_name = "spm_example.model" asset_path = get_asset_path(asset_name) transform = transforms.SentencePieceTokenizer(asset_path) - + if test_scripting: + transform = torch.jit.script(transform) actual = transform(["Hello World!, how are you?"]) expected = [['▁Hello', '▁World', '!', ',', '▁how', '▁are', '▁you', '?']] self.assertEqual(actual, expected) @@ -21,24 +22,18 @@ def test_spmtokenizer(self): expected = ['▁Hello', '▁World', '!', ',', '▁how', '▁are', '▁you', '?'] self.assertEqual(actual, expected) - def test_spmtokenizer_jit(self): - asset_name = "spm_example.model" - asset_path = get_asset_path(asset_name) - transform = transforms.SentencePieceTokenizer(asset_path) - transform_jit = torch.jit.script(transform) + def test_spmtokenizer(self): - actual = transform_jit(["Hello World!, how are you?"]) - expected = [['▁Hello', '▁World', '!', ',', '▁how', '▁are', '▁you', '?']] - self.assertEqual(actual, expected) + self._spmtokenizer(test_scripting=False) - actual = transform_jit("Hello World!, how are you?") - expected = ['▁Hello', '▁World', '!', ',', '▁how', '▁are', '▁you', '?'] - self.assertEqual(actual, expected) + def test_spmtokenizer_jit(self): + self._spmtokenizer(test_scripting=True) - def test_vocab_transform(self): + def _vocab_transform(self, test_scripting): vocab_obj = vocab(OrderedDict([('a', 1), ('b', 1), ('c', 1)])) transform = transforms.VocabTransform(vocab_obj) - + if test_scripting: + transform = torch.jit.script(transform) actual = transform([['a', 'b', 'c']]) expected = [[0, 1, 2]] self.assertEqual(actual, expected) @@ -47,22 +42,18 @@ def test_vocab_transform(self): expected = [0, 1, 2] self.assertEqual(actual, expected) - def test_vocab_transform_jit(self): - vocab_obj = vocab(OrderedDict([('a', 1), ('b', 1), ('c', 1)])) - transform_jit = torch.jit.script(transforms.VocabTransform(vocab_obj)) - - actual = transform_jit([['a', 'b', 'c']]) - expected = [[0, 1, 2]] - self.assertEqual(actual, expected) + def test_vocab_transform(self): + self._vocab_transform(test_scripting=False) - actual = transform_jit(['a', 'b', 'c']) - expected = [0, 1, 2] - self.assertEqual(actual, expected) + def test_vocab_transform_jit(self): + self._vocab_transform(test_scripting=True) - def test_totensor(self): - input = [[1, 2], [1, 2, 3]] + def _totensor(self, test_scripting): padding_value = 0 transform = transforms.ToTensor(padding_value=padding_value) + if test_scripting: + transform = torch.jit.script(transform) + input = [[1, 2], [1, 2, 3]] actual = transform(input) expected = torch.tensor([[1, 2, 0], [1, 2, 3]], dtype=torch.long) @@ -73,29 +64,24 @@ def test_totensor(self): expected = torch.tensor([1, 2], dtype=torch.long) torch.testing.assert_close(actual, expected) - def test_totensor_jit(self): - input = [[1, 2], [1, 2, 3]] - padding_value = 0 - transform = transforms.ToTensor(padding_value=padding_value) - transform_jit = torch.jit.script(transform) - - actual = transform_jit(input) - expected = torch.tensor([[1, 2, 0], [1, 2, 3]], dtype=torch.long) - torch.testing.assert_close(actual, expected) + def test_totensor(self): + self._totensor(test_scripting=False) - input = [1, 2] - actual = transform_jit(input) - expected = torch.tensor([1, 2], dtype=torch.long) - torch.testing.assert_close(actual, expected) + def test_totensor_jit(self): + self._totensor(test_scripting=True) - def test_labeltoindex(self): + def _labeltoindex(self, test_scripting): label_names = ['test', 'label', 'indices'] transform = transforms.LabelToIndex(label_names=label_names) + if test_scripting: + transform = torch.jit.script(transform) actual = transform(label_names) expected = [0, 1, 2] self.assertEqual(actual, expected) transform = transforms.LabelToIndex(label_names=label_names, sort_names=True) + if test_scripting: + transform = torch.jit.script(transform) actual = transform(label_names) expected = [2, 1, 0] self.assertEqual(actual, expected) @@ -107,45 +93,14 @@ def test_labeltoindex(self): asset_name = "label_names.txt" asset_path = get_asset_path(asset_name) transform = transforms.LabelToIndex(label_path=asset_path) + if test_scripting: + transform = torch.jit.script(transform) actual = transform(label_names) expected = [0, 1, 2] self.assertEqual(actual, expected) - def test_labeltoindex_jit(self): - label_names = ['test', 'label', 'indices'] - transform_jit = torch.jit.script(transforms.LabelToIndex(label_names=label_names)) - actual = transform_jit(label_names) - expected = [0, 1, 2] - self.assertEqual(actual, expected) - - actual = transform_jit("test") - expected = 0 - self.assertEqual(actual, expected) - - def test_truncate(self): - input = [[1, 2], [1, 2, 3]] - max_seq_len = 2 - - transform = transforms.Truncate(max_seq_len) - actual = transform(input) - expected = [[1, 2], [1, 2]] - self.assertEqual(actual, expected) - - input = [1, 2, 3] - actual = transform(input) - expected = [1, 2] - self.assertEqual(actual, expected) - - def test_truncate_jit(self): - input = [[1, 2], [1, 2, 3]] - max_seq_len = 2 - transform = transforms.Truncate(max_seq_len) - transform_jit = torch.jit.script(transform) - actual = transform_jit(input) - expected = [[1, 2], [1, 2]] - self.assertEqual(actual, expected) + def test_labeltoindex(self): + self._labeltoindex(test_scripting=False) - input = [1, 2, 3] - actual = transform_jit(input) - expected = [1, 2] - self.assertEqual(actual, expected) + def test_labeltoindex_jit(self): + self._labeltoindex(test_scripting=True) diff --git a/torchtext/functional.py b/torchtext/functional.py index 89a9310424..8f4f41d8da 100644 --- a/torchtext/functional.py +++ b/torchtext/functional.py @@ -11,6 +11,16 @@ def to_tensor(input: Any, padding_value: Optional[int] = None, dtype: Optional[torch.dtype] = torch.long) -> Tensor: + r"""Convert input to torch tensor + + :param padding_value: Pad value to make each input in the batch of length equal to the longest sequence in the batch. + :type padding_value: Optional[int] + :param dtype: :class:`torch.dtype` of output tensor + :type dtype: :class:`torch.dtype` + :param input: Sequence or batch of token ids + :type input: Union[List[int], List[List[int]]] + :rtype: Tensor + """ if torch.jit.isinstance(input, List[int]): return torch.tensor(input, dtype=torch.long) elif torch.jit.isinstance(input, List[List[int]]): @@ -29,6 +39,15 @@ def to_tensor(input: Any, padding_value: Optional[int] = None, dtype: Optional[t def truncate(input: Any, max_seq_len: int) -> Any: + """ Truncate input sequence or batch + + :param input: Input sequence or batch to be truncated + :type input: Union[List[Union[str, int]], List[List[Union[str, int]]]] + :param max_seq_len: Maximum length beyond which input is discarded + :type max_seq_len: int + :return: Truncated sequence + :rtype: Union[List[Union[str, int]], List[List[Union[str, int]]]] + """ if torch.jit.isinstance(input, List[int]): return input[:max_seq_len] elif torch.jit.isinstance(input, List[str]): @@ -47,15 +66,41 @@ def truncate(input: Any, max_seq_len: int) -> Any: raise TypeError("Input type not supported") -def add_token(input: Any, token_id: int, begin: bool = True) -> Any: - if torch.jit.isinstance(input, List[int]): +def add_token(input: Any, token_id: Any, begin: bool = True) -> Any: + """Add token to start or end of sequence + + :param input: Input sequence or batch + :type input: Union[List[Union[str, int]], List[List[Union[str, int]]]] + :param token_id: token to be added + :type token_id: Union[str, int] + :param begin: Whether to insert token at start or end or sequence, defaults to True + :type begin: bool, optional + :return: sequence or batch with token_id added to begin or end or input + :rtype: Union[List[Union[str, int]], List[List[Union[str, int]]]] + """ + if torch.jit.isinstance(input, List[int]) and torch.jit.isinstance(token_id, int): if begin: return [token_id] + input else: return input + [token_id] - elif torch.jit.isinstance(input, List[List[int]]): + elif torch.jit.isinstance(input, List[str]) and torch.jit.isinstance(token_id, str): + if begin: + return [token_id] + input + else: + return input + [token_id] + elif torch.jit.isinstance(input, List[List[int]]) and torch.jit.isinstance(token_id, int): output: List[List[int]] = [] + if begin: + for ids in input: + output.append([token_id] + ids) + else: + for ids in input: + output.append(ids + [token_id]) + + return output + elif torch.jit.isinstance(input, List[List[str]]) and torch.jit.isinstance(token_id, str): + output: List[List[str]] = [] if begin: for ids in input: output.append([token_id] + ids) diff --git a/torchtext/models/roberta/transforms.py b/torchtext/models/roberta/transforms.py index bfb7463d45..683b6406be 100644 --- a/torchtext/models/roberta/transforms.py +++ b/torchtext/models/roberta/transforms.py @@ -5,7 +5,7 @@ from torchtext import transforms from torchtext import functional -from typing import Any +from typing import List, Any class XLMRobertaModelTransform(Module): @@ -32,7 +32,7 @@ def __init__( self.sep_token = sep_token self.max_seq_len = max_seq_len - self.tokenizer = transforms.SentencePieceTokenizer(spm_model_path) + self.token_transform = transforms.SentencePieceTokenizer(spm_model_path) if os.path.exists(vocab_path): self.vocab = torch.load(vocab_path) @@ -48,21 +48,34 @@ def forward(self, input: Any, add_bos: bool = True, add_eos: bool = True, truncate: bool = True) -> Any: + if torch.jit.isinstance(input, str): + tokens = self.vocab_transform(self.token_transform(input)) - tokens = self.tokenizer(input) + if truncate: + tokens = functional.truncate(tokens, self.max_seq_len - 2) - if truncate: - tokens = functional.truncate(tokens, self.max_seq_len - 2) + if add_bos: + tokens = functional.add_token(tokens, self.bos_idx) - tokens = self.vocab_transform((tokens)) + if add_eos: + tokens = functional.add_token(tokens, self.eos_idx, begin=False) - if add_bos: - tokens = functional.add_token(tokens, self.bos_idx) + return tokens + elif torch.jit.isinstance(input, List[str]): + tokens = self.vocab_transform(self.token_transform(input)) - if add_eos: - tokens = functional.add_token(tokens, self.eos_idx, begin=False) + if truncate: + tokens = functional.truncate(tokens, self.max_seq_len - 2) - return tokens + if add_bos: + tokens = functional.add_token(tokens, self.bos_idx) + + if add_eos: + tokens = functional.add_token(tokens, self.eos_idx, begin=False) + + return tokens + else: + raise TypeError("Input type not supported") def get_xlmr_transform(vocab_path, spm_model_path, **kwargs) -> XLMRobertaModelTransform: diff --git a/torchtext/transforms.py b/torchtext/transforms.py index f94bf1501b..f5133d7391 100644 --- a/torchtext/transforms.py +++ b/torchtext/transforms.py @@ -15,15 +15,19 @@ 'VocabTransform', 'ToTensor', 'LabelToIndex', - 'Truncate', ] class SentencePieceTokenizer(Module): """ - Transform for Sentence Piece tokenizer from pre-trained SentencePiece model + Transform for Sentence Piece tokenizer from pre-trained sentencepiece model - Examples: + Additiona details: https://github.com/google/sentencepiece + + :param sp_model_path: Path to pre-trained sentencepiece model + :type sp_model_path: str + + Example >>> from torchtext.transforms import SpmTokenizerTransform >>> transform = SentencePieceTokenizer("spm_model") >>> transform(["hello world", "attention is all you need!"]) @@ -38,6 +42,12 @@ def __init__(self, sp_model_path: str): self.sp_model = load_sp_model(local_path) def forward(self, input: Any) -> Any: + """ + :param input: Input sentence or list of sentences on which to apply tokenizer. + :type input: Union[str, List[str]] + :return: tokenized text + :rtype: Union[List[str], List[List(str)]] + """ if torch.jit.isinstance(input, List[str]): tokens: List[List[str]] = [] for text in input: @@ -50,10 +60,9 @@ def forward(self, input: Any) -> Any: class VocabTransform(Module): - r"""Vocab transform + r"""Vocab transform to convert input batch of tokens into corresponding token ids - Args: - vocab: an instance of torchtext.vocab.Vocab class. + :param vocab: an instance of :class:`torchtext.vocab.Vocab` class. Example: >>> import torch @@ -72,10 +81,11 @@ def __init__(self, vocab: Vocab): self.vocab = vocab def forward(self, input: Any) -> Any: - r""" - - Args: - input: list of list tokens + """ + :param input: Input batch of token to convert to correspnding token ids + :type input: Union[List[str], List[List[str]]] + :return: Converted input into corresponding token ids + :rtype: Union[List[int], List[List[int]]] """ if torch.jit.isinstance(input, List[str]): @@ -93,8 +103,10 @@ def forward(self, input: Any) -> Any: class ToTensor(Module): r"""Convert input to torch tensor - Args: - padding_value (int, optional): Pad value to make each input in the batch of length equal to the longest sequence in the batch. + :param padding_value: Pad value to make each input in the batch of length equal to the longest sequence in the batch. + :type padding_value: Optional[int] + :param dtype: :class:`torch.dtype` of output tensor + :type dtype: :class:`torch.dtype` """ def __init__(self, padding_value: Optional[int] = None, dtype: Optional[torch.dtype] = torch.long) -> None: @@ -103,9 +115,10 @@ def __init__(self, padding_value: Optional[int] = None, dtype: Optional[torch.dt self.dtype = dtype def forward(self, input: Any) -> Tensor: - r""" - Args: - + """ + :param input: Sequence or batch of token ids + :type input: Union[List[int], List[List[int]]] + :rtype: Tensor """ return F.to_tensor(input, padding_value=self.padding_value, dtype=self.dtype) @@ -114,15 +127,16 @@ class LabelToIndex(Module): r""" Transform labels from string names to ids. - Args: - label_names (List[str], Optional): a list of unique label names - label_path (str, Optional): a path to file containing unique label names containing 1 label per line. + :param label_names: a list of unique label names + :type label_names: Optional[List[str]] + :param label_path: a path to file containing unique label names containing 1 label per line. Note that either label_names or label_path should be supplied + but not both. + :type label_path: Optional[str] """ def __init__( self, label_names: Optional[List[str]] = None, label_path: Optional[str] = None, sort_names=False, ): - assert label_names or label_path, "label_names or label_path is required" assert not (label_names and label_path), "label_names and label_path are mutually exclusive" super().__init__() @@ -139,6 +153,11 @@ def __init__( self._label_names = self._label_vocab.get_itos() def forward(self, input: Any) -> Any: + """ + :param input: Input labels to convert to corresponding ids + :type input: Union[str, List[str]] + :rtype: Union[int, List[int]] + """ if torch.jit.isinstance(input, List[str]): return self._label_vocab.lookup_indices(input) elif torch.jit.isinstance(input, str): @@ -149,22 +168,3 @@ def forward(self, input: Any) -> Any: @property def label_names(self) -> List[str]: return self._label_names - - -class Truncate(Module): - r"""Truncate input sequence - - Args: - max_seq_len (int): The maximum allowable length for input sequence - """ - - def __init__(self, max_seq_len: int) -> None: - super().__init__() - self.max_seq_len = max_seq_len - - def forward(self, input: Any) -> Any: - """ - Args: - input: Input sequence to truncate. The input can either be a ``List`` or ``List[List]`` for batched operation - """ - return F.truncate(input, self.max_seq_len) From 4c09607f1c58adbc61de3bf88f4d120954b2615c Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Wed, 8 Dec 2021 10:58:01 -0500 Subject: [PATCH 6/7] add doc to tests and update coverage --- test/test_transforms.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 41badad97c..9f8ea22c16 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -14,6 +14,7 @@ def _spmtokenizer(self, test_scripting): transform = transforms.SentencePieceTokenizer(asset_path) if test_scripting: transform = torch.jit.script(transform) + actual = transform(["Hello World!, how are you?"]) expected = [['▁Hello', '▁World', '!', ',', '▁how', '▁are', '▁you', '?']] self.assertEqual(actual, expected) @@ -23,10 +24,11 @@ def _spmtokenizer(self, test_scripting): self.assertEqual(actual, expected) def test_spmtokenizer(self): - + """test tokenization on single sentence input as well as batch on sentences""" self._spmtokenizer(test_scripting=False) def test_spmtokenizer_jit(self): + """test tokenization with scripting on single sentence input as well as batch on sentences""" self._spmtokenizer(test_scripting=True) def _vocab_transform(self, test_scripting): @@ -43,9 +45,11 @@ def _vocab_transform(self, test_scripting): self.assertEqual(actual, expected) def test_vocab_transform(self): + """test token to indices on both sequence of input tokens as well as batch of sequence""" self._vocab_transform(test_scripting=False) def test_vocab_transform_jit(self): + """test token to indices with scripting on both sequence of input tokens as well as batch of sequence""" self._vocab_transform(test_scripting=True) def _totensor(self, test_scripting): @@ -65,9 +69,11 @@ def _totensor(self, test_scripting): torch.testing.assert_close(actual, expected) def test_totensor(self): + """test tensorization on both single sequence and batch of sequence""" self._totensor(test_scripting=False) def test_totensor_jit(self): + """test tensorization with scripting on both single sequence and batch of sequence""" self._totensor(test_scripting=True) def _labeltoindex(self, test_scripting): @@ -100,7 +106,9 @@ def _labeltoindex(self, test_scripting): self.assertEqual(actual, expected) def test_labeltoindex(self): + """test labe to ids on single label input as well as batch of labels""" self._labeltoindex(test_scripting=False) def test_labeltoindex_jit(self): + """test labe to ids with scripting on single label input as well as batch of labels""" self._labeltoindex(test_scripting=True) From 79551dbe05c5855e12154adaa3240560129610d3 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Wed, 8 Dec 2021 11:05:49 -0500 Subject: [PATCH 7/7] update test --- test/test_functional.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/test/test_functional.py b/test/test_functional.py index 497d757664..f9b6065638 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -21,18 +21,20 @@ def _to_tensor(self, test_scripting): torch.testing.assert_close(actual, expected) def test_to_tensor(self): + """test tensorization on both single sequence and batch of sequence""" self._to_tensor(test_scripting=False) def test_to_tensor_jit(self): + """test tensorization with scripting on both single sequence and batch of sequence""" self._to_tensor(test_scripting=True) def _truncate(self, test_scripting): - input = [[1, 2], [1, 2, 3]] max_seq_len = 2 - func = functional.truncate if test_scripting: func = torch.jit.script(func) + + input = [[1, 2], [1, 2, 3]] actual = func(input, max_seq_len=max_seq_len) expected = [[1, 2], [1, 2]] self.assertEqual(actual, expected) @@ -42,13 +44,26 @@ def _truncate(self, test_scripting): expected = [1, 2] self.assertEqual(actual, expected) + input = [["a", "b"], ["a", "b", "c"]] + actual = func(input, max_seq_len=max_seq_len) + expected = [["a", "b"], ["a", "b"]] + self.assertEqual(actual, expected) + + input = ["a", "b", "c"] + actual = func(input, max_seq_len=max_seq_len) + expected = ["a", "b"] + self.assertEqual(actual, expected) + def test_truncate(self): + """test truncation on both sequence and batch of sequence with both str and int types""" self._truncate(test_scripting=False) def test_truncate_jit(self): + """test truncation with scripting on both sequence and batch of sequence with both str and int types""" self._truncate(test_scripting=True) def _add_token(self, test_scripting): + func = functional.add_token if test_scripting: func = torch.jit.script(func)