-
Notifications
You must be signed in to change notification settings - Fork 6.5k
[refactor] Making the xformers mem-efficient attention activation recursive #1493
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -246,10 +246,6 @@ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, retu | |
|
|
||
| return Transformer2DModelOutput(sample=output) | ||
|
|
||
| def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): | ||
| for block in self.transformer_blocks: | ||
| block._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) | ||
|
|
||
|
|
||
| class AttentionBlock(nn.Module): | ||
| """ | ||
|
|
@@ -428,7 +424,7 @@ def __init__( | |
| # if xformers is installed try to use memory_efficient_attention by default | ||
| if is_xformers_available(): | ||
| try: | ||
| self._set_use_memory_efficient_attention_xformers(True) | ||
| self.set_use_memory_efficient_attention_xformers(True) | ||
| except Exception as e: | ||
| warnings.warn( | ||
| "Could not enable memory efficient attention. Make sure xformers is installed" | ||
|
|
@@ -439,7 +435,7 @@ def _set_attention_slice(self, slice_size): | |
| self.attn1._slice_size = slice_size | ||
| self.attn2._slice_size = slice_size | ||
|
|
||
| def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): | ||
| def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. called from the outside so can be public ? Plus conveys the idea that it's a capability being exposed |
||
| if not is_xformers_available(): | ||
| print("Here is how to install it") | ||
| raise ModuleNotFoundError( | ||
|
|
@@ -849,11 +845,3 @@ def forward(self, hidden_states, encoder_hidden_states, timestep=None, return_di | |
| return (output_states,) | ||
|
|
||
| return Transformer2DModelOutput(sample=output_states) | ||
|
|
||
| def _set_attention_slice(self, slice_size): | ||
| for transformer in self.transformers: | ||
| transformer._set_attention_slice(slice_size) | ||
|
|
||
| def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): | ||
| for transformer in self.transformers: | ||
| transformer._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -252,17 +252,6 @@ def set_attention_slice(self, slice_size): | |
| if hasattr(block, "attentions") and block.attentions is not None: | ||
| block.set_attention_slice(slice_size) | ||
|
|
||
| def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. all these are just trampolines, not needed with the recursive call from the top. An issue with these trampolines is that they're bound to miss some cases (they do) since they would have to be changed any time a new capability is exposed somewhere in the pipeline |
||
| for block in self.down_blocks: | ||
| if hasattr(block, "attentions") and block.attentions is not None: | ||
| block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) | ||
|
|
||
| self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) | ||
|
|
||
| for block in self.up_blocks: | ||
| if hasattr(block, "attentions") and block.attentions is not None: | ||
| block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) | ||
|
|
||
| def _set_gradient_checkpointing(self, module, value=False): | ||
| if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)): | ||
| module.gradient_checkpointing = value | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -784,3 +784,38 @@ def progress_bar(self, iterable=None, total=None): | |
|
|
||
| def set_progress_bar_config(self, **kwargs): | ||
| self._progress_bar_config = kwargs | ||
|
|
||
| def enable_xformers_memory_efficient_attention(self): | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this enable and disable shorthands are just there because many derived pipelines were using that, so I figured that it was cheaper to expose the call here :)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense to me to make a method of |
||
| r""" | ||
| Enable memory efficient attention as implemented in xformers. | ||
|
|
||
| When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference | ||
| time. Speed up at training time is not guaranteed. | ||
|
|
||
| Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention | ||
| is used. | ||
| """ | ||
| self.set_use_memory_efficient_attention_xformers(True) | ||
|
|
||
| def disable_xformers_memory_efficient_attention(self): | ||
| r""" | ||
| Disable memory efficient attention as implemented in xformers. | ||
| """ | ||
| self.set_use_memory_efficient_attention_xformers(False) | ||
|
|
||
| def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None: | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the actual single implementation on how to enable mem-efficient attention across the whole model, for all pipelines (covers superres, outpainting or text2img, which mobilize attention in different places at times) |
||
| # Recursively walk through all the children. | ||
| # Any children which exposes the set_use_memory_efficient_attention_xformers method | ||
| # gets the message | ||
| def fn_recursive_set_mem_eff(module: torch.nn.Module): | ||
| if hasattr(module, "set_use_memory_efficient_attention_xformers"): | ||
| module.set_use_memory_efficient_attention_xformers(valid) | ||
|
|
||
| for child in module.children(): | ||
| fn_recursive_set_mem_eff(child) | ||
|
|
||
| module_names, _, _ = self.extract_init_dict(dict(self.config)) | ||
| for module_name in module_names: | ||
| module = getattr(self, module_name) | ||
| if isinstance(module, torch.nn.Module): | ||
| fn_recursive_set_mem_eff(module) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here and below: inherits from
DiffusionPipelineso I figured that this could be defined there (with the recursive take) to remove a lot of code duplication