-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Fix MPS scheduler indexing when using mps
#450
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
patil-suraj
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me!
But looking at the recent comments in the linked issue it seems the issue is resolve from PyTorch side. Do we still need this fix ?
| noise: Union[torch.FloatTensor, np.ndarray], | ||
| timesteps: Union[torch.IntTensor, np.ndarray], | ||
| ) -> Union[torch.FloatTensor, np.ndarray]: | ||
| timesteps = timesteps.to(self.alphas_cumprod.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This breaks the assumption that we can work with numpy arrays as well :(
Maybe we could add a device argument to scheduler.set_format() and call if from pipeline.to()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather minimize the changes across files just for this particular case that only affects mps. How about some defensive coding like:
- Check the format is
pt - And
alphas_cumprodis actually a tensor
Then move using to() ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think we'll move away from set_format() since we'll soon have framework dependent schedulers - so I think this is ok here
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that we'll soon have framework-specific schedulers I think we can go full PyTorch on the existing ones
Makes total sense to me. I'll verify that the inputs are actually torch tensors for now, and we can maybe remove the |
anton-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, this makes sense if we'll have framework-specific schedulers, all good then!
* added ordered benchmarks to dispatch benchmarking tool * saved changes * updated readme Co-authored-by: Elias Joseph <[email protected]>
* Fix LMS scheduler indexing in `add_noise` huggingface#358. * Fix DDIM and DDPM indexing with mps device. * Verify format is PyTorch before using `.to()`
This affects
add_noise, so the image to image tasks.See #358.