diff --git a/torchtext/models/roberta/bundler.py b/torchtext/models/roberta/bundler.py index 58774aacd5..7cd9e2c833 100644 --- a/torchtext/models/roberta/bundler.py +++ b/torchtext/models/roberta/bundler.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from functools import partial from urllib.parse import urljoin from typing import Optional, Callable, Dict, Union, Any @@ -15,7 +14,7 @@ RobertaModel, ) -from .transforms import get_xlmr_transform +import torchtext.transforms as T from torchtext import _TEXT_BUCKET @@ -156,10 +155,13 @@ def encoderConf(self) -> RobertaEncoderConf: XLMR_BASE_ENCODER = RobertaModelBundle( _path=urljoin(_TEXT_BUCKET, "xlmr.base.encoder.pt"), _encoder_conf=RobertaEncoderConf(vocab_size=250002), - transform=partial(get_xlmr_transform, - vocab_path=urljoin(_TEXT_BUCKET, "xlmr.vocab.pt"), - spm_model_path=urljoin(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model"), - ) + transform=lambda: T.Sequential( + T.SentencePieceTokenizer(urljoin(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model")), + T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "xlmr.vocab.pt"))), + T.Truncate(510), + T.AddToken(token=0, begin=True), + T.AddToken(token=2, begin=False), + ) ) XLMR_BASE_ENCODER.__doc__ = ( @@ -174,10 +176,13 @@ def encoderConf(self) -> RobertaEncoderConf: XLMR_LARGE_ENCODER = RobertaModelBundle( _path=urljoin(_TEXT_BUCKET, "xlmr.large.encoder.pt"), _encoder_conf=RobertaEncoderConf(vocab_size=250002, embedding_dim=1024, ffn_dimension=4096, num_attention_heads=16, num_encoder_layers=24), - transform=partial(get_xlmr_transform, - vocab_path=urljoin(_TEXT_BUCKET, "xlmr.vocab.pt"), - spm_model_path=urljoin(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model"), - ) + transform=lambda: T.Sequential( + T.SentencePieceTokenizer(urljoin(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model")), + T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "xlmr.vocab.pt"))), + T.Truncate(510), + T.AddToken(token=0, begin=True), + T.AddToken(token=2, begin=False), + ) ) XLMR_LARGE_ENCODER.__doc__ = ( diff --git a/torchtext/models/roberta/transforms.py b/torchtext/models/roberta/transforms.py deleted file mode 100644 index 683b6406be..0000000000 --- a/torchtext/models/roberta/transforms.py +++ /dev/null @@ -1,82 +0,0 @@ -import os -import torch -from torch.nn import Module -from torchtext._download_hooks import load_state_dict_from_url -from torchtext import transforms -from torchtext import functional - -from typing import List, Any - - -class XLMRobertaModelTransform(Module): - def __init__( - self, - vocab_path: str, - spm_model_path: str, - bos_token: str = "", - cls_token: str = "", - pad_token: str = "", - eos_token: str = "", - sep_token: str = "", - unk_token: str = "", - mask_token: str = "", - max_seq_len: int = 514, - ): - super().__init__() - self.bos_token = bos_token - self.eos_token = eos_token - self.pad_token = pad_token - self.unk_token = unk_token - self.mask_token = mask_token - self.cls_token = cls_token - self.sep_token = sep_token - self.max_seq_len = max_seq_len - - self.token_transform = transforms.SentencePieceTokenizer(spm_model_path) - - if os.path.exists(vocab_path): - self.vocab = torch.load(vocab_path) - else: - self.vocab = load_state_dict_from_url(vocab_path) - - self.vocab_transform = transforms.VocabTransform(self.vocab) - self.pad_idx = self.vocab[self.pad_token] - self.bos_idx = self.vocab[self.bos_token] - self.eos_idx = self.vocab[self.eos_token] - - def forward(self, input: Any, - add_bos: bool = True, - add_eos: bool = True, - truncate: bool = True) -> Any: - if torch.jit.isinstance(input, str): - tokens = self.vocab_transform(self.token_transform(input)) - - if truncate: - tokens = functional.truncate(tokens, self.max_seq_len - 2) - - if add_bos: - tokens = functional.add_token(tokens, self.bos_idx) - - if add_eos: - tokens = functional.add_token(tokens, self.eos_idx, begin=False) - - return tokens - elif torch.jit.isinstance(input, List[str]): - tokens = self.vocab_transform(self.token_transform(input)) - - if truncate: - tokens = functional.truncate(tokens, self.max_seq_len - 2) - - if add_bos: - tokens = functional.add_token(tokens, self.bos_idx) - - if add_eos: - tokens = functional.add_token(tokens, self.eos_idx, begin=False) - - return tokens - else: - raise TypeError("Input type not supported") - - -def get_xlmr_transform(vocab_path, spm_model_path, **kwargs) -> XLMRobertaModelTransform: - return XLMRobertaModelTransform(vocab_path, spm_model_path, **kwargs)