-
Notifications
You must be signed in to change notification settings - Fork 814
Add T5 Model and Demo on Text Summarization using CNNDM Dataset #1800
Description
🚀 Feature
Add CNNDM dataset and a pre-trained T5 model to TorchText. Demo model on task of abstractive summarization using the CNNDM dataset.
Motivation
There are multiple frameworks out in OSS that cater to a wide variety of audiences. As a result of this fragmentation, a typical NLP researcher usually writes their code in pure PyTorch while copying essential components from other repositories. Adding a pre-trained T5 model and CNNDM dataset increases the convenience of using the TorchText library and works towards making PyTorch the most preferred deep learning framework for NLP research.
T5 (Text-To-Text Transfer Transformer) is a transformer model that is trained in an end-to-end manner with text as input and modified text as output. This text-to-text formatting makes the T5 model fit for multiple NLP tasks like Summarization, Question-Answering, Machine Translation, and Classification problems. CNNDM (CNN/DailyMail) is also a popular dataset in the NLP community used for text summarization tasks.
Pitch
The T5 model architecture will be implemented to allow for initialization using hyper-parameters such as the number of layers, hidden size, attention size, etc. The user should also be able to specify whether they wish to access the Encoder-only model (for non-text generation tasks) or the Encoder-Decoder model. To load the pre-trained weights, Google has released 5 checkpoints for the different sized T5 models, so these checkpoint weights will be added to PyTorch.org and an API will be implemented to load these checkpoints. Finally, integration tests will be added for both the Encoder-only and Encoder-Decoder model APIs.
The CNNDM dataset will also be made available in the TorchText library. This will allows us to demo the pre-trained T5 model by using it to perform abstract summarization on the CNNDM dataset. A text pre-processing pipeline will need to be implemented in order to prep the data for the model.
Milestone 1: Add CNNDM dataset
- Add CNNDM dataset
Milestone 2: Implement T5 model architecture
- Create relative position buckets method compute relative position buckets for relative attention bias #1830
- Create method to compute relative attention bias term computing relative attention bias #1831
- Create method to compute attention scores using relative attention bias computing attention scores using relative attention bias #1832
- Implement MultiheadAttention module
- Implement Root-Mean-Square Layer normalization module add layer norm module for t5 model #1826
- Implement T5Layer module add t5 layer module that can be used for both encoder or decoder stack #1827
- Implement T5Stack module add t5 stack that can function as either the encoder or decoder of a t5 model #1828
- Implement T5Model module add t5 model that can function as both encodery-only or encoder-decoder model #1829
- Add pre-trained T5 model weights and an API to load them Bundler API for TorchText T5 Model #1846
- Create integration tests for the T5 model APIs Testing T5Model #1848
Milestone 3: Demo T5 model on text summarization
- Create text pre-processing pipeline to prep data for T5 model T5Transform text pre-processing for t5 model #1852
- Expose text pre-processing pipeline in T5 Bundler Add text transform to T5 bundler #1856
- Demonstrate text generation using T5 on the CNNDM dataset
Stretch Goals
- Implement a beam search generator Updating T5 demo to use beam search for generator #1869
- Demo T5 model on additional tasks Demo T5 model on sentiment classification and translation #1872
- Make the model torchscriptable Make T5 model torchscriptable #1876
- Add wrapper class for end-to-end T5 usage Wrapper class for end-to-end t5 model #1880
- Add remaining model configs (i.e. small, large, etc..) Add small, large, 3b, 11b pre-trained weights for t5 #1879