-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Description
I tried passing in an attention_mask, for use in a stable-diffusion Unet but it doesn't actually get passed down as deep as CrossAttention#forward.
I tried fixing it to pass the param down, but it blows up on tensor size mismatch, because self-attention and cross-attention have different masking requirements.
I made my own implementation of cross-attention masking a few weeks ago (before the refactor) but never upstreamed it. mainly because I didn't understand whether I'd done it correctly (I re-used the lucidrains implementation that CompVis used):
cbb4c02
EDIT: rebased implementation to show how it would fit in with the existing attention masking and the refactored attention:
Birch-san@e3a93e9
I explicitly named the parameter as a cross attention mask, because a self-attention mask has entirely different requirements.
in terms of wider API design, I wonder whether it should be an attention map (i.e. so you can use it to increase/decrease attention scores for certain token embeds). but for now I'm mostly interested in the mask aspect. because waifu-diffusion makes use of "multiple CLIP embeddings stitched together", so attention masking is useful to avoid attending to padding token embeddings, which would be biased towards conveying high-level semantic of the final CLIP segment only.