-
Notifications
You must be signed in to change notification settings - Fork 6.5k
[refactor] make set_attention_slice recursive #1532
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
| set_requires_grad(self.text_encoder, False) | ||
| set_requires_grad(self.clip_model, False) | ||
|
|
||
| def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not change community pipelines. This will remove functionality at the moment.
src/diffusers/models/attention.py
Outdated
| ) | ||
|
|
||
| def _set_attention_slice(self, slice_size): | ||
| def set_attention_slice(self, slice_size): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! works for me
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for doing this PR! Let's just remove the changes from the community pipelines.
pcuenca
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice! Same comment as Patrick, this would currently break community pipelines as they are loaded from main.
patil-suraj
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! Just left one comment
| if slice_size == "auto": | ||
| # half the attention head size is usually a good trade-off between | ||
| # speed and memory | ||
| slice_size = [dim // 2 for dim in sliceable_head_dims] | ||
| elif slice_size == "max": | ||
| # make smallest slice possible | ||
| slice_size = num_slicable_layers * [1] | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed offline, this logic could be moved to the attention block itself. That way if would be easy to support this functionality in new models which use the attention blocks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense. As discussed offline, I'll leave as is now though so that we can test that number of passed slices equals number of slicable layers. Otherwise we cannot really test it.
| if slice_size == "auto": | ||
| # half the attention head size is usually a good trade-off between | ||
| # speed and memory | ||
| slice_size = [dim // 2 for dim in sliceable_head_dims] | ||
| elif slice_size == "max": | ||
| # make smallest slice possible | ||
| slice_size = num_slicable_layers * [1] | ||
|
|
||
| slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment as above.
| def test_model_attention_slicing(self): | ||
| init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() | ||
|
|
||
| init_dict["attention_head_dim"] = (8, 16) | ||
|
|
||
| model = self.model_class(**init_dict) | ||
| model.to(torch_device) | ||
| model.eval() | ||
|
|
||
| model.set_attention_slice("auto") | ||
| with torch.no_grad(): | ||
| output = model(**inputs_dict) | ||
| assert output is not None | ||
|
|
||
| model.set_attention_slice("max") | ||
| with torch.no_grad(): | ||
| output = model(**inputs_dict) | ||
| assert output is not None | ||
|
|
||
| model.set_attention_slice(2) | ||
| with torch.no_grad(): | ||
| output = model(**inputs_dict) | ||
| assert output is not None | ||
|
|
||
| def test_model_slicable_head_dim(self): | ||
| init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() | ||
|
|
||
| init_dict["attention_head_dim"] = (8, 16) | ||
|
|
||
| model = self.model_class(**init_dict) | ||
|
|
||
| def check_slicable_dim_attr(module: torch.nn.Module): | ||
| if hasattr(module, "set_attention_slice"): | ||
| assert isinstance(module.sliceable_head_dim, int) | ||
|
|
||
| for child in module.children(): | ||
| check_slicable_dim_attr(child) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for adding the tests!
* make attn slice recursive * remove set_attention_slice from blocks * fix copies * make enable_attention_slicing base class method of DiffusionPipeline * fix set_attention_slice * fix set_attention_slice * fix copies * add tests * up * up * up * update * up * uP Co-authored-by: Patrick von Platen <[email protected]>
* make attn slice recursive * remove set_attention_slice from blocks * fix copies * make enable_attention_slicing base class method of DiffusionPipeline * fix set_attention_slice * fix set_attention_slice * fix copies * add tests * up * up * up * update * up * uP Co-authored-by: Patrick von Platen <[email protected]>
* make attn slice recursive * remove set_attention_slice from blocks * fix copies * make enable_attention_slicing base class method of DiffusionPipeline * fix set_attention_slice * fix set_attention_slice * fix copies * add tests * up * up * up * update * up * uP Co-authored-by: Patrick von Platen <[email protected]>
Inspired by #1493, this PR makes
set_attention_slicemethod recursive to be able to easily apply it to various attention blocks without much boilerplate.set_attention_slicea method ofDiffusionPipeline, so it can be applied to all pipelines, since all of them have attention.autoslice logic insideUNet2DConditionModeland recurse there to apply the slice to all blocks.The logic differs a bit from #1493, in that, the we keep the
set_attention_slicemethod on main model classes likeUNet2DConditionModeland move the recursive application logic there. This is so that we can control this functionality using individual model classes.