@@ -173,7 +173,8 @@ def set_use_memory_efficient_attention_xformers(
173173 LORA_ATTENTION_PROCESSORS ,
174174 )
175175 is_custom_diffusion = hasattr (self , "processor" ) and isinstance (
176- self .processor , (CustomDiffusionAttnProcessor , CustomDiffusionXFormersAttnProcessor )
176+ self .processor ,
177+ (CustomDiffusionAttnProcessor , CustomDiffusionXFormersAttnProcessor , CustomDiffusionAttnProcessor2_0 ),
177178 )
178179 is_added_kv_processor = hasattr (self , "processor" ) and isinstance (
179180 self .processor ,
@@ -261,7 +262,12 @@ def set_use_memory_efficient_attention_xformers(
261262 processor .load_state_dict (self .processor .state_dict ())
262263 processor .to (self .processor .to_q_lora .up .weight .device )
263264 elif is_custom_diffusion :
264- processor = CustomDiffusionAttnProcessor (
265+ attn_processor_class = (
266+ CustomDiffusionAttnProcessor2_0
267+ if hasattr (F , "scaled_dot_product_attention" )
268+ else CustomDiffusionAttnProcessor
269+ )
270+ processor = attn_processor_class (
265271 train_kv = self .processor .train_kv ,
266272 train_q_out = self .processor .train_q_out ,
267273 hidden_size = self .processor .hidden_size ,
@@ -1156,6 +1162,111 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
11561162 return hidden_states
11571163
11581164
1165+ class CustomDiffusionAttnProcessor2_0 (nn .Module ):
1166+ r"""
1167+ Processor for implementing attention for the Custom Diffusion method using PyTorch 2.0’s memory-efficient scaled
1168+ dot-product attention.
1169+
1170+ Args:
1171+ train_kv (`bool`, defaults to `True`):
1172+ Whether to newly train the key and value matrices corresponding to the text features.
1173+ train_q_out (`bool`, defaults to `True`):
1174+ Whether to newly train query matrices corresponding to the latent image features.
1175+ hidden_size (`int`, *optional*, defaults to `None`):
1176+ The hidden size of the attention layer.
1177+ cross_attention_dim (`int`, *optional*, defaults to `None`):
1178+ The number of channels in the `encoder_hidden_states`.
1179+ out_bias (`bool`, defaults to `True`):
1180+ Whether to include the bias parameter in `train_q_out`.
1181+ dropout (`float`, *optional*, defaults to 0.0):
1182+ The dropout probability to use.
1183+ """
1184+
1185+ def __init__ (
1186+ self ,
1187+ train_kv = True ,
1188+ train_q_out = True ,
1189+ hidden_size = None ,
1190+ cross_attention_dim = None ,
1191+ out_bias = True ,
1192+ dropout = 0.0 ,
1193+ ):
1194+ super ().__init__ ()
1195+ self .train_kv = train_kv
1196+ self .train_q_out = train_q_out
1197+
1198+ self .hidden_size = hidden_size
1199+ self .cross_attention_dim = cross_attention_dim
1200+
1201+ # `_custom_diffusion` id for easy serialization and loading.
1202+ if self .train_kv :
1203+ self .to_k_custom_diffusion = nn .Linear (cross_attention_dim or hidden_size , hidden_size , bias = False )
1204+ self .to_v_custom_diffusion = nn .Linear (cross_attention_dim or hidden_size , hidden_size , bias = False )
1205+ if self .train_q_out :
1206+ self .to_q_custom_diffusion = nn .Linear (hidden_size , hidden_size , bias = False )
1207+ self .to_out_custom_diffusion = nn .ModuleList ([])
1208+ self .to_out_custom_diffusion .append (nn .Linear (hidden_size , hidden_size , bias = out_bias ))
1209+ self .to_out_custom_diffusion .append (nn .Dropout (dropout ))
1210+
1211+ def __call__ (self , attn : Attention , hidden_states , encoder_hidden_states = None , attention_mask = None ):
1212+ batch_size , sequence_length , _ = hidden_states .shape
1213+ attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
1214+ if self .train_q_out :
1215+ query = self .to_q_custom_diffusion (hidden_states )
1216+ else :
1217+ query = attn .to_q (hidden_states )
1218+
1219+ if encoder_hidden_states is None :
1220+ crossattn = False
1221+ encoder_hidden_states = hidden_states
1222+ else :
1223+ crossattn = True
1224+ if attn .norm_cross :
1225+ encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
1226+
1227+ if self .train_kv :
1228+ key = self .to_k_custom_diffusion (encoder_hidden_states )
1229+ value = self .to_v_custom_diffusion (encoder_hidden_states )
1230+ else :
1231+ key = attn .to_k (encoder_hidden_states )
1232+ value = attn .to_v (encoder_hidden_states )
1233+
1234+ if crossattn :
1235+ detach = torch .ones_like (key )
1236+ detach [:, :1 , :] = detach [:, :1 , :] * 0.0
1237+ key = detach * key + (1 - detach ) * key .detach ()
1238+ value = detach * value + (1 - detach ) * value .detach ()
1239+
1240+ inner_dim = hidden_states .shape [- 1 ]
1241+
1242+ head_dim = inner_dim // attn .heads
1243+ query = query .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1244+ key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1245+ value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1246+
1247+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1248+ # TODO: add support for attn.scale when we move to Torch 2.1
1249+ hidden_states = F .scaled_dot_product_attention (
1250+ query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
1251+ )
1252+
1253+ hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
1254+ hidden_states = hidden_states .to (query .dtype )
1255+
1256+ if self .train_q_out :
1257+ # linear proj
1258+ hidden_states = self .to_out_custom_diffusion [0 ](hidden_states )
1259+ # dropout
1260+ hidden_states = self .to_out_custom_diffusion [1 ](hidden_states )
1261+ else :
1262+ # linear proj
1263+ hidden_states = attn .to_out [0 ](hidden_states )
1264+ # dropout
1265+ hidden_states = attn .to_out [1 ](hidden_states )
1266+
1267+ return hidden_states
1268+
1269+
11591270class SlicedAttnProcessor :
11601271 r"""
11611272 Processor for implementing sliced attention.
@@ -1639,6 +1750,7 @@ def __call__(self, attn: Attention, hidden_states, *args, **kwargs):
16391750 XFormersAttnAddedKVProcessor ,
16401751 CustomDiffusionAttnProcessor ,
16411752 CustomDiffusionXFormersAttnProcessor ,
1753+ CustomDiffusionAttnProcessor2_0 ,
16421754 # depraceted
16431755 LoRAAttnProcessor ,
16441756 LoRAAttnProcessor2_0 ,
0 commit comments