From 24db527dcca6d37cbec88995ab67e1fab0db0c16 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Sat, 3 Dec 2022 17:04:59 +0100 Subject: [PATCH 01/14] make attn slice recursive --- src/diffusers/models/attention.py | 16 +++- src/diffusers/models/unet_2d_blocks.py | 96 +++++++++++------------ src/diffusers/models/unet_2d_condition.py | 33 +++++--- src/diffusers/pipeline_utils.py | 30 +++++++ 4 files changed, 117 insertions(+), 58 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 31d1f1ff418a..edaef7b58c26 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -410,6 +410,7 @@ def __init__( ): super().__init__() self.only_cross_attention = only_cross_attention + self.attention_head_dim = attention_head_dim self.attn1 = CrossAttention( query_dim=dim, heads=num_attention_heads, @@ -448,7 +449,20 @@ def __init__( f" correctly and a GPU is available: {e}" ) - def _set_attention_slice(self, slice_size): + def set_attention_slice(self, slice_size): + head_dims = self.attention_head_dim + head_dims = [head_dims] if isinstance(head_dims, int) else head_dims + if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): + raise ValueError( + f"Make sure slice_size {slice_size} is a common divisor of " + f"the number of heads used in cross_attention: {head_dims}" + ) + if slice_size is not None and slice_size > min(head_dims): + raise ValueError( + f"slice_size {slice_size} has to be smaller or equal to " + f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" + ) + self.attn1._slice_size = slice_size self.attn2._slice_size = slice_size diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index d78804b18e75..d605061c0fa1 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -401,22 +401,22 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def set_attention_slice(self, slice_size): - head_dims = self.attn_num_head_channels - head_dims = [head_dims] if isinstance(head_dims, int) else head_dims - if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): - raise ValueError( - f"Make sure slice_size {slice_size} is a common divisor of " - f"the number of heads used in cross_attention: {head_dims}" - ) - if slice_size is not None and slice_size > min(head_dims): - raise ValueError( - f"slice_size {slice_size} has to be smaller or equal to " - f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" - ) - - for attn in self.attentions: - attn._set_attention_slice(slice_size) + # def set_attention_slice(self, slice_size): + # head_dims = self.attn_num_head_channels + # head_dims = [head_dims] if isinstance(head_dims, int) else head_dims + # if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): + # raise ValueError( + # f"Make sure slice_size {slice_size} is a common divisor of " + # f"the number of heads used in cross_attention: {head_dims}" + # ) + # if slice_size is not None and slice_size > min(head_dims): + # raise ValueError( + # f"slice_size {slice_size} has to be smaller or equal to " + # f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" + # ) + + # for attn in self.attentions: + # attn._set_attention_slice(slice_size) def forward(self, hidden_states, temb=None, encoder_hidden_states=None): hidden_states = self.resnets[0](hidden_states, temb) @@ -595,22 +595,22 @@ def __init__( self.gradient_checkpointing = False - def set_attention_slice(self, slice_size): - head_dims = self.attn_num_head_channels - head_dims = [head_dims] if isinstance(head_dims, int) else head_dims - if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): - raise ValueError( - f"Make sure slice_size {slice_size} is a common divisor of " - f"the number of heads used in cross_attention: {head_dims}" - ) - if slice_size is not None and slice_size > min(head_dims): - raise ValueError( - f"slice_size {slice_size} has to be smaller or equal to " - f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" - ) - - for attn in self.attentions: - attn._set_attention_slice(slice_size) + # def set_attention_slice(self, slice_size): + # head_dims = self.attn_num_head_channels + # head_dims = [head_dims] if isinstance(head_dims, int) else head_dims + # if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): + # raise ValueError( + # f"Make sure slice_size {slice_size} is a common divisor of " + # f"the number of heads used in cross_attention: {head_dims}" + # ) + # if slice_size is not None and slice_size > min(head_dims): + # raise ValueError( + # f"slice_size {slice_size} has to be smaller or equal to " + # f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" + # ) + + # for attn in self.attentions: + # attn._set_attention_slice(slice_size) def forward(self, hidden_states, temb=None, encoder_hidden_states=None): output_states = () @@ -1190,22 +1190,22 @@ def __init__( self.gradient_checkpointing = False - def set_attention_slice(self, slice_size): - head_dims = self.attn_num_head_channels - head_dims = [head_dims] if isinstance(head_dims, int) else head_dims - if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): - raise ValueError( - f"Make sure slice_size {slice_size} is a common divisor of " - f"the number of heads used in cross_attention: {head_dims}" - ) - if slice_size is not None and slice_size > min(head_dims): - raise ValueError( - f"slice_size {slice_size} has to be smaller or equal to " - f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" - ) - - for attn in self.attentions: - attn._set_attention_slice(slice_size) + # def set_attention_slice(self, slice_size): + # head_dims = self.attn_num_head_channels + # head_dims = [head_dims] if isinstance(head_dims, int) else head_dims + # if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): + # raise ValueError( + # f"Make sure slice_size {slice_size} is a common divisor of " + # f"the number of heads used in cross_attention: {head_dims}" + # ) + # if slice_size is not None and slice_size > min(head_dims): + # raise ValueError( + # f"slice_size {slice_size} has to be smaller or equal to " + # f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" + # ) + + # for attn in self.attentions: + # attn._set_attention_slice(slice_size) self.gradient_checkpointing = False diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index f9d3402d0619..c4c8c3b6b16e 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -229,6 +229,15 @@ def __init__( self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) def set_attention_slice(self, slice_size): + if slice_size == "auto": + if isinstance(self.unet.config.attention_head_dim, int): + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + else: + # if `attention_head_dim` is a list, take the smallest head size + slice_size = min(self.unet.config.attention_head_dim) + head_dims = self.config.attention_head_dim head_dims = [head_dims] if isinstance(head_dims, int) else head_dims if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): @@ -242,15 +251,21 @@ def set_attention_slice(self, slice_size): f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" ) - for block in self.down_blocks: - if hasattr(block, "attentions") and block.attentions is not None: - block.set_attention_slice(slice_size) - - self.mid_block.set_attention_slice(slice_size) - - for block in self.up_blocks: - if hasattr(block, "attentions") and block.attentions is not None: - block.set_attention_slice(slice_size) + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size) + + for child in module.children(): + fn_recursive_set_attention_slice(child) + + module_names, _, _ = self.extract_init_dict(dict(self.config)) + for module_name in module_names: + module = getattr(self, module_name) + if isinstance(module, torch.nn.Module): + fn_recursive_set_attention_slice(module) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)): diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index e65d55e20cd9..b069fe481d86 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -839,3 +839,33 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): module = getattr(self, module_name) if isinstance(module, torch.nn.Module): fn_recursive_set_mem_eff(module) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + self.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + + def set_attention_slice(self, slice_size: Optional[int]): + module_names, _, _ = self.extract_init_dict(dict(self.config)) + for module_name in module_names: + module = getattr(self, module_name) + if isinstance(module, torch.nn.Module) and hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size) From e813e628a0505f74ec4e94bd9169c140db203661 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Sat, 3 Dec 2022 17:16:54 +0100 Subject: [PATCH 02/14] remove set_attention_slice from blocks --- src/diffusers/models/unet_2d_blocks.py | 53 -------------------------- 1 file changed, 53 deletions(-) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index d605061c0fa1..63e2b809d721 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -401,23 +401,6 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - # def set_attention_slice(self, slice_size): - # head_dims = self.attn_num_head_channels - # head_dims = [head_dims] if isinstance(head_dims, int) else head_dims - # if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): - # raise ValueError( - # f"Make sure slice_size {slice_size} is a common divisor of " - # f"the number of heads used in cross_attention: {head_dims}" - # ) - # if slice_size is not None and slice_size > min(head_dims): - # raise ValueError( - # f"slice_size {slice_size} has to be smaller or equal to " - # f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" - # ) - - # for attn in self.attentions: - # attn._set_attention_slice(slice_size) - def forward(self, hidden_states, temb=None, encoder_hidden_states=None): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): @@ -595,23 +578,6 @@ def __init__( self.gradient_checkpointing = False - # def set_attention_slice(self, slice_size): - # head_dims = self.attn_num_head_channels - # head_dims = [head_dims] if isinstance(head_dims, int) else head_dims - # if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): - # raise ValueError( - # f"Make sure slice_size {slice_size} is a common divisor of " - # f"the number of heads used in cross_attention: {head_dims}" - # ) - # if slice_size is not None and slice_size > min(head_dims): - # raise ValueError( - # f"slice_size {slice_size} has to be smaller or equal to " - # f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" - # ) - - # for attn in self.attentions: - # attn._set_attention_slice(slice_size) - def forward(self, hidden_states, temb=None, encoder_hidden_states=None): output_states = () @@ -1190,25 +1156,6 @@ def __init__( self.gradient_checkpointing = False - # def set_attention_slice(self, slice_size): - # head_dims = self.attn_num_head_channels - # head_dims = [head_dims] if isinstance(head_dims, int) else head_dims - # if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): - # raise ValueError( - # f"Make sure slice_size {slice_size} is a common divisor of " - # f"the number of heads used in cross_attention: {head_dims}" - # ) - # if slice_size is not None and slice_size > min(head_dims): - # raise ValueError( - # f"slice_size {slice_size} has to be smaller or equal to " - # f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" - # ) - - # for attn in self.attentions: - # attn._set_attention_slice(slice_size) - - self.gradient_checkpointing = False - def forward( self, hidden_states, From a5dbf608265d63aa11b7156c301dd98e34db9e47 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Sat, 3 Dec 2022 17:22:23 +0100 Subject: [PATCH 03/14] fix copies --- .../versatile_diffusion/modeling_text_unet.py | 82 +++++-------------- 1 file changed, 22 insertions(+), 60 deletions(-) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index d1a3d4c55e91..b5e35ade3c47 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -307,6 +307,15 @@ def __init__( self.conv_out = LinearMultiDim(block_out_channels[0], out_channels, kernel_size=3, padding=1) def set_attention_slice(self, slice_size): + if slice_size == "auto": + if isinstance(self.unet.config.attention_head_dim, int): + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + else: + # if `attention_head_dim` is a list, take the smallest head size + slice_size = min(self.unet.config.attention_head_dim) + head_dims = self.config.attention_head_dim head_dims = [head_dims] if isinstance(head_dims, int) else head_dims if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): @@ -320,15 +329,21 @@ def set_attention_slice(self, slice_size): f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" ) - for block in self.down_blocks: - if hasattr(block, "attentions") and block.attentions is not None: - block.set_attention_slice(slice_size) + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size) - self.mid_block.set_attention_slice(slice_size) + for child in module.children(): + fn_recursive_set_attention_slice(child) - for block in self.up_blocks: - if hasattr(block, "attentions") and block.attentions is not None: - block.set_attention_slice(slice_size) + module_names, _, _ = self.extract_init_dict(dict(self.config)) + for module_name in module_names: + module = getattr(self, module_name) + if isinstance(module, torch.nn.Module): + fn_recursive_set_attention_slice(module) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlockFlat, DownBlockFlat, CrossAttnUpBlockFlat, UpBlockFlat)): @@ -739,23 +754,6 @@ def __init__( self.gradient_checkpointing = False - def set_attention_slice(self, slice_size): - head_dims = self.attn_num_head_channels - head_dims = [head_dims] if isinstance(head_dims, int) else head_dims - if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): - raise ValueError( - f"Make sure slice_size {slice_size} is a common divisor of " - f"the number of heads used in cross_attention: {head_dims}" - ) - if slice_size is not None and slice_size > min(head_dims): - raise ValueError( - f"slice_size {slice_size} has to be smaller or equal to " - f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" - ) - - for attn in self.attentions: - attn._set_attention_slice(slice_size) - def forward(self, hidden_states, temb=None, encoder_hidden_states=None): output_states = () @@ -948,25 +946,6 @@ def __init__( self.gradient_checkpointing = False - def set_attention_slice(self, slice_size): - head_dims = self.attn_num_head_channels - head_dims = [head_dims] if isinstance(head_dims, int) else head_dims - if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): - raise ValueError( - f"Make sure slice_size {slice_size} is a common divisor of " - f"the number of heads used in cross_attention: {head_dims}" - ) - if slice_size is not None and slice_size > min(head_dims): - raise ValueError( - f"slice_size {slice_size} has to be smaller or equal to " - f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" - ) - - for attn in self.attentions: - attn._set_attention_slice(slice_size) - - self.gradient_checkpointing = False - def forward( self, hidden_states, @@ -1092,23 +1071,6 @@ def __init__( self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def set_attention_slice(self, slice_size): - head_dims = self.attn_num_head_channels - head_dims = [head_dims] if isinstance(head_dims, int) else head_dims - if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): - raise ValueError( - f"Make sure slice_size {slice_size} is a common divisor of " - f"the number of heads used in cross_attention: {head_dims}" - ) - if slice_size is not None and slice_size > min(head_dims): - raise ValueError( - f"slice_size {slice_size} has to be smaller or equal to " - f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" - ) - - for attn in self.attentions: - attn._set_attention_slice(slice_size) - def forward(self, hidden_states, temb=None, encoder_hidden_states=None): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): From 2f85c1b37508645c91de3991059d2b8497301125 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Sat, 3 Dec 2022 17:25:13 +0100 Subject: [PATCH 04/14] make enable_attention_slicing base class method of DiffusionPipeline --- .../community/clip_guided_stable_diffusion.py | 10 ------ .../community/composable_stable_diffusion.py | 25 -------------- examples/community/imagic_stable_diffusion.py | 25 -------------- examples/community/img2img_inpainting.py | 27 --------------- .../community/interpolate_stable_diffusion.py | 27 --------------- examples/community/lpw_stable_diffusion.py | 27 --------------- .../multilingual_stable_diffusion.py | 27 --------------- examples/community/sd_text2img_k_diffusion.py | 27 --------------- .../community/seed_resize_stable_diffusion.py | 27 --------------- .../community/speech_to_image_diffusion.py | 8 ----- examples/community/stable_diffusion_mega.py | 27 --------------- examples/community/text_inpainting.py | 27 --------------- .../alt_diffusion/pipeline_alt_diffusion.py | 32 ----------------- .../pipeline_alt_diffusion_img2img.py | 32 ----------------- .../pipeline_cycle_diffusion.py | 34 ------------------- .../pipeline_stable_diffusion.py | 32 ----------------- ...peline_stable_diffusion_image_variation.py | 34 ------------------- .../pipeline_stable_diffusion_img2img.py | 34 ------------------- .../pipeline_stable_diffusion_inpaint.py | 34 ------------------- ...ipeline_stable_diffusion_inpaint_legacy.py | 34 ------------------- .../pipeline_stable_diffusion_upscale.py | 34 ------------------- .../pipeline_stable_diffusion_safe.py | 27 --------------- .../pipeline_versatile_diffusion.py | 28 --------------- ...ipeline_versatile_diffusion_dual_guided.py | 34 ------------------- ...ine_versatile_diffusion_image_variation.py | 34 ------------------- ...eline_versatile_diffusion_text_to_image.py | 34 ------------------- 26 files changed, 741 deletions(-) diff --git a/examples/community/clip_guided_stable_diffusion.py b/examples/community/clip_guided_stable_diffusion.py index 7a319bddf053..43c6a7b5aa18 100644 --- a/examples/community/clip_guided_stable_diffusion.py +++ b/examples/community/clip_guided_stable_diffusion.py @@ -88,16 +88,6 @@ def __init__( set_requires_grad(self.text_encoder, False) set_requires_grad(self.clip_model, False) - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - self.unet.set_attention_slice(slice_size) - - def disable_attention_slicing(self): - self.enable_attention_slicing(None) - def freeze_vae(self): set_requires_grad(self.vae, False) diff --git a/examples/community/composable_stable_diffusion.py b/examples/community/composable_stable_diffusion.py index eb207e1bdd47..2ecc9f8342a4 100644 --- a/examples/community/composable_stable_diffusion.py +++ b/examples/community/composable_stable_diffusion.py @@ -62,31 +62,6 @@ def __init__( feature_extractor=feature_extractor, ) - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - self.unet.set_attention_slice(slice_size) - - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - @torch.no_grad() def __call__( self, diff --git a/examples/community/imagic_stable_diffusion.py b/examples/community/imagic_stable_diffusion.py index f044a1f568cc..6235df0e62c9 100644 --- a/examples/community/imagic_stable_diffusion.py +++ b/examples/community/imagic_stable_diffusion.py @@ -105,31 +105,6 @@ def __init__( feature_extractor=feature_extractor, ) - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - self.unet.set_attention_slice(slice_size) - - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - def train( self, prompt: Union[str, List[str]], diff --git a/examples/community/img2img_inpainting.py b/examples/community/img2img_inpainting.py index 3fa7db13a482..7f200a67535a 100644 --- a/examples/community/img2img_inpainting.py +++ b/examples/community/img2img_inpainting.py @@ -129,33 +129,6 @@ def __init__( feature_extractor=feature_extractor, ) - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - self.unet.set_attention_slice(slice_size) - - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - @torch.no_grad() def __call__( self, diff --git a/examples/community/interpolate_stable_diffusion.py b/examples/community/interpolate_stable_diffusion.py index 4d7a73f5ba69..5ac14c466b78 100644 --- a/examples/community/interpolate_stable_diffusion.py +++ b/examples/community/interpolate_stable_diffusion.py @@ -120,33 +120,6 @@ def __init__( feature_extractor=feature_extractor, ) - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - self.unet.set_attention_slice(slice_size) - - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - @torch.no_grad() def __call__( self, diff --git a/examples/community/lpw_stable_diffusion.py b/examples/community/lpw_stable_diffusion.py index 8b067f93e733..44279e079574 100644 --- a/examples/community/lpw_stable_diffusion.py +++ b/examples/community/lpw_stable_diffusion.py @@ -488,33 +488,6 @@ def __init__( feature_extractor=feature_extractor, ) - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - self.unet.set_attention_slice(slice_size) - - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - def enable_sequential_cpu_offload(self): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, diff --git a/examples/community/multilingual_stable_diffusion.py b/examples/community/multilingual_stable_diffusion.py index 19974d6df08b..5f3260d49616 100644 --- a/examples/community/multilingual_stable_diffusion.py +++ b/examples/community/multilingual_stable_diffusion.py @@ -135,33 +135,6 @@ def __init__( feature_extractor=feature_extractor, ) - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - self.unet.set_attention_slice(slice_size) - - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - @torch.no_grad() def __call__( self, diff --git a/examples/community/sd_text2img_k_diffusion.py b/examples/community/sd_text2img_k_diffusion.py index 1c2ba36013f8..8bb8f294dc09 100755 --- a/examples/community/sd_text2img_k_diffusion.py +++ b/examples/community/sd_text2img_k_diffusion.py @@ -115,33 +115,6 @@ def set_sampler(self, scheduler_type: str): sampling = getattr(library, "sampling") self.sampler = getattr(sampling, scheduler_type) - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - self.unet.set_attention_slice(slice_size) - - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, diff --git a/examples/community/seed_resize_stable_diffusion.py b/examples/community/seed_resize_stable_diffusion.py index 92cd1c04f9f3..c1ee72f72b47 100644 --- a/examples/community/seed_resize_stable_diffusion.py +++ b/examples/community/seed_resize_stable_diffusion.py @@ -67,33 +67,6 @@ def __init__( feature_extractor=feature_extractor, ) - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - self.unet.set_attention_slice(slice_size) - - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - @torch.no_grad() def __call__( self, diff --git a/examples/community/speech_to_image_diffusion.py b/examples/community/speech_to_image_diffusion.py index 17bc08e3c291..6e0b97ad5830 100644 --- a/examples/community/speech_to_image_diffusion.py +++ b/examples/community/speech_to_image_diffusion.py @@ -62,14 +62,6 @@ def __init__( feature_extractor=feature_extractor, ) - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - if slice_size == "auto": - slice_size = self.unet.config.attention_head_dim // 2 - self.unet.set_attention_slice(slice_size) - - def disable_attention_slicing(self): - self.enable_attention_slicing(None) - @torch.no_grad() def __call__( self, diff --git a/examples/community/stable_diffusion_mega.py b/examples/community/stable_diffusion_mega.py index 30699b6a1bf3..e33e966e968f 100644 --- a/examples/community/stable_diffusion_mega.py +++ b/examples/community/stable_diffusion_mega.py @@ -90,33 +90,6 @@ def __init__( def components(self) -> Dict[str, Any]: return {k: getattr(self, k) for k in self.config.keys() if not k.startswith("_")} - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - self.unet.set_attention_slice(slice_size) - - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - @torch.no_grad() def inpaint( self, diff --git a/examples/community/text_inpainting.py b/examples/community/text_inpainting.py index f02d449fbd1d..37de78d5c5bd 100644 --- a/examples/community/text_inpainting.py +++ b/examples/community/text_inpainting.py @@ -120,33 +120,6 @@ def __init__( feature_extractor=feature_extractor, ) - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - self.unet.set_attention_slice(slice_size) - - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - def enable_sequential_cpu_offload(self): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index fb64a34a0bd8..993b9d305186 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -166,38 +166,6 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - if isinstance(self.unet.config.attention_head_dim, int): - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - else: - # if `attention_head_dim` is a list, take the smallest head size - slice_size = min(self.unet.config.attention_head_dim) - - self.unet.set_attention_slice(slice_size) - - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - def enable_vae_slicing(self): r""" Enable sliced VAE decoding. diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 346f5f727bb8..3a8ad7c2407f 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -179,38 +179,6 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - if isinstance(self.unet.config.attention_head_dim, int): - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - else: - # if `attention_head_dim` is a list, take the smallest head size - slice_size = min(self.unet.config.attention_head_dim) - - self.unet.set_attention_slice(slice_size) - - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py index a688a52a7a6e..1639b723af99 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py @@ -209,40 +209,6 @@ def __init__( ) self.register_to_config(requires_safety_checker=requires_safety_checker) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - if isinstance(self.unet.config.attention_head_dim, int): - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - else: - # if `attention_head_dim` is a list, take the smallest head size - slice_size = min(self.unet.config.attention_head_dim) - - self.unet.set_attention_slice(slice_size) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index a3a8703f3ea4..cbd817cc0588 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -165,38 +165,6 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - if isinstance(self.unet.config.attention_head_dim, int): - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - else: - # if `attention_head_dim` is a list, take the smallest head size - slice_size = min(self.unet.config.attention_head_dim) - - self.unet.set_attention_slice(slice_size) - - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - def enable_vae_slicing(self): r""" Enable sliced VAE decoding. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py index d77e71653078..1d34280d34de 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py @@ -134,40 +134,6 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - if isinstance(self.unet.config.attention_head_dim, int): - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - else: - # if `attention_head_dim` is a list, take the smallest head size - slice_size = min(self.unet.config.attention_head_dim) - - self.unet.set_attention_slice(slice_size) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 933f59c3bd32..b343a165e124 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -178,40 +178,6 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - if isinstance(self.unet.config.attention_head_dim, int): - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - else: - # if `attention_head_dim` is a list, take the smallest head size - slice_size = min(self.unet.config.attention_head_dim) - - self.unet.set_attention_slice(slice_size) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index bc416f57d3e0..446b3ef2a8a2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -243,40 +243,6 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - if isinstance(self.unet.config.attention_head_dim, int): - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - else: - # if `attention_head_dim` is a list, take the smallest head size - slice_size = min(self.unet.config.attention_head_dim) - - self.unet.set_attention_slice(slice_size) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index 60d52eaa1ab4..7c248ac1a721 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -191,40 +191,6 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - if isinstance(self.unet.config.attention_head_dim, int): - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - else: - # if `attention_head_dim` is a list, take the smallest head size - slice_size = min(self.unet.config.attention_head_dim) - - self.unet.set_attention_slice(slice_size) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): r""" diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 72981aebe184..6510d8a98e39 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -92,40 +92,6 @@ def __init__( ) self.register_to_config(max_noise_level=max_noise_level) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - if isinstance(self.unet.config.attention_head_dim, int): - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - else: - # if `attention_head_dim` is a list, take the smallest head size - slice_size = min(self.unet.config.attention_head_dim) - - self.unet.set_attention_slice(slice_size) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, diff --git a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py index 5cb0f2c03daf..921befcdf3cb 100644 --- a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +++ b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py @@ -182,33 +182,6 @@ def safety_concept(self, concept): """ self._safety_text_concept = concept - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 - self.unet.set_attention_slice(slice_size) - - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - def enable_sequential_cpu_offload(self): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py index 7be7f4d3aee6..e9ef505040c8 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py @@ -80,34 +80,6 @@ def __init__( ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.image_unet.config.attention_head_dim // 2 - self.image_unet.set_attention_slice(slice_size) - self.text_unet.set_attention_slice(slice_size) - - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - @torch.no_grad() def image_variation( self, diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py index 3a90ae2c7620..89b333c9c499 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py @@ -147,40 +147,6 @@ def _revert_dual_attention(self): self.image_unet.register_to_config(dual_cross_attention=False) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing with unet->image_unet - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - if isinstance(self.image_unet.config.attention_head_dim, int): - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.image_unet.config.attention_head_dim // 2 - else: - # if `attention_head_dim` is a list, take the smallest head size - slice_size = min(self.image_unet.config.attention_head_dim) - - self.image_unet.set_attention_slice(slice_size) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py index b68dd244ce47..87924fdff81a 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py @@ -73,40 +73,6 @@ def __init__( ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing with unet->image_unet - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - if isinstance(self.image_unet.config.attention_head_dim, int): - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.image_unet.config.attention_head_dim // 2 - else: - # if `attention_head_dim` is a list, take the smallest head size - slice_size = min(self.image_unet.config.attention_head_dim) - - self.image_unet.set_attention_slice(slice_size) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py index c9c4bb7dc40e..3920955f672d 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py @@ -98,40 +98,6 @@ def _swap_unet_attention_blocks(self): def remove_unused_weights(self): self.register_modules(text_unet=None) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing with unet->image_unet - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. - - Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. - """ - if slice_size == "auto": - if isinstance(self.image_unet.config.attention_head_dim, int): - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.image_unet.config.attention_head_dim // 2 - else: - # if `attention_head_dim` is a list, take the smallest head size - slice_size = min(self.image_unet.config.attention_head_dim) - - self.image_unet.set_attention_slice(slice_size) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing - def disable_attention_slicing(self): - r""" - Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go - back to computing attention in one step. - """ - # set slice_size = `None` to disable `attention slicing` - self.enable_attention_slicing(None) - def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, From 05677bc5557fd845a9d18b5e639b82f7e19791da Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Sat, 3 Dec 2022 17:34:16 +0100 Subject: [PATCH 05/14] fix set_attention_slice --- src/diffusers/models/unet_2d_condition.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index c4c8c3b6b16e..f1cbf4a5d561 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -230,13 +230,13 @@ def __init__( def set_attention_slice(self, slice_size): if slice_size == "auto": - if isinstance(self.unet.config.attention_head_dim, int): + if isinstance(self.config.attention_head_dim, int): # half the attention head size is usually a good trade-off between # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 + slice_size = self.config.attention_head_dim // 2 else: # if `attention_head_dim` is a list, take the smallest head size - slice_size = min(self.unet.config.attention_head_dim) + slice_size = min(self.config.attention_head_dim) head_dims = self.config.attention_head_dim head_dims = [head_dims] if isinstance(head_dims, int) else head_dims From 7546e18f1a8d400bda0a1116deba65854c587238 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Sat, 3 Dec 2022 17:38:58 +0100 Subject: [PATCH 06/14] fix set_attention_slice --- src/diffusers/models/unet_2d_condition.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index f1cbf4a5d561..b886b502d638 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -261,11 +261,8 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module): for child in module.children(): fn_recursive_set_attention_slice(child) - module_names, _, _ = self.extract_init_dict(dict(self.config)) - for module_name in module_names: - module = getattr(self, module_name) - if isinstance(module, torch.nn.Module): - fn_recursive_set_attention_slice(module) + for module in self.children(): + fn_recursive_set_attention_slice(module) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)): From 6800c0400f12fa03369c8df182386f4a409e10e0 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Sat, 3 Dec 2022 17:41:07 +0100 Subject: [PATCH 07/14] fix copies --- .../versatile_diffusion/modeling_text_unet.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index b5e35ade3c47..d33d2525bdd1 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -308,13 +308,13 @@ def __init__( def set_attention_slice(self, slice_size): if slice_size == "auto": - if isinstance(self.unet.config.attention_head_dim, int): + if isinstance(self.config.attention_head_dim, int): # half the attention head size is usually a good trade-off between # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 + slice_size = self.config.attention_head_dim // 2 else: # if `attention_head_dim` is a list, take the smallest head size - slice_size = min(self.unet.config.attention_head_dim) + slice_size = min(self.config.attention_head_dim) head_dims = self.config.attention_head_dim head_dims = [head_dims] if isinstance(head_dims, int) else head_dims @@ -339,11 +339,8 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module): for child in module.children(): fn_recursive_set_attention_slice(child) - module_names, _, _ = self.extract_init_dict(dict(self.config)) - for module_name in module_names: - module = getattr(self, module_name) - if isinstance(module, torch.nn.Module): - fn_recursive_set_attention_slice(module) + for module in self.children(): + fn_recursive_set_attention_slice(module) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlockFlat, DownBlockFlat, CrossAttnUpBlockFlat, UpBlockFlat)): From fc9dda1f7c417eebae2af123f09ba9c20d8e3fab Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 5 Dec 2022 15:41:28 +0000 Subject: [PATCH 08/14] add tests --- src/diffusers/models/attention.py | 36 +++++------ src/diffusers/models/unet_2d_condition.py | 73 ++++++++++++++------- src/diffusers/pipeline_utils.py | 5 +- tests/models/test_models_unet_2d.py | 78 +++++++++++++++++++++++ 4 files changed, 147 insertions(+), 45 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index edaef7b58c26..53d0d8c9c8d5 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -14,7 +14,7 @@ import math import warnings from dataclasses import dataclass -from typing import Optional +from typing import List, Optional, Union import torch import torch.nn.functional as F @@ -174,9 +174,12 @@ def __init__( self.norm_out = nn.LayerNorm(inner_dim) self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) - def _set_attention_slice(self, slice_size): - for block in self.transformer_blocks: - block._set_attention_slice(slice_size) + def _set_attention_slice(self, slice_size: Union[int, List[int]]): + if isinstance(slice_size, int): + slice_size = len(self.transformer_blocks) * [slice_size] + + for i, block in enumerate(self.transformer_blocks): + block._set_attention_slice(slice_size[i]) def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): """ @@ -411,6 +414,7 @@ def __init__( super().__init__() self.only_cross_attention = only_cross_attention self.attention_head_dim = attention_head_dim + self.cross_attention_dim = cross_attention_dim self.attn1 = CrossAttention( query_dim=dim, heads=num_attention_heads, @@ -449,23 +453,6 @@ def __init__( f" correctly and a GPU is available: {e}" ) - def set_attention_slice(self, slice_size): - head_dims = self.attention_head_dim - head_dims = [head_dims] if isinstance(head_dims, int) else head_dims - if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): - raise ValueError( - f"Make sure slice_size {slice_size} is a common divisor of " - f"the number of heads used in cross_attention: {head_dims}" - ) - if slice_size is not None and slice_size > min(head_dims): - raise ValueError( - f"slice_size {slice_size} has to be smaller or equal to " - f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" - ) - - self.attn1._slice_size = slice_size - self.attn2._slice_size = slice_size - def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): if not is_xformers_available(): print("Here is how to install it") @@ -548,6 +535,7 @@ def __init__( # for slice_size > 0 the attention score computation # is split across the batch axis to save memory # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads self._slice_size = None self._use_memory_efficient_attention_xformers = False @@ -573,6 +561,12 @@ def reshape_batch_dim_to_heads(self, tensor): tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) return tensor + def set_attention_slice(self, slice_size): + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + self._slice_size = slice_size + def forward(self, hidden_states, context=None, mask=None): batch_size, sequence_length, _ = hidden_states.shape diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index b886b502d638..5caae10b489f 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -229,40 +229,69 @@ def __init__( self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + if slice_size == "auto": - if isinstance(self.config.attention_head_dim, int): - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.config.attention_head_dim // 2 - else: - # if `attention_head_dim` is a list, take the smallest head size - slice_size = min(self.config.attention_head_dim) + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] - head_dims = self.config.attention_head_dim - head_dims = [head_dims] if isinstance(head_dims, int) else head_dims - if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): - raise ValueError( - f"Make sure slice_size {slice_size} is a common divisor of " - f"the number of heads used in cross_attention: {head_dims}" - ) - if slice_size is not None and slice_size > min(head_dims): + slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): raise ValueError( - f"slice_size {slice_size} has to be smaller or equal to " - f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message - def fn_recursive_set_attention_slice(module: torch.nn.Module): + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): if hasattr(module, "set_attention_slice"): - module.set_attention_slice(slice_size) + module.set_attention_slice(slice_size.pop()) for child in module.children(): - fn_recursive_set_attention_slice(child) + fn_recursive_set_attention_slice(child, slice_size) + reversed_slice_size = list(reversed(slice_size)) for module in self.children(): - fn_recursive_set_attention_slice(module) + fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)): diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index b069fe481d86..2ae968554f0d 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -850,8 +850,9 @@ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto Args: slice_size (`str` or `int`, *optional*, defaults to `"auto"`): When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, - `attention_head_dim` must be a multiple of `slice_size`. + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. """ self.set_attention_slice(slice_size) diff --git a/tests/models/test_models_unet_2d.py b/tests/models/test_models_unet_2d.py index 59b9e02ff8b9..f2b1240ebff5 100644 --- a/tests/models/test_models_unet_2d.py +++ b/tests/models/test_models_unet_2d.py @@ -479,6 +479,84 @@ def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"): return model + def test_set_attention_slice_auto(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + unet = self.get_unet_model() + unet.set_attention_slice("auto") + + latents = self.get_latents(0) + encoder_hidden_states = self.get_encoder_hidden_states(0) + timestep = 1 + + with torch.no_grad(): + _ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + + mem_bytes = torch.cuda.max_memory_allocated() + print(mem_bytes) + assert False + + def test_set_attention_slice_max(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + unet = self.get_unet_model() + unet.set_attention_slice("max") + + latents = self.get_latents(0) + encoder_hidden_states = self.get_encoder_hidden_states(0) + timestep = 1 + + with torch.no_grad(): + _ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + + mem_bytes = torch.cuda.max_memory_allocated() + print(mem_bytes) + assert False + + def test_set_attention_slice_int(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + unet = self.get_unet_model() + unet.set_attention_slice(2) + + latents = self.get_latents(0) + encoder_hidden_states = self.get_encoder_hidden_states(0) + timestep = 1 + + with torch.no_grad(): + _ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + + mem_bytes = torch.cuda.max_memory_allocated() + print(mem_bytes) + assert False + + def test_set_attention_slice_list(self): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + + # there are 16 slicable layers + slice_list = 8 * [2, 3] + unet = self.get_unet_model() + unet.set_attention_slice(slice_list) + + latents = self.get_latents(0) + encoder_hidden_states = self.get_encoder_hidden_states(0) + timestep = 1 + + with torch.no_grad(): + _ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample + + mem_bytes = torch.cuda.max_memory_allocated() + print(mem_bytes) + assert False + def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False): dtype = torch.float16 if fp16 else torch.float32 hidden_states = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype) From 876453577500189e1d03d65334f62e904fbae0f1 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 5 Dec 2022 15:59:15 +0000 Subject: [PATCH 09/14] up --- tests/models/test_models_unet_2d.py | 74 ++++++++++++++++++++++------- 1 file changed, 56 insertions(+), 18 deletions(-) diff --git a/tests/models/test_models_unet_2d.py b/tests/models/test_models_unet_2d.py index f2b1240ebff5..f270daf50b7e 100644 --- a/tests/models/test_models_unet_2d.py +++ b/tests/models/test_models_unet_2d.py @@ -334,6 +334,30 @@ def test_model_with_use_linear_projection(self): expected_shape = inputs_dict["sample"].shape self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + def test_model_attention_slicing(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + model.set_attention_slice("auto") + with torch.no_grad(): + output = model(**inputs_dict) + assert output is not None + + model.set_attention_slice("max") + with torch.no_grad(): + output = model(**inputs_dict) + assert output is not None + + model.set_attention_slice(2) + with torch.no_grad(): + output = model(**inputs_dict) + assert output is not None + class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): model_class = UNet2DModel @@ -487,16 +511,16 @@ def test_set_attention_slice_auto(self): unet = self.get_unet_model() unet.set_attention_slice("auto") - latents = self.get_latents(0) - encoder_hidden_states = self.get_encoder_hidden_states(0) + latents = self.get_latents(33) + encoder_hidden_states = self.get_encoder_hidden_states(33) timestep = 1 with torch.no_grad(): _ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample mem_bytes = torch.cuda.max_memory_allocated() - print(mem_bytes) - assert False + + assert mem_bytes < 5 * 10**9 def test_set_attention_slice_max(self): torch.cuda.empty_cache() @@ -506,16 +530,16 @@ def test_set_attention_slice_max(self): unet = self.get_unet_model() unet.set_attention_slice("max") - latents = self.get_latents(0) - encoder_hidden_states = self.get_encoder_hidden_states(0) + latents = self.get_latents(33) + encoder_hidden_states = self.get_encoder_hidden_states(33) timestep = 1 with torch.no_grad(): _ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample mem_bytes = torch.cuda.max_memory_allocated() - print(mem_bytes) - assert False + + assert mem_bytes < 5 * 10**9 def test_set_attention_slice_int(self): torch.cuda.empty_cache() @@ -525,37 +549,37 @@ def test_set_attention_slice_int(self): unet = self.get_unet_model() unet.set_attention_slice(2) - latents = self.get_latents(0) - encoder_hidden_states = self.get_encoder_hidden_states(0) + latents = self.get_latents(33) + encoder_hidden_states = self.get_encoder_hidden_states(33) timestep = 1 with torch.no_grad(): _ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample mem_bytes = torch.cuda.max_memory_allocated() - print(mem_bytes) - assert False + + assert mem_bytes < 5 * 10**9 def test_set_attention_slice_list(self): torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() torch.cuda.reset_peak_memory_stats() - # there are 16 slicable layers - slice_list = 8 * [2, 3] + # there are 32 slicable layers + slice_list = 16 * [2, 3] unet = self.get_unet_model() unet.set_attention_slice(slice_list) - latents = self.get_latents(0) - encoder_hidden_states = self.get_encoder_hidden_states(0) + latents = self.get_latents(33) + encoder_hidden_states = self.get_encoder_hidden_states(33) timestep = 1 with torch.no_grad(): _ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample mem_bytes = torch.cuda.max_memory_allocated() - print(mem_bytes) - assert False + + assert mem_bytes < 5 * 10**9 def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False): dtype = torch.float16 if fp16 else torch.float32 @@ -578,6 +602,8 @@ def test_compvis_sd_v1_4(self, seed, timestep, expected_slice): latents = self.get_latents(seed) encoder_hidden_states = self.get_encoder_hidden_states(seed) + timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) + with torch.no_grad(): sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample @@ -604,6 +630,8 @@ def test_compvis_sd_v1_4_fp16(self, seed, timestep, expected_slice): latents = self.get_latents(seed, fp16=True) encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) + timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) + with torch.no_grad(): sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample @@ -630,6 +658,8 @@ def test_compvis_sd_v1_5(self, seed, timestep, expected_slice): latents = self.get_latents(seed) encoder_hidden_states = self.get_encoder_hidden_states(seed) + timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) + with torch.no_grad(): sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample @@ -656,6 +686,8 @@ def test_compvis_sd_v1_5_fp16(self, seed, timestep, expected_slice): latents = self.get_latents(seed, fp16=True) encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) + timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) + with torch.no_grad(): sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample @@ -682,6 +714,8 @@ def test_compvis_sd_inpaint(self, seed, timestep, expected_slice): latents = self.get_latents(seed, shape=(4, 9, 64, 64)) encoder_hidden_states = self.get_encoder_hidden_states(seed) + timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) + with torch.no_grad(): sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample @@ -708,6 +742,8 @@ def test_compvis_sd_inpaint_fp16(self, seed, timestep, expected_slice): latents = self.get_latents(seed, shape=(4, 9, 64, 64), fp16=True) encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) + timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) + with torch.no_grad(): sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample @@ -734,6 +770,8 @@ def test_stabilityai_sd_v2_fp16(self, seed, timestep, expected_slice): latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True) encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True) + timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device) + with torch.no_grad(): sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample From 94135293391e89144d5dd0ccdac306f1927be69c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 5 Dec 2022 16:07:18 +0000 Subject: [PATCH 10/14] up --- .../community/clip_guided_stable_diffusion.py | 10 +++++++ .../community/composable_stable_diffusion.py | 25 +++++++++++++++++ examples/community/imagic_stable_diffusion.py | 25 +++++++++++++++++ examples/community/img2img_inpainting.py | 27 +++++++++++++++++++ .../community/interpolate_stable_diffusion.py | 27 +++++++++++++++++++ examples/community/lpw_stable_diffusion.py | 27 +++++++++++++++++++ .../multilingual_stable_diffusion.py | 27 +++++++++++++++++++ examples/community/sd_text2img_k_diffusion.py | 27 +++++++++++++++++++ .../community/seed_resize_stable_diffusion.py | 27 +++++++++++++++++++ .../community/speech_to_image_diffusion.py | 8 ++++++ examples/community/stable_diffusion_mega.py | 27 +++++++++++++++++++ examples/community/text_inpainting.py | 27 +++++++++++++++++++ src/diffusers/models/attention.py | 11 +------- tests/models/test_models_unet_2d.py | 18 +++++++++++++ 14 files changed, 303 insertions(+), 10 deletions(-) diff --git a/examples/community/clip_guided_stable_diffusion.py b/examples/community/clip_guided_stable_diffusion.py index 43c6a7b5aa18..7a319bddf053 100644 --- a/examples/community/clip_guided_stable_diffusion.py +++ b/examples/community/clip_guided_stable_diffusion.py @@ -88,6 +88,16 @@ def __init__( set_requires_grad(self.text_encoder, False) set_requires_grad(self.clip_model, False) + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + self.enable_attention_slicing(None) + def freeze_vae(self): set_requires_grad(self.vae, False) diff --git a/examples/community/composable_stable_diffusion.py b/examples/community/composable_stable_diffusion.py index 2ecc9f8342a4..eb207e1bdd47 100644 --- a/examples/community/composable_stable_diffusion.py +++ b/examples/community/composable_stable_diffusion.py @@ -62,6 +62,31 @@ def __init__( feature_extractor=feature_extractor, ) + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + @torch.no_grad() def __call__( self, diff --git a/examples/community/imagic_stable_diffusion.py b/examples/community/imagic_stable_diffusion.py index 6235df0e62c9..f044a1f568cc 100644 --- a/examples/community/imagic_stable_diffusion.py +++ b/examples/community/imagic_stable_diffusion.py @@ -105,6 +105,31 @@ def __init__( feature_extractor=feature_extractor, ) + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + def train( self, prompt: Union[str, List[str]], diff --git a/examples/community/img2img_inpainting.py b/examples/community/img2img_inpainting.py index 7f200a67535a..3fa7db13a482 100644 --- a/examples/community/img2img_inpainting.py +++ b/examples/community/img2img_inpainting.py @@ -129,6 +129,33 @@ def __init__( feature_extractor=feature_extractor, ) + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + @torch.no_grad() def __call__( self, diff --git a/examples/community/interpolate_stable_diffusion.py b/examples/community/interpolate_stable_diffusion.py index 5ac14c466b78..4d7a73f5ba69 100644 --- a/examples/community/interpolate_stable_diffusion.py +++ b/examples/community/interpolate_stable_diffusion.py @@ -120,6 +120,33 @@ def __init__( feature_extractor=feature_extractor, ) + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + @torch.no_grad() def __call__( self, diff --git a/examples/community/lpw_stable_diffusion.py b/examples/community/lpw_stable_diffusion.py index 44279e079574..8b067f93e733 100644 --- a/examples/community/lpw_stable_diffusion.py +++ b/examples/community/lpw_stable_diffusion.py @@ -488,6 +488,33 @@ def __init__( feature_extractor=feature_extractor, ) + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + def enable_sequential_cpu_offload(self): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, diff --git a/examples/community/multilingual_stable_diffusion.py b/examples/community/multilingual_stable_diffusion.py index 5f3260d49616..19974d6df08b 100644 --- a/examples/community/multilingual_stable_diffusion.py +++ b/examples/community/multilingual_stable_diffusion.py @@ -135,6 +135,33 @@ def __init__( feature_extractor=feature_extractor, ) + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + @torch.no_grad() def __call__( self, diff --git a/examples/community/sd_text2img_k_diffusion.py b/examples/community/sd_text2img_k_diffusion.py index 8bb8f294dc09..1c2ba36013f8 100755 --- a/examples/community/sd_text2img_k_diffusion.py +++ b/examples/community/sd_text2img_k_diffusion.py @@ -115,6 +115,33 @@ def set_sampler(self, scheduler_type: str): sampling = getattr(library, "sampling") self.sampler = getattr(sampling, scheduler_type) + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, diff --git a/examples/community/seed_resize_stable_diffusion.py b/examples/community/seed_resize_stable_diffusion.py index c1ee72f72b47..92cd1c04f9f3 100644 --- a/examples/community/seed_resize_stable_diffusion.py +++ b/examples/community/seed_resize_stable_diffusion.py @@ -67,6 +67,33 @@ def __init__( feature_extractor=feature_extractor, ) + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + @torch.no_grad() def __call__( self, diff --git a/examples/community/speech_to_image_diffusion.py b/examples/community/speech_to_image_diffusion.py index 6e0b97ad5830..17bc08e3c291 100644 --- a/examples/community/speech_to_image_diffusion.py +++ b/examples/community/speech_to_image_diffusion.py @@ -62,6 +62,14 @@ def __init__( feature_extractor=feature_extractor, ) + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + if slice_size == "auto": + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + self.enable_attention_slicing(None) + @torch.no_grad() def __call__( self, diff --git a/examples/community/stable_diffusion_mega.py b/examples/community/stable_diffusion_mega.py index e33e966e968f..30699b6a1bf3 100644 --- a/examples/community/stable_diffusion_mega.py +++ b/examples/community/stable_diffusion_mega.py @@ -90,6 +90,33 @@ def __init__( def components(self) -> Dict[str, Any]: return {k: getattr(self, k) for k in self.config.keys() if not k.startswith("_")} + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + @torch.no_grad() def inpaint( self, diff --git a/examples/community/text_inpainting.py b/examples/community/text_inpainting.py index 37de78d5c5bd..f02d449fbd1d 100644 --- a/examples/community/text_inpainting.py +++ b/examples/community/text_inpainting.py @@ -120,6 +120,33 @@ def __init__( feature_extractor=feature_extractor, ) + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, + `attention_head_dim` must be a multiple of `slice_size`. + """ + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + r""" + Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go + back to computing attention in one step. + """ + # set slice_size = `None` to disable `attention slicing` + self.enable_attention_slicing(None) + def enable_sequential_cpu_offload(self): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 53d0d8c9c8d5..a38324a03884 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -14,7 +14,7 @@ import math import warnings from dataclasses import dataclass -from typing import List, Optional, Union +from typing import Optional import torch import torch.nn.functional as F @@ -174,13 +174,6 @@ def __init__( self.norm_out = nn.LayerNorm(inner_dim) self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) - def _set_attention_slice(self, slice_size: Union[int, List[int]]): - if isinstance(slice_size, int): - slice_size = len(self.transformer_blocks) * [slice_size] - - for i, block in enumerate(self.transformer_blocks): - block._set_attention_slice(slice_size[i]) - def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): """ Args: @@ -412,8 +405,6 @@ def __init__( only_cross_attention: bool = False, ): super().__init__() - self.only_cross_attention = only_cross_attention - self.attention_head_dim = attention_head_dim self.cross_attention_dim = cross_attention_dim self.attn1 = CrossAttention( query_dim=dim, diff --git a/tests/models/test_models_unet_2d.py b/tests/models/test_models_unet_2d.py index f270daf50b7e..4a2d5a96ed9e 100644 --- a/tests/models/test_models_unet_2d.py +++ b/tests/models/test_models_unet_2d.py @@ -358,6 +358,24 @@ def test_model_attention_slicing(self): output = model(**inputs_dict) assert output is not None + def test_model_slicable_head_dim(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["attention_head_dim"] = (8, 16) + + model = self.model_class(**init_dict) + + def check_slicable_dim_attr(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + assert isinstance(module.sliceable_head_dim, int) + + for child in module.children(): + check_slicable_dim_attr(child) + + # retrieve number of attention layers + for module in model.children(): + check_slicable_dim_attr(module) + class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): model_class = UNet2DModel From ce8942b855ac32d798e46b2379ccda7ee433684d Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 5 Dec 2022 16:08:23 +0000 Subject: [PATCH 11/14] up --- src/diffusers/models/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index a38324a03884..61186a253572 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -405,7 +405,7 @@ def __init__( only_cross_attention: bool = False, ): super().__init__() - self.cross_attention_dim = cross_attention_dim + self.only_cross_attention = only_cross_attention self.attn1 = CrossAttention( query_dim=dim, heads=num_attention_heads, From 855a69c06685094be63d5726b998451c8179c2d5 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 5 Dec 2022 16:11:26 +0000 Subject: [PATCH 12/14] update --- .../versatile_diffusion/modeling_text_unet.py | 71 +++++++++++++------ 1 file changed, 50 insertions(+), 21 deletions(-) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index d33d2525bdd1..09889e55f0ce 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -307,40 +307,69 @@ def __init__( self.conv_out = LinearMultiDim(block_out_channels[0], out_channels, kernel_size=3, padding=1) def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + if slice_size == "auto": - if isinstance(self.config.attention_head_dim, int): - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.config.attention_head_dim // 2 - else: - # if `attention_head_dim` is a list, take the smallest head size - slice_size = min(self.config.attention_head_dim) + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] - head_dims = self.config.attention_head_dim - head_dims = [head_dims] if isinstance(head_dims, int) else head_dims - if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): - raise ValueError( - f"Make sure slice_size {slice_size} is a common divisor of " - f"the number of heads used in cross_attention: {head_dims}" - ) - if slice_size is not None and slice_size > min(head_dims): + slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): raise ValueError( - f"slice_size {slice_size} has to be smaller or equal to " - f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message - def fn_recursive_set_attention_slice(module: torch.nn.Module): + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): if hasattr(module, "set_attention_slice"): - module.set_attention_slice(slice_size) + module.set_attention_slice(slice_size.pop()) for child in module.children(): - fn_recursive_set_attention_slice(child) + fn_recursive_set_attention_slice(child, slice_size) + reversed_slice_size = list(reversed(slice_size)) for module in self.children(): - fn_recursive_set_attention_slice(module) + fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlockFlat, DownBlockFlat, CrossAttnUpBlockFlat, UpBlockFlat)): From 068e77996bb41f11b7743fcc8b727d8113031db2 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 5 Dec 2022 16:22:24 +0000 Subject: [PATCH 13/14] up --- src/diffusers/models/unet_2d_condition.py | 2 +- .../pipelines/versatile_diffusion/modeling_text_unet.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 5caae10b489f..2acc2cb4764e 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -236,7 +236,7 @@ def set_attention_slice(self, slice_size): in several steps. This is useful to save some memory in exchange for a small speed decrease. Args: - slice_size (`str` or `int`, *optional*, defaults to `"auto"`): + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 09889e55f0ce..07f10425eeea 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, List import numpy as np import torch From 8e0a98ded13c4895594e32a0c9ad840f2bcfcd1e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 5 Dec 2022 16:25:55 +0000 Subject: [PATCH 14/14] uP --- .../pipelines/versatile_diffusion/modeling_text_unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 07f10425eeea..fdf97481feb9 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union, List +from typing import List, Optional, Tuple, Union import numpy as np import torch