Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions torchtext/models/roberta/bundler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,23 @@ class RobertaModelBundle:
_head: Optional[Module] = None
transform: Optional[Callable] = None

def get_model(self, head: Optional[Module] = None, load_weights=True, *, dl_kwargs=None) -> RobertaModel:
def get_model(self, head: Optional[Module] = None, load_weights: bool = True, freeze_encoder: bool = False, *, dl_kwargs=None) -> RobertaModel:

if load_weights:
assert self._path is not None, "load_weights cannot be True. The pre-trained model weights are not available for the current object"

if freeze_encoder:
if not load_weights or not self._path:
logger.warn("The encoder is not loaded with pre-trained weights. Setting freeze_encoder to True will hinder encoder from learning appropriate weights.")

if head is not None:
input_head = head
if self._head is not None:
logger.log("A custom head module was provided, discarding the default head module.")
else:
input_head = self._head

model = _get_model(self._params, input_head)
model = _get_model(self._params, input_head, freeze_encoder)

if not load_weights:
return model
Expand Down
10 changes: 9 additions & 1 deletion torchtext/models/roberta/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from .modules import (
TransformerEncoder,
)
import logging
logger = logging.getLogger(__name__)


@dataclass
Expand Down Expand Up @@ -103,6 +105,12 @@ def forward(self, tokens: Tensor) -> Tensor:
return x


def _get_model(params: RobertaEncoderParams, head: Module) -> RobertaModel:
def _get_model(params: RobertaEncoderParams, head: Optional[Module] = None, freeze_encoder: bool = False) -> RobertaModel:
encoder = RobertaEncoder(**asdict(params))
if freeze_encoder:
for param in encoder.parameters():
param.requires_grad = False

logger.info("Encoder weights are frozen")

return RobertaModel(encoder, head)