1+ import pytest
12import torch
23import torch .nn as nn
34
4- from timm .layers import create_act_layer , set_layer_config , get_act_layer , get_act_fn
5+ from timm .layers import create_act_layer , set_layer_config , get_act_layer , get_act_fn , Attention2d
56
67import importlib
78import os
@@ -119,3 +120,27 @@ def test_get_act_fn_none():
119120 assert get_act_fn (None ) is None
120121 assert get_act_fn ('' ) is None
121122
123+
124+ @pytest .mark .parametrize ("bias" , [True , False ])
125+ @pytest .mark .parametrize ("expand_first" , [True , False ])
126+ @pytest .mark .parametrize ("head_first" , [True , False ])
127+ @pytest .mark .parametrize ("attn_mask" , [True , False ])
128+ def test_attn2d (bias , expand_first , head_first , attn_mask ):
129+ x = torch .randn (1 , 128 , 32 , 48 )
130+ attn = Attention2d (
131+ 128 , 128 , num_heads = 4 , bias = bias , expand_first = expand_first , head_first = head_first
132+ )
133+
134+ if attn_mask :
135+ mask = torch .randint (0 , 1 , size = (32 * 48 , 32 * 48 ), dtype = torch .float32 )
136+ else :
137+ mask = None
138+
139+ o1 = attn (x , mask )
140+ attn .fused_attn = False
141+ o2 = attn (x , mask )
142+
143+ assert torch .allclose (o1 , o2 , atol = 1e-5 ), f"{ torch .abs (o1 - o2 ).max ()} "
144+
145+
146+
0 commit comments