diff --git a/torchtext/models/roberta/__init__.py b/torchtext/models/roberta/__init__.py index f03218844b..1057c6deb6 100644 --- a/torchtext/models/roberta/__init__.py +++ b/torchtext/models/roberta/__init__.py @@ -1,6 +1,7 @@ from .model import ( - RobertaEncoderParams, + RobertaEncoderConf, RobertaClassificationHead, + RobertaModel, ) from .bundler import ( @@ -10,8 +11,9 @@ ) __all__ = [ - "RobertaEncoderParams", + "RobertaEncoderConf", "RobertaClassificationHead", + "RobertaModel", "RobertaModelBundle", "XLMR_BASE_ENCODER", "XLMR_LARGE_ENCODER", diff --git a/torchtext/models/roberta/bundler.py b/torchtext/models/roberta/bundler.py index 186a5ff688..d26a2d71cc 100644 --- a/torchtext/models/roberta/bundler.py +++ b/torchtext/models/roberta/bundler.py @@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) from .model import ( - RobertaEncoderParams, + RobertaEncoderConf, RobertaModel, _get_model, ) @@ -50,8 +50,17 @@ class RobertaModelBundle: >>> output = classification_model(model_input) >>> output.shape torch.Size([1, 2]) + + Example - User-specified configuration and checkpoint + >>> from torchtext.models import RobertaEncoderConf, RobertaModelBundle, RobertaClassificationHead + >>> model_weights_path = "https://download.pytorch.org/models/text/xlmr.base.encoder.pt" + >>> roberta_encoder_conf = RobertaEncoderConf(vocab_size=250002) + >>> roberta_bundle = RobertaModelBundle(_encoder_conf=roberta_encoder_conf, _path=model_weights_path) + >>> encoder = roberta_bundle.get_model() + >>> classifier_head = RobertaClassificationHead(num_classes=2, input_dim=768) + >>> classifier = roberta_bundle.get_model(head=classifier_head) """ - _params: RobertaEncoderParams + _encoder_conf: RobertaEncoderConf _path: Optional[str] = None _head: Optional[Module] = None transform: Optional[Callable] = None @@ -72,7 +81,7 @@ def get_model(self, head: Optional[Module] = None, load_weights: bool = True, fr else: input_head = self._head - model = _get_model(self._params, input_head, freeze_encoder) + model = _get_model(self._encoder_conf, input_head, freeze_encoder) if not load_weights: return model @@ -86,13 +95,13 @@ def get_model(self, head: Optional[Module] = None, load_weights: bool = True, fr return model @property - def params(self) -> RobertaEncoderParams: - return self._params + def encoderConf(self) -> RobertaEncoderConf: + return self._encoder_conf XLMR_BASE_ENCODER = RobertaModelBundle( _path=os.path.join(_TEXT_BUCKET, "xlmr.base.encoder.pt"), - _params=RobertaEncoderParams(vocab_size=250002), + _encoder_conf=RobertaEncoderConf(vocab_size=250002), transform=partial(get_xlmr_transform, vocab_path=os.path.join(_TEXT_BUCKET, "xlmr.vocab.pt"), spm_model_path=os.path.join(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model"), @@ -101,7 +110,7 @@ def params(self) -> RobertaEncoderParams: XLMR_LARGE_ENCODER = RobertaModelBundle( _path=os.path.join(_TEXT_BUCKET, "xlmr.large.encoder.pt"), - _params=RobertaEncoderParams(vocab_size=250002, embedding_dim=1024, ffn_dimension=4096, num_attention_heads=16, num_encoder_layers=24), + _encoder_conf=RobertaEncoderConf(vocab_size=250002, embedding_dim=1024, ffn_dimension=4096, num_attention_heads=16, num_encoder_layers=24), transform=partial(get_xlmr_transform, vocab_path=os.path.join(_TEXT_BUCKET, "xlmr.vocab.pt"), spm_model_path=os.path.join(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model"), diff --git a/torchtext/models/roberta/model.py b/torchtext/models/roberta/model.py index 35d27fc79a..60a4362550 100644 --- a/torchtext/models/roberta/model.py +++ b/torchtext/models/roberta/model.py @@ -16,7 +16,7 @@ @dataclass -class RobertaEncoderParams: +class RobertaEncoderConf: vocab_size: int = 50265 embedding_dim: int = 768 ffn_dimension: int = 3072 @@ -62,6 +62,10 @@ def __init__( return_all_layers=False, ) + @classmethod + def from_config(cls, config: RobertaEncoderConf): + return cls(**asdict(config)) + def forward(self, tokens: Tensor, mask: Optional[Tensor] = None) -> Tensor: output = self.transformer(tokens) if torch.jit.isinstance(output, List[Tensor]): @@ -94,13 +98,31 @@ def forward(self, features): class RobertaModel(Module): - def __init__(self, encoder: Module, head: Optional[Module] = None): + """ + + Example - Instantiate model with user-specified configuration + >>> from torchtext.models import RobertaEncoderConf, RobertaModel, RobertaClassificationHead + >>> roberta_encoder_conf = RobertaEncoderConf(vocab_size=250002) + >>> encoder = RobertaModel(config=roberta_encoder_conf) + >>> classifier_head = RobertaClassificationHead(num_classes=2, input_dim=768) + >>> classifier = RobertaModel(config=roberta_encoder_conf, head=classifier_head) + """ + + def __init__(self, config: RobertaEncoderConf, head: Optional[Module] = None, freeze_encoder: bool = False): super().__init__() - self.encoder = encoder + assert isinstance(config, RobertaEncoderConf) + + self.encoder = RobertaEncoder.from_config(config) + if freeze_encoder: + for param in self.encoder.parameters(): + param.requires_grad = False + + logger.info("Encoder weights are frozen") + self.head = head - def forward(self, tokens: Tensor) -> Tensor: - features = self.encoder(tokens) + def forward(self, tokens: Tensor, mask: Optional[Tensor] = None) -> Tensor: + features = self.encoder(tokens, mask) if self.head is None: return features @@ -108,12 +130,5 @@ def forward(self, tokens: Tensor) -> Tensor: return x -def _get_model(params: RobertaEncoderParams, head: Optional[Module] = None, freeze_encoder: bool = False) -> RobertaModel: - encoder = RobertaEncoder(**asdict(params)) - if freeze_encoder: - for param in encoder.parameters(): - param.requires_grad = False - - logger.info("Encoder weights are frozen") - - return RobertaModel(encoder, head) +def _get_model(config: RobertaEncoderConf, head: Optional[Module] = None, freeze_encoder: bool = False) -> RobertaModel: + return RobertaModel(config, head, freeze_encoder)