Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Conversation

@pmabbo13
Copy link
Contributor

Description

Add a wrapper class for the T5 model that takes text as input and outputs the generated text from the decoder

Process

The wrapper class performs the following end-to-end process:

  1. transforms input text using T5Transform to get truncated and padded token IDs
  2. passes resulting model input through t5 model, using a simple beam search generator to produce the output sequence
  3. uses transform pipeline to translate the generator output (sequence of token IDs) to text.

Wrapper class accepts either: configuration parameter which is a string indicating which bundler object to use (currently only supports 'base', but will update once other configurations are added), or parameters checkpoint, t5_conf, and transform if custom model architecture and weights are to be used.

Testing

Feed wrapper class input text and ensure that it returns the expected output. Tests are performed for wrapper instantiation using both configuration and custom settings. Scripted and non-scripted versions are also tested.

pytest test/prototype/integration_tests/test_models.py

Copy link
Contributor

@Nayef211 Nayef211 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM. How difficult would it be to enable the wrapper class to work with other model configurations?

@pmabbo13
Copy link
Contributor Author

Overall LGTM. How difficult would it be to enable the wrapper class to work with other model configurations?

It would be straight forward. We could update line 43 in wrapper.py to assert that the value passed in for configuration is in ('base', 'small', 'large', '3b', '11b'), and then use if statements to load the bundler object that corresponds to the input configuration.

@pmabbo13 pmabbo13 marked this pull request as ready for review August 15, 2022 14:21
@pmabbo13 pmabbo13 merged commit 5a351b4 into pytorch:main Aug 15, 2022
@pmabbo13 pmabbo13 deleted the feature/t5-wrapper-class branch August 15, 2022 14:22
@pmabbo13 pmabbo13 changed the title [WIP] wrapper class for end-to-end t5 model Wrapper class for end-to-end t5 model Aug 15, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants