diff --git a/docs/source/models.rst b/docs/source/models.rst index 500b2a6c7d..b5425c8da4 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -32,3 +32,19 @@ XLMR_LARGE_ENCODER .. autodata:: XLMR_LARGE_ENCODER :no-value: +ROBERTA_BASE_ENCODER +-------------------- + +.. container:: py attribute + + .. autodata:: ROBERTA_BASE_ENCODER + :no-value: + + +ROBERTA_LARGE_ENCODER +--------------------- + +.. container:: py attribute + + .. autodata:: ROBERTA_LARGE_ENCODER + :no-value: diff --git a/test/asset/roberta.base.output.pt b/test/asset/roberta.base.output.pt new file mode 100644 index 0000000000..d04c740b88 Binary files /dev/null and b/test/asset/roberta.base.output.pt differ diff --git a/test/asset/roberta.large.output.pt b/test/asset/roberta.large.output.pt new file mode 100644 index 0000000000..9f5287ca3f Binary files /dev/null and b/test/asset/roberta.large.output.pt differ diff --git a/test/integration_tests/__init__.py b/test/integration_tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/integration_tests/test_models.py b/test/integration_tests/test_models.py new file mode 100644 index 0000000000..ee811beff1 --- /dev/null +++ b/test/integration_tests/test_models.py @@ -0,0 +1,67 @@ +import torch +import torchtext + +from ..common.assets import get_asset_path +from ..common.torchtext_test_case import TorchtextTestCase + + +class TestModels(TorchtextTestCase): + def test_roberta_base(self): + asset_path = get_asset_path("roberta.base.output.pt") + test_text = "Roberta base Model Comparison" + + roberta_base = torchtext.models.ROBERTA_BASE_ENCODER + transform = roberta_base.transform() + model = roberta_base.get_model() + model = model.eval() + + model_input = torch.tensor(transform([test_text])) + actual = model(model_input) + expected = torch.load(asset_path) + torch.testing.assert_close(actual, expected) + + def test_roberta_base_jit(self): + asset_path = get_asset_path("roberta.base.output.pt") + test_text = "Roberta base Model Comparison" + + roberta_base = torchtext.models.ROBERTA_BASE_ENCODER + transform = roberta_base.transform() + transform_jit = torch.jit.script(transform) + model = roberta_base.get_model() + model = model.eval() + model_jit = torch.jit.script(model) + + model_input = torch.tensor(transform_jit([test_text])) + actual = model_jit(model_input) + expected = torch.load(asset_path) + torch.testing.assert_close(actual, expected) + + def test_roberta_large(self): + asset_path = get_asset_path("roberta.large.output.pt") + test_text = "Roberta base Model Comparison" + + roberta_large = torchtext.models.ROBERTA_LARGE_ENCODER + transform = roberta_large.transform() + model = roberta_large.get_model() + model = model.eval() + + model_input = torch.tensor(transform([test_text])) + actual = model(model_input) + expected = torch.load(asset_path) + torch.testing.assert_close(actual, expected) + + def test_roberta_large_jit(self): + asset_path = get_asset_path("roberta.large.output.pt") + test_text = "Roberta base Model Comparison" + + roberta_large = torchtext.models.ROBERTA_LARGE_ENCODER + transform = roberta_large.transform() + transform_jit = torch.jit.script(transform) + model = roberta_large.get_model() + model = model.eval() + model_jit = torch.jit.script(model) + + model_input = torch.tensor(transform_jit([test_text])) + actual = model_jit(model_input) + expected = torch.load(asset_path) + torch.testing.assert_close(actual, expected) diff --git a/torchtext/models/roberta/__init__.py b/torchtext/models/roberta/__init__.py index 1057c6deb6..79830cb2c3 100644 --- a/torchtext/models/roberta/__init__.py +++ b/torchtext/models/roberta/__init__.py @@ -8,6 +8,8 @@ RobertaModelBundle, XLMR_BASE_ENCODER, XLMR_LARGE_ENCODER, + ROBERTA_BASE_ENCODER, + ROBERTA_LARGE_ENCODER, ) __all__ = [ @@ -17,4 +19,6 @@ "RobertaModelBundle", "XLMR_BASE_ENCODER", "XLMR_LARGE_ENCODER", + "ROBERTA_BASE_ENCODER", + "ROBERTA_LARGE_ENCODER", ] diff --git a/torchtext/models/roberta/bundler.py b/torchtext/models/roberta/bundler.py index 7cd9e2c833..fcfb82dbd6 100644 --- a/torchtext/models/roberta/bundler.py +++ b/torchtext/models/roberta/bundler.py @@ -192,3 +192,89 @@ def encoderConf(self) -> RobertaEncoderConf: Please refer to :func:`torchtext.models.RobertaModelBundle` for the usage. ''' ) + + +ROBERTA_BASE_ENCODER = RobertaModelBundle( + _path=urljoin(_TEXT_BUCKET, "roberta.base.encoder.pt"), + _encoder_conf=RobertaEncoderConf(vocab_size=50265), + transform=lambda: T.Sequential( + T.GPT2BPETokenizer( + encoder_json_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_encoder.json"), + vocab_bpe_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_vocab.bpe"), + ), + T.VocabTransform( + load_state_dict_from_url(urljoin(_TEXT_BUCKET, "roberta.vocab.pt")) + ), + T.Truncate(254), + T.AddToken(token=0, begin=True), + T.AddToken(token=2, begin=False), + ), +) + +ROBERTA_BASE_ENCODER.__doc__ = ( + ''' + Roberta Encoder with Base configuration + + RoBERTa iterates on BERT's pretraining procedure, including training the model longer, + with bigger batches over more data; removing the next sentence prediction objective; + training on longer sequences; and dynamically changing the masking pattern applied + to the training data. + + The RoBERTa model was pretrained on the reunion of five datasets: BookCorpus, + English Wikipedia, CC-News, OpenWebText, and STORIES. Together theses datasets + contain over a 160GB of text. + + Originally published by the authors of RoBERTa under MIT License + and redistributed with the same license. + [`License `__, + `Source `__] + + Please refer to :func:`torchtext.models.RobertaModelBundle` for the usage. + ''' +) + + +ROBERTA_LARGE_ENCODER = RobertaModelBundle( + _path=urljoin(_TEXT_BUCKET, "roberta.large.encoder.pt"), + _encoder_conf=RobertaEncoderConf( + vocab_size=50265, + embedding_dim=1024, + ffn_dimension=4096, + num_attention_heads=16, + num_encoder_layers=24, + ), + transform=lambda: T.Sequential( + T.GPT2BPETokenizer( + encoder_json_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_encoder.json"), + vocab_bpe_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_vocab.bpe"), + ), + T.VocabTransform( + load_state_dict_from_url(urljoin(_TEXT_BUCKET, "roberta.vocab.pt")) + ), + T.Truncate(510), + T.AddToken(token=0, begin=True), + T.AddToken(token=2, begin=False), + ), +) + +ROBERTA_LARGE_ENCODER.__doc__ = ( + ''' + Roberta Encoder with Large configuration + + RoBERTa iterates on BERT's pretraining procedure, including training the model longer, + with bigger batches over more data; removing the next sentence prediction objective; + training on longer sequences; and dynamically changing the masking pattern applied + to the training data. + + The RoBERTa model was pretrained on the reunion of five datasets: BookCorpus, + English Wikipedia, CC-News, OpenWebText, and STORIES. Together theses datasets + contain over a 160GB of text. + + Originally published by the authors of RoBERTa under MIT License + and redistributed with the same license. + [`License `__, + `Source `__] + + Please refer to :func:`torchtext.models.RobertaModelBundle` for the usage. + ''' +)