Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,6 @@ def __init__(
self.norm_out = nn.LayerNorm(inner_dim)
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)

def _set_attention_slice(self, slice_size):
for block in self.transformer_blocks:
block._set_attention_slice(slice_size)

def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
"""
Args:
Expand Down Expand Up @@ -448,10 +444,6 @@ def __init__(
f" correctly and a GPU is available: {e}"
)

def _set_attention_slice(self, slice_size):
self.attn1._slice_size = slice_size
self.attn2._slice_size = slice_size

def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
if not is_xformers_available():
print("Here is how to install it")
Expand Down Expand Up @@ -534,6 +526,7 @@ def __init__(
# for slice_size > 0 the attention score computation
# is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice`
self.sliceable_head_dim = heads
self._slice_size = None
self._use_memory_efficient_attention_xformers = False

Expand All @@ -559,6 +552,12 @@ def reshape_batch_dim_to_heads(self, tensor):
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor

def set_attention_slice(self, slice_size):
if slice_size is not None and slice_size > self.sliceable_head_dim:
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")

self._slice_size = slice_size

def forward(self, hidden_states, context=None, mask=None):
batch_size, sequence_length, _ = hidden_states.shape

Expand Down
53 changes: 0 additions & 53 deletions src/diffusers/models/unet_2d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,23 +401,6 @@ def __init__(
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)

def set_attention_slice(self, slice_size):
head_dims = self.attn_num_head_channels
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
raise ValueError(
f"Make sure slice_size {slice_size} is a common divisor of "
f"the number of heads used in cross_attention: {head_dims}"
)
if slice_size is not None and slice_size > min(head_dims):
raise ValueError(
f"slice_size {slice_size} has to be smaller or equal to "
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
)

for attn in self.attentions:
attn._set_attention_slice(slice_size)

def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
Expand Down Expand Up @@ -595,23 +578,6 @@ def __init__(

self.gradient_checkpointing = False

def set_attention_slice(self, slice_size):
head_dims = self.attn_num_head_channels
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
raise ValueError(
f"Make sure slice_size {slice_size} is a common divisor of "
f"the number of heads used in cross_attention: {head_dims}"
)
if slice_size is not None and slice_size > min(head_dims):
raise ValueError(
f"slice_size {slice_size} has to be smaller or equal to "
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
)

for attn in self.attentions:
attn._set_attention_slice(slice_size)

def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = ()

Expand Down Expand Up @@ -1190,25 +1156,6 @@ def __init__(

self.gradient_checkpointing = False

def set_attention_slice(self, slice_size):
head_dims = self.attn_num_head_channels
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
raise ValueError(
f"Make sure slice_size {slice_size} is a common divisor of "
f"the number of heads used in cross_attention: {head_dims}"
)
if slice_size is not None and slice_size > min(head_dims):
raise ValueError(
f"slice_size {slice_size} has to be smaller or equal to "
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
)

for attn in self.attentions:
attn._set_attention_slice(slice_size)

self.gradient_checkpointing = False

def forward(
self,
hidden_states,
Expand Down
81 changes: 61 additions & 20 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -229,28 +229,69 @@ def __init__(
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)

def set_attention_slice(self, slice_size):
head_dims = self.config.attention_head_dim
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
raise ValueError(
f"Make sure slice_size {slice_size} is a common divisor of "
f"the number of heads used in cross_attention: {head_dims}"
)
if slice_size is not None and slice_size > min(head_dims):
raise ValueError(
f"slice_size {slice_size} has to be smaller or equal to "
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
)
r"""
Enable sliced attention computation.

When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.

Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
sliceable_head_dims = []

def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"):
sliceable_head_dims.append(module.sliceable_head_dim)

for child in module.children():
fn_recursive_retrieve_slicable_dims(child)

# retrieve number of attention layers
for module in self.children():
fn_recursive_retrieve_slicable_dims(module)

for block in self.down_blocks:
if hasattr(block, "attentions") and block.attentions is not None:
block.set_attention_slice(slice_size)
num_slicable_layers = len(sliceable_head_dims)

self.mid_block.set_attention_slice(slice_size)
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = [dim // 2 for dim in sliceable_head_dims]
elif slice_size == "max":
# make smallest slice possible
slice_size = num_slicable_layers * [1]

Comment on lines +260 to +267
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, this logic could be moved to the attention block itself. That way if would be easy to support this functionality in new models which use the attention blocks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. As discussed offline, I'll leave as is now though so that we can test that number of passed slices equals number of slicable layers. Otherwise we cannot really test it.

slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size

if len(slice_size) != len(sliceable_head_dims):
raise ValueError(
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
)

for block in self.up_blocks:
if hasattr(block, "attentions") and block.attentions is not None:
block.set_attention_slice(slice_size)
for i in range(len(slice_size)):
size = slice_size[i]
dim = sliceable_head_dims[i]
if size is not None and size > dim:
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")

# Recursively walk through all the children.
# Any children which exposes the set_attention_slice method
# gets the message
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
if hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size.pop())

for child in module.children():
fn_recursive_set_attention_slice(child, slice_size)

reversed_slice_size = list(reversed(slice_size))
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
Expand Down
31 changes: 31 additions & 0 deletions src/diffusers/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,3 +839,34 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
module = getattr(self, module_name)
if isinstance(module, torch.nn.Module):
fn_recursive_set_mem_eff(module)

def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.

When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.

Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
self.set_attention_slice(slice_size)

def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)

def set_attention_slice(self, slice_size: Optional[int]):
module_names, _, _ = self.extract_init_dict(dict(self.config))
for module_name in module_names:
module = getattr(self, module_name)
if isinstance(module, torch.nn.Module) and hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size)
32 changes: 0 additions & 32 deletions src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,38 +166,6 @@ def __init__(
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)

def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.

When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.

Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)

self.unet.set_attention_slice(slice_size)

def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)

def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,38 +179,6 @@ def __init__(
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker)

def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.

When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.

Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)

self.unet.set_attention_slice(slice_size)

def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)

def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,40 +209,6 @@ def __init__(
)
self.register_to_config(requires_safety_checker=requires_safety_checker)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.

When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.

Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)

self.unet.set_attention_slice(slice_size)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Expand Down
Loading