Skip to content

Commit 19d7e3f

Browse files
vivekmigfacebook-github-bot
authored andcommitted
LLM Attribution GPU Tests (#1210)
Summary: Adds GPU tests for LLM attribution by parametrizing test class and verifying output attribution device. Currently also includes changes for KernelShap and Lime since there are dependencies, will rebase once that PR is merged. Pull Request resolved: #1210 Reviewed By: aobo-y Differential Revision: D51579521 Pulled By: vivekmig fbshipit-source-id: 01fefe1fbb3b3784c80e2afc781c11575627669f
1 parent b321570 commit 19d7e3f

File tree

1 file changed

+35
-2
lines changed

1 file changed

+35
-2
lines changed

tests/attr/test_llm_attr.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from captum.attr._core.llm_attr import LLMAttribution, LLMGradientAttribution
1212
from captum.attr._core.shapley_value import ShapleyValueSampling
1313
from captum.attr._utils.interpretable_input import TextTemplateInput, TextTokenInput
14-
from parameterized import parameterized
14+
from parameterized import parameterized, parameterized_class
1515
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
1616
from torch import nn, Tensor
1717

@@ -60,17 +60,25 @@ def forward(self, input_ids, *args, **kwargs):
6060
def generate(self, input_ids, *args, mock_response=None, **kwargs):
6161
assert mock_response, "must mock response to use DummyLLM to geenrate"
6262
response = self.tokenizer.encode(mock_response)[1:]
63-
return torch.cat([input_ids, torch.tensor([response])], dim=1)
63+
return torch.cat(
64+
[input_ids, torch.tensor([response], device=self.device)], dim=1
65+
)
6466

6567
@property
6668
def device(self):
6769
return next(self.parameters()).device
6870

6971

72+
@parameterized_class(
73+
("device",), [("cpu",), ("cuda",)] if torch.cuda.is_available() else [("cpu",)]
74+
)
7075
class TestLLMAttr(BaseTest):
76+
device: str
77+
7178
@parameterized.expand([(FeatureAblation,), (ShapleyValueSampling,)])
7279
def test_llm_attr(self, AttrClass) -> None:
7380
llm = DummyLLM()
81+
llm.to(self.device)
7482
tokenizer = DummyTokenizer()
7583
llm_attr = LLMAttribution(AttrClass(llm), tokenizer)
7684

@@ -81,9 +89,12 @@ def test_llm_attr(self, AttrClass) -> None:
8189
self.assertEqual(cast(Tensor, res.token_attr).shape, (5, 4))
8290
self.assertEqual(res.input_tokens, ["a", "c", "d", "f"])
8391
self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"])
92+
self.assertEqual(res.seq_attr.device.type, self.device)
93+
self.assertEqual(cast(Tensor, res.token_attr).device.type, self.device)
8494

8595
def test_llm_attr_without_target(self) -> None:
8696
llm = DummyLLM()
97+
llm.to(self.device)
8798
tokenizer = DummyTokenizer()
8899
fa = FeatureAblation(llm)
89100
llm_fa = LLMAttribution(fa, tokenizer)
@@ -95,9 +106,12 @@ def test_llm_attr_without_target(self) -> None:
95106
self.assertEqual(cast(Tensor, res.token_attr).shape, (3, 4))
96107
self.assertEqual(res.input_tokens, ["a", "c", "d", "f"])
97108
self.assertEqual(res.output_tokens, ["x", "y", "z"])
109+
self.assertEqual(res.seq_attr.device.type, self.device)
110+
self.assertEqual(cast(Tensor, res.token_attr).device.type, self.device)
98111

99112
def test_llm_attr_fa_log_prob(self) -> None:
100113
llm = DummyLLM()
114+
llm.to(self.device)
101115
tokenizer = DummyTokenizer()
102116
fa = FeatureAblation(llm)
103117
llm_fa = LLMAttribution(fa, tokenizer, attr_target="log_prob")
@@ -112,6 +126,7 @@ def test_llm_attr_fa_log_prob(self) -> None:
112126
@parameterized.expand([(Lime,), (KernelShap,)])
113127
def test_llm_attr_without_token(self, AttrClass) -> None:
114128
llm = DummyLLM()
129+
llm.to(self.device)
115130
tokenizer = DummyTokenizer()
116131
fa = AttrClass(llm)
117132
llm_fa = LLMAttribution(fa, tokenizer, attr_target="log_prob")
@@ -120,14 +135,21 @@ def test_llm_attr_without_token(self, AttrClass) -> None:
120135
res = llm_fa.attribute(inp, "m n o p q")
121136

122137
self.assertEqual(res.seq_attr.shape, (4,))
138+
self.assertEqual(res.seq_attr.device.type, self.device)
123139
self.assertEqual(res.token_attr, None)
124140
self.assertEqual(res.input_tokens, ["a", "c", "d", "f"])
125141
self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"])
126142

127143

144+
@parameterized_class(
145+
("device",), [("cpu",), ("cuda",)] if torch.cuda.is_available() else [("cpu",)]
146+
)
128147
class TestLLMGradAttr(BaseTest):
148+
device: str
149+
129150
def test_llm_attr(self) -> None:
130151
llm = DummyLLM()
152+
llm.to(self.device)
131153
tokenizer = DummyTokenizer()
132154
attr = LayerIntegratedGradients(llm, llm.emb)
133155
llm_attr = LLMGradientAttribution(attr, tokenizer)
@@ -141,8 +163,12 @@ def test_llm_attr(self) -> None:
141163
self.assertEqual(res.input_tokens, ["<sos>", "a", "b", "c"])
142164
self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"])
143165

166+
self.assertEqual(res.seq_attr.device.type, self.device)
167+
self.assertEqual(res.token_attr.device.type, self.device)
168+
144169
def test_llm_attr_without_target(self) -> None:
145170
llm = DummyLLM()
171+
llm.to(self.device)
146172
tokenizer = DummyTokenizer()
147173
attr = LayerIntegratedGradients(llm, llm.emb)
148174
llm_attr = LLMGradientAttribution(attr, tokenizer)
@@ -155,8 +181,12 @@ def test_llm_attr_without_target(self) -> None:
155181
self.assertEqual(res.input_tokens, ["<sos>", "a", "b", "c"])
156182
self.assertEqual(res.output_tokens, ["x", "y", "z"])
157183

184+
self.assertEqual(res.seq_attr.device.type, self.device)
185+
self.assertEqual(res.token_attr.device.type, self.device)
186+
158187
def test_llm_attr_with_skip_tokens(self) -> None:
159188
llm = DummyLLM()
189+
llm.to(self.device)
160190
tokenizer = DummyTokenizer()
161191
attr = LayerIntegratedGradients(llm, llm.emb)
162192
llm_attr = LLMGradientAttribution(attr, tokenizer)
@@ -169,3 +199,6 @@ def test_llm_attr_with_skip_tokens(self) -> None:
169199
self.assertEqual(res.token_attr.shape, (5, 3))
170200
self.assertEqual(res.input_tokens, ["a", "b", "c"])
171201
self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"])
202+
203+
self.assertEqual(res.seq_attr.device.type, self.device)
204+
self.assertEqual(res.token_attr.device.type, self.device)

0 commit comments

Comments
 (0)