diff --git a/test/asset/xlmr.base.output.pt b/test/asset/xlmr.base.output.pt
new file mode 100644
index 0000000000..d32b3c7d8e
Binary files /dev/null and b/test/asset/xlmr.base.output.pt differ
diff --git a/test/asset/xlmr.large.output.pt b/test/asset/xlmr.large.output.pt
new file mode 100644
index 0000000000..5888eb54b1
Binary files /dev/null and b/test/asset/xlmr.large.output.pt differ
diff --git a/test/models/test_models.py b/test/models/test_models.py
new file mode 100644
index 0000000000..c6b5d01acd
--- /dev/null
+++ b/test/models/test_models.py
@@ -0,0 +1,70 @@
+import torchtext
+import torch
+
+from ..common.torchtext_test_case import TorchtextTestCase
+from ..common.assets import get_asset_path
+
+
+class TestModels(TorchtextTestCase):
+ def test_xlmr_base_output(self):
+ asset_name = "xlmr.base.output.pt"
+ asset_path = get_asset_path(asset_name)
+ xlmr_base = torchtext.models.XLMR_BASE_ENCODER
+ model = xlmr_base.get_model()
+ model = model.eval()
+ model_input = torch.tensor([[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]])
+ actual = model(model_input)
+ expected = torch.load(asset_path)
+ torch.testing.assert_close(actual, expected)
+
+ def test_xlmr_base_jit_output(self):
+ asset_name = "xlmr.base.output.pt"
+ asset_path = get_asset_path(asset_name)
+ xlmr_base = torchtext.models.XLMR_BASE_ENCODER
+ model = xlmr_base.get_model()
+ model = model.eval()
+ model_jit = torch.jit.script(model)
+ model_input = torch.tensor([[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]])
+ actual = model_jit(model_input)
+ expected = torch.load(asset_path)
+ torch.testing.assert_close(actual, expected)
+
+ def test_xlmr_large_output(self):
+ asset_name = "xlmr.large.output.pt"
+ asset_path = get_asset_path(asset_name)
+ xlmr_base = torchtext.models.XLMR_LARGE_ENCODER
+ model = xlmr_base.get_model()
+ model = model.eval()
+ model_input = torch.tensor([[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]])
+ actual = model(model_input)
+ expected = torch.load(asset_path)
+ torch.testing.assert_close(actual, expected)
+
+ def test_xlmr_large_jit_output(self):
+ asset_name = "xlmr.large.output.pt"
+ asset_path = get_asset_path(asset_name)
+ xlmr_base = torchtext.models.XLMR_LARGE_ENCODER
+ model = xlmr_base.get_model()
+ model = model.eval()
+ model_jit = torch.jit.script(model)
+ model_input = torch.tensor([[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]])
+ actual = model_jit(model_input)
+ expected = torch.load(asset_path)
+ torch.testing.assert_close(actual, expected)
+
+ def test_xlmr_transform(self):
+ xlmr_base = torchtext.models.XLMR_BASE_ENCODER
+ transform = xlmr_base.transform()
+ test_text = "XLMR base Model Comparison"
+ actual = transform([test_text])
+ expected = [[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]]
+ torch.testing.assert_close(actual, expected)
+
+ def test_xlmr_transform_jit(self):
+ xlmr_base = torchtext.models.XLMR_BASE_ENCODER
+ transform = xlmr_base.transform()
+ transform_jit = torch.jit.script(transform)
+ test_text = "XLMR base Model Comparison"
+ actual = transform_jit([test_text])
+ expected = [[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]]
+ torch.testing.assert_close(actual, expected)
diff --git a/test/test_functional.py b/test/test_functional.py
new file mode 100644
index 0000000000..6d71dd71aa
--- /dev/null
+++ b/test/test_functional.py
@@ -0,0 +1,58 @@
+import torch
+from torchtext import functional
+from .common.torchtext_test_case import TorchtextTestCase
+
+
+class TestFunctional(TorchtextTestCase):
+ def test_to_tensor(self):
+ input = [[1, 2], [1, 2, 3]]
+ padding_value = 0
+ actual = functional.to_tensor(input, padding_value=padding_value)
+ expected = torch.tensor([[1, 2, 0], [1, 2, 3]], dtype=torch.long)
+ torch.testing.assert_close(actual, expected)
+
+ def test_to_tensor_jit(self):
+ input = [[1, 2], [1, 2, 3]]
+ padding_value = 0
+ to_tensor_jit = torch.jit.script(functional.to_tensor)
+ actual = to_tensor_jit(input, padding_value=padding_value)
+ expected = torch.tensor([[1, 2, 0], [1, 2, 3]], dtype=torch.long)
+ torch.testing.assert_close(actual, expected)
+
+ def test_truncate(self):
+ input = [[1, 2], [1, 2, 3]]
+ max_seq_len = 2
+ actual = functional.truncate(input, max_seq_len=max_seq_len)
+ expected = [[1, 2], [1, 2]]
+ self.assertEqual(actual, expected)
+
+ def test_truncate_jit(self):
+ input = [[1, 2], [1, 2, 3]]
+ max_seq_len = 2
+ truncate_jit = torch.jit.script(functional.truncate)
+ actual = truncate_jit(input, max_seq_len=max_seq_len)
+ expected = [[1, 2], [1, 2]]
+ self.assertEqual(actual, expected)
+
+ def test_add_token(self):
+ input = [[1, 2], [1, 2, 3]]
+ token_id = 0
+ actual = functional.add_token(input, token_id=token_id)
+ expected = [[0, 1, 2], [0, 1, 2, 3]]
+ self.assertEqual(actual, expected)
+
+ actual = functional.add_token(input, token_id=token_id, begin=False)
+ expected = [[1, 2, 0], [1, 2, 3, 0]]
+ self.assertEqual(actual, expected)
+
+ def test_add_token_jit(self):
+ input = [[1, 2], [1, 2, 3]]
+ token_id = 0
+ add_token_jit = torch.jit.script(functional.add_token)
+ actual = add_token_jit(input, token_id=token_id)
+ expected = [[0, 1, 2], [0, 1, 2, 3]]
+ self.assertEqual(actual, expected)
+
+ actual = add_token_jit(input, token_id=token_id, begin=False)
+ expected = [[1, 2, 0], [1, 2, 3, 0]]
+ self.assertEqual(actual, expected)
diff --git a/test/test_transforms.py b/test/test_transforms.py
new file mode 100644
index 0000000000..ab941fb39c
--- /dev/null
+++ b/test/test_transforms.py
@@ -0,0 +1,33 @@
+import torch
+from torchtext import transforms
+from torchtext.vocab import vocab
+from collections import OrderedDict
+
+from .common.torchtext_test_case import TorchtextTestCase
+from .common.assets import get_asset_path
+
+
+class TestTransforms(TorchtextTestCase):
+ def test_spmtokenizer_transform(self):
+ asset_name = "spm_example.model"
+ asset_path = get_asset_path(asset_name)
+ transform = transforms.SpmTokenizerTransform(asset_path)
+ actual = transform(["Hello World!, how are you?"])
+ expected = [['▁Hello', '▁World', '!', ',', '▁how', '▁are', '▁you', '?']]
+ self.assertEqual(actual, expected)
+
+ def test_spmtokenizer_transform_jit(self):
+ asset_name = "spm_example.model"
+ asset_path = get_asset_path(asset_name)
+ transform = transforms.SpmTokenizerTransform(asset_path)
+ transform_jit = torch.jit.script(transform)
+ actual = transform_jit(["Hello World!, how are you?"])
+ expected = [['▁Hello', '▁World', '!', ',', '▁how', '▁are', '▁you', '?']]
+ self.assertEqual(actual, expected)
+
+ def test_vocab_transform(self):
+ vocab_obj = vocab(OrderedDict([('a', 1), ('b', 1), ('c', 1)]))
+ transform = transforms.VocabTransform(vocab_obj)
+ actual = transform([['a', 'b', 'c']])
+ expected = [[0, 1, 2]]
+ self.assertEqual(actual, expected)
diff --git a/torchtext/__init__.py b/torchtext/__init__.py
index 8e1a92feab..a5c2802614 100644
--- a/torchtext/__init__.py
+++ b/torchtext/__init__.py
@@ -1,8 +1,15 @@
+import os
+_TEXT_BUCKET = 'https://download.pytorch.org/models/text'
+_CACHE_DIR = os.path.expanduser('~/.torchtext/cache')
+
from . import data
from . import nn
from . import datasets
from . import utils
from . import vocab
+from . import transforms
+from . import functional
+from . import models
from . import experimental
from . import legacy
from ._extension import _init_extension
@@ -18,6 +25,9 @@
'datasets',
'utils',
'vocab',
+ 'transforms',
+ 'functional',
+ 'models',
'experimental',
'legacy']
diff --git a/torchtext/data/datasets_utils.py b/torchtext/data/datasets_utils.py
index 571b43c479..12297931f9 100644
--- a/torchtext/data/datasets_utils.py
+++ b/torchtext/data/datasets_utils.py
@@ -15,6 +15,8 @@
import defusedxml.ElementTree as ET
except ImportError:
import xml.etree.ElementTree as ET
+
+from torchtext import _CACHE_DIR
"""
These functions and classes are meant solely for use in torchtext.datasets and not
for public consumption yet.
@@ -213,7 +215,7 @@ def _wrap_split_argument_with_fn(fn, splits):
raise ValueError("Internal Error: Given function {} did not adhere to standard signature.".format(fn))
@functools.wraps(fn)
- def new_fn(root=os.path.expanduser('~/.torchtext/cache'), split=splits, **kwargs):
+ def new_fn(root=_CACHE_DIR, split=splits, **kwargs):
result = []
for item in _check_default_set(split, splits, fn.__name__):
result.append(fn(root, item, **kwargs))
@@ -250,7 +252,7 @@ def decorator(func):
raise ValueError("Internal Error: Given function {} did not adhere to standard signature.".format(fn))
@functools.wraps(func)
- def wrapper(root=os.path.expanduser('~/.torchtext/cache'), *args, **kwargs):
+ def wrapper(root=_CACHE_DIR, *args, **kwargs):
new_root = os.path.join(root, dataset_name)
if not os.path.exists(new_root):
os.makedirs(new_root)
diff --git a/torchtext/functional.py b/torchtext/functional.py
new file mode 100644
index 0000000000..9231c9644e
--- /dev/null
+++ b/torchtext/functional.py
@@ -0,0 +1,45 @@
+import torch
+from torch import Tensor
+from torch.nn.utils.rnn import pad_sequence
+from typing import List, Optional
+
+__all__ = [
+ 'to_tensor',
+ 'truncate',
+ 'add_token',
+]
+
+
+def to_tensor(input: List[List[int]], padding_value: Optional[int] = None) -> Tensor:
+ if padding_value is None:
+ output = torch.tensor(input, dtype=torch.long)
+ return output
+ else:
+ output = pad_sequence(
+ [torch.tensor(ids, dtype=torch.long) for ids in input],
+ batch_first=True,
+ padding_value=float(padding_value)
+ )
+ return output
+
+
+def truncate(input: List[List[int]], max_seq_len: int) -> List[List[int]]:
+ output: List[List[int]] = []
+
+ for ids in input:
+ output.append(ids[:max_seq_len])
+
+ return output
+
+
+def add_token(input: List[List[int]], token_id: int, begin: bool = True) -> List[List[int]]:
+ output: List[List[int]] = []
+
+ if begin:
+ for ids in input:
+ output.append([token_id] + ids)
+ else:
+ for ids in input:
+ output.append(ids + [token_id])
+
+ return output
diff --git a/torchtext/models/__init__.py b/torchtext/models/__init__.py
new file mode 100644
index 0000000000..a7cbc0c88a
--- /dev/null
+++ b/torchtext/models/__init__.py
@@ -0,0 +1 @@
+from .roberta import * # noqa: F401, F403
diff --git a/torchtext/models/roberta/__init__.py b/torchtext/models/roberta/__init__.py
new file mode 100644
index 0000000000..f03218844b
--- /dev/null
+++ b/torchtext/models/roberta/__init__.py
@@ -0,0 +1,18 @@
+from .model import (
+ RobertaEncoderParams,
+ RobertaClassificationHead,
+)
+
+from .bundler import (
+ RobertaModelBundle,
+ XLMR_BASE_ENCODER,
+ XLMR_LARGE_ENCODER,
+)
+
+__all__ = [
+ "RobertaEncoderParams",
+ "RobertaClassificationHead",
+ "RobertaModelBundle",
+ "XLMR_BASE_ENCODER",
+ "XLMR_LARGE_ENCODER",
+]
diff --git a/torchtext/models/roberta/bundler.py b/torchtext/models/roberta/bundler.py
new file mode 100644
index 0000000000..b7db892db9
--- /dev/null
+++ b/torchtext/models/roberta/bundler.py
@@ -0,0 +1,99 @@
+
+import os
+from dataclasses import dataclass
+from functools import partial
+
+from typing import Optional, Callable
+from torch.hub import load_state_dict_from_url
+from torch.nn import Module
+import logging
+
+logger = logging.getLogger(__name__)
+
+from .model import (
+ RobertaEncoderParams,
+ RobertaModel,
+ _get_model,
+)
+
+from .transforms import get_xlmr_transform
+
+from torchtext import _TEXT_BUCKET
+
+
+@dataclass
+class RobertaModelBundle:
+ """
+ Example - Pretrained encoder
+ >>> import torch, torchtext
+ >>> xlmr_base = torchtext.models.XLMR_BASE_ENCODER
+ >>> model = xlmr_base.get_model()
+ >>> transform = xlmr_base.transform()
+ >>> model_input = torch.tensor(transform(["Hello World"]))
+ >>> output = model(model_input)
+ >>> output.shape
+ torch.Size([1, 4, 768])
+ >>> input_batch = ["Hello world", "How are you!"]
+ >>> from torchtext.functional import to_tensor
+ >>> model_input = to_tensor(transform(input_batch), padding_value=transform.pad_idx)
+ >>> output = model(model_input)
+ >>> output.shape
+ torch.Size([2, 6, 768])
+
+ Example - Pretrained encoder attached to un-initialized classification head
+ >>> import torch, torchtext
+ >>> xlmr_large = torchtext.models.XLMR_LARGE_ENCODER
+ >>> classifier_head = torchtext.models.RobertaClassificationHead(num_classes=2, input_dim = xlmr_large.params.embedding_dim)
+ >>> classification_model = xlmr_large.get_model(head=classifier_head)
+ >>> transform = xlmr_large.transform()
+ >>> model_input = torch.tensor(transform(["Hello World"]))
+ >>> output = classification_model(model_input)
+ >>> output.shape
+ torch.Size([1, 2])
+ """
+ _params: RobertaEncoderParams
+ _path: Optional[str] = None
+ _head: Optional[Module] = None
+ transform: Optional[Callable] = None
+
+ def get_model(self, head: Optional[Module] = None, *, dl_kwargs=None) -> RobertaModel:
+
+ if head is not None:
+ input_head = head
+ if self._head is not None:
+ logger.log("A custom head module was provided, discarding the default head module.")
+ else:
+ input_head = self._head
+
+ model = _get_model(self._params, input_head)
+
+ dl_kwargs = {} if dl_kwargs is None else dl_kwargs
+ state_dict = load_state_dict_from_url(self._path, **dl_kwargs)
+ if input_head is not None:
+ model.load_state_dict(state_dict, strict=False)
+ else:
+ model.load_state_dict(state_dict, strict=True)
+ return model
+
+ @property
+ def params(self) -> RobertaEncoderParams:
+ return self._params
+
+
+XLMR_BASE_ENCODER = RobertaModelBundle(
+ _path=os.path.join(_TEXT_BUCKET, "xlmr.base.encoder.pt"),
+ _params=RobertaEncoderParams(vocab_size=250002),
+ transform=partial(get_xlmr_transform,
+ vocab_path=os.path.join(_TEXT_BUCKET, "xlmr.vocab.pt"),
+ spm_model_path=os.path.join(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model"),
+ )
+)
+
+XLMR_LARGE_ENCODER = RobertaModelBundle(
+ _path=os.path.join(_TEXT_BUCKET, "xlmr.large.encoder.pt"),
+ _params=RobertaEncoderParams(vocab_size=250002, embedding_dim=1024, ffn_dimension=4096, num_attention_heads=16, num_encoder_layers=24),
+ transform=partial(get_xlmr_transform,
+ vocab_path=os.path.join(_TEXT_BUCKET, "xlmr.vocab.pt"),
+ spm_model_path=os.path.join(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model"),
+ )
+)
diff --git a/torchtext/models/roberta/model.py b/torchtext/models/roberta/model.py
new file mode 100644
index 0000000000..0656e647f8
--- /dev/null
+++ b/torchtext/models/roberta/model.py
@@ -0,0 +1,108 @@
+import math
+
+from dataclasses import dataclass, asdict
+from typing import Optional
+
+from torch.nn import Module
+import torch
+from torch import Tensor
+import torch.nn as nn
+
+from .modules import (
+ TransformerEncoder,
+)
+
+
+@dataclass
+class RobertaEncoderParams:
+ vocab_size: int = 50265
+ embedding_dim: int = 768
+ ffn_dimension: int = 3072
+ padding_idx: int = 1
+ max_seq_len: int = 514
+ num_attention_heads: int = 12
+ num_encoder_layers: int = 12
+ dropout: float = 0.1
+ scaling: Optional[float] = None
+ normalize_before: bool = False
+
+
+class RobertaEncoder(Module):
+ def __init__(
+ self,
+ vocab_size: int,
+ embedding_dim: int,
+ ffn_dimension: int,
+ padding_idx: int,
+ max_seq_len: int,
+ num_attention_heads: int,
+ num_encoder_layers: int,
+ dropout: float = 0.1,
+ scaling: Optional[float] = None,
+ normalize_before: bool = False,
+ ):
+ super().__init__()
+ if not scaling:
+ head_dim = embedding_dim // num_attention_heads
+ scaling = 1.0 / math.sqrt(head_dim)
+
+ self.transformer = TransformerEncoder(
+ vocab_size=vocab_size,
+ embedding_dim=embedding_dim,
+ padding_idx=padding_idx,
+ max_seq_len=max_seq_len,
+ ffn_dimension=ffn_dimension,
+ num_encoder_layers=num_encoder_layers,
+ num_attention_heads=num_attention_heads,
+ dropout=dropout,
+ normalize_before=normalize_before,
+ scaling=scaling,
+ )
+
+ def forward(self, tokens: Tensor, mask: Optional[Tensor] = None) -> Tensor:
+ all_layers = self.transformer(tokens)
+ last_layer = all_layers[-1].transpose(1, 0)
+ if mask is not None:
+ last_layer = last_layer[mask.to(torch.bool), :]
+ return last_layer
+
+
+# TODO: Add Missing quant noise and spectral norm from latest Roberta head in fairseq repo
+class RobertaClassificationHead(nn.Module):
+ def __init__(self, num_classes, input_dim, inner_dim: Optional[int] = None, dropout: float = 0.1, activation=nn.ReLU):
+ super().__init__()
+ if not inner_dim:
+ inner_dim = input_dim
+ self.dense = nn.Linear(input_dim, inner_dim)
+ self.dropout = nn.Dropout(dropout)
+ self.out_proj = nn.Linear(inner_dim, num_classes)
+ self.activation_fn = activation()
+
+ def forward(self, features):
+ x = features[:, 0, :]
+ x = self.dropout(x)
+ x = self.dense(x)
+ x = self.activation_fn(x)
+ x = self.dropout(x)
+ x = self.out_proj(x)
+ return x
+
+
+class RobertaModel(Module):
+ def __init__(self, encoder: Module, head: Optional[Module] = None):
+ super().__init__()
+ self.encoder = encoder
+ self.head = head
+
+ def forward(self, tokens: Tensor) -> Tensor:
+ features = self.encoder(tokens)
+ if self.head is None:
+ return features
+
+ x = self.head(features)
+ return x
+
+
+def _get_model(params: RobertaEncoderParams, head: Module) -> RobertaModel:
+ encoder = RobertaEncoder(**asdict(params))
+ return RobertaModel(encoder, head)
diff --git a/torchtext/models/roberta/modules.py b/torchtext/models/roberta/modules.py
new file mode 100644
index 0000000000..5f0fc01d56
--- /dev/null
+++ b/torchtext/models/roberta/modules.py
@@ -0,0 +1,296 @@
+from typing import Optional, List
+
+import torch
+from torch import nn
+from torch.nn import Module
+
+import math
+
+from torch.nn import functional as F
+
+
+class PositionalEmbedding(Module):
+ def __init__(
+ self, num_embeddings: int, embedding_dim: int, pad_index: int
+ ):
+ super().__init__()
+ self.embedding = nn.Embedding(num_embeddings, embedding_dim, pad_index)
+ self.pad_index = pad_index
+
+ def forward(self, input):
+ positions = self._make_positions(input, self.pad_index)
+ return self.embedding(positions)
+
+ def max_positions(self):
+ if self.pad_index is not None:
+ return self.num_embeddings - self.pad_index - 1
+ else:
+ return self.num_embeddings
+
+ def _make_positions(self, tensor, pad_index: int):
+ masked = tensor.ne(pad_index).long()
+ return torch.cumsum(masked, dim=1) * masked + pad_index
+
+
+class ResidualMLP(Module):
+ def __init__(
+ self,
+ input_dim: int,
+ hidden_dims: List[int],
+ dropout: float = 0.1,
+ activation=nn.GELU,
+ add_residual=True,
+ ):
+ super().__init__()
+ modules = []
+ for last_dim, dim in zip([input_dim] + hidden_dims, hidden_dims):
+ modules.extend(
+ [nn.Linear(last_dim, dim), activation(), nn.Dropout(dropout)]
+ )
+
+ last_dim = hidden_dims[-1] if hidden_dims else input_dim
+ modules.extend([nn.Linear(last_dim, input_dim), nn.Dropout(dropout)])
+
+ self.mlp = nn.Sequential(*modules)
+ self.add_residual = add_residual
+
+ def forward(self, input):
+ bias = self.mlp(input)
+ if not hasattr(self, "add_residual"):
+ self.add_residual = True
+ if self.add_residual:
+ return input + bias
+ else:
+ return bias
+
+
+class MultiheadSelfAttention(Module):
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ scaling: Optional[float] = None,
+ dropout: float = 0.1,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.head_dim = embed_dim // num_heads
+
+ expected_scaling = float(1 / math.sqrt(self.head_dim))
+
+ if not scaling and self.head_dim == 64:
+ scaling = 0.125
+
+ if not scaling:
+ raise Exception(
+ f"""
+ Scaling not set. Please manually set scaling for transformers with
+ head_dim != 64. The suggested value in this case is {expected_scaling},
+ or float(1 / math.sqrt(head_dim))
+ where head_dim = embed_dim // num_heads = {self.head_dim}
+ and embed_dim = {embed_dim} and num_heads = {num_heads}.
+ """
+ )
+
+ self.scaling = scaling
+ self.dropout = nn.Dropout(dropout)
+ self.input_projection = nn.Linear(embed_dim, 3 * embed_dim)
+ self.output_projection = nn.Linear(embed_dim, embed_dim)
+
+ def forward(self, query, key_padding_mask):
+ target_length, batch_size, embed_dim = query.size()
+ mask_batch_size, source_length = key_padding_mask.size()
+
+ torch._assert(embed_dim == self.embed_dim, "query embed dim doesn't match")
+ torch._assert(
+ batch_size == mask_batch_size,
+ "query and key_padding_mask batch sizes differed",
+ )
+
+ projection = self.input_projection(query)
+ q, k, v = projection.chunk(3, dim=-1)
+ q = self.scaling * q
+
+ batch_heads = batch_size * self.num_heads
+
+ q = q.contiguous().view(-1, batch_heads, self.head_dim).transpose(0, 1)
+ k = k.contiguous().view(-1, batch_heads, self.head_dim).transpose(0, 1)
+ v = v.contiguous().view(-1, batch_heads, self.head_dim).transpose(0, 1)
+
+ torch._assert(
+ k.size(1) == source_length, "key size should be equal to source length"
+ )
+
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
+
+ torch._assert(attn_weights.dim() == 3, "Unexpected attn_weights dim")
+ torch._assert(
+ attn_weights.size(0) == batch_heads,
+ "attn_weights shape didn't match for batch heads",
+ )
+ torch._assert(
+ attn_weights.size(1) == target_length,
+ "attn_weights shape didn't match for target length",
+ )
+ torch._assert(
+ attn_weights.size(2) == source_length,
+ "attn_weights shape didn't match for source length",
+ )
+
+ attn_weights = attn_weights.view(
+ batch_size, self.num_heads, target_length, source_length
+ )
+ attn_weights = attn_weights.masked_fill(
+ key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf")
+ )
+ attn_weights = attn_weights.view(batch_heads, target_length, source_length)
+
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(
+ attn_weights
+ )
+ attn_weights = self.dropout(attn_weights)
+
+ attn = torch.bmm(attn_weights, v)
+
+ torch._assert(
+ attn.dim() == 3,
+ "unexpected attn dim size",
+ )
+ torch._assert(
+ attn.size(0) == batch_heads,
+ "attn shape didn't match for batch heads",
+ )
+ torch._assert(
+ attn.size(1) == target_length,
+ "attn shape didn't match for target length",
+ )
+ torch._assert(
+ attn.size(2) == self.head_dim,
+ "attn shape didn't match for head dim",
+ )
+ attn = (
+ attn.transpose(0, 1)
+ .contiguous()
+ .view(target_length, batch_size, self.head_dim * self.num_heads)
+ )
+ attn = self.output_projection(attn)
+
+ return attn
+
+
+class TransformerEncoderLayer(Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ num_attention_heads: int,
+ ffn_dimension: Optional[int] = None,
+ dropout: float = 0.1,
+ normalize_before: bool = False,
+ scaling: Optional[float] = None,
+ ):
+ super().__init__()
+ self.dropout = nn.Dropout(dropout)
+ self.attention = MultiheadSelfAttention(
+ embedding_dim,
+ num_heads=num_attention_heads,
+ scaling=scaling,
+ dropout=dropout,
+ )
+
+ self.residual_mlp = ResidualMLP(
+ embedding_dim,
+ hidden_dims=[ffn_dimension or embedding_dim * 4],
+ add_residual=not normalize_before,
+ )
+
+ self.attention_layer_norm = nn.LayerNorm(embedding_dim)
+ self.final_layer_norm = nn.LayerNorm(embedding_dim)
+ self.normalize_before = normalize_before
+
+ def forward(self, input, key_padding_mask):
+ if not hasattr(self, "normalize_before"):
+ self.normalize_before = False
+ if self.normalize_before:
+ x = self.attention_layer_norm(input)
+ attention = self.attention(x, key_padding_mask)
+ attention = self.dropout(attention)
+ biased_input = input + attention
+ x = self.final_layer_norm(biased_input)
+ return self.residual_mlp(x) + biased_input
+ else:
+ attention = self.attention(input, key_padding_mask)
+ attention = self.dropout(attention)
+ biased_input = input + attention
+ biased_input = self.attention_layer_norm(biased_input)
+ biased = self.residual_mlp(biased_input)
+ return self.final_layer_norm(biased)
+
+
+class TransformerEncoder(Module):
+ def __init__(
+ self,
+ vocab_size: int,
+ embedding_dim: int,
+ padding_idx: int,
+ max_seq_len: int,
+ num_encoder_layers: int,
+ num_attention_heads: int,
+ ffn_dimension: Optional[int] = None,
+ dropout: float = 0.1,
+ normalize_before: bool = False,
+ scaling: Optional[float] = None,
+ ):
+ super().__init__()
+ self.padding_idx = padding_idx
+ self.token_embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx)
+ self.layers = nn.ModuleList(
+ [
+ TransformerEncoderLayer(
+ embedding_dim=embedding_dim,
+ num_attention_heads=num_attention_heads,
+ ffn_dimension=ffn_dimension,
+ dropout=dropout,
+ normalize_before=normalize_before,
+ scaling=scaling,
+ )
+ for _ in range(num_encoder_layers)
+ ]
+ )
+ self.positional_embedding = PositionalEmbedding(
+ max_seq_len, embedding_dim, padding_idx
+ )
+ self.embedding_layer_norm = nn.LayerNorm(embedding_dim)
+ self.dropout = nn.Dropout(dropout)
+ self.normalize_before = normalize_before
+
+ def forward(self, tokens: torch.Tensor) -> List[torch.Tensor]:
+ padding_mask = tokens.eq(self.padding_idx)
+
+ token_embeddings = self.token_embedding(tokens)
+ embedded_positions = self.positional_embedding(tokens)
+
+ embedded = token_embeddings + embedded_positions
+
+ if not hasattr(self, "normalize_before"):
+ self.normalize_before = False
+ if not self.normalize_before:
+ embedded = self.embedding_layer_norm(embedded)
+ embedded = self.dropout(embedded)
+
+ padded_embedded = embedded * (1 - padding_mask.unsqueeze(-1).type_as(embedded))
+
+ encoded = padded_embedded.transpose(0, 1)
+
+ states = [encoded]
+
+ for layer in self.layers:
+ encoded = layer(encoded, padding_mask)
+ states.append(encoded)
+
+ if self.normalize_before:
+ for i, state in enumerate(states):
+ states[i] = self.embedding_layer_norm(state)
+
+ # states are returned as T x B x C
+ return states
diff --git a/torchtext/models/roberta/transforms.py b/torchtext/models/roberta/transforms.py
new file mode 100644
index 0000000000..febf6a858b
--- /dev/null
+++ b/torchtext/models/roberta/transforms.py
@@ -0,0 +1,66 @@
+import os
+import torch
+from torch.nn import Module
+from torch.hub import load_state_dict_from_url
+from torchtext import transforms
+from torchtext import functional
+
+from typing import List
+
+
+class XLMRobertaModelTransform(Module):
+ def __init__(
+ self,
+ vocab_path: str,
+ spm_model_path: str,
+ bos_token: str = "",
+ cls_token: str = "",
+ pad_token: str = "",
+ eos_token: str = "",
+ sep_token: str = "",
+ unk_token: str = "",
+ mask_token: str = "",
+ max_seq_len: int = 514,
+ ):
+ super().__init__()
+ self.bos_token = bos_token
+ self.eos_token = eos_token
+ self.pad_token = pad_token
+ self.unk_token = unk_token
+ self.mask_token = mask_token
+ self.cls_token = cls_token
+ self.sep_token = sep_token
+ self.max_seq_len = max_seq_len
+
+ self.token_transform = transforms.SpmTokenizerTransform(spm_model_path)
+
+ if os.path.exists(vocab_path):
+ self.vocab = torch.load(vocab_path)
+ else:
+ self.vocab = load_state_dict_from_url(vocab_path)
+
+ self.vocab_transform = transforms.VocabTransform(self.vocab)
+ self.pad_idx = self.vocab[self.pad_token]
+ self.bos_idx = self.vocab[self.bos_token]
+ self.eos_idx = self.vocab[self.eos_token]
+
+ def forward(self, input: List[str],
+ add_bos: bool = True,
+ add_eos: bool = True,
+ truncate: bool = True) -> List[List[int]]:
+ tokens: List[List[int]] = self.vocab_transform(self.token_transform(input))
+
+ if truncate:
+ tokens = functional.truncate(tokens, self.max_seq_len - 2)
+
+ if add_bos:
+ tokens = functional.add_token(tokens, self.bos_idx)
+
+ if add_eos:
+ tokens = functional.add_token(tokens, self.eos_idx, begin=False)
+
+ return tokens
+
+
+def get_xlmr_transform(vocab_path, spm_model_path, **kwargs) -> XLMRobertaModelTransform:
+ return XLMRobertaModelTransform(vocab_path, spm_model_path, **kwargs)
diff --git a/torchtext/transforms.py b/torchtext/transforms.py
new file mode 100644
index 0000000000..ece8ebbc80
--- /dev/null
+++ b/torchtext/transforms.py
@@ -0,0 +1,75 @@
+from torch.nn import Module
+from torchtext.data.functional import load_sp_model
+from torchtext.utils import download_from_url
+import torchtext
+from typing import List
+import os
+
+from torchtext import _CACHE_DIR
+
+__all__ = [
+ 'SpmTokenizerTransform',
+ 'VocabTransform',
+]
+
+
+class SpmTokenizerTransform(Module):
+ """
+ Transform for Sentence Piece tokenizer.
+
+ Examples:
+ >>> from torchtext.transforms import PRETRAINED_SP_MODEL
+ >>> from torchtext.transforms import SpmTokenizerTransform
+ >>> transform = SpmTokenizerTransform(PRETRAINED_SP_MODEL["text_unigram_15000"])
+ >>> transform(["hello world", "attention is all you need!"])
+ """
+
+ def __init__(self, sp_model_path: str):
+ super().__init__()
+ if os.path.exists(sp_model_path):
+ local_path = sp_model_path
+ else:
+ local_path = download_from_url(url=sp_model_path, root=_CACHE_DIR)
+ self.sp_model = load_sp_model(local_path)
+
+ def forward(self, input: List[str]) -> List[List[str]]:
+ tokens: List[List[str]] = []
+ for text in input:
+ tokens.append(self.sp_model.EncodeAsPieces(text))
+ return tokens
+
+
+class VocabTransform(Module):
+ r"""Vocab transform
+
+ Args:
+ vocab: an instance of torchtext.vocab.Vocab class.
+
+ Example:
+ >>> import torch
+ >>> from torchtext.vocab import vocab
+ >>> from torchtext.transforms import VocabTransform
+ >>> from collections import OrderedDict
+ >>> vocab_obj = vocab(OrderedDict([('a', 1), ('b', 1), ('c', 1)]))
+ >>> vocab_transform = VocabTransform(vocab_obj)
+ >>> output = vocab_transform([['a','b'],['a','b','c']])
+ >>> jit_vocab_transform = torch.jit.script(vocab_transform)
+ """
+
+ def __init__(self, vocab):
+ super().__init__()
+ assert isinstance(vocab, torchtext.vocab.Vocab)
+ self.vocab = vocab
+
+ def forward(self, input: List[List[str]]) -> List[List[int]]:
+ r"""
+
+ Args:
+ input: list of list tokens
+ """
+
+ output: List[List[int]] = []
+ for tokens in input:
+ output.append(self.vocab.lookup_indices(tokens))
+
+ return output