1313# limitations under the License.
1414
1515import inspect
16- import math
1716from typing import Any , Callable , Dict , List , Optional , Union
1817
1918import torch
@@ -606,64 +605,73 @@ def __call__(
606605 store_processor = CrossAttnStoreProcessor ()
607606 self .unet .mid_block .attentions [0 ].transformer_blocks [0 ].attn1 .processor = store_processor
608607 num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
609- with self . progress_bar ( total = num_inference_steps ) as progress_bar :
610- for i , t in enumerate ( timesteps ):
611- # expand the latents if we are doing classifier free guidance
612- latent_model_input = torch . cat ([ latents ] * 2 ) if do_classifier_free_guidance else latents
613- latent_model_input = self . scheduler . scale_model_input ( latent_model_input , t )
614-
615- # predict the noise residual
616- noise_pred = self .unet (
617- latent_model_input ,
618- t ,
619- encoder_hidden_states = prompt_embeds ,
620- cross_attention_kwargs = cross_attention_kwargs ,
621- ). sample
622-
623- # perform guidance
624- if do_classifier_free_guidance :
625- noise_pred_uncond , noise_pred_text = noise_pred . chunk ( 2 )
626- noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_text - noise_pred_uncond )
627-
628- # perform self-attention guidance with the stored self-attentnion map
629- if do_self_attention_guidance :
630- # classifier-free guidance produces two chunks of attention map
631- # and we only use unconditional one according to equation (24)
632- # in https://arxiv.org/pdf/2210.00939.pdf
608+
609+ map_size = None
610+
611+ def get_map_size ( module , input , output ):
612+ nonlocal map_size
613+ map_size = output . sample . shape [ - 2 :]
614+
615+ with self .unet . mid_block . attentions [ 0 ]. register_forward_hook ( get_map_size ):
616+ with self . progress_bar ( total = num_inference_steps ) as progress_bar :
617+ for i , t in enumerate ( timesteps ):
618+ # expand the latents if we are doing classifier free guidance
619+ latent_model_input = torch . cat ([ latents ] * 2 ) if do_classifier_free_guidance else latents
620+ latent_model_input = self . scheduler . scale_model_input ( latent_model_input , t )
621+
622+ # predict the noise residual
623+
624+ noise_pred = self . unet (
625+ latent_model_input ,
626+ t ,
627+ encoder_hidden_states = prompt_embeds ,
628+ cross_attention_kwargs = cross_attention_kwargs ,
629+ ). sample
630+
631+ # perform guidance
633632 if do_classifier_free_guidance :
634- # DDIM-like prediction of x0
635- pred_x0 = self .pred_x0 (latents , noise_pred_uncond , t )
636- # get the stored attention maps
637- uncond_attn , cond_attn = store_processor .attention_probs .chunk (2 )
638- # self-attention-based degrading of latents
639- degraded_latents = self .sag_masking (
640- pred_x0 , uncond_attn , t , self .pred_epsilon (latents , noise_pred_uncond , t )
641- )
642- uncond_emb , _ = prompt_embeds .chunk (2 )
643- # forward and give guidance
644- degraded_pred = self .unet (degraded_latents , t , encoder_hidden_states = uncond_emb ).sample
645- noise_pred += sag_scale * (noise_pred_uncond - degraded_pred )
646- else :
647- # DDIM-like prediction of x0
648- pred_x0 = self .pred_x0 (latents , noise_pred , t )
649- # get the stored attention maps
650- cond_attn = store_processor .attention_probs
651- # self-attention-based degrading of latents
652- degraded_latents = self .sag_masking (
653- pred_x0 , cond_attn , t , self .pred_epsilon (latents , noise_pred , t )
654- )
655- # forward and give guidance
656- degraded_pred = self .unet (degraded_latents , t , encoder_hidden_states = prompt_embeds ).sample
657- noise_pred += sag_scale * (noise_pred - degraded_pred )
658-
659- # compute the previous noisy sample x_t -> x_t-1
660- latents = self .scheduler .step (noise_pred , t , latents , ** extra_step_kwargs ).prev_sample
661-
662- # call the callback, if provided
663- if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
664- progress_bar .update ()
665- if callback is not None and i % callback_steps == 0 :
666- callback (i , t , latents )
633+ noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
634+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond )
635+
636+ # perform self-attention guidance with the stored self-attentnion map
637+ if do_self_attention_guidance :
638+ # classifier-free guidance produces two chunks of attention map
639+ # and we only use unconditional one according to equation (24)
640+ # in https://arxiv.org/pdf/2210.00939.pdf
641+ if do_classifier_free_guidance :
642+ # DDIM-like prediction of x0
643+ pred_x0 = self .pred_x0 (latents , noise_pred_uncond , t )
644+ # get the stored attention maps
645+ uncond_attn , cond_attn = store_processor .attention_probs .chunk (2 )
646+ # self-attention-based degrading of latents
647+ degraded_latents = self .sag_masking (
648+ pred_x0 , uncond_attn , map_size , t , self .pred_epsilon (latents , noise_pred_uncond , t )
649+ )
650+ uncond_emb , _ = prompt_embeds .chunk (2 )
651+ # forward and give guidance
652+ degraded_pred = self .unet (degraded_latents , t , encoder_hidden_states = uncond_emb ).sample
653+ noise_pred += sag_scale * (noise_pred_uncond - degraded_pred )
654+ else :
655+ # DDIM-like prediction of x0
656+ pred_x0 = self .pred_x0 (latents , noise_pred , t )
657+ # get the stored attention maps
658+ cond_attn = store_processor .attention_probs
659+ # self-attention-based degrading of latents
660+ degraded_latents = self .sag_masking (
661+ pred_x0 , cond_attn , map_size , t , self .pred_epsilon (latents , noise_pred , t )
662+ )
663+ # forward and give guidance
664+ degraded_pred = self .unet (degraded_latents , t , encoder_hidden_states = prompt_embeds ).sample
665+ noise_pred += sag_scale * (noise_pred - degraded_pred )
666+
667+ # compute the previous noisy sample x_t -> x_t-1
668+ latents = self .scheduler .step (noise_pred , t , latents , ** extra_step_kwargs ).prev_sample
669+
670+ # call the callback, if provided
671+ if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
672+ progress_bar .update ()
673+ if callback is not None and i % callback_steps == 0 :
674+ callback (i , t , latents )
667675
668676 # 8. Post-processing
669677 image = self .decode_latents (latents )
@@ -680,20 +688,22 @@ def __call__(
680688
681689 return StableDiffusionPipelineOutput (images = image , nsfw_content_detected = has_nsfw_concept )
682690
683- def sag_masking (self , original_latents , attn_map , t , eps ):
691+ def sag_masking (self , original_latents , attn_map , map_size , t , eps ):
684692 # Same masking process as in SAG paper: https://arxiv.org/pdf/2210.00939.pdf
685693 bh , hw1 , hw2 = attn_map .shape
686694 b , latent_channel , latent_h , latent_w = original_latents .shape
687695 h = self .unet .attention_head_dim
688696 if isinstance (h , list ):
689697 h = h [- 1 ]
690- map_size = math .isqrt (hw1 )
691698
692699 # Produce attention mask
693700 attn_map = attn_map .reshape (b , h , hw1 , hw2 )
694701 attn_mask = attn_map .mean (1 , keepdim = False ).sum (1 , keepdim = False ) > 1.0
695702 attn_mask = (
696- attn_mask .reshape (b , map_size , map_size ).unsqueeze (1 ).repeat (1 , latent_channel , 1 , 1 ).type (attn_map .dtype )
703+ attn_mask .reshape (b , map_size [0 ], map_size [1 ])
704+ .unsqueeze (1 )
705+ .repeat (1 , latent_channel , 1 , 1 )
706+ .type (attn_map .dtype )
697707 )
698708 attn_mask = F .interpolate (attn_mask , (latent_h , latent_w ))
699709
0 commit comments