diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 31d1f1ff418a..61186a253572 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -174,10 +174,6 @@ def __init__( self.norm_out = nn.LayerNorm(inner_dim) self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) - def _set_attention_slice(self, slice_size): - for block in self.transformer_blocks: - block._set_attention_slice(slice_size) - def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): """ Args: @@ -448,10 +444,6 @@ def __init__( f" correctly and a GPU is available: {e}" ) - def _set_attention_slice(self, slice_size): - self.attn1._slice_size = slice_size - self.attn2._slice_size = slice_size - def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): if not is_xformers_available(): print("Here is how to install it") @@ -534,6 +526,7 @@ def __init__( # for slice_size > 0 the attention score computation # is split across the batch axis to save memory # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads self._slice_size = None self._use_memory_efficient_attention_xformers = False @@ -559,6 +552,12 @@ def reshape_batch_dim_to_heads(self, tensor): tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) return tensor + def set_attention_slice(self, slice_size): + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + self._slice_size = slice_size + def forward(self, hidden_states, context=None, mask=None): batch_size, sequence_length, _ = hidden_states.shape diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index d78804b18e75..63e2b809d721 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -401,23 +401,6 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def set_attention_slice(self, slice_size): - head_dims = self.attn_num_head_channels - head_dims = [head_dims] if isinstance(head_dims, int) else head_dims - if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): - raise ValueError( - f"Make sure slice_size {slice_size} is a common divisor of " - f"the number of heads used in cross_attention: {head_dims}" - ) - if slice_size is not None and slice_size > min(head_dims): - raise ValueError( - f"slice_size {slice_size} has to be smaller or equal to " - f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" - ) - - for attn in self.attentions: - attn._set_attention_slice(slice_size) - def forward(self, hidden_states, temb=None, encoder_hidden_states=None): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): @@ -595,23 +578,6 @@ def __init__( self.gradient_checkpointing = False - def set_attention_slice(self, slice_size): - head_dims = self.attn_num_head_channels - head_dims = [head_dims] if isinstance(head_dims, int) else head_dims - if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): - raise ValueError( - f"Make sure slice_size {slice_size} is a common divisor of " - f"the number of heads used in cross_attention: {head_dims}" - ) - if slice_size is not None and slice_size > min(head_dims): - raise ValueError( - f"slice_size {slice_size} has to be smaller or equal to " - f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" - ) - - for attn in self.attentions: - attn._set_attention_slice(slice_size) - def forward(self, hidden_states, temb=None, encoder_hidden_states=None): output_states = () @@ -1190,25 +1156,6 @@ def __init__( self.gradient_checkpointing = False - def set_attention_slice(self, slice_size): - head_dims = self.attn_num_head_channels - head_dims = [head_dims] if isinstance(head_dims, int) else head_dims - if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): - raise ValueError( - f"Make sure slice_size {slice_size} is a common divisor of " - f"the number of heads used in cross_attention: {head_dims}" - ) - if slice_size is not None and slice_size > min(head_dims): - raise ValueError( - f"slice_size {slice_size} has to be smaller or equal to " - f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" - ) - - for attn in self.attentions: - attn._set_attention_slice(slice_size) - - self.gradient_checkpointing = False - def forward( self, hidden_states, diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index f9d3402d0619..2acc2cb4764e 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -229,28 +229,69 @@ def __init__( self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) def set_attention_slice(self, slice_size): - head_dims = self.config.attention_head_dim - head_dims = [head_dims] if isinstance(head_dims, int) else head_dims - if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): - raise ValueError( - f"Make sure slice_size {slice_size} is a common divisor of " - f"the number of heads used in cross_attention: {head_dims}" - ) - if slice_size is not None and slice_size > min(head_dims): - raise ValueError( - f"slice_size {slice_size} has to be smaller or equal to " - f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" - ) + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) - for block in self.down_blocks: - if hasattr(block, "attentions") and block.attentions is not None: - block.set_attention_slice(slice_size) + num_slicable_layers = len(sliceable_head_dims) - self.mid_block.set_attention_slice(slice_size) + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) - for block in self.up_blocks: - if hasattr(block, "attentions") and block.attentions is not None: - block.set_attention_slice(slice_size) + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)): diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index e65d55e20cd9..2ae968554f0d 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -839,3 +839,34 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): module = getattr(self, module_name) if isinstance(module, torch.nn.Module): fn_recursive_set_mem_eff(module) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + self.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def set_attention_slice(self, slice_size: Optional[int]): + module_names, _, _ = self.extract_init_dict(dict(self.config)) + for module_name in module_names: + module = getattr(self, module_name) + if isinstance(module, torch.nn.Module) and hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index fb64a34a0bd8..993b9d305186 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -166,38 +166,6 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - if isinstance(self.unet.config.attention_head_dim, int): - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - else: - # if `attention_head_dim` is a list, take the smallest head size - slice_size = min(self.unet.config.attention_head_dim) - - self.unet.set_attention_slice(slice_size) - - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - def enable_vae_slicing(self): r""" Enable sliced VAE decoding. diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 346f5f727bb8..3a8ad7c2407f 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -179,38 +179,6 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - if isinstance(self.unet.config.attention_head_dim, int): - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - else: - # if `attention_head_dim` is a list, take the smallest head size - slice_size = min(self.unet.config.attention_head_dim) - - self.unet.set_attention_slice(slice_size) - - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py index a688a52a7a6e..1639b723af99 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py @@ -209,40 +209,6 @@ def __init__( ) self.register_to_config(requires_safety_checker=requires_safety_checker) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - if isinstance(self.unet.config.attention_head_dim, int): - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - else: - # if `attention_head_dim` is a list, take the smallest head size - slice_size = min(self.unet.config.attention_head_dim) - - self.unet.set_attention_slice(slice_size) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index a3a8703f3ea4..cbd817cc0588 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -165,38 +165,6 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - if isinstance(self.unet.config.attention_head_dim, int): - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - else: - # if `attention_head_dim` is a list, take the smallest head size - slice_size = min(self.unet.config.attention_head_dim) - - self.unet.set_attention_slice(slice_size) - - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - def enable_vae_slicing(self): r""" Enable sliced VAE decoding. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py index d77e71653078..1d34280d34de 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py @@ -134,40 +134,6 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - if isinstance(self.unet.config.attention_head_dim, int): - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - else: - # if `attention_head_dim` is a list, take the smallest head size - slice_size = min(self.unet.config.attention_head_dim) - - self.unet.set_attention_slice(slice_size) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 933f59c3bd32..b343a165e124 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -178,40 +178,6 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - if isinstance(self.unet.config.attention_head_dim, int): - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - else: - # if `attention_head_dim` is a list, take the smallest head size - slice_size = min(self.unet.config.attention_head_dim) - - self.unet.set_attention_slice(slice_size) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index f4ce399d8d0b..34ee8e6840af 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -243,40 +243,6 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - if isinstance(self.unet.config.attention_head_dim, int): - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - else: - # if `attention_head_dim` is a list, take the smallest head size - slice_size = min(self.unet.config.attention_head_dim) - - self.unet.set_attention_slice(slice_size) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index 60d52eaa1ab4..7c248ac1a721 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -191,40 +191,6 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - if isinstance(self.unet.config.attention_head_dim, int): - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - else: - # if `attention_head_dim` is a list, take the smallest head size - slice_size = min(self.unet.config.attention_head_dim) - - self.unet.set_attention_slice(slice_size) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 53395fa9c3ca..4eb53a7f607f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -92,40 +92,6 @@ def __init__( ) self.register_to_config(max_noise_level=max_noise_level) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - if isinstance(self.unet.config.attention_head_dim, int): - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - else: - # if `attention_head_dim` is a list, take the smallest head size - slice_size = min(self.unet.config.attention_head_dim) - - self.unet.set_attention_slice(slice_size) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, diff --git a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py index 5cb0f2c03daf..921befcdf3cb 100644 --- a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +++ b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py @@ -182,33 +182,6 @@ def safety_concept(self, concept): """ self._safety_text_concept = concept - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - self.unet.set_attention_slice(slice_size) - - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - def enable_sequential_cpu_offload(self): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index d1a3d4c55e91..fdf97481feb9 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -307,28 +307,69 @@ def __init__( self.conv_out = LinearMultiDim(block_out_channels[0], out_channels, kernel_size=3, padding=1) def set_attention_slice(self, slice_size): - head_dims = self.config.attention_head_dim - head_dims = [head_dims] if isinstance(head_dims, int) else head_dims - if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): - raise ValueError( - f"Make sure slice_size {slice_size} is a common divisor of " - f"the number of heads used in cross_attention: {head_dims}" - ) - if slice_size is not None and slice_size > min(head_dims): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): raise ValueError( - f"slice_size {slice_size} has to be smaller or equal to " - f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) - for block in self.down_blocks: - if hasattr(block, "attentions") and block.attentions is not None: - block.set_attention_slice(slice_size) + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) - self.mid_block.set_attention_slice(slice_size) + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) - for block in self.up_blocks: - if hasattr(block, "attentions") and block.attentions is not None: - block.set_attention_slice(slice_size) + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlockFlat, DownBlockFlat, CrossAttnUpBlockFlat, UpBlockFlat)): @@ -739,23 +780,6 @@ def __init__( self.gradient_checkpointing = False - def set_attention_slice(self, slice_size): - head_dims = self.attn_num_head_channels - head_dims = [head_dims] if isinstance(head_dims, int) else head_dims - if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): - raise ValueError( - f"Make sure slice_size {slice_size} is a common divisor of " - f"the number of heads used in cross_attention: {head_dims}" - ) - if slice_size is not None and slice_size > min(head_dims): - raise ValueError( - f"slice_size {slice_size} has to be smaller or equal to " - f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" - ) - - for attn in self.attentions: - attn._set_attention_slice(slice_size) - def forward(self, hidden_states, temb=None, encoder_hidden_states=None): output_states = () @@ -948,25 +972,6 @@ def __init__( self.gradient_checkpointing = False - def set_attention_slice(self, slice_size): - head_dims = self.attn_num_head_channels - head_dims = [head_dims] if isinstance(head_dims, int) else head_dims - if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): - raise ValueError( - f"Make sure slice_size {slice_size} is a common divisor of " - f"the number of heads used in cross_attention: {head_dims}" - ) - if slice_size is not None and slice_size > min(head_dims): - raise ValueError( - f"slice_size {slice_size} has to be smaller or equal to " - f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" - ) - - for attn in self.attentions: - attn._set_attention_slice(slice_size) - - self.gradient_checkpointing = False - def forward( self, hidden_states, @@ -1092,23 +1097,6 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def set_attention_slice(self, slice_size): - head_dims = self.attn_num_head_channels - head_dims = [head_dims] if isinstance(head_dims, int) else head_dims - if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): - raise ValueError( - f"Make sure slice_size {slice_size} is a common divisor of " - f"the number of heads used in cross_attention: {head_dims}" - ) - if slice_size is not None and slice_size > min(head_dims): - raise ValueError( - f"slice_size {slice_size} has to be smaller or equal to " - f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" - ) - - for attn in self.attentions: - attn._set_attention_slice(slice_size) - def forward(self, hidden_states, temb=None, encoder_hidden_states=None): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py index 7be7f4d3aee6..e9ef505040c8 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py @@ -80,34 +80,6 @@ def __init__( ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.image_unet.config.attention_head_dim // 2 - self.image_unet.set_attention_slice(slice_size) - self.text_unet.set_attention_slice(slice_size) - - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - @torch.no_grad() def image_variation( self, diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py index 3a90ae2c7620..89b333c9c499 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py @@ -147,40 +147,6 @@ def _revert_dual_attention(self): self.image_unet.register_to_config(dual_cross_attention=False) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing with unet->image_unet - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - if isinstance(self.image_unet.config.attention_head_dim, int): - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.image_unet.config.attention_head_dim // 2 - else: - # if `attention_head_dim` is a list, take the smallest head size - slice_size = min(self.image_unet.config.attention_head_dim) - - self.image_unet.set_attention_slice(slice_size) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py index b68dd244ce47..87924fdff81a 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py @@ -73,40 +73,6 @@ def __init__( ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing with unet->image_unet - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - if isinstance(self.image_unet.config.attention_head_dim, int): - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.image_unet.config.attention_head_dim // 2 - else: - # if `attention_head_dim` is a list, take the smallest head size - slice_size = min(self.image_unet.config.attention_head_dim) - - self.image_unet.set_attention_slice(slice_size) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py index c9c4bb7dc40e..3920955f672d 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py @@ -98,40 +98,6 @@ def _swap_unet_attention_blocks(self): def remove_unused_weights(self): self.register_modules(text_unet=None) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing with unet->image_unet - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - if isinstance(self.image_unet.config.attention_head_dim, int): - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.image_unet.config.attention_head_dim // 2 - else: - # if `attention_head_dim` is a list, take the smallest head size - slice_size = min(self.image_unet.config.attention_head_dim) - - self.image_unet.set_attention_slice(slice_size) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, diff --git a/tests/models/test_models_unet_2d.py b/tests/models/test_models_unet_2d.py index 59b9e02ff8b9..4a2d5a96ed9e 100644 --- a/tests/models/test_models_unet_2d.py +++ b/tests/models/test_models_unet_2d.py @@ -334,6 +334,48 @@ def test_model_with_use_linear_projection(self): expected_shape = inputs_dict["sample"].shape self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + def test_model_attention_slicing(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + model.set_attention_slice("auto") + with torch.no_grad(): + output = model(**inputs_dict) + assert output is not None + + model.set_attention_slice("max") + with torch.no_grad(): + output = model(**inputs_dict) + assert output is not None + + model.set_attention_slice(2) + with torch.no_grad(): + output = model(**inputs_dict) + assert output is not None + + def test_model_slicable_head_dim(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + model = self.model_class(**init_dict) + + def check_slicable_dim_attr(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + assert isinstance(module.sliceable_head_dim, int) + + for child in module.children(): + check_slicable_dim_attr(child) + + # retrieve number of attention layers + for module in model.children(): + check_slicable_dim_attr(module) + class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): model_class = UNet2DModel @@ -479,6 +521,84 @@ def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"): return model + def test_set_attention_slice_auto(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + unet = self.get_unet_model() + unet.set_attention_slice("auto") + + latents = self.get_latents(33) + encoder_hidden_states = self.get_encoder_hidden_states(33) + timestep = 1 + + with torch.no_grad(): + _ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + + mem_bytes = torch.cuda.max_memory_allocated() + + assert mem_bytes < 5 * 10**9 + + def test_set_attention_slice_max(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + unet = self.get_unet_model() + unet.set_attention_slice("max") + + latents = self.get_latents(33) + encoder_hidden_states = self.get_encoder_hidden_states(33) + timestep = 1 + + with torch.no_grad(): + _ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + + mem_bytes = torch.cuda.max_memory_allocated() + + assert mem_bytes < 5 * 10**9 + + def test_set_attention_slice_int(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + unet = self.get_unet_model() + unet.set_attention_slice(2) + + latents = self.get_latents(33) + encoder_hidden_states = self.get_encoder_hidden_states(33) + timestep = 1 + + with torch.no_grad(): + _ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + + mem_bytes = torch.cuda.max_memory_allocated() + + assert mem_bytes < 5 * 10**9 + + def test_set_attention_slice_list(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + # there are 32 slicable layers + slice_list = 16 * [2, 3] + unet = self.get_unet_model() + unet.set_attention_slice(slice_list) + + latents = self.get_latents(33) + encoder_hidden_states = self.get_encoder_hidden_states(33) + timestep = 1 + + with torch.no_grad(): + _ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + + mem_bytes = torch.cuda.max_memory_allocated() + + assert mem_bytes < 5 * 10**9 + def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False): dtype = torch.float16 if fp16 else torch.float32 hidden_states = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype) @@ -500,6 +620,8 @@ def test_compvis_sd_v1_4(self, seed, timestep, expected_slice): latents = self.get_latents(seed) encoder_hidden_states = self.get_encoder_hidden_states(seed) + timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) + with torch.no_grad(): sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample @@ -526,6 +648,8 @@ def test_compvis_sd_v1_4_fp16(self, seed, timestep, expected_slice): latents = self.get_latents(seed, fp16=True) encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) + timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) + with torch.no_grad(): sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample @@ -552,6 +676,8 @@ def test_compvis_sd_v1_5(self, seed, timestep, expected_slice): latents = self.get_latents(seed) encoder_hidden_states = self.get_encoder_hidden_states(seed) + timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) + with torch.no_grad(): sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample @@ -578,6 +704,8 @@ def test_compvis_sd_v1_5_fp16(self, seed, timestep, expected_slice): latents = self.get_latents(seed, fp16=True) encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) + timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) + with torch.no_grad(): sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample @@ -604,6 +732,8 @@ def test_compvis_sd_inpaint(self, seed, timestep, expected_slice): latents = self.get_latents(seed, shape=(4, 9, 64, 64)) encoder_hidden_states = self.get_encoder_hidden_states(seed) + timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) + with torch.no_grad(): sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample @@ -630,6 +760,8 @@ def test_compvis_sd_inpaint_fp16(self, seed, timestep, expected_slice): latents = self.get_latents(seed, shape=(4, 9, 64, 64), fp16=True) encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) + timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) + with torch.no_grad(): sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample @@ -656,6 +788,8 @@ def test_stabilityai_sd_v2_fp16(self, seed, timestep, expected_slice): latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True) encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True) + timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) + with torch.no_grad(): sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample