Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 0 additions & 18 deletions examples/community/lpw_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,24 +488,6 @@ def __init__(
feature_extractor=feature_extractor,
)

def enable_xformers_memory_efficient_attention(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here and below: inherits from DiffusionPipeline so I figured that this could be defined there (with the recursive take) to remove a lot of code duplication

r"""
Enable memory efficient attention as implemented in xformers.

When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.

Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)

def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)

def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
Expand Down
18 changes: 0 additions & 18 deletions examples/community/sd_text2img_k_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,24 +106,6 @@ def set_sampler(self, scheduler_type: str):
sampling = getattr(library, "sampling")
self.sampler = getattr(sampling, scheduler_type)

def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.

When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.

Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)

def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)

def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
Expand Down
18 changes: 0 additions & 18 deletions examples/community/text_inpainting.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,24 +183,6 @@ def _execution_device(self):
return torch.device(module._hf_hook.execution_device)
return self.device

def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.

When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.

Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)

def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)

@torch.no_grad()
def __call__(
self,
Expand Down
16 changes: 2 additions & 14 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,6 @@ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, retu

return Transformer2DModelOutput(sample=output)

def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for block in self.transformer_blocks:
block._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)


class AttentionBlock(nn.Module):
"""
Expand Down Expand Up @@ -428,7 +424,7 @@ def __init__(
# if xformers is installed try to use memory_efficient_attention by default
if is_xformers_available():
try:
self._set_use_memory_efficient_attention_xformers(True)
self.set_use_memory_efficient_attention_xformers(True)
except Exception as e:
warnings.warn(
"Could not enable memory efficient attention. Make sure xformers is installed"
Expand All @@ -439,7 +435,7 @@ def _set_attention_slice(self, slice_size):
self.attn1._slice_size = slice_size
self.attn2._slice_size = slice_size

def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
Copy link
Contributor Author

@blefaudeux blefaudeux Nov 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

called from the outside so can be public ? Plus conveys the idea that it's a capability being exposed

if not is_xformers_available():
print("Here is how to install it")
raise ModuleNotFoundError(
Expand Down Expand Up @@ -849,11 +845,3 @@ def forward(self, hidden_states, encoder_hidden_states, timestep=None, return_di
return (output_states,)

return Transformer2DModelOutput(sample=output_states)

def _set_attention_slice(self, slice_size):
for transformer in self.transformers:
transformer._set_attention_slice(slice_size)

def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for transformer in self.transformers:
transformer._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
12 changes: 0 additions & 12 deletions src/diffusers/models/unet_2d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,10 +418,6 @@ def set_attention_slice(self, slice_size):
for attn in self.attentions:
attn._set_attention_slice(slice_size)

def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for attn in self.attentions:
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)

def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
Expand Down Expand Up @@ -616,10 +612,6 @@ def set_attention_slice(self, slice_size):
for attn in self.attentions:
attn._set_attention_slice(slice_size)

def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for attn in self.attentions:
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)

def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = ()

Expand Down Expand Up @@ -1217,10 +1209,6 @@ def set_attention_slice(self, slice_size):

self.gradient_checkpointing = False

def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for attn in self.attentions:
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)

def forward(
self,
hidden_states,
Expand Down
11 changes: 0 additions & 11 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,17 +252,6 @@ def set_attention_slice(self, slice_size):
if hasattr(block, "attentions") and block.attentions is not None:
block.set_attention_slice(slice_size)

def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all these are just trampolines, not needed with the recursive call from the top. An issue with these trampolines is that they're bound to miss some cases (they do) since they would have to be changed any time a new capability is exposed somewhere in the pipeline

for block in self.down_blocks:
if hasattr(block, "attentions") and block.attentions is not None:
block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)

self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)

for block in self.up_blocks:
if hasattr(block, "attentions") and block.attentions is not None:
block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
module.gradient_checkpointing = value
Expand Down
35 changes: 35 additions & 0 deletions src/diffusers/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,3 +784,38 @@ def progress_bar(self, iterable=None, total=None):

def set_progress_bar_config(self, **kwargs):
self._progress_bar_config = kwargs

def enable_xformers_memory_efficient_attention(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this enable and disable shorthands are just there because many derived pipelines were using that, so I figured that it was cheaper to expose the call here :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me to make a method of DiffusionPipeline !

r"""
Enable memory efficient attention as implemented in xformers.

When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.

Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.set_use_memory_efficient_attention_xformers(True)

def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.set_use_memory_efficient_attention_xformers(False)

def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the actual single implementation on how to enable mem-efficient attention across the whole model, for all pipelines (covers superres, outpainting or text2img, which mobilize attention in different places at times)

# Recursively walk through all the children.
# Any children which exposes the set_use_memory_efficient_attention_xformers method
# gets the message
def fn_recursive_set_mem_eff(module: torch.nn.Module):
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
module.set_use_memory_efficient_attention_xformers(valid)

for child in module.children():
fn_recursive_set_mem_eff(child)

module_names, _, _ = self.extract_init_dict(dict(self.config))
for module_name in module_names:
module = getattr(self, module_name)
if isinstance(module, torch.nn.Module):
fn_recursive_set_mem_eff(module)
18 changes: 0 additions & 18 deletions src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,24 +166,6 @@ def __init__(
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)

def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.

When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.

Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)

def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)

def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,24 +251,6 @@ def _execution_device(self):
return torch.device(module._hf_hook.execution_device)
return self.device

def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.

When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.

Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)

def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)

def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
r"""
Encodes the prompt into text encoder hidden states.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,26 +285,6 @@ def _execution_device(self):
return torch.device(module._hf_hook.execution_device)
return self.device

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.

When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.

Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
r"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,24 +165,6 @@ def __init__(
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)

def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.

When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.

Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)

def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)

def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,26 +134,6 @@ def __init__(
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.

When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.

Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,26 +254,6 @@ def _execution_device(self):
return torch.device(module._hf_hook.execution_device)
return self.device

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.

When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.

Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
r"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,26 +300,6 @@ def enable_sequential_cpu_offload(self, gpu_id=0):
# fix by only offloading self.safety_checker for now
cpu_offload(self.safety_checker.vision_model, device)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.

When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.

Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)

@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self):
Expand Down
Loading