-
Notifications
You must be signed in to change notification settings - Fork 604
Make Transformer tolerate missing layers for PP #322
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
A few small changes here lets manual PP frontend 'reconfigure' a whole transformer model to a stage's portion simply by setting undesired layers to None (in cases of top level layers) or deleting them from the ModuleDict (for 'layers.*'). These changes don't impact the FQNs of the remaining layers, which is critical for checkpoint load/save compatibility. ghstack-source-id: 48a7aaf Pull Request resolved: #322
| h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens | ||
|
|
||
| for layer in self.layers: | ||
| for layer in self.layers.values(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is order still respected after switching to dict? If not, we need to sort the layers based on int(key).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
Nice. But it is less intuitive than I originally thought. Especially the int/str conversion part. Not sure if that's a best UX for pippy or a customized PipelineModuleList will be easier for users. |
wanchaol
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
A few small changes here lets manual PP frontend 'reconfigure' a whole transformer model to a stage's portion simply by setting undesired layers to None (in cases of top level layers) or deleting them from the ModuleDict (for 'layers.*'). These changes don't impact the FQNs of the remaining layers, which is critical for checkpoint load/save compatibility. ghstack-source-id: 48a7aaf Pull Request resolved: #322
| self.layers = torch.nn.ModuleDict() | ||
| for layer_id in range(model_args.n_layers): | ||
| self.layers.append(TransformerBlock(layer_id, model_args)) | ||
| self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious why do the dict keys have to be str (as opposed to int directly)?
|
One downside to using |
A few small changes here lets manual PP frontend 'reconfigure' a whole transformer model to a stage's portion simply by setting undesired layers to None (in cases of top level layers) or deleting them from the ModuleDict (for 'layers.*'). These changes don't impact the FQNs of the remaining layers, which is critical for checkpoint load/save compatibility. ghstack-source-id: 48a7aaf Pull Request resolved: pytorch#322
A few small changes here lets manual PP frontend 'reconfigure' a whole transformer model to a stage's portion simply by setting undesired layers to None (in cases of top level layers) or deleting them from the ModuleDict (for 'layers.*'). These changes don't impact the FQNs of the remaining layers, which is critical for checkpoint load/save compatibility. ghstack-source-id: 48a7aaf Pull Request resolved: pytorch#322
Stack from ghstack (oldest at bottom):
A few small changes here lets manual PP frontend 'reconfigure' a whole
transformer model to a stage's portion simply by setting undesired
layers to None (in cases of top level layers) or deleting them from the
ModuleDict (for 'layers.*').
These changes don't impact the FQNs of the remaining layers, which is
critical for checkpoint load/save compatibility.