-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
When the fused_attn is used, the scale of the attention is not specified in torch.nn.functional.scaled_dot_product_attention and the value defaults to q.size(-1) ** -0.5, which is different from the default from the Attention2d layer (num_heads ** -0.5).
This means that the results from the fused implementation and the vanilla one are different.
pytorch-image-models/timm/layers/attention2d.py
Lines 294 to 351 in dafe866
| class Attention2d(nn.Module): | |
| fused_attn: torch.jit.Final[bool] | |
| """ multi-head attention for 2D NCHW tensors""" | |
| def __init__( | |
| self, | |
| dim: int, | |
| dim_out: Optional[int] = None, | |
| num_heads: int = 32, | |
| bias: bool = True, | |
| expand_first: bool = False, | |
| head_first: bool = False, | |
| attn_drop: float = 0., | |
| proj_drop: float = 0. | |
| ): | |
| super().__init__() | |
| dim_out = dim_out or dim | |
| dim_attn = dim_out if expand_first else dim | |
| self.num_heads = num_heads | |
| self.dim_head = dim_attn // num_heads | |
| self.head_first = head_first | |
| self.scale = num_heads ** -0.5 | |
| self.fused_attn = use_fused_attn() | |
| self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias) | |
| self.attn_drop = nn.Dropout(attn_drop) | |
| self.proj = nn.Conv2d(dim_attn, dim_out, 1, bias=bias) | |
| self.proj_drop = nn.Dropout(proj_drop) | |
| def forward(self, x, attn_mask: Optional[torch.Tensor] = None): | |
| B, C, H, W = x.shape | |
| if self.head_first: | |
| q, k, v = self.qkv(x).view(B, self.num_heads, self.dim_head * 3, -1).chunk(3, dim=2) | |
| else: | |
| q, k, v = self.qkv(x).reshape(B, 3, self.num_heads, self.dim_head, -1).unbind(1) | |
| if self.fused_attn: | |
| x = torch.nn.functional.scaled_dot_product_attention( | |
| q.transpose(-1, -2).contiguous(), | |
| k.transpose(-1, -2).contiguous(), | |
| v.transpose(-1, -2).contiguous(), | |
| attn_mask=attn_mask, | |
| dropout_p=self.attn_drop.p if self.training else 0., | |
| ).transpose(-1, -2).reshape(B, -1, H, W) | |
| else: | |
| q = q * self.scale | |
| attn = q.transpose(-2, -1) @ k | |
| if attn_mask is not None: | |
| # NOTE: assumes mask is float and in correct shape | |
| attn = attn + attn_mask | |
| attn = attn.softmax(dim=-1) | |
| attn = self.attn_drop(attn) | |
| x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| return x |
Expected behavior
Same results for the two implementations.
Desktop (please complete the following information):
- OS: macOS
- This repository version: 1.0.12
- PyTorch version 2.5 (CPU)
Additional context
Add any other context about the problem here.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working