diff --git a/torchtext/models/roberta/bundler.py b/torchtext/models/roberta/bundler.py index dc4ce808d4..186a5ff688 100644 --- a/torchtext/models/roberta/bundler.py +++ b/torchtext/models/roberta/bundler.py @@ -56,11 +56,15 @@ class RobertaModelBundle: _head: Optional[Module] = None transform: Optional[Callable] = None - def get_model(self, head: Optional[Module] = None, load_weights=True, *, dl_kwargs=None) -> RobertaModel: + def get_model(self, head: Optional[Module] = None, load_weights: bool = True, freeze_encoder: bool = False, *, dl_kwargs=None) -> RobertaModel: if load_weights: assert self._path is not None, "load_weights cannot be True. The pre-trained model weights are not available for the current object" + if freeze_encoder: + if not load_weights or not self._path: + logger.warn("The encoder is not loaded with pre-trained weights. Setting freeze_encoder to True will hinder encoder from learning appropriate weights.") + if head is not None: input_head = head if self._head is not None: @@ -68,7 +72,7 @@ def get_model(self, head: Optional[Module] = None, load_weights=True, *, dl_kwar else: input_head = self._head - model = _get_model(self._params, input_head) + model = _get_model(self._params, input_head, freeze_encoder) if not load_weights: return model diff --git a/torchtext/models/roberta/model.py b/torchtext/models/roberta/model.py index 0656e647f8..1c9e291857 100644 --- a/torchtext/models/roberta/model.py +++ b/torchtext/models/roberta/model.py @@ -11,6 +11,8 @@ from .modules import ( TransformerEncoder, ) +import logging +logger = logging.getLogger(__name__) @dataclass @@ -103,6 +105,12 @@ def forward(self, tokens: Tensor) -> Tensor: return x -def _get_model(params: RobertaEncoderParams, head: Module) -> RobertaModel: +def _get_model(params: RobertaEncoderParams, head: Optional[Module] = None, freeze_encoder: bool = False) -> RobertaModel: encoder = RobertaEncoder(**asdict(params)) + if freeze_encoder: + for param in encoder.parameters(): + param.requires_grad = False + + logger.info("Encoder weights are frozen") + return RobertaModel(encoder, head)