diff --git a/test/test_functional.py b/test/test_functional.py index f8dde30e06..f9b6065638 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -4,87 +4,86 @@ class TestFunctional(TorchtextTestCase): - def test_to_tensor(self): + def _to_tensor(self, test_scripting): input = [[1, 2], [1, 2, 3]] padding_value = 0 - actual = functional.to_tensor(input, padding_value=padding_value) + func = functional.to_tensor + if test_scripting: + func = torch.jit.script(func) + actual = func(input, padding_value=padding_value) expected = torch.tensor([[1, 2, 0], [1, 2, 3]], dtype=torch.long) torch.testing.assert_close(actual, expected) input = [1, 2] - actual = functional.to_tensor(input, padding_value=padding_value) + actual = func(input, padding_value=padding_value) expected = torch.tensor([1, 2], dtype=torch.long) torch.testing.assert_close(actual, expected) - def test_to_tensor_jit(self): - input = [[1, 2], [1, 2, 3]] - padding_value = 0 - to_tensor_jit = torch.jit.script(functional.to_tensor) - actual = to_tensor_jit(input, padding_value=padding_value) - expected = torch.tensor([[1, 2, 0], [1, 2, 3]], dtype=torch.long) - torch.testing.assert_close(actual, expected) + def test_to_tensor(self): + """test tensorization on both single sequence and batch of sequence""" + self._to_tensor(test_scripting=False) - input = [1, 2] - actual = to_tensor_jit(input, padding_value=padding_value) - expected = torch.tensor([1, 2], dtype=torch.long) - torch.testing.assert_close(actual, expected) + def test_to_tensor_jit(self): + """test tensorization with scripting on both single sequence and batch of sequence""" + self._to_tensor(test_scripting=True) - def test_truncate(self): - input = [[1, 2], [1, 2, 3]] + def _truncate(self, test_scripting): max_seq_len = 2 + func = functional.truncate + if test_scripting: + func = torch.jit.script(func) - actual = functional.truncate(input, max_seq_len=max_seq_len) + input = [[1, 2], [1, 2, 3]] + actual = func(input, max_seq_len=max_seq_len) expected = [[1, 2], [1, 2]] self.assertEqual(actual, expected) input = [1, 2, 3] - actual = functional.truncate(input, max_seq_len=max_seq_len) + actual = func(input, max_seq_len=max_seq_len) expected = [1, 2] self.assertEqual(actual, expected) - def test_truncate_jit(self): - input = [[1, 2], [1, 2, 3]] - max_seq_len = 2 - truncate_jit = torch.jit.script(functional.truncate) - actual = truncate_jit(input, max_seq_len=max_seq_len) - expected = [[1, 2], [1, 2]] + input = [["a", "b"], ["a", "b", "c"]] + actual = func(input, max_seq_len=max_seq_len) + expected = [["a", "b"], ["a", "b"]] self.assertEqual(actual, expected) - input = [1, 2, 3] - actual = truncate_jit(input, max_seq_len=max_seq_len) - expected = [1, 2] + input = ["a", "b", "c"] + actual = func(input, max_seq_len=max_seq_len) + expected = ["a", "b"] self.assertEqual(actual, expected) - def test_add_token(self): - input = [[1, 2], [1, 2, 3]] - token_id = 0 - actual = functional.add_token(input, token_id=token_id) - expected = [[0, 1, 2], [0, 1, 2, 3]] - self.assertEqual(actual, expected) + def test_truncate(self): + """test truncation on both sequence and batch of sequence with both str and int types""" + self._truncate(test_scripting=False) - actual = functional.add_token(input, token_id=token_id, begin=False) - expected = [[1, 2, 0], [1, 2, 3, 0]] - self.assertEqual(actual, expected) + def test_truncate_jit(self): + """test truncation with scripting on both sequence and batch of sequence with both str and int types""" + self._truncate(test_scripting=True) - input = [1, 2] - actual = functional.add_token(input, token_id=token_id, begin=False) - expected = [1, 2, 0] - self.assertEqual(actual, expected) + def _add_token(self, test_scripting): - def test_add_token_jit(self): + func = functional.add_token + if test_scripting: + func = torch.jit.script(func) input = [[1, 2], [1, 2, 3]] token_id = 0 - add_token_jit = torch.jit.script(functional.add_token) - actual = add_token_jit(input, token_id=token_id) + actual = func(input, token_id=token_id) expected = [[0, 1, 2], [0, 1, 2, 3]] self.assertEqual(actual, expected) - actual = add_token_jit(input, token_id=token_id, begin=False) + actual = func(input, token_id=token_id, begin=False) expected = [[1, 2, 0], [1, 2, 3, 0]] self.assertEqual(actual, expected) input = [1, 2] - actual = add_token_jit(input, token_id=token_id, begin=False) + actual = func(input, token_id=token_id, begin=False) expected = [1, 2, 0] self.assertEqual(actual, expected) + + def test_add_token(self): + self._add_token(test_scripting=False) + + def test_add_token_jit(self): + self._add_token(test_scripting=True) diff --git a/test/test_transforms.py b/test/test_transforms.py index d7df6d3876..9f8ea22c16 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -8,10 +8,12 @@ class TestTransforms(TorchtextTestCase): - def test_spmtokenizer(self): + def _spmtokenizer(self, test_scripting): asset_name = "spm_example.model" asset_path = get_asset_path(asset_name) transform = transforms.SentencePieceTokenizer(asset_path) + if test_scripting: + transform = torch.jit.script(transform) actual = transform(["Hello World!, how are you?"]) expected = [['▁Hello', '▁World', '!', ',', '▁how', '▁are', '▁you', '?']] @@ -21,24 +23,19 @@ def test_spmtokenizer(self): expected = ['▁Hello', '▁World', '!', ',', '▁how', '▁are', '▁you', '?'] self.assertEqual(actual, expected) - def test_spmtokenizer_jit(self): - asset_name = "spm_example.model" - asset_path = get_asset_path(asset_name) - transform = transforms.SentencePieceTokenizer(asset_path) - transform_jit = torch.jit.script(transform) - - actual = transform_jit(["Hello World!, how are you?"]) - expected = [['▁Hello', '▁World', '!', ',', '▁how', '▁are', '▁you', '?']] - self.assertEqual(actual, expected) + def test_spmtokenizer(self): + """test tokenization on single sentence input as well as batch on sentences""" + self._spmtokenizer(test_scripting=False) - actual = transform_jit("Hello World!, how are you?") - expected = ['▁Hello', '▁World', '!', ',', '▁how', '▁are', '▁you', '?'] - self.assertEqual(actual, expected) + def test_spmtokenizer_jit(self): + """test tokenization with scripting on single sentence input as well as batch on sentences""" + self._spmtokenizer(test_scripting=True) - def test_vocab_transform(self): + def _vocab_transform(self, test_scripting): vocab_obj = vocab(OrderedDict([('a', 1), ('b', 1), ('c', 1)])) transform = transforms.VocabTransform(vocab_obj) - + if test_scripting: + transform = torch.jit.script(transform) actual = transform([['a', 'b', 'c']]) expected = [[0, 1, 2]] self.assertEqual(actual, expected) @@ -47,22 +44,20 @@ def test_vocab_transform(self): expected = [0, 1, 2] self.assertEqual(actual, expected) - def test_vocab_transform_jit(self): - vocab_obj = vocab(OrderedDict([('a', 1), ('b', 1), ('c', 1)])) - transform_jit = torch.jit.script(transforms.VocabTransform(vocab_obj)) - - actual = transform_jit([['a', 'b', 'c']]) - expected = [[0, 1, 2]] - self.assertEqual(actual, expected) + def test_vocab_transform(self): + """test token to indices on both sequence of input tokens as well as batch of sequence""" + self._vocab_transform(test_scripting=False) - actual = transform_jit(['a', 'b', 'c']) - expected = [0, 1, 2] - self.assertEqual(actual, expected) + def test_vocab_transform_jit(self): + """test token to indices with scripting on both sequence of input tokens as well as batch of sequence""" + self._vocab_transform(test_scripting=True) - def test_totensor(self): - input = [[1, 2], [1, 2, 3]] + def _totensor(self, test_scripting): padding_value = 0 transform = transforms.ToTensor(padding_value=padding_value) + if test_scripting: + transform = torch.jit.script(transform) + input = [[1, 2], [1, 2, 3]] actual = transform(input) expected = torch.tensor([[1, 2, 0], [1, 2, 3]], dtype=torch.long) @@ -73,29 +68,26 @@ def test_totensor(self): expected = torch.tensor([1, 2], dtype=torch.long) torch.testing.assert_close(actual, expected) - def test_totensor_jit(self): - input = [[1, 2], [1, 2, 3]] - padding_value = 0 - transform = transforms.ToTensor(padding_value=padding_value) - transform_jit = torch.jit.script(transform) - - actual = transform_jit(input) - expected = torch.tensor([[1, 2, 0], [1, 2, 3]], dtype=torch.long) - torch.testing.assert_close(actual, expected) + def test_totensor(self): + """test tensorization on both single sequence and batch of sequence""" + self._totensor(test_scripting=False) - input = [1, 2] - actual = transform_jit(input) - expected = torch.tensor([1, 2], dtype=torch.long) - torch.testing.assert_close(actual, expected) + def test_totensor_jit(self): + """test tensorization with scripting on both single sequence and batch of sequence""" + self._totensor(test_scripting=True) - def test_labeltoindex(self): + def _labeltoindex(self, test_scripting): label_names = ['test', 'label', 'indices'] transform = transforms.LabelToIndex(label_names=label_names) + if test_scripting: + transform = torch.jit.script(transform) actual = transform(label_names) expected = [0, 1, 2] self.assertEqual(actual, expected) transform = transforms.LabelToIndex(label_names=label_names, sort_names=True) + if test_scripting: + transform = torch.jit.script(transform) actual = transform(label_names) expected = [2, 1, 0] self.assertEqual(actual, expected) @@ -107,17 +99,16 @@ def test_labeltoindex(self): asset_name = "label_names.txt" asset_path = get_asset_path(asset_name) transform = transforms.LabelToIndex(label_path=asset_path) + if test_scripting: + transform = torch.jit.script(transform) actual = transform(label_names) expected = [0, 1, 2] self.assertEqual(actual, expected) - def test_labeltoindex_jit(self): - label_names = ['test', 'label', 'indices'] - transform_jit = torch.jit.script(transforms.LabelToIndex(label_names=label_names)) - actual = transform_jit(label_names) - expected = [0, 1, 2] - self.assertEqual(actual, expected) + def test_labeltoindex(self): + """test labe to ids on single label input as well as batch of labels""" + self._labeltoindex(test_scripting=False) - actual = transform_jit("test") - expected = 0 - self.assertEqual(actual, expected) + def test_labeltoindex_jit(self): + """test labe to ids with scripting on single label input as well as batch of labels""" + self._labeltoindex(test_scripting=True) diff --git a/torchtext/functional.py b/torchtext/functional.py index fe81920ed5..8f4f41d8da 100644 --- a/torchtext/functional.py +++ b/torchtext/functional.py @@ -1,7 +1,7 @@ import torch from torch import Tensor from torch.nn.utils.rnn import pad_sequence -from typing import List, Optional, Union +from typing import List, Optional, Any __all__ = [ 'to_tensor', @@ -10,10 +10,20 @@ ] -def to_tensor(input: Union[List[int], List[List[int]]], padding_value: Optional[int] = None, dtype: Optional[torch.dtype] = torch.long) -> Tensor: +def to_tensor(input: Any, padding_value: Optional[int] = None, dtype: Optional[torch.dtype] = torch.long) -> Tensor: + r"""Convert input to torch tensor + + :param padding_value: Pad value to make each input in the batch of length equal to the longest sequence in the batch. + :type padding_value: Optional[int] + :param dtype: :class:`torch.dtype` of output tensor + :type dtype: :class:`torch.dtype` + :param input: Sequence or batch of token ids + :type input: Union[List[int], List[List[int]]] + :rtype: Tensor + """ if torch.jit.isinstance(input, List[int]): return torch.tensor(input, dtype=torch.long) - else: + elif torch.jit.isinstance(input, List[List[int]]): if padding_value is None: output = torch.tensor(input, dtype=dtype) return output @@ -24,27 +34,61 @@ def to_tensor(input: Union[List[int], List[List[int]]], padding_value: Optional[ padding_value=float(padding_value) ) return output + else: + raise TypeError("Input type not supported") + +def truncate(input: Any, max_seq_len: int) -> Any: + """ Truncate input sequence or batch -def truncate(input: Union[List[int], List[List[int]]], max_seq_len: int) -> Union[List[int], List[List[int]]]: + :param input: Input sequence or batch to be truncated + :type input: Union[List[Union[str, int]], List[List[Union[str, int]]]] + :param max_seq_len: Maximum length beyond which input is discarded + :type max_seq_len: int + :return: Truncated sequence + :rtype: Union[List[Union[str, int]], List[List[Union[str, int]]]] + """ if torch.jit.isinstance(input, List[int]): return input[:max_seq_len] - else: + elif torch.jit.isinstance(input, List[str]): + return input[:max_seq_len] + elif torch.jit.isinstance(input, List[List[int]]): output: List[List[int]] = [] - for ids in input: output.append(ids[:max_seq_len]) - return output + elif torch.jit.isinstance(input, List[List[str]]): + output: List[List[str]] = [] + for ids in input: + output.append(ids[:max_seq_len]) + return output + else: + raise TypeError("Input type not supported") -def add_token(input: Union[List[int], List[List[int]]], token_id: int, begin: bool = True) -> Union[List[int], List[List[int]]]: - if torch.jit.isinstance(input, List[int]): +def add_token(input: Any, token_id: Any, begin: bool = True) -> Any: + """Add token to start or end of sequence + + :param input: Input sequence or batch + :type input: Union[List[Union[str, int]], List[List[Union[str, int]]]] + :param token_id: token to be added + :type token_id: Union[str, int] + :param begin: Whether to insert token at start or end or sequence, defaults to True + :type begin: bool, optional + :return: sequence or batch with token_id added to begin or end or input + :rtype: Union[List[Union[str, int]], List[List[Union[str, int]]]] + """ + if torch.jit.isinstance(input, List[int]) and torch.jit.isinstance(token_id, int): if begin: return [token_id] + input else: return input + [token_id] - else: + elif torch.jit.isinstance(input, List[str]) and torch.jit.isinstance(token_id, str): + if begin: + return [token_id] + input + else: + return input + [token_id] + elif torch.jit.isinstance(input, List[List[int]]) and torch.jit.isinstance(token_id, int): output: List[List[int]] = [] if begin: @@ -55,3 +99,15 @@ def add_token(input: Union[List[int], List[List[int]]], token_id: int, begin: bo output.append(ids + [token_id]) return output + elif torch.jit.isinstance(input, List[List[str]]) and torch.jit.isinstance(token_id, str): + output: List[List[str]] = [] + if begin: + for ids in input: + output.append([token_id] + ids) + else: + for ids in input: + output.append(ids + [token_id]) + + return output + else: + raise TypeError("Input type not supported") diff --git a/torchtext/models/roberta/transforms.py b/torchtext/models/roberta/transforms.py index 5cb2279649..683b6406be 100644 --- a/torchtext/models/roberta/transforms.py +++ b/torchtext/models/roberta/transforms.py @@ -5,7 +5,7 @@ from torchtext import transforms from torchtext import functional -from typing import List, Union +from typing import List, Any class XLMRobertaModelTransform(Module): @@ -44,23 +44,38 @@ def __init__( self.bos_idx = self.vocab[self.bos_token] self.eos_idx = self.vocab[self.eos_token] - def forward(self, input: Union[str, List[str]], + def forward(self, input: Any, add_bos: bool = True, add_eos: bool = True, - truncate: bool = True) -> Union[List[int], List[List[int]]]: + truncate: bool = True) -> Any: + if torch.jit.isinstance(input, str): + tokens = self.vocab_transform(self.token_transform(input)) - tokens = self.vocab_transform(self.token_transform(input)) + if truncate: + tokens = functional.truncate(tokens, self.max_seq_len - 2) - if truncate: - tokens = functional.truncate(tokens, self.max_seq_len - 2) + if add_bos: + tokens = functional.add_token(tokens, self.bos_idx) - if add_bos: - tokens = functional.add_token(tokens, self.bos_idx) + if add_eos: + tokens = functional.add_token(tokens, self.eos_idx, begin=False) - if add_eos: - tokens = functional.add_token(tokens, self.eos_idx, begin=False) + return tokens + elif torch.jit.isinstance(input, List[str]): + tokens = self.vocab_transform(self.token_transform(input)) - return tokens + if truncate: + tokens = functional.truncate(tokens, self.max_seq_len - 2) + + if add_bos: + tokens = functional.add_token(tokens, self.bos_idx) + + if add_eos: + tokens = functional.add_token(tokens, self.eos_idx, begin=False) + + return tokens + else: + raise TypeError("Input type not supported") def get_xlmr_transform(vocab_path, spm_model_path, **kwargs) -> XLMRobertaModelTransform: diff --git a/torchtext/transforms.py b/torchtext/transforms.py index cf43e40f4d..f5133d7391 100644 --- a/torchtext/transforms.py +++ b/torchtext/transforms.py @@ -5,7 +5,7 @@ from torchtext.data.functional import load_sp_model from torchtext.utils import download_from_url from torchtext.vocab import Vocab -from typing import List, Optional, Union +from typing import List, Optional, Any import os from torchtext import _CACHE_DIR @@ -20,9 +20,14 @@ class SentencePieceTokenizer(Module): """ - Transform for Sentence Piece tokenizer from pre-trained SentencePiece model + Transform for Sentence Piece tokenizer from pre-trained sentencepiece model - Examples: + Additiona details: https://github.com/google/sentencepiece + + :param sp_model_path: Path to pre-trained sentencepiece model + :type sp_model_path: str + + Example >>> from torchtext.transforms import SpmTokenizerTransform >>> transform = SentencePieceTokenizer("spm_model") >>> transform(["hello world", "attention is all you need!"]) @@ -36,21 +41,28 @@ def __init__(self, sp_model_path: str): local_path = download_from_url(url=sp_model_path, root=_CACHE_DIR) self.sp_model = load_sp_model(local_path) - def forward(self, input: Union[str, List[str]]) -> Union[List[str], List[List[str]]]: + def forward(self, input: Any) -> Any: + """ + :param input: Input sentence or list of sentences on which to apply tokenizer. + :type input: Union[str, List[str]] + :return: tokenized text + :rtype: Union[List[str], List[List(str)]] + """ if torch.jit.isinstance(input, List[str]): tokens: List[List[str]] = [] for text in input: tokens.append(self.sp_model.EncodeAsPieces(text)) return tokens - else: + elif torch.jit.isinstance(input, str): return self.sp_model.EncodeAsPieces(input) + else: + raise TypeError("Input type not supported") class VocabTransform(Module): - r"""Vocab transform + r"""Vocab transform to convert input batch of tokens into corresponding token ids - Args: - vocab: an instance of torchtext.vocab.Vocab class. + :param vocab: an instance of :class:`torchtext.vocab.Vocab` class. Example: >>> import torch @@ -68,28 +80,33 @@ def __init__(self, vocab: Vocab): assert isinstance(vocab, Vocab) self.vocab = vocab - def forward(self, input: Union[List[str], List[List[str]]]) -> Union[List[int], List[List[int]]]: - r""" - - Args: - input: list of list tokens + def forward(self, input: Any) -> Any: + """ + :param input: Input batch of token to convert to correspnding token ids + :type input: Union[List[str], List[List[str]]] + :return: Converted input into corresponding token ids + :rtype: Union[List[int], List[List[int]]] """ if torch.jit.isinstance(input, List[str]): return self.vocab.lookup_indices(input) - else: + elif torch.jit.isinstance(input, List[List[str]]): output: List[List[int]] = [] for tokens in input: output.append(self.vocab.lookup_indices(tokens)) return output + else: + raise TypeError("Input type not supported") class ToTensor(Module): r"""Convert input to torch tensor - Args: - padding_value (int, optional): Pad value to make each input in the batch of length equal to the longest sequence in the batch. + :param padding_value: Pad value to make each input in the batch of length equal to the longest sequence in the batch. + :type padding_value: Optional[int] + :param dtype: :class:`torch.dtype` of output tensor + :type dtype: :class:`torch.dtype` """ def __init__(self, padding_value: Optional[int] = None, dtype: Optional[torch.dtype] = torch.long) -> None: @@ -97,10 +114,11 @@ def __init__(self, padding_value: Optional[int] = None, dtype: Optional[torch.dt self.padding_value = padding_value self.dtype = dtype - def forward(self, input: Union[List[int], List[List[int]]]) -> Tensor: - r""" - Args: - + def forward(self, input: Any) -> Tensor: + """ + :param input: Sequence or batch of token ids + :type input: Union[List[int], List[List[int]]] + :rtype: Tensor """ return F.to_tensor(input, padding_value=self.padding_value, dtype=self.dtype) @@ -109,15 +127,16 @@ class LabelToIndex(Module): r""" Transform labels from string names to ids. - Args: - label_names (List[str], Optional): a list of unique label names - label_path (str, Optional): a path to file containing unique label names containing 1 label per line. + :param label_names: a list of unique label names + :type label_names: Optional[List[str]] + :param label_path: a path to file containing unique label names containing 1 label per line. Note that either label_names or label_path should be supplied + but not both. + :type label_path: Optional[str] """ def __init__( self, label_names: Optional[List[str]] = None, label_path: Optional[str] = None, sort_names=False, ): - assert label_names or label_path, "label_names or label_path is required" assert not (label_names and label_path), "label_names and label_path are mutually exclusive" super().__init__() @@ -133,11 +152,18 @@ def __init__( self._label_vocab = Vocab(torch.classes.torchtext.Vocab(label_names, 0)) self._label_names = self._label_vocab.get_itos() - def forward(self, labels: Union[str, List[str]]) -> Union[int, List[int]]: - if torch.jit.isinstance(labels, List[str]): - return self._label_vocab.lookup_indices(labels) + def forward(self, input: Any) -> Any: + """ + :param input: Input labels to convert to corresponding ids + :type input: Union[str, List[str]] + :rtype: Union[int, List[int]] + """ + if torch.jit.isinstance(input, List[str]): + return self._label_vocab.lookup_indices(input) + elif torch.jit.isinstance(input, str): + return self._label_vocab.__getitem__(input) else: - return self._label_vocab.__getitem__(labels) + raise TypeError("Input type not supported") @property def label_names(self) -> List[str]: