-
Notifications
You must be signed in to change notification settings - Fork 6.5k
multiple prediction options in ddpm, ddim #818
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
|
Yes we indeed need this now I think :-) (also for dance diffusion) |
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally, this looks good to me :-) We'll definitely need tests here though
|
I want to make a colab comparing the prediction for training on one scheduler to start (make sure it works). |
|
I'm new to contributing and so I'm a little confused about what I should be doing. Should I clone the changes and make a colab to compare with original predictions? |
|
Hey @pie31415, Since you mentioned you were interested in this PR, I think it'd be super useful to do a PR review here :-) |
|
@pie31415 Another really useful thing would be to just verify the implementation from the original papers and links above. This is a pretty tricky port so I will do this too, but it would be hugely useful. For example, I actually think the DDIM implementation is much closer than DDPM. |
| model_output: torch.FloatTensor, | ||
| timestep: int, | ||
| sample: torch.FloatTensor, | ||
| prediction_type: str = "epsilon", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think this should go in the __init__ function and we've somewhat settled on predict_epsilon: bool I think in terms of naming :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah ok now we actually have three types, so we might have to reconsider this choice 😅
But I think it should definitely go in the config of the scheduler and not be an arg of __call__
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually messaged Nathan about my progress on DDIM v prediction in a separate branch as you commented - crazy timing!
I'll make sure I make these changes in my branch before opening a PR
|
That's very interesting here actually - @patil-suraj @anton-l could you also take a look? :-) |
|
DDIM will hopefully be ready for review soon too. Results training on it are still a little pixelated, but you can clearly see the shape of a butterfly. I'm guessing I have something not quite right with the variance calculation. Will hopefully have updates here soon! |
* v diffusion support for ddpm * quality and style * variable name consistency * missing base case * pass prediction type along in the pipeline * put prediction type in scheduler config * style * try to train on ddim * changes to ddim * ddim v prediction works to train butterflies example * fix bad merge, style and quality * try to fix broken doc strings * second pass * one more * white space * Update src/diffusers/schedulers/scheduling_ddim.py * remove extra lines * Update src/diffusers/schedulers/scheduling_ddim.py Co-authored-by: Ben Glickenhaus <[email protected]> Co-authored-by: Nathan Lambert <[email protected]>
|
Update for the diffusers team (@patrickvonplaten , @anton-l , @patil-suraj ). We updated DDIM now (promising results), and I'll add tests / fix merge issues this afternoon. |
|
@patrickvonplaten this should be go to go. Now, this leaves only Lots more good work from @bglick13 |
| set_alpha_to_one: bool = True, | ||
| variance_type: str = "fixed", | ||
| steps_offset: int = 0, | ||
| prediction_type: Literal["epsilon", "sample", "velocity"] = "epsilon", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note we currently have a config parameter called predict_epsilon that is already used in multiple schedulers:
| predict_epsilon: bool = True, |
So we cannot really add this prediciton_type here without deprecating the other one and also deprecating arguments like this one:
| "--predict_epsilon", |
pcuenca
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Same comments about deprecating predict_epsilon everywhere.
| def expand_to_shape(input, timesteps, shape, device): | ||
| """ | ||
| Helper indexes a 1D tensor `input` using a 1D index tensor `timesteps`, then reshapes the result to broadcast | ||
| nicely with `shape`. Useful for parallelizing operations over `shape[0]` number of diffusion steps at once. | ||
| """ | ||
| out = torch.gather(input.to(device), 0, timesteps.to(device)) | ||
| reshape = [shape[0]] + [1] * (len(shape) - 1) | ||
| out = out.reshape(*reshape) | ||
| return out | ||
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do we feel about moving this to scheduling_utils.py? Maybe get_alpha_sigma as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm good with this. I had this as a TODO in my mind. Could also be made more elegant, but wasn't 100% sure how yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My only concern with these two is
expand_to_shapewould be the only function like this. It's okay to start the trend.get_alpha_sigmawon't work with many of the schedulers, so I'm okay with leaving it in the ones that use v-prediction for now.
Co-authored-by: Pedro Cuenca <[email protected]>
|
Added more deprecating across the board. I tried to address @patrickvonplaten's comment above, but would like a double check on that! |
| ) | ||
|
|
||
| # not check on predict_epsilon for depreciation flag above | ||
| elif self.prediction_type == "sample" or not self.config.predict_epsilon: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These if statement's I had to mess with a little bit to get the tests to pass. All will be much cleaner when its deprecated.
|
The code isn't as clear, but you can see some details on model parametrization in the SD 2.0 code here. The option
@patil-suraj @patrickvonplaten @bglick13 |
starting to work on discussion in #778.
Please contribute and leave feedback. This is mostly a placeholder for my work right now as I figure out how to do it.
Some relevant repositories: