Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions tests/attr/test_llm_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from captum.attr._core.llm_attr import LLMAttribution, LLMGradientAttribution
from captum.attr._core.shapley_value import ShapleyValueSampling
from captum.attr._utils.interpretable_input import TextTemplateInput, TextTokenInput
from parameterized import parameterized
from parameterized import parameterized, parameterized_class
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from torch import nn, Tensor

Expand Down Expand Up @@ -60,17 +60,25 @@ def forward(self, input_ids, *args, **kwargs):
def generate(self, input_ids, *args, mock_response=None, **kwargs):
assert mock_response, "must mock response to use DummyLLM to geenrate"
response = self.tokenizer.encode(mock_response)[1:]
return torch.cat([input_ids, torch.tensor([response])], dim=1)
return torch.cat(
[input_ids, torch.tensor([response], device=self.device)], dim=1
)

@property
def device(self):
return next(self.parameters()).device


@parameterized_class(
("device",), [("cpu",), ("cuda",)] if torch.cuda.is_available() else [("cpu",)]
)
class TestLLMAttr(BaseTest):
device: str

@parameterized.expand([(FeatureAblation,), (ShapleyValueSampling,)])
def test_llm_attr(self, AttrClass) -> None:
llm = DummyLLM()
llm.to(self.device)
tokenizer = DummyTokenizer()
llm_attr = LLMAttribution(AttrClass(llm), tokenizer)

Expand All @@ -81,9 +89,12 @@ def test_llm_attr(self, AttrClass) -> None:
self.assertEqual(cast(Tensor, res.token_attr).shape, (5, 4))
self.assertEqual(res.input_tokens, ["a", "c", "d", "f"])
self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"])
self.assertEqual(res.seq_attr.device.type, self.device)
self.assertEqual(cast(Tensor, res.token_attr).device.type, self.device)

def test_llm_attr_without_target(self) -> None:
llm = DummyLLM()
llm.to(self.device)
tokenizer = DummyTokenizer()
fa = FeatureAblation(llm)
llm_fa = LLMAttribution(fa, tokenizer)
Expand All @@ -95,9 +106,12 @@ def test_llm_attr_without_target(self) -> None:
self.assertEqual(cast(Tensor, res.token_attr).shape, (3, 4))
self.assertEqual(res.input_tokens, ["a", "c", "d", "f"])
self.assertEqual(res.output_tokens, ["x", "y", "z"])
self.assertEqual(res.seq_attr.device.type, self.device)
self.assertEqual(cast(Tensor, res.token_attr).device.type, self.device)

def test_llm_attr_fa_log_prob(self) -> None:
llm = DummyLLM()
llm.to(self.device)
tokenizer = DummyTokenizer()
fa = FeatureAblation(llm)
llm_fa = LLMAttribution(fa, tokenizer, attr_target="log_prob")
Expand All @@ -112,6 +126,7 @@ def test_llm_attr_fa_log_prob(self) -> None:
@parameterized.expand([(Lime,), (KernelShap,)])
def test_llm_attr_without_token(self, AttrClass) -> None:
llm = DummyLLM()
llm.to(self.device)
tokenizer = DummyTokenizer()
fa = AttrClass(llm)
llm_fa = LLMAttribution(fa, tokenizer, attr_target="log_prob")
Expand All @@ -120,14 +135,21 @@ def test_llm_attr_without_token(self, AttrClass) -> None:
res = llm_fa.attribute(inp, "m n o p q")

self.assertEqual(res.seq_attr.shape, (4,))
self.assertEqual(res.seq_attr.device.type, self.device)
self.assertEqual(res.token_attr, None)
self.assertEqual(res.input_tokens, ["a", "c", "d", "f"])
self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"])


@parameterized_class(
("device",), [("cpu",), ("cuda",)] if torch.cuda.is_available() else [("cpu",)]
)
class TestLLMGradAttr(BaseTest):
device: str

def test_llm_attr(self) -> None:
llm = DummyLLM()
llm.to(self.device)
tokenizer = DummyTokenizer()
attr = LayerIntegratedGradients(llm, llm.emb)
llm_attr = LLMGradientAttribution(attr, tokenizer)
Expand All @@ -141,8 +163,12 @@ def test_llm_attr(self) -> None:
self.assertEqual(res.input_tokens, ["<sos>", "a", "b", "c"])
self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"])

self.assertEqual(res.seq_attr.device.type, self.device)
self.assertEqual(res.token_attr.device.type, self.device)

def test_llm_attr_without_target(self) -> None:
llm = DummyLLM()
llm.to(self.device)
tokenizer = DummyTokenizer()
attr = LayerIntegratedGradients(llm, llm.emb)
llm_attr = LLMGradientAttribution(attr, tokenizer)
Expand All @@ -155,8 +181,12 @@ def test_llm_attr_without_target(self) -> None:
self.assertEqual(res.input_tokens, ["<sos>", "a", "b", "c"])
self.assertEqual(res.output_tokens, ["x", "y", "z"])

self.assertEqual(res.seq_attr.device.type, self.device)
self.assertEqual(res.token_attr.device.type, self.device)

def test_llm_attr_with_skip_tokens(self) -> None:
llm = DummyLLM()
llm.to(self.device)
tokenizer = DummyTokenizer()
attr = LayerIntegratedGradients(llm, llm.emb)
llm_attr = LLMGradientAttribution(attr, tokenizer)
Expand All @@ -169,3 +199,6 @@ def test_llm_attr_with_skip_tokens(self) -> None:
self.assertEqual(res.token_attr.shape, (5, 3))
self.assertEqual(res.input_tokens, ["a", "b", "c"])
self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"])

self.assertEqual(res.seq_attr.device.type, self.device)
self.assertEqual(res.token_attr.device.type, self.device)