From a227f0534db31524296f199caa923b132d55dfa7 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Thu, 28 Oct 2021 00:06:41 -0400 Subject: [PATCH 1/2] Provide option for freezing encoder weights --- torchtext/models/roberta/bundler.py | 4 ++-- torchtext/models/roberta/model.py | 10 +++++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/torchtext/models/roberta/bundler.py b/torchtext/models/roberta/bundler.py index dc4ce808d4..d37d59a468 100644 --- a/torchtext/models/roberta/bundler.py +++ b/torchtext/models/roberta/bundler.py @@ -56,7 +56,7 @@ class RobertaModelBundle: _head: Optional[Module] = None transform: Optional[Callable] = None - def get_model(self, head: Optional[Module] = None, load_weights=True, *, dl_kwargs=None) -> RobertaModel: + def get_model(self, head: Optional[Module] = None, load_weights: bool = True, freeze_encoder: bool = False, *, dl_kwargs=None) -> RobertaModel: if load_weights: assert self._path is not None, "load_weights cannot be True. The pre-trained model weights are not available for the current object" @@ -68,7 +68,7 @@ def get_model(self, head: Optional[Module] = None, load_weights=True, *, dl_kwar else: input_head = self._head - model = _get_model(self._params, input_head) + model = _get_model(self._params, input_head, freeze_encoder) if not load_weights: return model diff --git a/torchtext/models/roberta/model.py b/torchtext/models/roberta/model.py index 0656e647f8..1c9e291857 100644 --- a/torchtext/models/roberta/model.py +++ b/torchtext/models/roberta/model.py @@ -11,6 +11,8 @@ from .modules import ( TransformerEncoder, ) +import logging +logger = logging.getLogger(__name__) @dataclass @@ -103,6 +105,12 @@ def forward(self, tokens: Tensor) -> Tensor: return x -def _get_model(params: RobertaEncoderParams, head: Module) -> RobertaModel: +def _get_model(params: RobertaEncoderParams, head: Optional[Module] = None, freeze_encoder: bool = False) -> RobertaModel: encoder = RobertaEncoder(**asdict(params)) + if freeze_encoder: + for param in encoder.parameters(): + param.requires_grad = False + + logger.info("Encoder weights are frozen") + return RobertaModel(encoder, head) From db169a3645f3bdb7036dff9367ceee5da9d0baed Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Thu, 28 Oct 2021 00:10:52 -0400 Subject: [PATCH 2/2] add warning --- torchtext/models/roberta/bundler.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchtext/models/roberta/bundler.py b/torchtext/models/roberta/bundler.py index d37d59a468..186a5ff688 100644 --- a/torchtext/models/roberta/bundler.py +++ b/torchtext/models/roberta/bundler.py @@ -61,6 +61,10 @@ def get_model(self, head: Optional[Module] = None, load_weights: bool = True, fr if load_weights: assert self._path is not None, "load_weights cannot be True. The pre-trained model weights are not available for the current object" + if freeze_encoder: + if not load_weights or not self._path: + logger.warn("The encoder is not loaded with pre-trained weights. Setting freeze_encoder to True will hinder encoder from learning appropriate weights.") + if head is not None: input_head = head if self._head is not None: