1111from captum .attr ._core .llm_attr import LLMAttribution , LLMGradientAttribution
1212from captum .attr ._core .shapley_value import ShapleyValueSampling
1313from captum .attr ._utils .interpretable_input import TextTemplateInput , TextTokenInput
14- from parameterized import parameterized
14+ from parameterized import parameterized , parameterized_class
1515from tests .helpers .basic import assertTensorAlmostEqual , BaseTest
1616from torch import nn , Tensor
1717
@@ -60,17 +60,25 @@ def forward(self, input_ids, *args, **kwargs):
6060 def generate (self , input_ids , * args , mock_response = None , ** kwargs ):
6161 assert mock_response , "must mock response to use DummyLLM to geenrate"
6262 response = self .tokenizer .encode (mock_response )[1 :]
63- return torch .cat ([input_ids , torch .tensor ([response ])], dim = 1 )
63+ return torch .cat (
64+ [input_ids , torch .tensor ([response ], device = self .device )], dim = 1
65+ )
6466
6567 @property
6668 def device (self ):
6769 return next (self .parameters ()).device
6870
6971
72+ @parameterized_class (
73+ ("device" ,), [("cpu" ,), ("cuda" ,)] if torch .cuda .is_available () else [("cpu" ,)]
74+ )
7075class TestLLMAttr (BaseTest ):
76+ device : str
77+
7178 @parameterized .expand ([(FeatureAblation ,), (ShapleyValueSampling ,)])
7279 def test_llm_attr (self , AttrClass ) -> None :
7380 llm = DummyLLM ()
81+ llm .to (self .device )
7482 tokenizer = DummyTokenizer ()
7583 llm_attr = LLMAttribution (AttrClass (llm ), tokenizer )
7684
@@ -81,9 +89,12 @@ def test_llm_attr(self, AttrClass) -> None:
8189 self .assertEqual (cast (Tensor , res .token_attr ).shape , (5 , 4 ))
8290 self .assertEqual (res .input_tokens , ["a" , "c" , "d" , "f" ])
8391 self .assertEqual (res .output_tokens , ["m" , "n" , "o" , "p" , "q" ])
92+ self .assertEqual (res .seq_attr .device .type , self .device )
93+ self .assertEqual (cast (Tensor , res .token_attr ).device .type , self .device )
8494
8595 def test_llm_attr_without_target (self ) -> None :
8696 llm = DummyLLM ()
97+ llm .to (self .device )
8798 tokenizer = DummyTokenizer ()
8899 fa = FeatureAblation (llm )
89100 llm_fa = LLMAttribution (fa , tokenizer )
@@ -95,9 +106,12 @@ def test_llm_attr_without_target(self) -> None:
95106 self .assertEqual (cast (Tensor , res .token_attr ).shape , (3 , 4 ))
96107 self .assertEqual (res .input_tokens , ["a" , "c" , "d" , "f" ])
97108 self .assertEqual (res .output_tokens , ["x" , "y" , "z" ])
109+ self .assertEqual (res .seq_attr .device .type , self .device )
110+ self .assertEqual (cast (Tensor , res .token_attr ).device .type , self .device )
98111
99112 def test_llm_attr_fa_log_prob (self ) -> None :
100113 llm = DummyLLM ()
114+ llm .to (self .device )
101115 tokenizer = DummyTokenizer ()
102116 fa = FeatureAblation (llm )
103117 llm_fa = LLMAttribution (fa , tokenizer , attr_target = "log_prob" )
@@ -112,6 +126,7 @@ def test_llm_attr_fa_log_prob(self) -> None:
112126 @parameterized .expand ([(Lime ,), (KernelShap ,)])
113127 def test_llm_attr_without_token (self , AttrClass ) -> None :
114128 llm = DummyLLM ()
129+ llm .to (self .device )
115130 tokenizer = DummyTokenizer ()
116131 fa = AttrClass (llm )
117132 llm_fa = LLMAttribution (fa , tokenizer , attr_target = "log_prob" )
@@ -120,14 +135,21 @@ def test_llm_attr_without_token(self, AttrClass) -> None:
120135 res = llm_fa .attribute (inp , "m n o p q" )
121136
122137 self .assertEqual (res .seq_attr .shape , (4 ,))
138+ self .assertEqual (res .seq_attr .device .type , self .device )
123139 self .assertEqual (res .token_attr , None )
124140 self .assertEqual (res .input_tokens , ["a" , "c" , "d" , "f" ])
125141 self .assertEqual (res .output_tokens , ["m" , "n" , "o" , "p" , "q" ])
126142
127143
144+ @parameterized_class (
145+ ("device" ,), [("cpu" ,), ("cuda" ,)] if torch .cuda .is_available () else [("cpu" ,)]
146+ )
128147class TestLLMGradAttr (BaseTest ):
148+ device : str
149+
129150 def test_llm_attr (self ) -> None :
130151 llm = DummyLLM ()
152+ llm .to (self .device )
131153 tokenizer = DummyTokenizer ()
132154 attr = LayerIntegratedGradients (llm , llm .emb )
133155 llm_attr = LLMGradientAttribution (attr , tokenizer )
@@ -141,8 +163,12 @@ def test_llm_attr(self) -> None:
141163 self .assertEqual (res .input_tokens , ["<sos>" , "a" , "b" , "c" ])
142164 self .assertEqual (res .output_tokens , ["m" , "n" , "o" , "p" , "q" ])
143165
166+ self .assertEqual (res .seq_attr .device .type , self .device )
167+ self .assertEqual (res .token_attr .device .type , self .device )
168+
144169 def test_llm_attr_without_target (self ) -> None :
145170 llm = DummyLLM ()
171+ llm .to (self .device )
146172 tokenizer = DummyTokenizer ()
147173 attr = LayerIntegratedGradients (llm , llm .emb )
148174 llm_attr = LLMGradientAttribution (attr , tokenizer )
@@ -155,8 +181,12 @@ def test_llm_attr_without_target(self) -> None:
155181 self .assertEqual (res .input_tokens , ["<sos>" , "a" , "b" , "c" ])
156182 self .assertEqual (res .output_tokens , ["x" , "y" , "z" ])
157183
184+ self .assertEqual (res .seq_attr .device .type , self .device )
185+ self .assertEqual (res .token_attr .device .type , self .device )
186+
158187 def test_llm_attr_with_skip_tokens (self ) -> None :
159188 llm = DummyLLM ()
189+ llm .to (self .device )
160190 tokenizer = DummyTokenizer ()
161191 attr = LayerIntegratedGradients (llm , llm .emb )
162192 llm_attr = LLMGradientAttribution (attr , tokenizer )
@@ -169,3 +199,6 @@ def test_llm_attr_with_skip_tokens(self) -> None:
169199 self .assertEqual (res .token_attr .shape , (5 , 3 ))
170200 self .assertEqual (res .input_tokens , ["a" , "b" , "c" ])
171201 self .assertEqual (res .output_tokens , ["m" , "n" , "o" , "p" , "q" ])
202+
203+ self .assertEqual (res .seq_attr .device .type , self .device )
204+ self .assertEqual (res .token_attr .device .type , self .device )
0 commit comments