Skip to content

Commit 31e7983

Browse files
authored
Merge branch 'master' into export-D41830245
2 parents 1ab6ae4 + 332c8c8 commit 31e7983

File tree

6 files changed

+275
-17
lines changed

6 files changed

+275
-17
lines changed

captum/attr/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
)
3131
from captum.attr._core.layer.layer_lrp import LayerLRP # noqa
3232
from captum.attr._core.lime import Lime, LimeBase # noqa
33+
from captum.attr._core.llm_attr import LLMAttribution, LLMGradientAttribution # noqa
3334
from captum.attr._core.lrp import LRP # noqa
3435
from captum.attr._core.neuron.neuron_conductance import NeuronConductance # noqa
3536
from captum.attr._core.neuron.neuron_deep_lift import ( # noqa
@@ -67,6 +68,11 @@
6768
PerturbationAttribution,
6869
)
6970
from captum.attr._utils.class_summarizer import ClassSummarizer
71+
from captum.attr._utils.interpretable_input import ( # noqa
72+
InterpretableInput,
73+
TextTemplateInput,
74+
TextTokenInput,
75+
)
7076
from captum.attr._utils.stat import (
7177
CommonStats,
7278
Count,
@@ -108,7 +114,10 @@
108114
"LayerGradientXActivation",
109115
"LayerActivation",
110116
"LayerFeatureAblation",
117+
"LLMAttribution",
118+
"LLMGradientAttribution",
111119
"InternalInfluence",
120+
"InterpretableInput",
112121
"LayerGradCam",
113122
"LayerDeepLift",
114123
"LayerDeepLiftShap",
@@ -127,6 +136,8 @@
127136
"NoiseTunnel",
128137
"GradientShap",
129138
"InterpretableEmbeddingBase",
139+
"TextTemplateInput",
140+
"TextTokenInput",
130141
"TokenReferenceBase",
131142
"visualization",
132143
"configure_interpretable_embedding_layer",

captum/attr/_core/layer/layer_integrated_gradients.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,8 @@ def layer_forward_hook(
467467

468468
hooks.append(hook)
469469

470+
# the inputs is an empty tuple
471+
# coz it is prepended into additional_forward_args
470472
output = _run_forward(
471473
self.forward_func, tuple(), target_ind, additional_forward_args
472474
)

captum/attr/_core/llm_attr.py

Lines changed: 208 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,17 @@
44

55
import torch
66
from captum.attr._core.feature_ablation import FeatureAblation
7+
from captum.attr._core.layer.layer_integrated_gradients import LayerIntegratedGradients
78
from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling
89
from captum.attr._utils.attribution import Attribution
9-
from captum.attr._utils.interpretable_input import InterpretableInput, TextTemplateInput
10+
from captum.attr._utils.interpretable_input import (
11+
InterpretableInput,
12+
TextTemplateInput,
13+
TextTokenInput,
14+
)
1015
from torch import nn, Tensor
1116

1217

13-
SUPPORTED_METHODS = (FeatureAblation, ShapleyValueSampling, ShapleyValues)
14-
SUPPORTED_INPUTS = (TextTemplateInput,)
15-
1618
DEFAULT_GEN_ARGS = {"max_new_tokens": 25, "do_sample": False}
1719

1820

@@ -57,6 +59,9 @@ class LLMAttribution(Attribution):
5759
and returns LLMAttributionResult
5860
"""
5961

62+
SUPPORTED_METHODS = (FeatureAblation, ShapleyValueSampling, ShapleyValues)
63+
SUPPORTED_INPUTS = (TextTemplateInput, TextTokenInput)
64+
6065
def __init__(
6166
self,
6267
attr_method: Attribution,
@@ -75,7 +80,7 @@ class created with the llm model that follows huggingface style
7580
"""
7681

7782
assert isinstance(
78-
attr_method, SUPPORTED_METHODS
83+
attr_method, self.SUPPORTED_METHODS
7984
), f"LLMAttribution does not support {type(attr_method)}"
8085

8186
super().__init__(attr_method.forward_func)
@@ -86,6 +91,7 @@ class created with the llm model that follows huggingface style
8691
self.attr_method.forward_func = self._forward_func
8792

8893
# alias, we really need a model and don't support wrapper functions
94+
# coz we need call model.forward, model.generate, etc.
8995
self.model = cast(nn.Module, self.forward_func)
9096

9197
self.tokenizer = tokenizer
@@ -103,14 +109,12 @@ class created with the llm model that follows huggingface style
103109

104110
def _forward_func(
105111
self,
106-
perturbed_feature,
107-
input_feature,
112+
perturbed_tensor,
113+
inp,
108114
target_tokens,
109115
_inspect_forward,
110116
):
111-
perturbed_input = self._format_model_input(
112-
input_feature.to_model_input(perturbed_feature)
113-
)
117+
perturbed_input = self._format_model_input(inp.to_model_input(perturbed_tensor))
114118
init_model_inp = perturbed_input
115119

116120
model_inp = init_model_inp
@@ -192,7 +196,7 @@ def attribute(
192196
"""
193197

194198
assert isinstance(
195-
inp, SUPPORTED_INPUTS
199+
inp, self.SUPPORTED_INPUTS
196200
), f"LLMAttribution does not support input type {type(inp)}"
197201

198202
if target is None:
@@ -214,6 +218,7 @@ def attribute(
214218
if type(target) is str:
215219
# exclude sos
216220
target_tokens = self.tokenizer.encode(target)[1:]
221+
target_tokens = torch.tensor(target_tokens)
217222
elif type(target) is torch.Tensor:
218223
target_tokens = target
219224

@@ -249,3 +254,195 @@ def attribute(
249254
inp.values,
250255
self.tokenizer.convert_ids_to_tokens(target_tokens),
251256
)
257+
258+
259+
class LLMGradientAttribution(Attribution):
260+
"""
261+
Attribution class for large language models. It wraps a gradient-based
262+
attribution algorthm to produce commonly interested attribution
263+
results for the use case of text generation.
264+
The wrapped instance will calculate attribution in the
265+
same way as configured in the original attribution algorthm,
266+
with respect to the log probabilities of each
267+
generated token and the whole sequence. It will provide a
268+
new "attribute" function which accepts text-based inputs
269+
and returns LLMAttributionResult
270+
"""
271+
272+
SUPPORTED_METHODS = (LayerIntegratedGradients,)
273+
SUPPORTED_INPUTS = (TextTokenInput,)
274+
275+
def __init__(
276+
self,
277+
attr_method,
278+
tokenizer,
279+
):
280+
"""
281+
Args:
282+
attr_method (Attribution): instance of a supported perturbation attribution
283+
class created with the llm model that follows huggingface style
284+
interface convention
285+
tokenizer (Tokenizer): tokenizer of the llm model used in the attr_method
286+
"""
287+
assert isinstance(
288+
attr_method, self.SUPPORTED_METHODS
289+
), f"LLMGradientAttribution does not support {type(attr_method)}"
290+
291+
super().__init__(attr_method.forward_func)
292+
293+
# shallow copy is enough to avoid modifying original instance
294+
self.attr_method = copy(attr_method)
295+
self.attr_method.forward_func = self._forward_func
296+
297+
# alias, we really need a model and don't support wrapper functions
298+
# coz we need call model.forward, model.generate, etc.
299+
self.model = cast(nn.Module, self.forward_func)
300+
301+
self.tokenizer = tokenizer
302+
self.device = (
303+
cast(torch.device, self.model.device)
304+
if hasattr(self.model, "device")
305+
else next(self.model.parameters()).device
306+
)
307+
308+
def _forward_func(
309+
self,
310+
perturbed_tensor: Tensor,
311+
inp: InterpretableInput,
312+
target_tokens: Tensor, # 1D tensor of target token ids
313+
cur_target_idx: int, # current target index
314+
):
315+
perturbed_input = self._format_model_input(inp.to_model_input(perturbed_tensor))
316+
317+
if cur_target_idx:
318+
# the input batch size can be expanded by attr method
319+
output_token_tensor = (
320+
target_tokens[:cur_target_idx]
321+
.unsqueeze(0)
322+
.expand(perturbed_input.size(0), -1)
323+
.to(self.device)
324+
)
325+
new_input_tensor = torch.cat([perturbed_input, output_token_tensor], dim=1)
326+
else:
327+
new_input_tensor = perturbed_input
328+
329+
output_logits = self.model(new_input_tensor)
330+
331+
new_token_logits = output_logits.logits[:, -1]
332+
log_probs = torch.nn.functional.log_softmax(new_token_logits, dim=1)
333+
334+
target_token = target_tokens[cur_target_idx]
335+
token_log_probs = log_probs[..., target_token]
336+
337+
# the attribution target is limited to the log probability
338+
return token_log_probs
339+
340+
def _format_model_input(self, model_input):
341+
"""
342+
Convert str to tokenized tensor
343+
"""
344+
return model_input.to(self.device)
345+
346+
def attribute(
347+
self,
348+
inp: InterpretableInput,
349+
target: Union[str, torch.Tensor, None] = None,
350+
gen_args: Optional[Dict] = None,
351+
**kwargs,
352+
):
353+
"""
354+
Args:
355+
inp (InterpretableInput): input prompt for which attributions are computed
356+
target (str or Tensor, optional): target response with respect to
357+
which attributions are computed. If None, it uses the model
358+
to generate the target based on the input and gen_args.
359+
Default: None
360+
gen_args (dict, optional): arguments for generating the target. Only used if
361+
target is not given. When None, the default arguments are used,
362+
{"max_length": 25, "do_sample": False}
363+
Defaults: None
364+
**kwargs (Any): any extra keyword arguments passed to the call of the
365+
underlying attribute function of the given attribution instance
366+
367+
Returns:
368+
369+
attr (LLMAttributionResult): attribution result
370+
"""
371+
372+
assert isinstance(
373+
inp, self.SUPPORTED_INPUTS
374+
), f"LLMGradAttribution does not support input type {type(inp)}"
375+
376+
if target is None:
377+
# generate when None
378+
assert hasattr(self.model, "generate") and callable(self.model.generate), (
379+
"The model does not have recognizable generate function."
380+
"Target must be given for attribution"
381+
)
382+
383+
if not gen_args:
384+
gen_args = DEFAULT_GEN_ARGS
385+
386+
model_inp = self._format_model_input(inp.to_model_input())
387+
output_tokens = self.model.generate(model_inp, **gen_args)
388+
target_tokens = output_tokens[0][model_inp.size(1) :]
389+
else:
390+
assert gen_args is None, "gen_args must be None when target is given"
391+
392+
if type(target) is str:
393+
# exclude sos
394+
target_tokens = self.tokenizer.encode(target)[1:]
395+
target_tokens = torch.tensor(target_tokens)
396+
elif type(target) is torch.Tensor:
397+
target_tokens = target
398+
399+
attr_inp = inp.to_tensor().to(self.device)
400+
401+
attr_list = []
402+
for cur_target_idx, _ in enumerate(target_tokens):
403+
# attr in shape(batch_size, input+output_len, emb_dim)
404+
attr = self.attr_method.attribute(
405+
attr_inp,
406+
additional_forward_args=(
407+
inp,
408+
target_tokens,
409+
cur_target_idx,
410+
),
411+
**kwargs,
412+
)
413+
attr = cast(Tensor, attr)
414+
415+
# will have the attr for previous output tokens
416+
# cut to shape(batch_size, inp_len, emb_dim)
417+
if cur_target_idx:
418+
attr = attr[:, :-cur_target_idx]
419+
420+
# the author of IG uses sum
421+
# https://github.com/ankurtaly/Integrated-Gradients/blob/master/BertModel/bert_model_utils.py#L350
422+
attr = attr.sum(-1)
423+
424+
attr_list.append(attr)
425+
426+
# assume inp batch only has one instance
427+
# to shape(n_output_token, ...)
428+
attr = torch.cat(attr_list, dim=0)
429+
430+
# grad attr method do not care the length of features in interpretable format
431+
# it attributes to all the elements of the output of the specified layer
432+
# so we need special handling for the inp type which don't care all the elements
433+
if isinstance(inp, TextTokenInput) and inp.itp_mask is not None:
434+
itp_mask = inp.itp_mask.to(self.device)
435+
itp_mask = itp_mask.expand_as(attr)
436+
attr = attr[itp_mask].view(attr.size(0), -1)
437+
438+
# for all the gradient methods we support in this class
439+
# the seq attr is the sum of all the token attr if the attr_target is log_prob,
440+
# shape(n_input_features)
441+
seq_attr = attr.sum(0)
442+
443+
return LLMAttributionResult(
444+
seq_attr,
445+
attr, # shape(n_output_token, n_input_features)
446+
inp.values,
447+
self.tokenizer.convert_ids_to_tokens(target_tokens),
448+
)

captum/attr/_utils/interpretable_input.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ def __init__(
398398
self.skip_tokens = skip_tokens
399399

400400
# features values, the tokens
401-
self.values = tokenizer.convert_ids_to_tokens(self.itp_tensor[0])
401+
self.values = tokenizer.convert_ids_to_tokens(self.itp_tensor[0].tolist())
402402
self.tokenizer = tokenizer
403403
self.n_itp_features = len(self.values)
404404

@@ -409,6 +409,7 @@ def __init__(
409409
)
410410

411411
def to_tensor(self) -> torch.Tensor:
412+
# return the perturbation indicator as interpretable tensor instead of token ids
412413
return torch.ones_like(self.itp_tensor)
413414

414415
def to_model_input(self, perturbed_tensor=None) -> torch.Tensor:
@@ -422,14 +423,14 @@ def to_model_input(self, perturbed_tensor=None) -> torch.Tensor:
422423
# perturb_per_eval or gradient based can expand the batch dim
423424
expand_shape = (perturbed_tensor.size(0), -1)
424425

425-
perturb_itp_tensor = self.itp_tensor.expand(*expand_shape).to(device)
426+
perturb_itp_tensor = self.itp_tensor.expand(*expand_shape).clone().to(device)
426427
perturb_itp_tensor[perturb_mask] = self.baselines
427428

428429
# if no iterpretable mask, the interpretable tensor is the input tensor
429430
if self.itp_mask is None:
430431
return perturb_itp_tensor
431432

432-
perturb_inp_tensor = self.inp_tensor.expand(*expand_shape).to(device)
433+
perturb_inp_tensor = self.inp_tensor.expand(*expand_shape).clone().to(device)
433434
itp_mask = self.itp_mask.expand(*expand_shape).to(device)
434435

435436
perturb_inp_tensor[itp_mask] = perturb_itp_tensor.view(-1)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def get_package_files(root, subdirs):
147147
long_description=long_description,
148148
long_description_content_type="text/markdown",
149149
python_requires=">=3.6",
150-
install_requires=["matplotlib", "numpy", "torch>=1.6"],
150+
install_requires=["matplotlib", "numpy", "torch>=1.6", "tqdm"],
151151
packages=find_packages(exclude=("tests", "tests.*")),
152152
extras_require={
153153
"dev": DEV_REQUIRES,

0 commit comments

Comments
 (0)