From 8e2a51f6891fb1ee9a4567b13360295aa071060c Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Thu, 4 Nov 2021 10:51:24 -0400 Subject: [PATCH 1/5] Update RobertaModel to directly take in configuration --- torchtext/models/roberta/__init__.py | 6 ++++-- torchtext/models/roberta/bundler.py | 19 ++++++++++++----- torchtext/models/roberta/model.py | 31 +++++++++++++++------------- 3 files changed, 35 insertions(+), 21 deletions(-) diff --git a/torchtext/models/roberta/__init__.py b/torchtext/models/roberta/__init__.py index f03218844b..1057c6deb6 100644 --- a/torchtext/models/roberta/__init__.py +++ b/torchtext/models/roberta/__init__.py @@ -1,6 +1,7 @@ from .model import ( - RobertaEncoderParams, + RobertaEncoderConf, RobertaClassificationHead, + RobertaModel, ) from .bundler import ( @@ -10,8 +11,9 @@ ) __all__ = [ - "RobertaEncoderParams", + "RobertaEncoderConf", "RobertaClassificationHead", + "RobertaModel", "RobertaModelBundle", "XLMR_BASE_ENCODER", "XLMR_LARGE_ENCODER", diff --git a/torchtext/models/roberta/bundler.py b/torchtext/models/roberta/bundler.py index 186a5ff688..0689b7ba25 100644 --- a/torchtext/models/roberta/bundler.py +++ b/torchtext/models/roberta/bundler.py @@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) from .model import ( - RobertaEncoderParams, + RobertaEncoderConf, RobertaModel, _get_model, ) @@ -50,8 +50,17 @@ class RobertaModelBundle: >>> output = classification_model(model_input) >>> output.shape torch.Size([1, 2]) + + Example - User-specified configuration and checkpoint + >>> from torchtext.models import RobertaEncoderConf, RobertaModelBundle, RobertaClassificationHead + >>> model_weights_path = "https://download.pytorch.org/models/text/xlmr.base.encoder.pt" + >>> roberta_encoder_conf = RobertaEncoderConf(vocab_size=250002) + >>> roberta_bundle = RobertaModelBundle(_params=roberta_encoder_conf, _path=model_weights_path) + >>> encoder = roberta_bundle.get_model() + >>> classifier_head = RobertaClassificationHead(num_classes=2, input_dim=768) + >>> classifier = roberta_bundle.get_model(head=classifier_head) """ - _params: RobertaEncoderParams + _params: RobertaEncoderConf _path: Optional[str] = None _head: Optional[Module] = None transform: Optional[Callable] = None @@ -86,13 +95,13 @@ def get_model(self, head: Optional[Module] = None, load_weights: bool = True, fr return model @property - def params(self) -> RobertaEncoderParams: + def params(self) -> RobertaEncoderConf: return self._params XLMR_BASE_ENCODER = RobertaModelBundle( _path=os.path.join(_TEXT_BUCKET, "xlmr.base.encoder.pt"), - _params=RobertaEncoderParams(vocab_size=250002), + _params=RobertaEncoderConf(vocab_size=250002), transform=partial(get_xlmr_transform, vocab_path=os.path.join(_TEXT_BUCKET, "xlmr.vocab.pt"), spm_model_path=os.path.join(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model"), @@ -101,7 +110,7 @@ def params(self) -> RobertaEncoderParams: XLMR_LARGE_ENCODER = RobertaModelBundle( _path=os.path.join(_TEXT_BUCKET, "xlmr.large.encoder.pt"), - _params=RobertaEncoderParams(vocab_size=250002, embedding_dim=1024, ffn_dimension=4096, num_attention_heads=16, num_encoder_layers=24), + _params=RobertaEncoderConf(vocab_size=250002, embedding_dim=1024, ffn_dimension=4096, num_attention_heads=16, num_encoder_layers=24), transform=partial(get_xlmr_transform, vocab_path=os.path.join(_TEXT_BUCKET, "xlmr.vocab.pt"), spm_model_path=os.path.join(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model"), diff --git a/torchtext/models/roberta/model.py b/torchtext/models/roberta/model.py index 35d27fc79a..32412f7d22 100644 --- a/torchtext/models/roberta/model.py +++ b/torchtext/models/roberta/model.py @@ -16,7 +16,7 @@ @dataclass -class RobertaEncoderParams: +class RobertaEncoderConf: vocab_size: int = 50265 embedding_dim: int = 768 ffn_dimension: int = 3072 @@ -62,6 +62,10 @@ def __init__( return_all_layers=False, ) + @classmethod + def from_config(cls, config: RobertaEncoderConf): + return cls(**asdict(config)) + def forward(self, tokens: Tensor, mask: Optional[Tensor] = None) -> Tensor: output = self.transformer(tokens) if torch.jit.isinstance(output, List[Tensor]): @@ -94,13 +98,19 @@ def forward(self, features): class RobertaModel(Module): - def __init__(self, encoder: Module, head: Optional[Module] = None): + def __init__(self, config: RobertaEncoderConf, head: Optional[Module] = None, freeze_encoder: bool = False): super().__init__() - self.encoder = encoder + self.encoder = RobertaEncoder.from_config(config) + if freeze_encoder: + for param in self.encoder.parameters(): + param.requires_grad = False + + logger.info("Encoder weights are frozen") + self.head = head - def forward(self, tokens: Tensor) -> Tensor: - features = self.encoder(tokens) + def forward(self, tokens: Tensor, mask: Optional[Tensor] = None) -> Tensor: + features = self.encoder(tokens, mask) if self.head is None: return features @@ -108,12 +118,5 @@ def forward(self, tokens: Tensor) -> Tensor: return x -def _get_model(params: RobertaEncoderParams, head: Optional[Module] = None, freeze_encoder: bool = False) -> RobertaModel: - encoder = RobertaEncoder(**asdict(params)) - if freeze_encoder: - for param in encoder.parameters(): - param.requires_grad = False - - logger.info("Encoder weights are frozen") - - return RobertaModel(encoder, head) +def _get_model(config: RobertaEncoderConf, head: Optional[Module] = None, freeze_encoder: bool = False) -> RobertaModel: + return RobertaModel(config, head, freeze_encoder) From a9467de25e0c84b63aed4936a9f8f3cc3009751e Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Thu, 4 Nov 2021 10:53:20 -0400 Subject: [PATCH 2/5] flake fix --- torchtext/models/roberta/bundler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtext/models/roberta/bundler.py b/torchtext/models/roberta/bundler.py index 0689b7ba25..61fe3c6f50 100644 --- a/torchtext/models/roberta/bundler.py +++ b/torchtext/models/roberta/bundler.py @@ -52,7 +52,7 @@ class RobertaModelBundle: torch.Size([1, 2]) Example - User-specified configuration and checkpoint - >>> from torchtext.models import RobertaEncoderConf, RobertaModelBundle, RobertaClassificationHead + >>> from torchtext.models import RobertaEncoderConf, RobertaModelBundle, RobertaClassificationHead >>> model_weights_path = "https://download.pytorch.org/models/text/xlmr.base.encoder.pt" >>> roberta_encoder_conf = RobertaEncoderConf(vocab_size=250002) >>> roberta_bundle = RobertaModelBundle(_params=roberta_encoder_conf, _path=model_weights_path) From 786cb744da265ae3d6339d85d39afca4d039f894 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Thu, 4 Nov 2021 12:49:15 -0400 Subject: [PATCH 3/5] add docstring --- torchtext/models/roberta/model.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/torchtext/models/roberta/model.py b/torchtext/models/roberta/model.py index 32412f7d22..df56f49870 100644 --- a/torchtext/models/roberta/model.py +++ b/torchtext/models/roberta/model.py @@ -98,6 +98,15 @@ def forward(self, features): class RobertaModel(Module): + """ + + Example - Instantiate model with user-specified configuration + >>> from torchtext.models import RobertaEncoderConf, RobertaModel, RobertaClassificationHead + >>> roberta_encoder_conf = RobertaEncoderConf(vocab_size=250002) + >>> encoder = RobertaModel(config=roberta_encoder_conf) + >>> classifier_head = RobertaClassificationHead(num_classes=2, input_dim=768) + >>> classifier = RobertaModel(config=roberta_encoder_conf, head=classifier_head) + """ def __init__(self, config: RobertaEncoderConf, head: Optional[Module] = None, freeze_encoder: bool = False): super().__init__() self.encoder = RobertaEncoder.from_config(config) From 1d5f90af1364abfd662f427504418862ee674886 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Thu, 4 Nov 2021 13:09:27 -0400 Subject: [PATCH 4/5] change _params to _encoder_conf to aling the naming --- torchtext/models/roberta/bundler.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/torchtext/models/roberta/bundler.py b/torchtext/models/roberta/bundler.py index 61fe3c6f50..d26a2d71cc 100644 --- a/torchtext/models/roberta/bundler.py +++ b/torchtext/models/roberta/bundler.py @@ -55,12 +55,12 @@ class RobertaModelBundle: >>> from torchtext.models import RobertaEncoderConf, RobertaModelBundle, RobertaClassificationHead >>> model_weights_path = "https://download.pytorch.org/models/text/xlmr.base.encoder.pt" >>> roberta_encoder_conf = RobertaEncoderConf(vocab_size=250002) - >>> roberta_bundle = RobertaModelBundle(_params=roberta_encoder_conf, _path=model_weights_path) + >>> roberta_bundle = RobertaModelBundle(_encoder_conf=roberta_encoder_conf, _path=model_weights_path) >>> encoder = roberta_bundle.get_model() >>> classifier_head = RobertaClassificationHead(num_classes=2, input_dim=768) >>> classifier = roberta_bundle.get_model(head=classifier_head) """ - _params: RobertaEncoderConf + _encoder_conf: RobertaEncoderConf _path: Optional[str] = None _head: Optional[Module] = None transform: Optional[Callable] = None @@ -81,7 +81,7 @@ def get_model(self, head: Optional[Module] = None, load_weights: bool = True, fr else: input_head = self._head - model = _get_model(self._params, input_head, freeze_encoder) + model = _get_model(self._encoder_conf, input_head, freeze_encoder) if not load_weights: return model @@ -95,13 +95,13 @@ def get_model(self, head: Optional[Module] = None, load_weights: bool = True, fr return model @property - def params(self) -> RobertaEncoderConf: - return self._params + def encoderConf(self) -> RobertaEncoderConf: + return self._encoder_conf XLMR_BASE_ENCODER = RobertaModelBundle( _path=os.path.join(_TEXT_BUCKET, "xlmr.base.encoder.pt"), - _params=RobertaEncoderConf(vocab_size=250002), + _encoder_conf=RobertaEncoderConf(vocab_size=250002), transform=partial(get_xlmr_transform, vocab_path=os.path.join(_TEXT_BUCKET, "xlmr.vocab.pt"), spm_model_path=os.path.join(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model"), @@ -110,7 +110,7 @@ def params(self) -> RobertaEncoderConf: XLMR_LARGE_ENCODER = RobertaModelBundle( _path=os.path.join(_TEXT_BUCKET, "xlmr.large.encoder.pt"), - _params=RobertaEncoderConf(vocab_size=250002, embedding_dim=1024, ffn_dimension=4096, num_attention_heads=16, num_encoder_layers=24), + _encoder_conf=RobertaEncoderConf(vocab_size=250002, embedding_dim=1024, ffn_dimension=4096, num_attention_heads=16, num_encoder_layers=24), transform=partial(get_xlmr_transform, vocab_path=os.path.join(_TEXT_BUCKET, "xlmr.vocab.pt"), spm_model_path=os.path.join(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model"), From 31188626e8a23312262ed50fe92dedd24ad37067 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Fri, 5 Nov 2021 01:26:08 -0400 Subject: [PATCH 5/5] add type checking --- torchtext/models/roberta/model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchtext/models/roberta/model.py b/torchtext/models/roberta/model.py index df56f49870..60a4362550 100644 --- a/torchtext/models/roberta/model.py +++ b/torchtext/models/roberta/model.py @@ -107,8 +107,11 @@ class RobertaModel(Module): >>> classifier_head = RobertaClassificationHead(num_classes=2, input_dim=768) >>> classifier = RobertaModel(config=roberta_encoder_conf, head=classifier_head) """ + def __init__(self, config: RobertaEncoderConf, head: Optional[Module] = None, freeze_encoder: bool = False): super().__init__() + assert isinstance(config, RobertaEncoderConf) + self.encoder = RobertaEncoder.from_config(config) if freeze_encoder: for param in self.encoder.parameters():