diff --git a/test/asset/xlmr.base.output.pt b/test/asset/xlmr.base.output.pt new file mode 100644 index 0000000000..d32b3c7d8e Binary files /dev/null and b/test/asset/xlmr.base.output.pt differ diff --git a/test/asset/xlmr.large.output.pt b/test/asset/xlmr.large.output.pt new file mode 100644 index 0000000000..5888eb54b1 Binary files /dev/null and b/test/asset/xlmr.large.output.pt differ diff --git a/test/models/test_models.py b/test/models/test_models.py new file mode 100644 index 0000000000..c6b5d01acd --- /dev/null +++ b/test/models/test_models.py @@ -0,0 +1,70 @@ +import torchtext +import torch + +from ..common.torchtext_test_case import TorchtextTestCase +from ..common.assets import get_asset_path + + +class TestModels(TorchtextTestCase): + def test_xlmr_base_output(self): + asset_name = "xlmr.base.output.pt" + asset_path = get_asset_path(asset_name) + xlmr_base = torchtext.models.XLMR_BASE_ENCODER + model = xlmr_base.get_model() + model = model.eval() + model_input = torch.tensor([[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]]) + actual = model(model_input) + expected = torch.load(asset_path) + torch.testing.assert_close(actual, expected) + + def test_xlmr_base_jit_output(self): + asset_name = "xlmr.base.output.pt" + asset_path = get_asset_path(asset_name) + xlmr_base = torchtext.models.XLMR_BASE_ENCODER + model = xlmr_base.get_model() + model = model.eval() + model_jit = torch.jit.script(model) + model_input = torch.tensor([[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]]) + actual = model_jit(model_input) + expected = torch.load(asset_path) + torch.testing.assert_close(actual, expected) + + def test_xlmr_large_output(self): + asset_name = "xlmr.large.output.pt" + asset_path = get_asset_path(asset_name) + xlmr_base = torchtext.models.XLMR_LARGE_ENCODER + model = xlmr_base.get_model() + model = model.eval() + model_input = torch.tensor([[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]]) + actual = model(model_input) + expected = torch.load(asset_path) + torch.testing.assert_close(actual, expected) + + def test_xlmr_large_jit_output(self): + asset_name = "xlmr.large.output.pt" + asset_path = get_asset_path(asset_name) + xlmr_base = torchtext.models.XLMR_LARGE_ENCODER + model = xlmr_base.get_model() + model = model.eval() + model_jit = torch.jit.script(model) + model_input = torch.tensor([[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]]) + actual = model_jit(model_input) + expected = torch.load(asset_path) + torch.testing.assert_close(actual, expected) + + def test_xlmr_transform(self): + xlmr_base = torchtext.models.XLMR_BASE_ENCODER + transform = xlmr_base.transform() + test_text = "XLMR base Model Comparison" + actual = transform([test_text]) + expected = [[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]] + torch.testing.assert_close(actual, expected) + + def test_xlmr_transform_jit(self): + xlmr_base = torchtext.models.XLMR_BASE_ENCODER + transform = xlmr_base.transform() + transform_jit = torch.jit.script(transform) + test_text = "XLMR base Model Comparison" + actual = transform_jit([test_text]) + expected = [[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]] + torch.testing.assert_close(actual, expected) diff --git a/test/test_functional.py b/test/test_functional.py new file mode 100644 index 0000000000..6d71dd71aa --- /dev/null +++ b/test/test_functional.py @@ -0,0 +1,58 @@ +import torch +from torchtext import functional +from .common.torchtext_test_case import TorchtextTestCase + + +class TestFunctional(TorchtextTestCase): + def test_to_tensor(self): + input = [[1, 2], [1, 2, 3]] + padding_value = 0 + actual = functional.to_tensor(input, padding_value=padding_value) + expected = torch.tensor([[1, 2, 0], [1, 2, 3]], dtype=torch.long) + torch.testing.assert_close(actual, expected) + + def test_to_tensor_jit(self): + input = [[1, 2], [1, 2, 3]] + padding_value = 0 + to_tensor_jit = torch.jit.script(functional.to_tensor) + actual = to_tensor_jit(input, padding_value=padding_value) + expected = torch.tensor([[1, 2, 0], [1, 2, 3]], dtype=torch.long) + torch.testing.assert_close(actual, expected) + + def test_truncate(self): + input = [[1, 2], [1, 2, 3]] + max_seq_len = 2 + actual = functional.truncate(input, max_seq_len=max_seq_len) + expected = [[1, 2], [1, 2]] + self.assertEqual(actual, expected) + + def test_truncate_jit(self): + input = [[1, 2], [1, 2, 3]] + max_seq_len = 2 + truncate_jit = torch.jit.script(functional.truncate) + actual = truncate_jit(input, max_seq_len=max_seq_len) + expected = [[1, 2], [1, 2]] + self.assertEqual(actual, expected) + + def test_add_token(self): + input = [[1, 2], [1, 2, 3]] + token_id = 0 + actual = functional.add_token(input, token_id=token_id) + expected = [[0, 1, 2], [0, 1, 2, 3]] + self.assertEqual(actual, expected) + + actual = functional.add_token(input, token_id=token_id, begin=False) + expected = [[1, 2, 0], [1, 2, 3, 0]] + self.assertEqual(actual, expected) + + def test_add_token_jit(self): + input = [[1, 2], [1, 2, 3]] + token_id = 0 + add_token_jit = torch.jit.script(functional.add_token) + actual = add_token_jit(input, token_id=token_id) + expected = [[0, 1, 2], [0, 1, 2, 3]] + self.assertEqual(actual, expected) + + actual = add_token_jit(input, token_id=token_id, begin=False) + expected = [[1, 2, 0], [1, 2, 3, 0]] + self.assertEqual(actual, expected) diff --git a/test/test_transforms.py b/test/test_transforms.py new file mode 100644 index 0000000000..ab941fb39c --- /dev/null +++ b/test/test_transforms.py @@ -0,0 +1,33 @@ +import torch +from torchtext import transforms +from torchtext.vocab import vocab +from collections import OrderedDict + +from .common.torchtext_test_case import TorchtextTestCase +from .common.assets import get_asset_path + + +class TestTransforms(TorchtextTestCase): + def test_spmtokenizer_transform(self): + asset_name = "spm_example.model" + asset_path = get_asset_path(asset_name) + transform = transforms.SpmTokenizerTransform(asset_path) + actual = transform(["Hello World!, how are you?"]) + expected = [['▁Hello', '▁World', '!', ',', '▁how', '▁are', '▁you', '?']] + self.assertEqual(actual, expected) + + def test_spmtokenizer_transform_jit(self): + asset_name = "spm_example.model" + asset_path = get_asset_path(asset_name) + transform = transforms.SpmTokenizerTransform(asset_path) + transform_jit = torch.jit.script(transform) + actual = transform_jit(["Hello World!, how are you?"]) + expected = [['▁Hello', '▁World', '!', ',', '▁how', '▁are', '▁you', '?']] + self.assertEqual(actual, expected) + + def test_vocab_transform(self): + vocab_obj = vocab(OrderedDict([('a', 1), ('b', 1), ('c', 1)])) + transform = transforms.VocabTransform(vocab_obj) + actual = transform([['a', 'b', 'c']]) + expected = [[0, 1, 2]] + self.assertEqual(actual, expected) diff --git a/torchtext/__init__.py b/torchtext/__init__.py index 8e1a92feab..a5c2802614 100644 --- a/torchtext/__init__.py +++ b/torchtext/__init__.py @@ -1,8 +1,15 @@ +import os +_TEXT_BUCKET = 'https://download.pytorch.org/models/text' +_CACHE_DIR = os.path.expanduser('~/.torchtext/cache') + from . import data from . import nn from . import datasets from . import utils from . import vocab +from . import transforms +from . import functional +from . import models from . import experimental from . import legacy from ._extension import _init_extension @@ -18,6 +25,9 @@ 'datasets', 'utils', 'vocab', + 'transforms', + 'functional', + 'models', 'experimental', 'legacy'] diff --git a/torchtext/data/datasets_utils.py b/torchtext/data/datasets_utils.py index 571b43c479..12297931f9 100644 --- a/torchtext/data/datasets_utils.py +++ b/torchtext/data/datasets_utils.py @@ -15,6 +15,8 @@ import defusedxml.ElementTree as ET except ImportError: import xml.etree.ElementTree as ET + +from torchtext import _CACHE_DIR """ These functions and classes are meant solely for use in torchtext.datasets and not for public consumption yet. @@ -213,7 +215,7 @@ def _wrap_split_argument_with_fn(fn, splits): raise ValueError("Internal Error: Given function {} did not adhere to standard signature.".format(fn)) @functools.wraps(fn) - def new_fn(root=os.path.expanduser('~/.torchtext/cache'), split=splits, **kwargs): + def new_fn(root=_CACHE_DIR, split=splits, **kwargs): result = [] for item in _check_default_set(split, splits, fn.__name__): result.append(fn(root, item, **kwargs)) @@ -250,7 +252,7 @@ def decorator(func): raise ValueError("Internal Error: Given function {} did not adhere to standard signature.".format(fn)) @functools.wraps(func) - def wrapper(root=os.path.expanduser('~/.torchtext/cache'), *args, **kwargs): + def wrapper(root=_CACHE_DIR, *args, **kwargs): new_root = os.path.join(root, dataset_name) if not os.path.exists(new_root): os.makedirs(new_root) diff --git a/torchtext/functional.py b/torchtext/functional.py new file mode 100644 index 0000000000..9231c9644e --- /dev/null +++ b/torchtext/functional.py @@ -0,0 +1,45 @@ +import torch +from torch import Tensor +from torch.nn.utils.rnn import pad_sequence +from typing import List, Optional + +__all__ = [ + 'to_tensor', + 'truncate', + 'add_token', +] + + +def to_tensor(input: List[List[int]], padding_value: Optional[int] = None) -> Tensor: + if padding_value is None: + output = torch.tensor(input, dtype=torch.long) + return output + else: + output = pad_sequence( + [torch.tensor(ids, dtype=torch.long) for ids in input], + batch_first=True, + padding_value=float(padding_value) + ) + return output + + +def truncate(input: List[List[int]], max_seq_len: int) -> List[List[int]]: + output: List[List[int]] = [] + + for ids in input: + output.append(ids[:max_seq_len]) + + return output + + +def add_token(input: List[List[int]], token_id: int, begin: bool = True) -> List[List[int]]: + output: List[List[int]] = [] + + if begin: + for ids in input: + output.append([token_id] + ids) + else: + for ids in input: + output.append(ids + [token_id]) + + return output diff --git a/torchtext/models/__init__.py b/torchtext/models/__init__.py new file mode 100644 index 0000000000..a7cbc0c88a --- /dev/null +++ b/torchtext/models/__init__.py @@ -0,0 +1 @@ +from .roberta import * # noqa: F401, F403 diff --git a/torchtext/models/roberta/__init__.py b/torchtext/models/roberta/__init__.py new file mode 100644 index 0000000000..f03218844b --- /dev/null +++ b/torchtext/models/roberta/__init__.py @@ -0,0 +1,18 @@ +from .model import ( + RobertaEncoderParams, + RobertaClassificationHead, +) + +from .bundler import ( + RobertaModelBundle, + XLMR_BASE_ENCODER, + XLMR_LARGE_ENCODER, +) + +__all__ = [ + "RobertaEncoderParams", + "RobertaClassificationHead", + "RobertaModelBundle", + "XLMR_BASE_ENCODER", + "XLMR_LARGE_ENCODER", +] diff --git a/torchtext/models/roberta/bundler.py b/torchtext/models/roberta/bundler.py new file mode 100644 index 0000000000..b7db892db9 --- /dev/null +++ b/torchtext/models/roberta/bundler.py @@ -0,0 +1,99 @@ + +import os +from dataclasses import dataclass +from functools import partial + +from typing import Optional, Callable +from torch.hub import load_state_dict_from_url +from torch.nn import Module +import logging + +logger = logging.getLogger(__name__) + +from .model import ( + RobertaEncoderParams, + RobertaModel, + _get_model, +) + +from .transforms import get_xlmr_transform + +from torchtext import _TEXT_BUCKET + + +@dataclass +class RobertaModelBundle: + """ + Example - Pretrained encoder + >>> import torch, torchtext + >>> xlmr_base = torchtext.models.XLMR_BASE_ENCODER + >>> model = xlmr_base.get_model() + >>> transform = xlmr_base.transform() + >>> model_input = torch.tensor(transform(["Hello World"])) + >>> output = model(model_input) + >>> output.shape + torch.Size([1, 4, 768]) + >>> input_batch = ["Hello world", "How are you!"] + >>> from torchtext.functional import to_tensor + >>> model_input = to_tensor(transform(input_batch), padding_value=transform.pad_idx) + >>> output = model(model_input) + >>> output.shape + torch.Size([2, 6, 768]) + + Example - Pretrained encoder attached to un-initialized classification head + >>> import torch, torchtext + >>> xlmr_large = torchtext.models.XLMR_LARGE_ENCODER + >>> classifier_head = torchtext.models.RobertaClassificationHead(num_classes=2, input_dim = xlmr_large.params.embedding_dim) + >>> classification_model = xlmr_large.get_model(head=classifier_head) + >>> transform = xlmr_large.transform() + >>> model_input = torch.tensor(transform(["Hello World"])) + >>> output = classification_model(model_input) + >>> output.shape + torch.Size([1, 2]) + """ + _params: RobertaEncoderParams + _path: Optional[str] = None + _head: Optional[Module] = None + transform: Optional[Callable] = None + + def get_model(self, head: Optional[Module] = None, *, dl_kwargs=None) -> RobertaModel: + + if head is not None: + input_head = head + if self._head is not None: + logger.log("A custom head module was provided, discarding the default head module.") + else: + input_head = self._head + + model = _get_model(self._params, input_head) + + dl_kwargs = {} if dl_kwargs is None else dl_kwargs + state_dict = load_state_dict_from_url(self._path, **dl_kwargs) + if input_head is not None: + model.load_state_dict(state_dict, strict=False) + else: + model.load_state_dict(state_dict, strict=True) + return model + + @property + def params(self) -> RobertaEncoderParams: + return self._params + + +XLMR_BASE_ENCODER = RobertaModelBundle( + _path=os.path.join(_TEXT_BUCKET, "xlmr.base.encoder.pt"), + _params=RobertaEncoderParams(vocab_size=250002), + transform=partial(get_xlmr_transform, + vocab_path=os.path.join(_TEXT_BUCKET, "xlmr.vocab.pt"), + spm_model_path=os.path.join(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model"), + ) +) + +XLMR_LARGE_ENCODER = RobertaModelBundle( + _path=os.path.join(_TEXT_BUCKET, "xlmr.large.encoder.pt"), + _params=RobertaEncoderParams(vocab_size=250002, embedding_dim=1024, ffn_dimension=4096, num_attention_heads=16, num_encoder_layers=24), + transform=partial(get_xlmr_transform, + vocab_path=os.path.join(_TEXT_BUCKET, "xlmr.vocab.pt"), + spm_model_path=os.path.join(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model"), + ) +) diff --git a/torchtext/models/roberta/model.py b/torchtext/models/roberta/model.py new file mode 100644 index 0000000000..0656e647f8 --- /dev/null +++ b/torchtext/models/roberta/model.py @@ -0,0 +1,108 @@ +import math + +from dataclasses import dataclass, asdict +from typing import Optional + +from torch.nn import Module +import torch +from torch import Tensor +import torch.nn as nn + +from .modules import ( + TransformerEncoder, +) + + +@dataclass +class RobertaEncoderParams: + vocab_size: int = 50265 + embedding_dim: int = 768 + ffn_dimension: int = 3072 + padding_idx: int = 1 + max_seq_len: int = 514 + num_attention_heads: int = 12 + num_encoder_layers: int = 12 + dropout: float = 0.1 + scaling: Optional[float] = None + normalize_before: bool = False + + +class RobertaEncoder(Module): + def __init__( + self, + vocab_size: int, + embedding_dim: int, + ffn_dimension: int, + padding_idx: int, + max_seq_len: int, + num_attention_heads: int, + num_encoder_layers: int, + dropout: float = 0.1, + scaling: Optional[float] = None, + normalize_before: bool = False, + ): + super().__init__() + if not scaling: + head_dim = embedding_dim // num_attention_heads + scaling = 1.0 / math.sqrt(head_dim) + + self.transformer = TransformerEncoder( + vocab_size=vocab_size, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + max_seq_len=max_seq_len, + ffn_dimension=ffn_dimension, + num_encoder_layers=num_encoder_layers, + num_attention_heads=num_attention_heads, + dropout=dropout, + normalize_before=normalize_before, + scaling=scaling, + ) + + def forward(self, tokens: Tensor, mask: Optional[Tensor] = None) -> Tensor: + all_layers = self.transformer(tokens) + last_layer = all_layers[-1].transpose(1, 0) + if mask is not None: + last_layer = last_layer[mask.to(torch.bool), :] + return last_layer + + +# TODO: Add Missing quant noise and spectral norm from latest Roberta head in fairseq repo +class RobertaClassificationHead(nn.Module): + def __init__(self, num_classes, input_dim, inner_dim: Optional[int] = None, dropout: float = 0.1, activation=nn.ReLU): + super().__init__() + if not inner_dim: + inner_dim = input_dim + self.dense = nn.Linear(input_dim, inner_dim) + self.dropout = nn.Dropout(dropout) + self.out_proj = nn.Linear(inner_dim, num_classes) + self.activation_fn = activation() + + def forward(self, features): + x = features[:, 0, :] + x = self.dropout(x) + x = self.dense(x) + x = self.activation_fn(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +class RobertaModel(Module): + def __init__(self, encoder: Module, head: Optional[Module] = None): + super().__init__() + self.encoder = encoder + self.head = head + + def forward(self, tokens: Tensor) -> Tensor: + features = self.encoder(tokens) + if self.head is None: + return features + + x = self.head(features) + return x + + +def _get_model(params: RobertaEncoderParams, head: Module) -> RobertaModel: + encoder = RobertaEncoder(**asdict(params)) + return RobertaModel(encoder, head) diff --git a/torchtext/models/roberta/modules.py b/torchtext/models/roberta/modules.py new file mode 100644 index 0000000000..5f0fc01d56 --- /dev/null +++ b/torchtext/models/roberta/modules.py @@ -0,0 +1,296 @@ +from typing import Optional, List + +import torch +from torch import nn +from torch.nn import Module + +import math + +from torch.nn import functional as F + + +class PositionalEmbedding(Module): + def __init__( + self, num_embeddings: int, embedding_dim: int, pad_index: int + ): + super().__init__() + self.embedding = nn.Embedding(num_embeddings, embedding_dim, pad_index) + self.pad_index = pad_index + + def forward(self, input): + positions = self._make_positions(input, self.pad_index) + return self.embedding(positions) + + def max_positions(self): + if self.pad_index is not None: + return self.num_embeddings - self.pad_index - 1 + else: + return self.num_embeddings + + def _make_positions(self, tensor, pad_index: int): + masked = tensor.ne(pad_index).long() + return torch.cumsum(masked, dim=1) * masked + pad_index + + +class ResidualMLP(Module): + def __init__( + self, + input_dim: int, + hidden_dims: List[int], + dropout: float = 0.1, + activation=nn.GELU, + add_residual=True, + ): + super().__init__() + modules = [] + for last_dim, dim in zip([input_dim] + hidden_dims, hidden_dims): + modules.extend( + [nn.Linear(last_dim, dim), activation(), nn.Dropout(dropout)] + ) + + last_dim = hidden_dims[-1] if hidden_dims else input_dim + modules.extend([nn.Linear(last_dim, input_dim), nn.Dropout(dropout)]) + + self.mlp = nn.Sequential(*modules) + self.add_residual = add_residual + + def forward(self, input): + bias = self.mlp(input) + if not hasattr(self, "add_residual"): + self.add_residual = True + if self.add_residual: + return input + bias + else: + return bias + + +class MultiheadSelfAttention(Module): + def __init__( + self, + embed_dim: int, + num_heads: int, + scaling: Optional[float] = None, + dropout: float = 0.1, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + + expected_scaling = float(1 / math.sqrt(self.head_dim)) + + if not scaling and self.head_dim == 64: + scaling = 0.125 + + if not scaling: + raise Exception( + f""" + Scaling not set. Please manually set scaling for transformers with + head_dim != 64. The suggested value in this case is {expected_scaling}, + or float(1 / math.sqrt(head_dim)) + where head_dim = embed_dim // num_heads = {self.head_dim} + and embed_dim = {embed_dim} and num_heads = {num_heads}. + """ + ) + + self.scaling = scaling + self.dropout = nn.Dropout(dropout) + self.input_projection = nn.Linear(embed_dim, 3 * embed_dim) + self.output_projection = nn.Linear(embed_dim, embed_dim) + + def forward(self, query, key_padding_mask): + target_length, batch_size, embed_dim = query.size() + mask_batch_size, source_length = key_padding_mask.size() + + torch._assert(embed_dim == self.embed_dim, "query embed dim doesn't match") + torch._assert( + batch_size == mask_batch_size, + "query and key_padding_mask batch sizes differed", + ) + + projection = self.input_projection(query) + q, k, v = projection.chunk(3, dim=-1) + q = self.scaling * q + + batch_heads = batch_size * self.num_heads + + q = q.contiguous().view(-1, batch_heads, self.head_dim).transpose(0, 1) + k = k.contiguous().view(-1, batch_heads, self.head_dim).transpose(0, 1) + v = v.contiguous().view(-1, batch_heads, self.head_dim).transpose(0, 1) + + torch._assert( + k.size(1) == source_length, "key size should be equal to source length" + ) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + + torch._assert(attn_weights.dim() == 3, "Unexpected attn_weights dim") + torch._assert( + attn_weights.size(0) == batch_heads, + "attn_weights shape didn't match for batch heads", + ) + torch._assert( + attn_weights.size(1) == target_length, + "attn_weights shape didn't match for target length", + ) + torch._assert( + attn_weights.size(2) == source_length, + "attn_weights shape didn't match for source length", + ) + + attn_weights = attn_weights.view( + batch_size, self.num_heads, target_length, source_length + ) + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf") + ) + attn_weights = attn_weights.view(batch_heads, target_length, source_length) + + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as( + attn_weights + ) + attn_weights = self.dropout(attn_weights) + + attn = torch.bmm(attn_weights, v) + + torch._assert( + attn.dim() == 3, + "unexpected attn dim size", + ) + torch._assert( + attn.size(0) == batch_heads, + "attn shape didn't match for batch heads", + ) + torch._assert( + attn.size(1) == target_length, + "attn shape didn't match for target length", + ) + torch._assert( + attn.size(2) == self.head_dim, + "attn shape didn't match for head dim", + ) + attn = ( + attn.transpose(0, 1) + .contiguous() + .view(target_length, batch_size, self.head_dim * self.num_heads) + ) + attn = self.output_projection(attn) + + return attn + + +class TransformerEncoderLayer(Module): + def __init__( + self, + embedding_dim: int, + num_attention_heads: int, + ffn_dimension: Optional[int] = None, + dropout: float = 0.1, + normalize_before: bool = False, + scaling: Optional[float] = None, + ): + super().__init__() + self.dropout = nn.Dropout(dropout) + self.attention = MultiheadSelfAttention( + embedding_dim, + num_heads=num_attention_heads, + scaling=scaling, + dropout=dropout, + ) + + self.residual_mlp = ResidualMLP( + embedding_dim, + hidden_dims=[ffn_dimension or embedding_dim * 4], + add_residual=not normalize_before, + ) + + self.attention_layer_norm = nn.LayerNorm(embedding_dim) + self.final_layer_norm = nn.LayerNorm(embedding_dim) + self.normalize_before = normalize_before + + def forward(self, input, key_padding_mask): + if not hasattr(self, "normalize_before"): + self.normalize_before = False + if self.normalize_before: + x = self.attention_layer_norm(input) + attention = self.attention(x, key_padding_mask) + attention = self.dropout(attention) + biased_input = input + attention + x = self.final_layer_norm(biased_input) + return self.residual_mlp(x) + biased_input + else: + attention = self.attention(input, key_padding_mask) + attention = self.dropout(attention) + biased_input = input + attention + biased_input = self.attention_layer_norm(biased_input) + biased = self.residual_mlp(biased_input) + return self.final_layer_norm(biased) + + +class TransformerEncoder(Module): + def __init__( + self, + vocab_size: int, + embedding_dim: int, + padding_idx: int, + max_seq_len: int, + num_encoder_layers: int, + num_attention_heads: int, + ffn_dimension: Optional[int] = None, + dropout: float = 0.1, + normalize_before: bool = False, + scaling: Optional[float] = None, + ): + super().__init__() + self.padding_idx = padding_idx + self.token_embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx) + self.layers = nn.ModuleList( + [ + TransformerEncoderLayer( + embedding_dim=embedding_dim, + num_attention_heads=num_attention_heads, + ffn_dimension=ffn_dimension, + dropout=dropout, + normalize_before=normalize_before, + scaling=scaling, + ) + for _ in range(num_encoder_layers) + ] + ) + self.positional_embedding = PositionalEmbedding( + max_seq_len, embedding_dim, padding_idx + ) + self.embedding_layer_norm = nn.LayerNorm(embedding_dim) + self.dropout = nn.Dropout(dropout) + self.normalize_before = normalize_before + + def forward(self, tokens: torch.Tensor) -> List[torch.Tensor]: + padding_mask = tokens.eq(self.padding_idx) + + token_embeddings = self.token_embedding(tokens) + embedded_positions = self.positional_embedding(tokens) + + embedded = token_embeddings + embedded_positions + + if not hasattr(self, "normalize_before"): + self.normalize_before = False + if not self.normalize_before: + embedded = self.embedding_layer_norm(embedded) + embedded = self.dropout(embedded) + + padded_embedded = embedded * (1 - padding_mask.unsqueeze(-1).type_as(embedded)) + + encoded = padded_embedded.transpose(0, 1) + + states = [encoded] + + for layer in self.layers: + encoded = layer(encoded, padding_mask) + states.append(encoded) + + if self.normalize_before: + for i, state in enumerate(states): + states[i] = self.embedding_layer_norm(state) + + # states are returned as T x B x C + return states diff --git a/torchtext/models/roberta/transforms.py b/torchtext/models/roberta/transforms.py new file mode 100644 index 0000000000..febf6a858b --- /dev/null +++ b/torchtext/models/roberta/transforms.py @@ -0,0 +1,66 @@ +import os +import torch +from torch.nn import Module +from torch.hub import load_state_dict_from_url +from torchtext import transforms +from torchtext import functional + +from typing import List + + +class XLMRobertaModelTransform(Module): + def __init__( + self, + vocab_path: str, + spm_model_path: str, + bos_token: str = "", + cls_token: str = "", + pad_token: str = "", + eos_token: str = "", + sep_token: str = "", + unk_token: str = "", + mask_token: str = "", + max_seq_len: int = 514, + ): + super().__init__() + self.bos_token = bos_token + self.eos_token = eos_token + self.pad_token = pad_token + self.unk_token = unk_token + self.mask_token = mask_token + self.cls_token = cls_token + self.sep_token = sep_token + self.max_seq_len = max_seq_len + + self.token_transform = transforms.SpmTokenizerTransform(spm_model_path) + + if os.path.exists(vocab_path): + self.vocab = torch.load(vocab_path) + else: + self.vocab = load_state_dict_from_url(vocab_path) + + self.vocab_transform = transforms.VocabTransform(self.vocab) + self.pad_idx = self.vocab[self.pad_token] + self.bos_idx = self.vocab[self.bos_token] + self.eos_idx = self.vocab[self.eos_token] + + def forward(self, input: List[str], + add_bos: bool = True, + add_eos: bool = True, + truncate: bool = True) -> List[List[int]]: + tokens: List[List[int]] = self.vocab_transform(self.token_transform(input)) + + if truncate: + tokens = functional.truncate(tokens, self.max_seq_len - 2) + + if add_bos: + tokens = functional.add_token(tokens, self.bos_idx) + + if add_eos: + tokens = functional.add_token(tokens, self.eos_idx, begin=False) + + return tokens + + +def get_xlmr_transform(vocab_path, spm_model_path, **kwargs) -> XLMRobertaModelTransform: + return XLMRobertaModelTransform(vocab_path, spm_model_path, **kwargs) diff --git a/torchtext/transforms.py b/torchtext/transforms.py new file mode 100644 index 0000000000..ece8ebbc80 --- /dev/null +++ b/torchtext/transforms.py @@ -0,0 +1,75 @@ +from torch.nn import Module +from torchtext.data.functional import load_sp_model +from torchtext.utils import download_from_url +import torchtext +from typing import List +import os + +from torchtext import _CACHE_DIR + +__all__ = [ + 'SpmTokenizerTransform', + 'VocabTransform', +] + + +class SpmTokenizerTransform(Module): + """ + Transform for Sentence Piece tokenizer. + + Examples: + >>> from torchtext.transforms import PRETRAINED_SP_MODEL + >>> from torchtext.transforms import SpmTokenizerTransform + >>> transform = SpmTokenizerTransform(PRETRAINED_SP_MODEL["text_unigram_15000"]) + >>> transform(["hello world", "attention is all you need!"]) + """ + + def __init__(self, sp_model_path: str): + super().__init__() + if os.path.exists(sp_model_path): + local_path = sp_model_path + else: + local_path = download_from_url(url=sp_model_path, root=_CACHE_DIR) + self.sp_model = load_sp_model(local_path) + + def forward(self, input: List[str]) -> List[List[str]]: + tokens: List[List[str]] = [] + for text in input: + tokens.append(self.sp_model.EncodeAsPieces(text)) + return tokens + + +class VocabTransform(Module): + r"""Vocab transform + + Args: + vocab: an instance of torchtext.vocab.Vocab class. + + Example: + >>> import torch + >>> from torchtext.vocab import vocab + >>> from torchtext.transforms import VocabTransform + >>> from collections import OrderedDict + >>> vocab_obj = vocab(OrderedDict([('a', 1), ('b', 1), ('c', 1)])) + >>> vocab_transform = VocabTransform(vocab_obj) + >>> output = vocab_transform([['a','b'],['a','b','c']]) + >>> jit_vocab_transform = torch.jit.script(vocab_transform) + """ + + def __init__(self, vocab): + super().__init__() + assert isinstance(vocab, torchtext.vocab.Vocab) + self.vocab = vocab + + def forward(self, input: List[List[str]]) -> List[List[int]]: + r""" + + Args: + input: list of list tokens + """ + + output: List[List[int]] = [] + for tokens in input: + output.append(self.vocab.lookup_indices(tokens)) + + return output