From 4b301cd0a79145472c338f191b9982797c4486f5 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Thu, 11 Nov 2021 11:32:32 -0500 Subject: [PATCH 1/7] add from_config bundler method --- test/models/test_models.py | 14 ++++++++++++++ torchtext/models/roberta/bundler.py | 28 ++++++++++++++++++++++++++-- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/test/models/test_models.py b/test/models/test_models.py index 91b3eb999a..f1eca20d2e 100644 --- a/test/models/test_models.py +++ b/test/models/test_models.py @@ -1,6 +1,8 @@ import torchtext import torch +import os +from torchtext import _TEXT_BUCKET from ..common.torchtext_test_case import TorchtextTestCase from ..common.assets import get_asset_path @@ -91,3 +93,15 @@ def test_xlmr_transform_jit(self): actual = transform_jit([test_text]) expected = [[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]] torch.testing.assert_close(actual, expected) + + def test_roberta_bundler_from_config(self): + from torchtext.models import RobertaEncoderConf + asset_name = "xlmr.base.output.pt" + asset_path = get_asset_path(asset_name) + model_path = os.path.join(_TEXT_BUCKET, "xlmr.base.encoder.pt") + model = torchtext.models.RobertaModelBundle.from_config(config=RobertaEncoderConf(vocab_size=250002), path=model_path) + model = model.eval() + model_input = torch.tensor([[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]]) + actual = model(model_input) + expected = torch.load(asset_path) + torch.testing.assert_close(actual, expected) diff --git a/torchtext/models/roberta/bundler.py b/torchtext/models/roberta/bundler.py index d26a2d71cc..8f9ffb2bc0 100644 --- a/torchtext/models/roberta/bundler.py +++ b/torchtext/models/roberta/bundler.py @@ -1,4 +1,3 @@ - import os from dataclasses import dataclass from functools import partial @@ -65,7 +64,12 @@ class RobertaModelBundle: _head: Optional[Module] = None transform: Optional[Callable] = None - def get_model(self, head: Optional[Module] = None, load_weights: bool = True, freeze_encoder: bool = False, *, dl_kwargs=None) -> RobertaModel: + def get_model(self, + head: Optional[Module] = None, + load_weights: bool = True, + freeze_encoder: bool = False, + *, + dl_kwargs=None) -> RobertaModel: if load_weights: assert self._path is not None, "load_weights cannot be True. The pre-trained model weights are not available for the current object" @@ -94,6 +98,26 @@ def get_model(self, head: Optional[Module] = None, load_weights: bool = True, fr model.load_state_dict(state_dict, strict=True) return model + @classmethod + def from_config( + self, + config: RobertaEncoderConf, + head: Optional[Module] = None, + freeze_encoder: bool = False, + path: Optional[str] = None, + *, + dl_kwargs=None, + ) -> RobertaModel: + model = _get_model(config, head, freeze_encoder) + if path is not None: + dl_kwargs = {} if dl_kwargs is None else dl_kwargs + state_dict = load_state_dict_from_url(path, **dl_kwargs) + if head is not None: + model.load_state_dict(state_dict, strict=False) + else: + model.load_state_dict(state_dict, strict=True) + return model + @property def encoderConf(self) -> RobertaEncoderConf: return self._encoder_conf From f303d86c37441a6265b5ab5a4eb12586a4c92e14 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Fri, 12 Nov 2021 11:45:29 -0500 Subject: [PATCH 2/7] make amendments --- torchtext/models/roberta/bundler.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/torchtext/models/roberta/bundler.py b/torchtext/models/roberta/bundler.py index 10a1c2be93..c31fa0b02d 100644 --- a/torchtext/models/roberta/bundler.py +++ b/torchtext/models/roberta/bundler.py @@ -2,9 +2,10 @@ from functools import partial from urllib.parse import urljoin -from typing import Optional, Callable +from typing import Optional, Callable, Dict, Union from torchtext._download_hooks import load_state_dict_from_url from torch.nn import Module +import torch import logging logger = logging.getLogger(__name__) @@ -42,7 +43,7 @@ class RobertaModelBundle: Example - Pretrained encoder attached to un-initialized classification head >>> import torch, torchtext >>> xlmr_large = torchtext.models.XLMR_LARGE_ENCODER - >>> classifier_head = torchtext.models.RobertaClassificationHead(num_classes=2, input_dim = xlmr_large.params.embedding_dim) + >>> classifier_head = torchtext.models.RobertaClassificationHead(num_classes=2, input_dim = xlmr_large.encoderConf.embedding_dim) >>> classification_model = xlmr_large.get_model(head=classifier_head) >>> transform = xlmr_large.transform() >>> model_input = torch.tensor(transform(["Hello World"])) @@ -58,6 +59,9 @@ class RobertaModelBundle: >>> encoder = roberta_bundle.get_model() >>> classifier_head = RobertaClassificationHead(num_classes=2, input_dim=768) >>> classifier = roberta_bundle.get_model(head=classifier_head) + >>> # using from_config + >>> encoder = RobertaModelBundle.from_config(config=roberta_encoder_conf, checkpoint=model_weights_path) + >>> classifier = RobertaModelBundle.from_config(config=roberta_encoder_conf, head=classifier_head, checkpoint=model_weights_path) """ _encoder_conf: RobertaEncoderConf _path: Optional[str] = None @@ -104,14 +108,27 @@ def from_config( config: RobertaEncoderConf, head: Optional[Module] = None, freeze_encoder: bool = False, - path: Optional[str] = None, + checkpoint: Optional[Union[str, Dict[str, torch.Tensor]]] = None, *, dl_kwargs=None, ) -> RobertaModel: + """Class method to intantiate model with user-defined encoder configuration and checkpoint + + Args: + config: An instance of class RobertaEncoderConf that defined the encoder configuration + head: A module to be attached to the encoder to perform specific task + freeze_encoder: Indicates whether to freeze the encoder weights + checkpoint: Path to or actual model state_dict. state_dict can have partial weights i.e only for encoder. + """ model = _get_model(config, head, freeze_encoder) - if path is not None: - dl_kwargs = {} if dl_kwargs is None else dl_kwargs - state_dict = load_state_dict_from_url(path, **dl_kwargs) + if checkpoint is not None: + if torch.jit.isinstance(checkpoint, Dict[str, torch.Tensor]): + state_dict = checkpoint + elif isinstance(checkpoint, str): + dl_kwargs = {} if dl_kwargs is None else dl_kwargs + state_dict = load_state_dict_from_url(checkpoint, **dl_kwargs) + else: + raise TypeError("checkpoint must be of type `str` or `Dict[str, torch.Tensor]` but got {}".format(type(checkpoint))) if head is not None: model.load_state_dict(state_dict, strict=False) else: From 9a0b5713237c521e07f6d1b336650b3e843497ff Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Fri, 12 Nov 2021 12:27:47 -0500 Subject: [PATCH 3/7] fix tests --- test/models/test_models.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/models/test_models.py b/test/models/test_models.py index f1eca20d2e..037f80c570 100644 --- a/test/models/test_models.py +++ b/test/models/test_models.py @@ -1,7 +1,6 @@ import torchtext import torch -import os - +import urllib from torchtext import _TEXT_BUCKET from ..common.torchtext_test_case import TorchtextTestCase from ..common.assets import get_asset_path @@ -98,8 +97,8 @@ def test_roberta_bundler_from_config(self): from torchtext.models import RobertaEncoderConf asset_name = "xlmr.base.output.pt" asset_path = get_asset_path(asset_name) - model_path = os.path.join(_TEXT_BUCKET, "xlmr.base.encoder.pt") - model = torchtext.models.RobertaModelBundle.from_config(config=RobertaEncoderConf(vocab_size=250002), path=model_path) + model_path = urllib.parse.urljoin(_TEXT_BUCKET, "xlmr.base.encoder.pt") + model = torchtext.models.RobertaModelBundle.from_config(config=RobertaEncoderConf(vocab_size=250002), checkpoint=model_path) model = model.eval() model_input = torch.tensor([[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]]) actual = model(model_input) From 992304dbf273c900cdd94011e6a781f8e7cb4b6c Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Fri, 12 Nov 2021 12:40:19 -0500 Subject: [PATCH 4/7] fix doc --- torchtext/models/roberta/bundler.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/torchtext/models/roberta/bundler.py b/torchtext/models/roberta/bundler.py index bdd2b32b85..d009ff07bb 100644 --- a/torchtext/models/roberta/bundler.py +++ b/torchtext/models/roberta/bundler.py @@ -2,7 +2,7 @@ from functools import partial from urllib.parse import urljoin -from typing import Optional, Callable, Dict, Union +from typing import Optional, Callable, Dict, Union, Any from torchtext._download_hooks import load_state_dict_from_url from torch.nn import Module import torch @@ -113,15 +113,16 @@ def from_config( freeze_encoder: bool = False, checkpoint: Optional[Union[str, Dict[str, torch.Tensor]]] = None, *, - dl_kwargs=None, + dl_kwargs: Dict[str, Any] = None, ) -> RobertaModel: """Class method to intantiate model with user-defined encoder configuration and checkpoint Args: - config: An instance of class RobertaEncoderConf that defined the encoder configuration - head: A module to be attached to the encoder to perform specific task - freeze_encoder: Indicates whether to freeze the encoder weights - checkpoint: Path to or actual model state_dict. state_dict can have partial weights i.e only for encoder. + config (RobertaEncoderConf): An instance of class RobertaEncoderConf that defined the encoder configuration + head (nn.Module, optional): A module to be attached to the encoder to perform specific task + freeze_encoder (bool): Indicates whether to freeze the encoder weights + checkpoint (str or Dict[str, torch.Tensor], optional): Path to or actual model state_dict. state_dict can have partial weights i.e only for encoder. + dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. """ model = _get_model(config, head, freeze_encoder) if checkpoint is not None: From 7cfdffcbcaf73369257716360bbeef897e2923bc Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Fri, 12 Nov 2021 21:32:12 -0500 Subject: [PATCH 5/7] update checkpoint logic --- torchtext/models/roberta/bundler.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/torchtext/models/roberta/bundler.py b/torchtext/models/roberta/bundler.py index d009ff07bb..a1895b23ae 100644 --- a/torchtext/models/roberta/bundler.py +++ b/torchtext/models/roberta/bundler.py @@ -7,7 +7,7 @@ from torch.nn import Module import torch import logging - +import re logger = logging.getLogger(__name__) from .model import ( @@ -134,10 +134,15 @@ def from_config( else: raise TypeError("checkpoint must be of type `str` or `Dict[str, torch.Tensor]` but got {}".format(type(checkpoint))) if head is not None: - model.load_state_dict(state_dict, strict=False) - else: - model.load_state_dict(state_dict, strict=True) - return model + regex = re.compile(r"^head\.") + head_state_dict = {k: v for k, v in model.state_dict().items() if regex.findall(k)} + # if not all the head keys are present in checkpoint then we shall update the state_dict with the provided head state_dict + if not all(key in state_dict.keys() for key in head_state_dict.keys()): + state_dict.update(head_state_dict) + + model.load_state_dict(state_dict, strict=True) + + return model @property def encoderConf(self) -> RobertaEncoderConf: From 8b4048a6fd73a31d9bff04ce388e56fd73e2ef76 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Mon, 22 Nov 2021 11:59:49 -0500 Subject: [PATCH 6/7] solidifying behaviour around checkpoint and head --- test/models/test_models.py | 46 ++++++++++++++++++------- torchtext/models/roberta/bundler.py | 53 ++++++++++++++++------------- 2 files changed, 64 insertions(+), 35 deletions(-) diff --git a/test/models/test_models.py b/test/models/test_models.py index 037f80c570..488b9fc561 100644 --- a/test/models/test_models.py +++ b/test/models/test_models.py @@ -1,7 +1,5 @@ import torchtext import torch -import urllib -from torchtext import _TEXT_BUCKET from ..common.torchtext_test_case import TorchtextTestCase from ..common.assets import get_asset_path @@ -94,13 +92,37 @@ def test_xlmr_transform_jit(self): torch.testing.assert_close(actual, expected) def test_roberta_bundler_from_config(self): - from torchtext.models import RobertaEncoderConf - asset_name = "xlmr.base.output.pt" - asset_path = get_asset_path(asset_name) - model_path = urllib.parse.urljoin(_TEXT_BUCKET, "xlmr.base.encoder.pt") - model = torchtext.models.RobertaModelBundle.from_config(config=RobertaEncoderConf(vocab_size=250002), checkpoint=model_path) - model = model.eval() - model_input = torch.tensor([[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]]) - actual = model(model_input) - expected = torch.load(asset_path) - torch.testing.assert_close(actual, expected) + from torchtext.models import RobertaEncoderConf, RobertaClassificationHead, RobertaModel, RobertaModelBundle + dummy_encoder_conf = RobertaEncoderConf(vocab_size=10, embedding_dim=16, ffn_dimension=64, num_attention_heads=2, num_encoder_layers=2) + + # case: user provide encoder checkpoint state dict + dummy_encoder = RobertaModel(dummy_encoder_conf) + model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf, + checkpoint=dummy_encoder.state_dict()) + self.assertEqual(model.state_dict(), dummy_encoder.state_dict()) + + # case: user provide classifier checkpoint state dict when head is given and override_head is False (by default) + dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16) + another_dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16) + dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head) + model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf, + head=another_dummy_classifier_head, + checkpoint=dummy_classifier.state_dict()) + self.assertEqual(model.state_dict(), dummy_classifier.state_dict()) + + # case: user provide classifier checkpoint state dict when head is given and override_head is set True + another_dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16) + model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf, + head=another_dummy_classifier_head, + checkpoint=dummy_classifier.state_dict(), + override_head=True) + self.assertEqual(model.head.state_dict(), another_dummy_classifier_head.state_dict()) + + # case: user provide only encoder checkpoint state dict when head is given + dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16) + dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head) + encoder_state_dict = {} + for k, v in dummy_classifier.encoder.state_dict().items(): + encoder_state_dict['encoder.' + k] = v + model = torchtext.models.RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf, head=dummy_classifier_head, checkpoint=encoder_state_dict) + self.assertEqual(model.state_dict(), dummy_classifier.state_dict()) diff --git a/torchtext/models/roberta/bundler.py b/torchtext/models/roberta/bundler.py index a1895b23ae..6e2dc57551 100644 --- a/torchtext/models/roberta/bundler.py +++ b/torchtext/models/roberta/bundler.py @@ -21,6 +21,11 @@ from torchtext import _TEXT_BUCKET +def _is_head_available_in_checkpoint(checkpoint, head_state_dict): + # ensure all keys are present + return all(key in checkpoint.keys() for key in head_state_dict.keys()) + + @dataclass class RobertaModelBundle: """RobertaModelBundle(_params: torchtext.models.RobertaEncoderParams, _path: Optional[str] = None, _head: Optional[torch.nn.Module] = None, transform: Optional[Callable] = None) @@ -76,6 +81,11 @@ def get_model(self, *, dl_kwargs=None) -> RobertaModel: r"""get_model(head: Optional[torch.nn.Module] = None, load_weights: bool = True, freeze_encoder: bool = False, *, dl_kwargs=None) -> torctext.models.RobertaModel + + Args: + head (nn.Module): A module to be attached to the encoder to perform specific task. If provided, it will replace the default member head (Default: ``None``) + freeze_encoder (bool): Indicates whether to freeze the encoder weights. (Default: ``False``) + dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. (Default: ``None``) """ if load_weights: @@ -92,39 +102,35 @@ def get_model(self, else: input_head = self._head - model = _get_model(self._encoder_conf, input_head, freeze_encoder) - - if not load_weights: - return model - - dl_kwargs = {} if dl_kwargs is None else dl_kwargs - state_dict = load_state_dict_from_url(self._path, **dl_kwargs) - if input_head is not None: - model.load_state_dict(state_dict, strict=False) - else: - model.load_state_dict(state_dict, strict=True) - return model + return RobertaModelBundle.from_config(encoder_conf=self._encoder_conf, + head=input_head, + freeze_encoder=freeze_encoder, + checkpoint=self._path, + override_head=True, + dl_kwargs=dl_kwargs) @classmethod def from_config( - self, - config: RobertaEncoderConf, + cls, + encoder_conf: RobertaEncoderConf, head: Optional[Module] = None, freeze_encoder: bool = False, checkpoint: Optional[Union[str, Dict[str, torch.Tensor]]] = None, *, + override_head: bool = False, dl_kwargs: Dict[str, Any] = None, ) -> RobertaModel: - """Class method to intantiate model with user-defined encoder configuration and checkpoint + """Class method to create model with user-defined encoder configuration and checkpoint Args: - config (RobertaEncoderConf): An instance of class RobertaEncoderConf that defined the encoder configuration - head (nn.Module, optional): A module to be attached to the encoder to perform specific task - freeze_encoder (bool): Indicates whether to freeze the encoder weights - checkpoint (str or Dict[str, torch.Tensor], optional): Path to or actual model state_dict. state_dict can have partial weights i.e only for encoder. - dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. + encoder_conf (RobertaEncoderConf): An instance of class RobertaEncoderConf that defined the encoder configuration + head (nn.Module): A module to be attached to the encoder to perform specific task. (Default: ``None``) + freeze_encoder (bool): Indicates whether to freeze the encoder weights. (Default: ``False``) + checkpoint (str or Dict[str, torch.Tensor]): Path to or actual model state_dict. state_dict can have partial weights i.e only for encoder. (Default: ``None``) + override_head (bool): Override the checkpoint's head state dict (if present) with provided head state dict. (Default: ``False``) + dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. (Default: ``None``) """ - model = _get_model(config, head, freeze_encoder) + model = _get_model(encoder_conf, head, freeze_encoder) if checkpoint is not None: if torch.jit.isinstance(checkpoint, Dict[str, torch.Tensor]): state_dict = checkpoint @@ -133,11 +139,12 @@ def from_config( state_dict = load_state_dict_from_url(checkpoint, **dl_kwargs) else: raise TypeError("checkpoint must be of type `str` or `Dict[str, torch.Tensor]` but got {}".format(type(checkpoint))) + if head is not None: regex = re.compile(r"^head\.") head_state_dict = {k: v for k, v in model.state_dict().items() if regex.findall(k)} - # if not all the head keys are present in checkpoint then we shall update the state_dict with the provided head state_dict - if not all(key in state_dict.keys() for key in head_state_dict.keys()): + # If checkpoint does not contains head_state_dict, then we augment the checkpoint with user-provided head state_dict + if not _is_head_available_in_checkpoint(state_dict, head_state_dict) or override_head: state_dict.update(head_state_dict) model.load_state_dict(state_dict, strict=True) From a60eaeece9e44f668c96cf079eb27d4d88787759 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Mon, 22 Nov 2021 12:54:33 -0500 Subject: [PATCH 7/7] fixing minor docstring --- torchtext/models/roberta/bundler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchtext/models/roberta/bundler.py b/torchtext/models/roberta/bundler.py index 6e2dc57551..d1b699e4c0 100644 --- a/torchtext/models/roberta/bundler.py +++ b/torchtext/models/roberta/bundler.py @@ -84,7 +84,8 @@ def get_model(self, Args: head (nn.Module): A module to be attached to the encoder to perform specific task. If provided, it will replace the default member head (Default: ``None``) - freeze_encoder (bool): Indicates whether to freeze the encoder weights. (Default: ``False``) + load_weights (bool): Indicates whether or not to load weights if available. (Default: ``True``) + freeze_encoder (bool): Indicates whether or not to freeze the encoder weights. (Default: ``False``) dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. (Default: ``None``) """