Skip to content

Conversation

@yasyf
Copy link
Contributor

@yasyf yasyf commented Jan 3, 2023

I saw @patrickvonplaten is working on LoRA support for the non-Flax Dreambooth in #1884. We've been taking a stab at implementing LoRA support for TPUs, taking example from the patching method used by @cloneofsimo in cloneofsimo/lora.

I've got it successfully patching and training, but the output is currently no good. I'm reaching the end of the time I have allocated for this—might pick it up in the future, but for now I'm putting this up in case anyone finds it useful!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@Thomas-MMJ
Copy link

Thomas-MMJ commented Jan 3, 2023

It appears you are using the 1e-6 learning rate of Dreambooth default, should be 1e-4 for LoRA.

@yasyf
Copy link
Contributor Author

yasyf commented Jan 4, 2023

@Thomas-MMJ good call, let me try that

@yasyf
Copy link
Contributor Author

yasyf commented Jan 5, 2023

yea I still get NaNs for the loss function right away. Will leave in case its useful for @patrickvonplaten

@yasyf
Copy link
Contributor Author

yasyf commented Jan 5, 2023

@cloneofsimo any ideas?

@Thomas-MMJ
Copy link

Thomas-MMJ commented Jan 5, 2023

@yasyf is your base model fp16? See this PR/commit,

#1916

I saw @patrickvonplaten is working on LoRA support for the non-Flax Dreambooth.
We've been taking a stab at implementing LoRA support for TPUs, taking example from
the patching method used by @cloneofsimo in cloneofsimo/lora.

I've got it successfully patching and training, but the output is currently no good.
I'm reaching the end of the time I have allocated for this—might pick it up in the future,
but for now I'm putting this up in case anyone finds it useful!
@yasyf
Copy link
Contributor Author

yasyf commented Jan 6, 2023

Made some progress. I was only patching FlaxAttentionBlock. By adding FlaxGEGLU to the list, I'm now able to get it to start training. However, after a few steps, the loss rapidly blows up, so this still isn't usable :(

Epoch... (1/100 | Loss: 0.22375227510929108)
Epoch... (2/100 | Loss: 0.42204561829566956)
Epoch... (3/100 | Loss: 0.9007828235626221)
Epoch... (4/100 | Loss: 1.4756149053573608)
Epoch... (5/100 | Loss: 344.7566223144531)
Epoch... (6/100 | Loss: nan)

@Thomas-MMJ nope!

@krahnikblis
Copy link

@yasyf are you training the lora up/down weights directly, or taking low-rank approximation of the model gradients?

i've also reinvented this wheel (building LoRA on Flax for Diffusers), but i'm a purist; i built the model as a whole separate entity, which acts on the params of the unet and text encoder, prior to inference - meaning the diffusers models themselves are not altered (no replacing/wrapping model layers), only their parameters. this works really well for extracting from all the fine-tune models, and combining all those extractions together into a single lora...

so in training, my thoughts were to use the dreambooth script, get the gradients against the text encoder and unet models' params, then take the low-rank approximation of the grads, and those are the same shape as the lora layers, so the apply_gradients is done on a trainstate of the lora model. it's a bit more mathsy but essentially i thought that compressing the important parts of the train step gradients and then applying those to the lora params, it would actually train a lot faster due to higher learn rates enabled by the SVD filtering of noise... but, i can't get the dang thing to compile (and on expensive TPU colab, i don't let it try longer than a few minutes), and if it does (e.g. after 25 mins on CPU colab), it OOMs when training starts. might just be i don't know jax, or might be that i'm trying to apply gradients to a different set of params than optax calculates them on, complicating all the tracing it's doing under its magic little hood...

it would definitely simplify things to use the lora up/down params themselves as the gradient descent target (calculate loss and grads using the lora up/down params as the train state), but i thought that would loose the ability to leverage lora's strength (or rather, SVD's strength) in filtering the "unimportant" parts of a matrix (the matrix being the gradients).

@yasyf
Copy link
Contributor Author

yasyf commented Jan 18, 2023

@krahnikblis I'm using the up/down as the optimization targets directly, since that's what the reference repo for non-flax was doing!

@yasyf
Copy link
Contributor Author

yasyf commented Jan 18, 2023

@krahnikblis if you send me the code I can run it on a beefy v4 TPU to see if memory is the issue

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Feb 25, 2023
@github-actions github-actions bot closed this Mar 6, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

stale Issues that haven't received updates

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants