diff --git a/torchtext/models/roberta/bundler.py b/torchtext/models/roberta/bundler.py
index 58774aacd5..7cd9e2c833 100644
--- a/torchtext/models/roberta/bundler.py
+++ b/torchtext/models/roberta/bundler.py
@@ -1,5 +1,4 @@
from dataclasses import dataclass
-from functools import partial
from urllib.parse import urljoin
from typing import Optional, Callable, Dict, Union, Any
@@ -15,7 +14,7 @@
RobertaModel,
)
-from .transforms import get_xlmr_transform
+import torchtext.transforms as T
from torchtext import _TEXT_BUCKET
@@ -156,10 +155,13 @@ def encoderConf(self) -> RobertaEncoderConf:
XLMR_BASE_ENCODER = RobertaModelBundle(
_path=urljoin(_TEXT_BUCKET, "xlmr.base.encoder.pt"),
_encoder_conf=RobertaEncoderConf(vocab_size=250002),
- transform=partial(get_xlmr_transform,
- vocab_path=urljoin(_TEXT_BUCKET, "xlmr.vocab.pt"),
- spm_model_path=urljoin(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model"),
- )
+ transform=lambda: T.Sequential(
+ T.SentencePieceTokenizer(urljoin(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model")),
+ T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "xlmr.vocab.pt"))),
+ T.Truncate(510),
+ T.AddToken(token=0, begin=True),
+ T.AddToken(token=2, begin=False),
+ )
)
XLMR_BASE_ENCODER.__doc__ = (
@@ -174,10 +176,13 @@ def encoderConf(self) -> RobertaEncoderConf:
XLMR_LARGE_ENCODER = RobertaModelBundle(
_path=urljoin(_TEXT_BUCKET, "xlmr.large.encoder.pt"),
_encoder_conf=RobertaEncoderConf(vocab_size=250002, embedding_dim=1024, ffn_dimension=4096, num_attention_heads=16, num_encoder_layers=24),
- transform=partial(get_xlmr_transform,
- vocab_path=urljoin(_TEXT_BUCKET, "xlmr.vocab.pt"),
- spm_model_path=urljoin(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model"),
- )
+ transform=lambda: T.Sequential(
+ T.SentencePieceTokenizer(urljoin(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model")),
+ T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "xlmr.vocab.pt"))),
+ T.Truncate(510),
+ T.AddToken(token=0, begin=True),
+ T.AddToken(token=2, begin=False),
+ )
)
XLMR_LARGE_ENCODER.__doc__ = (
diff --git a/torchtext/models/roberta/transforms.py b/torchtext/models/roberta/transforms.py
deleted file mode 100644
index 683b6406be..0000000000
--- a/torchtext/models/roberta/transforms.py
+++ /dev/null
@@ -1,82 +0,0 @@
-import os
-import torch
-from torch.nn import Module
-from torchtext._download_hooks import load_state_dict_from_url
-from torchtext import transforms
-from torchtext import functional
-
-from typing import List, Any
-
-
-class XLMRobertaModelTransform(Module):
- def __init__(
- self,
- vocab_path: str,
- spm_model_path: str,
- bos_token: str = "",
- cls_token: str = "",
- pad_token: str = "",
- eos_token: str = "",
- sep_token: str = "",
- unk_token: str = "",
- mask_token: str = "",
- max_seq_len: int = 514,
- ):
- super().__init__()
- self.bos_token = bos_token
- self.eos_token = eos_token
- self.pad_token = pad_token
- self.unk_token = unk_token
- self.mask_token = mask_token
- self.cls_token = cls_token
- self.sep_token = sep_token
- self.max_seq_len = max_seq_len
-
- self.token_transform = transforms.SentencePieceTokenizer(spm_model_path)
-
- if os.path.exists(vocab_path):
- self.vocab = torch.load(vocab_path)
- else:
- self.vocab = load_state_dict_from_url(vocab_path)
-
- self.vocab_transform = transforms.VocabTransform(self.vocab)
- self.pad_idx = self.vocab[self.pad_token]
- self.bos_idx = self.vocab[self.bos_token]
- self.eos_idx = self.vocab[self.eos_token]
-
- def forward(self, input: Any,
- add_bos: bool = True,
- add_eos: bool = True,
- truncate: bool = True) -> Any:
- if torch.jit.isinstance(input, str):
- tokens = self.vocab_transform(self.token_transform(input))
-
- if truncate:
- tokens = functional.truncate(tokens, self.max_seq_len - 2)
-
- if add_bos:
- tokens = functional.add_token(tokens, self.bos_idx)
-
- if add_eos:
- tokens = functional.add_token(tokens, self.eos_idx, begin=False)
-
- return tokens
- elif torch.jit.isinstance(input, List[str]):
- tokens = self.vocab_transform(self.token_transform(input))
-
- if truncate:
- tokens = functional.truncate(tokens, self.max_seq_len - 2)
-
- if add_bos:
- tokens = functional.add_token(tokens, self.bos_idx)
-
- if add_eos:
- tokens = functional.add_token(tokens, self.eos_idx, begin=False)
-
- return tokens
- else:
- raise TypeError("Input type not supported")
-
-
-def get_xlmr_transform(vocab_path, spm_model_path, **kwargs) -> XLMRobertaModelTransform:
- return XLMRobertaModelTransform(vocab_path, spm_model_path, **kwargs)