-
Notifications
You must be signed in to change notification settings - Fork 814
Bundler API for TorchText T5 Model #1846
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I intended to add the tests in a separate PR, but I have no opposition to adding it to this PR. |
parmeet
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM! Do we plan to add integration tests in follow-up PR?
Oops just saw the answer here #1846 (comment). Feel free to ignore :) |
Gotcha that should be fine. I do agree with @parmeet's comment about not expecting the user to pass in the |
Sure, feel free to do it in separate PR. |
Description
Add pre-trained T5 model weights for base configuration and an API to load them
Process
T5Confdataclass was created to store the configuration values of the modelT5BundleAPI was created such that thebuild_modelmethod initializes a T5 model (encoder-only or encoder-decoder) according to the input configuration and checkpoint provided to load the pre-trained weights. Theget_modelmethod uses thebuild_modelmethod to load a model according to the object's config and path attributes (the latter stores a path to the saved pre-trained weights).t5.base.encoder.ptpre-trained weights in the s3 bucket.t5.base.ptpre-trained weights in the s3 bucket.Testing
Informal testing was done in this notebook to ensure that we can successfully load the pre-trained models using the bundler objects, and that they have the same output as the HuggingFace implementation.
This PR description will be updated once formal unittests and integration tests have been added.