-
Notifications
You must be signed in to change notification settings - Fork 6.5k
[WIP] Add Flax LoRA Support to Dreambooth #1894
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
|
It appears you are using the 1e-6 learning rate of Dreambooth default, should be 1e-4 for LoRA. |
|
@Thomas-MMJ good call, let me try that |
|
yea I still get |
|
@cloneofsimo any ideas? |
I saw @patrickvonplaten is working on LoRA support for the non-Flax Dreambooth. We've been taking a stab at implementing LoRA support for TPUs, taking example from the patching method used by @cloneofsimo in cloneofsimo/lora. I've got it successfully patching and training, but the output is currently no good. I'm reaching the end of the time I have allocated for this—might pick it up in the future, but for now I'm putting this up in case anyone finds it useful!
|
Made some progress. I was only patching @Thomas-MMJ nope! |
|
@yasyf are you training the lora up/down weights directly, or taking low-rank approximation of the model gradients? i've also reinvented this wheel (building LoRA on Flax for Diffusers), but i'm a purist; i built the model as a whole separate entity, which acts on the params of the unet and text encoder, prior to inference - meaning the diffusers models themselves are not altered (no replacing/wrapping model layers), only their parameters. this works really well for extracting from all the fine-tune models, and combining all those extractions together into a single lora... so in training, my thoughts were to use the dreambooth script, get the gradients against the text encoder and unet models' params, then take the low-rank approximation of the grads, and those are the same shape as the lora layers, so the apply_gradients is done on a trainstate of the lora model. it's a bit more mathsy but essentially i thought that compressing the important parts of the train step gradients and then applying those to the lora params, it would actually train a lot faster due to higher learn rates enabled by the SVD filtering of noise... but, i can't get the dang thing to compile (and on expensive TPU colab, i don't let it try longer than a few minutes), and if it does (e.g. after 25 mins on CPU colab), it OOMs when training starts. might just be i don't know jax, or might be that i'm trying to apply gradients to a different set of params than optax calculates them on, complicating all the tracing it's doing under its magic little hood... it would definitely simplify things to use the lora up/down params themselves as the gradient descent target (calculate loss and grads using the lora up/down params as the train state), but i thought that would loose the ability to leverage lora's strength (or rather, SVD's strength) in filtering the "unimportant" parts of a matrix (the matrix being the gradients). |
|
@krahnikblis I'm using the up/down as the optimization targets directly, since that's what the reference repo for non-flax was doing! |
|
@krahnikblis if you send me the code I can run it on a beefy v4 TPU to see if memory is the issue |
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
I saw @patrickvonplaten is working on LoRA support for the non-Flax Dreambooth in #1884. We've been taking a stab at implementing LoRA support for TPUs, taking example from the patching method used by @cloneofsimo in cloneofsimo/lora.
I've got it successfully patching and training, but the output is currently no good. I'm reaching the end of the time I have allocated for this—might pick it up in the future, but for now I'm putting this up in case anyone finds it useful!