-
Notifications
You must be signed in to change notification settings - Fork 814
Update RobertaModel to directly take in configuration #1431
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,7 +16,7 @@ | |
|
|
||
|
|
||
| @dataclass | ||
| class RobertaEncoderParams: | ||
| class RobertaEncoderConf: | ||
| vocab_size: int = 50265 | ||
| embedding_dim: int = 768 | ||
| ffn_dimension: int = 3072 | ||
|
|
@@ -62,6 +62,10 @@ def __init__( | |
| return_all_layers=False, | ||
| ) | ||
|
|
||
| @classmethod | ||
| def from_config(cls, config: RobertaEncoderConf): | ||
| return cls(**asdict(config)) | ||
|
|
||
| def forward(self, tokens: Tensor, mask: Optional[Tensor] = None) -> Tensor: | ||
| output = self.transformer(tokens) | ||
| if torch.jit.isinstance(output, List[Tensor]): | ||
|
|
@@ -94,26 +98,37 @@ def forward(self, features): | |
|
|
||
|
|
||
| class RobertaModel(Module): | ||
| def __init__(self, encoder: Module, head: Optional[Module] = None): | ||
| """ | ||
|
|
||
| Example - Instantiate model with user-specified configuration | ||
| >>> from torchtext.models import RobertaEncoderConf, RobertaModel, RobertaClassificationHead | ||
| >>> roberta_encoder_conf = RobertaEncoderConf(vocab_size=250002) | ||
| >>> encoder = RobertaModel(config=roberta_encoder_conf) | ||
| >>> classifier_head = RobertaClassificationHead(num_classes=2, input_dim=768) | ||
| >>> classifier = RobertaModel(config=roberta_encoder_conf, head=classifier_head) | ||
| """ | ||
|
|
||
| def __init__(self, config: RobertaEncoderConf, head: Optional[Module] = None, freeze_encoder: bool = False): | ||
| super().__init__() | ||
| self.encoder = encoder | ||
| assert isinstance(config, RobertaEncoderConf) | ||
|
|
||
| self.encoder = RobertaEncoder.from_config(config) | ||
| if freeze_encoder: | ||
| for param in self.encoder.parameters(): | ||
| param.requires_grad = False | ||
|
|
||
| logger.info("Encoder weights are frozen") | ||
|
|
||
| self.head = head | ||
|
|
||
| def forward(self, tokens: Tensor) -> Tensor: | ||
| features = self.encoder(tokens) | ||
| def forward(self, tokens: Tensor, mask: Optional[Tensor] = None) -> Tensor: | ||
| features = self.encoder(tokens, mask) | ||
| if self.head is None: | ||
| return features | ||
|
|
||
| x = self.head(features) | ||
| return x | ||
|
|
||
|
|
||
| def _get_model(params: RobertaEncoderParams, head: Optional[Module] = None, freeze_encoder: bool = False) -> RobertaModel: | ||
| encoder = RobertaEncoder(**asdict(params)) | ||
| if freeze_encoder: | ||
| for param in encoder.parameters(): | ||
| param.requires_grad = False | ||
|
|
||
| logger.info("Encoder weights are frozen") | ||
|
|
||
| return RobertaModel(encoder, head) | ||
| def _get_model(config: RobertaEncoderConf, head: Optional[Module] = None, freeze_encoder: bool = False) -> RobertaModel: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm wondering if we can turn
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm it's a good question. So RobertaModel is a user facing API (we might actually keep RobertaEncoder as private since one can still instantiate encoder through RobertaModel by passing head as None). I am bit hesitant to introduce new API for class method that instantiate object in domain library (it hasn't been the case so far for other domains if i am not mistaken). Though there is nothing inherently wrong with it, I would like to keep things simple and let users use plain old constructor for object instantiation, unless there is a compelling reason otherwise. Another advantage of having a private builder is to allow making changes to it without BC breaking issues. For example in future, we might want to provide sharding support out-of-the-box, that might require making changes in the builder function itself.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I guess I was confused because I saw a Your other points about using plain constructors and allowing for future BC breaking changes makes sense to me.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yupp, that's right. for now it's only for internal convenience. That said, I am not totally against the idea of introducing from_config, I just need more feedback from other domains, if it would be useful to expose such APIs. Also what exactly are the semantics around from_config is also something that needs to be figured out if we decide to expose it as a API for end user. Till then, I think having a conservative approach would help since from_config can be added at any time later without introducing any BC breaking changing :). |
||
| return RobertaModel(config, head, freeze_encoder) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to continue to accept python objects in
__init__as we were doing before and instead deal with instantiation from a config in a separatefrom_configmethod?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking about the different between config-based / object-based approach here.
I do not have a recommendation, as neither is clear winner, but I came to think that the major difference is that accepting only config means that the code is allowing only a set of encoders parameterized by the config, while accepting an object means that the code is allowing any encoder with the same interface (duck-typing).
You might have already thought through, but I think that's something you want to be aware when making the choice here, so that it's aligned with the goal of this API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think passing python object would mean that user would need to instantiate them separately and use this class only to stitch encoder with head. Also as noted in previous comment, we might want to keep encoder as private and only expose Model as the only public facing API as this is sufficient to instantiate both only encoder or encoder+head.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think passing the dataclass config might allow for better type checking (http://blog.ezyang.com/2020/10/idiomatic-algebraic-data-types-in-python-with-dataclasses-and-union/).
Also as noted in other comment, user would need to first instantiate separately the encoder and then make another call to Model Construction when otherwise they could have just done with one call to the constructor of RobertaModel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, we might not want to do this, since this class is specific to RobertaEncoder. Not sure if I miss-understood your comment though?
In general my mental model behind the (Roberta)Model class API is as follows:
I chose the dataclass config to completely define the architecture. The advantage of specifying architecture through dataclass is that it can be easily carried around in multiple places and use-sites(including inside various bundler and pipeline APIs).
Alternative would be to provide the information explicitly as constructor arguments but this would lead to duplication of architecture parameters across multiple places and changes in encoder would inevitably require making changes in model class (or other places wherever encoder is used).
Optionally provide a head module to be used for specific tasks.
Optionally freeze the encoder to allow for fine-tuning only the head with a pre-trained encoder.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My comment was more of my thinking out loud for API design around model construction. I see many different ways to approach like (
from_pretrained, object-based, config-based etc...) and I was trying to see if there is a winner, and while there is no single winner, I noticed that there is one essential difference in config-based (strongly-typed) construction vs object-based construction (duck-typing).It seems like you have a good principles for the choice of the API, and the points you listed make sense. I do not think I am miss understood. Thanks for the reply.