2121
2222from ..configuration_utils import ConfigMixin , register_to_config
2323from ..utils import BaseOutput , logging
24+ from .attention_processor import AttentionProcessor
2425from .embeddings import TimestepEmbedding , Timesteps
2526from .modeling_utils import ModelMixin
2627from .transformer_temporal import TransformerTemporalModel
@@ -249,6 +250,32 @@ def __init__(
249250 block_out_channels [0 ], out_channels , kernel_size = conv_out_kernel , padding = conv_out_padding
250251 )
251252
253+ @property
254+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
255+ def attn_processors (self ) -> Dict [str , AttentionProcessor ]:
256+ r"""
257+ Returns:
258+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
259+ indexed by its weight name.
260+ """
261+ # set recursively
262+ processors = {}
263+
264+ def fn_recursive_add_processors (name : str , module : torch .nn .Module , processors : Dict [str , AttentionProcessor ]):
265+ if hasattr (module , "set_processor" ):
266+ processors [f"{ name } .processor" ] = module .processor
267+
268+ for sub_name , child in module .named_children ():
269+ fn_recursive_add_processors (f"{ name } .{ sub_name } " , child , processors )
270+
271+ return processors
272+
273+ for name , module in self .named_children ():
274+ fn_recursive_add_processors (name , module , processors )
275+
276+ return processors
277+
278+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
252279 def set_attention_slice (self , slice_size ):
253280 r"""
254281 Enable sliced attention computation.
@@ -259,34 +286,34 @@ def set_attention_slice(self, slice_size):
259286 Args:
260287 slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
261288 When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
262- `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
289+ `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
263290 provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
264291 must be a multiple of `slice_size`.
265292 """
266293 sliceable_head_dims = []
267294
268- def fn_recursive_retrieve_slicable_dims (module : torch .nn .Module ):
295+ def fn_recursive_retrieve_sliceable_dims (module : torch .nn .Module ):
269296 if hasattr (module , "set_attention_slice" ):
270297 sliceable_head_dims .append (module .sliceable_head_dim )
271298
272299 for child in module .children ():
273- fn_recursive_retrieve_slicable_dims (child )
300+ fn_recursive_retrieve_sliceable_dims (child )
274301
275302 # retrieve number of attention layers
276303 for module in self .children ():
277- fn_recursive_retrieve_slicable_dims (module )
304+ fn_recursive_retrieve_sliceable_dims (module )
278305
279- num_slicable_layers = len (sliceable_head_dims )
306+ num_sliceable_layers = len (sliceable_head_dims )
280307
281308 if slice_size == "auto" :
282309 # half the attention head size is usually a good trade-off between
283310 # speed and memory
284311 slice_size = [dim // 2 for dim in sliceable_head_dims ]
285312 elif slice_size == "max" :
286313 # make smallest slice possible
287- slice_size = num_slicable_layers * [1 ]
314+ slice_size = num_sliceable_layers * [1 ]
288315
289- slice_size = num_slicable_layers * [slice_size ] if not isinstance (slice_size , list ) else slice_size
316+ slice_size = num_sliceable_layers * [slice_size ] if not isinstance (slice_size , list ) else slice_size
290317
291318 if len (slice_size ) != len (sliceable_head_dims ):
292319 raise ValueError (
@@ -314,6 +341,37 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i
314341 for module in self .children ():
315342 fn_recursive_set_attention_slice (module , reversed_slice_size )
316343
344+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
345+ def set_attn_processor (self , processor : Union [AttentionProcessor , Dict [str , AttentionProcessor ]]):
346+ r"""
347+ Parameters:
348+ `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
349+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
350+ of **all** `Attention` layers.
351+ In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
352+
353+ """
354+ count = len (self .attn_processors .keys ())
355+
356+ if isinstance (processor , dict ) and len (processor ) != count :
357+ raise ValueError (
358+ f"A dict of processors was passed, but the number of processors { len (processor )} does not match the"
359+ f" number of attention layers: { count } . Please make sure to pass { count } processor classes."
360+ )
361+
362+ def fn_recursive_attn_processor (name : str , module : torch .nn .Module , processor ):
363+ if hasattr (module , "set_processor" ):
364+ if not isinstance (processor , dict ):
365+ module .set_processor (processor )
366+ else :
367+ module .set_processor (processor .pop (f"{ name } .processor" ))
368+
369+ for sub_name , child in module .named_children ():
370+ fn_recursive_attn_processor (f"{ name } .{ sub_name } " , child , processor )
371+
372+ for name , module in self .named_children ():
373+ fn_recursive_attn_processor (name , module , processor )
374+
317375 def _set_gradient_checkpointing (self , module , value = False ):
318376 if isinstance (module , (CrossAttnDownBlock3D , DownBlock3D , CrossAttnUpBlock3D , UpBlock3D )):
319377 module .gradient_checkpointing = value
0 commit comments