diff --git a/tests/attr/test_llm_attr.py b/tests/attr/test_llm_attr.py index cc4edc5654..17aca630a9 100644 --- a/tests/attr/test_llm_attr.py +++ b/tests/attr/test_llm_attr.py @@ -11,7 +11,7 @@ from captum.attr._core.llm_attr import LLMAttribution, LLMGradientAttribution from captum.attr._core.shapley_value import ShapleyValueSampling from captum.attr._utils.interpretable_input import TextTemplateInput, TextTokenInput -from parameterized import parameterized +from parameterized import parameterized, parameterized_class from tests.helpers.basic import assertTensorAlmostEqual, BaseTest from torch import nn, Tensor @@ -60,17 +60,25 @@ def forward(self, input_ids, *args, **kwargs): def generate(self, input_ids, *args, mock_response=None, **kwargs): assert mock_response, "must mock response to use DummyLLM to geenrate" response = self.tokenizer.encode(mock_response)[1:] - return torch.cat([input_ids, torch.tensor([response])], dim=1) + return torch.cat( + [input_ids, torch.tensor([response], device=self.device)], dim=1 + ) @property def device(self): return next(self.parameters()).device +@parameterized_class( + ("device",), [("cpu",), ("cuda",)] if torch.cuda.is_available() else [("cpu",)] +) class TestLLMAttr(BaseTest): + device: str + @parameterized.expand([(FeatureAblation,), (ShapleyValueSampling,)]) def test_llm_attr(self, AttrClass) -> None: llm = DummyLLM() + llm.to(self.device) tokenizer = DummyTokenizer() llm_attr = LLMAttribution(AttrClass(llm), tokenizer) @@ -81,9 +89,12 @@ def test_llm_attr(self, AttrClass) -> None: self.assertEqual(cast(Tensor, res.token_attr).shape, (5, 4)) self.assertEqual(res.input_tokens, ["a", "c", "d", "f"]) self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"]) + self.assertEqual(res.seq_attr.device.type, self.device) + self.assertEqual(cast(Tensor, res.token_attr).device.type, self.device) def test_llm_attr_without_target(self) -> None: llm = DummyLLM() + llm.to(self.device) tokenizer = DummyTokenizer() fa = FeatureAblation(llm) llm_fa = LLMAttribution(fa, tokenizer) @@ -95,9 +106,12 @@ def test_llm_attr_without_target(self) -> None: self.assertEqual(cast(Tensor, res.token_attr).shape, (3, 4)) self.assertEqual(res.input_tokens, ["a", "c", "d", "f"]) self.assertEqual(res.output_tokens, ["x", "y", "z"]) + self.assertEqual(res.seq_attr.device.type, self.device) + self.assertEqual(cast(Tensor, res.token_attr).device.type, self.device) def test_llm_attr_fa_log_prob(self) -> None: llm = DummyLLM() + llm.to(self.device) tokenizer = DummyTokenizer() fa = FeatureAblation(llm) llm_fa = LLMAttribution(fa, tokenizer, attr_target="log_prob") @@ -112,6 +126,7 @@ def test_llm_attr_fa_log_prob(self) -> None: @parameterized.expand([(Lime,), (KernelShap,)]) def test_llm_attr_without_token(self, AttrClass) -> None: llm = DummyLLM() + llm.to(self.device) tokenizer = DummyTokenizer() fa = AttrClass(llm) llm_fa = LLMAttribution(fa, tokenizer, attr_target="log_prob") @@ -120,14 +135,21 @@ def test_llm_attr_without_token(self, AttrClass) -> None: res = llm_fa.attribute(inp, "m n o p q") self.assertEqual(res.seq_attr.shape, (4,)) + self.assertEqual(res.seq_attr.device.type, self.device) self.assertEqual(res.token_attr, None) self.assertEqual(res.input_tokens, ["a", "c", "d", "f"]) self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"]) +@parameterized_class( + ("device",), [("cpu",), ("cuda",)] if torch.cuda.is_available() else [("cpu",)] +) class TestLLMGradAttr(BaseTest): + device: str + def test_llm_attr(self) -> None: llm = DummyLLM() + llm.to(self.device) tokenizer = DummyTokenizer() attr = LayerIntegratedGradients(llm, llm.emb) llm_attr = LLMGradientAttribution(attr, tokenizer) @@ -141,8 +163,12 @@ def test_llm_attr(self) -> None: self.assertEqual(res.input_tokens, ["", "a", "b", "c"]) self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"]) + self.assertEqual(res.seq_attr.device.type, self.device) + self.assertEqual(res.token_attr.device.type, self.device) + def test_llm_attr_without_target(self) -> None: llm = DummyLLM() + llm.to(self.device) tokenizer = DummyTokenizer() attr = LayerIntegratedGradients(llm, llm.emb) llm_attr = LLMGradientAttribution(attr, tokenizer) @@ -155,8 +181,12 @@ def test_llm_attr_without_target(self) -> None: self.assertEqual(res.input_tokens, ["", "a", "b", "c"]) self.assertEqual(res.output_tokens, ["x", "y", "z"]) + self.assertEqual(res.seq_attr.device.type, self.device) + self.assertEqual(res.token_attr.device.type, self.device) + def test_llm_attr_with_skip_tokens(self) -> None: llm = DummyLLM() + llm.to(self.device) tokenizer = DummyTokenizer() attr = LayerIntegratedGradients(llm, llm.emb) llm_attr = LLMGradientAttribution(attr, tokenizer) @@ -169,3 +199,6 @@ def test_llm_attr_with_skip_tokens(self) -> None: self.assertEqual(res.token_attr.shape, (5, 3)) self.assertEqual(res.input_tokens, ["a", "b", "c"]) self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"]) + + self.assertEqual(res.seq_attr.device.type, self.device) + self.assertEqual(res.token_attr.device.type, self.device)