diff --git a/test/models/test_models.py b/test/models/test_models.py index 91b3eb999a..488b9fc561 100644 --- a/test/models/test_models.py +++ b/test/models/test_models.py @@ -1,6 +1,5 @@ import torchtext import torch - from ..common.torchtext_test_case import TorchtextTestCase from ..common.assets import get_asset_path @@ -91,3 +90,39 @@ def test_xlmr_transform_jit(self): actual = transform_jit([test_text]) expected = [[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]] torch.testing.assert_close(actual, expected) + + def test_roberta_bundler_from_config(self): + from torchtext.models import RobertaEncoderConf, RobertaClassificationHead, RobertaModel, RobertaModelBundle + dummy_encoder_conf = RobertaEncoderConf(vocab_size=10, embedding_dim=16, ffn_dimension=64, num_attention_heads=2, num_encoder_layers=2) + + # case: user provide encoder checkpoint state dict + dummy_encoder = RobertaModel(dummy_encoder_conf) + model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf, + checkpoint=dummy_encoder.state_dict()) + self.assertEqual(model.state_dict(), dummy_encoder.state_dict()) + + # case: user provide classifier checkpoint state dict when head is given and override_head is False (by default) + dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16) + another_dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16) + dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head) + model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf, + head=another_dummy_classifier_head, + checkpoint=dummy_classifier.state_dict()) + self.assertEqual(model.state_dict(), dummy_classifier.state_dict()) + + # case: user provide classifier checkpoint state dict when head is given and override_head is set True + another_dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16) + model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf, + head=another_dummy_classifier_head, + checkpoint=dummy_classifier.state_dict(), + override_head=True) + self.assertEqual(model.head.state_dict(), another_dummy_classifier_head.state_dict()) + + # case: user provide only encoder checkpoint state dict when head is given + dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16) + dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head) + encoder_state_dict = {} + for k, v in dummy_classifier.encoder.state_dict().items(): + encoder_state_dict['encoder.' + k] = v + model = torchtext.models.RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf, head=dummy_classifier_head, checkpoint=encoder_state_dict) + self.assertEqual(model.state_dict(), dummy_classifier.state_dict()) diff --git a/torchtext/models/roberta/bundler.py b/torchtext/models/roberta/bundler.py index d0ba4b7028..d1b699e4c0 100644 --- a/torchtext/models/roberta/bundler.py +++ b/torchtext/models/roberta/bundler.py @@ -1,13 +1,13 @@ - from dataclasses import dataclass from functools import partial from urllib.parse import urljoin -from typing import Optional, Callable +from typing import Optional, Callable, Dict, Union, Any from torchtext._download_hooks import load_state_dict_from_url from torch.nn import Module +import torch import logging - +import re logger = logging.getLogger(__name__) from .model import ( @@ -21,6 +21,11 @@ from torchtext import _TEXT_BUCKET +def _is_head_available_in_checkpoint(checkpoint, head_state_dict): + # ensure all keys are present + return all(key in checkpoint.keys() for key in head_state_dict.keys()) + + @dataclass class RobertaModelBundle: """RobertaModelBundle(_params: torchtext.models.RobertaEncoderParams, _path: Optional[str] = None, _head: Optional[torch.nn.Module] = None, transform: Optional[Callable] = None) @@ -44,7 +49,7 @@ class RobertaModelBundle: Example - Pretrained encoder attached to un-initialized classification head >>> import torch, torchtext >>> xlmr_large = torchtext.models.XLMR_LARGE_ENCODER - >>> classifier_head = torchtext.models.RobertaClassificationHead(num_classes=2, input_dim = xlmr_large.params.embedding_dim) + >>> classifier_head = torchtext.models.RobertaClassificationHead(num_classes=2, input_dim = xlmr_large.encoderConf.embedding_dim) >>> classification_model = xlmr_large.get_model(head=classifier_head) >>> transform = xlmr_large.transform() >>> model_input = torch.tensor(transform(["Hello World"])) @@ -60,14 +65,28 @@ class RobertaModelBundle: >>> encoder = roberta_bundle.get_model() >>> classifier_head = RobertaClassificationHead(num_classes=2, input_dim=768) >>> classifier = roberta_bundle.get_model(head=classifier_head) + >>> # using from_config + >>> encoder = RobertaModelBundle.from_config(config=roberta_encoder_conf, checkpoint=model_weights_path) + >>> classifier = RobertaModelBundle.from_config(config=roberta_encoder_conf, head=classifier_head, checkpoint=model_weights_path) """ _encoder_conf: RobertaEncoderConf _path: Optional[str] = None _head: Optional[Module] = None transform: Optional[Callable] = None - def get_model(self, head: Optional[Module] = None, load_weights: bool = True, freeze_encoder: bool = False, *, dl_kwargs=None) -> RobertaModel: + def get_model(self, + head: Optional[Module] = None, + load_weights: bool = True, + freeze_encoder: bool = False, + *, + dl_kwargs=None) -> RobertaModel: r"""get_model(head: Optional[torch.nn.Module] = None, load_weights: bool = True, freeze_encoder: bool = False, *, dl_kwargs=None) -> torctext.models.RobertaModel + + Args: + head (nn.Module): A module to be attached to the encoder to perform specific task. If provided, it will replace the default member head (Default: ``None``) + load_weights (bool): Indicates whether or not to load weights if available. (Default: ``True``) + freeze_encoder (bool): Indicates whether or not to freeze the encoder weights. (Default: ``False``) + dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. (Default: ``None``) """ if load_weights: @@ -84,17 +103,53 @@ def get_model(self, head: Optional[Module] = None, load_weights: bool = True, fr else: input_head = self._head - model = _get_model(self._encoder_conf, input_head, freeze_encoder) - - if not load_weights: - return model + return RobertaModelBundle.from_config(encoder_conf=self._encoder_conf, + head=input_head, + freeze_encoder=freeze_encoder, + checkpoint=self._path, + override_head=True, + dl_kwargs=dl_kwargs) + + @classmethod + def from_config( + cls, + encoder_conf: RobertaEncoderConf, + head: Optional[Module] = None, + freeze_encoder: bool = False, + checkpoint: Optional[Union[str, Dict[str, torch.Tensor]]] = None, + *, + override_head: bool = False, + dl_kwargs: Dict[str, Any] = None, + ) -> RobertaModel: + """Class method to create model with user-defined encoder configuration and checkpoint + + Args: + encoder_conf (RobertaEncoderConf): An instance of class RobertaEncoderConf that defined the encoder configuration + head (nn.Module): A module to be attached to the encoder to perform specific task. (Default: ``None``) + freeze_encoder (bool): Indicates whether to freeze the encoder weights. (Default: ``False``) + checkpoint (str or Dict[str, torch.Tensor]): Path to or actual model state_dict. state_dict can have partial weights i.e only for encoder. (Default: ``None``) + override_head (bool): Override the checkpoint's head state dict (if present) with provided head state dict. (Default: ``False``) + dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. (Default: ``None``) + """ + model = _get_model(encoder_conf, head, freeze_encoder) + if checkpoint is not None: + if torch.jit.isinstance(checkpoint, Dict[str, torch.Tensor]): + state_dict = checkpoint + elif isinstance(checkpoint, str): + dl_kwargs = {} if dl_kwargs is None else dl_kwargs + state_dict = load_state_dict_from_url(checkpoint, **dl_kwargs) + else: + raise TypeError("checkpoint must be of type `str` or `Dict[str, torch.Tensor]` but got {}".format(type(checkpoint))) + + if head is not None: + regex = re.compile(r"^head\.") + head_state_dict = {k: v for k, v in model.state_dict().items() if regex.findall(k)} + # If checkpoint does not contains head_state_dict, then we augment the checkpoint with user-provided head state_dict + if not _is_head_available_in_checkpoint(state_dict, head_state_dict) or override_head: + state_dict.update(head_state_dict) - dl_kwargs = {} if dl_kwargs is None else dl_kwargs - state_dict = load_state_dict_from_url(self._path, **dl_kwargs) - if input_head is not None: - model.load_state_dict(state_dict, strict=False) - else: model.load_state_dict(state_dict, strict=True) + return model @property