-
Notifications
You must be signed in to change notification settings - Fork 6.5k
[train_text_to_image] allow using non-ema weights for training #1834
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
pcuenca
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks a lot!
However, I may not be fully understanding it yet. I know that state_dict and load_state_dict are used by accelerate during the checkpointing process, but I don't understand how store and restore are used. In addition, the line to resume from a checkpoint appears to have been removed, is resuming performed differently now?
| temporarily stored. If `None`, the parameters of with which this | ||
| `ExponentialMovingAverage` was initialized will be used. | ||
| """ | ||
| parameters = list(parameters) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very minor question, why do we need the conversion to list here?
Co-authored-by: Pedro Cuenca <[email protected]>
The
My bad, removed it by mistake. |
| if args.allow_tf32: | ||
| torch.backends.cuda.matmul.allow_tf32 = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This gives ~1.3x speed-up on A100.
| unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( | ||
| unet, optimizer, train_dataloader, lr_scheduler | ||
| ) | ||
| accelerator.register_for_checkpointing(lr_scheduler) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not required to register lr_scheduler here, it's automatically checkpointed by accelerate. We only need to register custom objects.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting, thanks. This documentation led me to believe we needed to register it, but in those examples the learning rate scheduler is not being passed to prepare.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm should we maybe ask on the accelerate repo?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've verified this, all standard objects that we pass to prepare (like nn.Module, DataLoader, Optimizer, Scheduler) are automatically checkpointed by accelerate. We only need to register custom objects or models that we don't pass to prepare.
| inputs = tokenizer( | ||
| captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
always padding to max_length now to completely match with original implem.
pcuenca
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks clean! I trust @pcuenca and @patil-suraj here :-)
This PR allows using
non-emaweights for training andemaweights for EMA updates to mimic the original training process. For now, the workflow is as followsnon-ema.--non_ema_revisionargument. If it'sNoneit will default to using ema weights for both training and ema as is the case now.--non_ema_revisionis specified it will be used to load theunetfor training and the ema (main) weights will be used for ema updates.This approach of using branches is not the best solution but will be used until we have the
variationsfeature indiffusers.This PR also
unetis checkpointed.--allow_tf32to enable TF32 on Ampere GPUs (A100) for faster full-precision training. Gives about ~1.33x speed-up.Example command:
Fixes #1153