44
55import torch
66from captum .attr ._core .feature_ablation import FeatureAblation
7+ from captum .attr ._core .layer .layer_integrated_gradients import LayerIntegratedGradients
78from captum .attr ._core .shapley_value import ShapleyValues , ShapleyValueSampling
89from captum .attr ._utils .attribution import Attribution
9- from captum .attr ._utils .interpretable_input import InterpretableInput , TextTemplateInput
10+ from captum .attr ._utils .interpretable_input import (
11+ InterpretableInput ,
12+ TextTemplateInput ,
13+ TextTokenInput ,
14+ )
1015from torch import nn , Tensor
1116
1217
13- SUPPORTED_METHODS = (FeatureAblation , ShapleyValueSampling , ShapleyValues )
14- SUPPORTED_INPUTS = (TextTemplateInput ,)
15-
1618DEFAULT_GEN_ARGS = {"max_new_tokens" : 25 , "do_sample" : False }
1719
1820
@@ -57,6 +59,9 @@ class LLMAttribution(Attribution):
5759 and returns LLMAttributionResult
5860 """
5961
62+ SUPPORTED_METHODS = (FeatureAblation , ShapleyValueSampling , ShapleyValues )
63+ SUPPORTED_INPUTS = (TextTemplateInput , TextTokenInput )
64+
6065 def __init__ (
6166 self ,
6267 attr_method : Attribution ,
@@ -75,7 +80,7 @@ class created with the llm model that follows huggingface style
7580 """
7681
7782 assert isinstance (
78- attr_method , SUPPORTED_METHODS
83+ attr_method , self . SUPPORTED_METHODS
7984 ), f"LLMAttribution does not support { type (attr_method )} "
8085
8186 super ().__init__ (attr_method .forward_func )
@@ -86,6 +91,7 @@ class created with the llm model that follows huggingface style
8691 self .attr_method .forward_func = self ._forward_func
8792
8893 # alias, we really need a model and don't support wrapper functions
94+ # coz we need call model.forward, model.generate, etc.
8995 self .model = cast (nn .Module , self .forward_func )
9096
9197 self .tokenizer = tokenizer
@@ -103,14 +109,12 @@ class created with the llm model that follows huggingface style
103109
104110 def _forward_func (
105111 self ,
106- perturbed_feature ,
107- input_feature ,
112+ perturbed_tensor ,
113+ inp ,
108114 target_tokens ,
109115 _inspect_forward ,
110116 ):
111- perturbed_input = self ._format_model_input (
112- input_feature .to_model_input (perturbed_feature )
113- )
117+ perturbed_input = self ._format_model_input (inp .to_model_input (perturbed_tensor ))
114118 init_model_inp = perturbed_input
115119
116120 model_inp = init_model_inp
@@ -192,7 +196,7 @@ def attribute(
192196 """
193197
194198 assert isinstance (
195- inp , SUPPORTED_INPUTS
199+ inp , self . SUPPORTED_INPUTS
196200 ), f"LLMAttribution does not support input type { type (inp )} "
197201
198202 if target is None :
@@ -214,6 +218,7 @@ def attribute(
214218 if type (target ) is str :
215219 # exclude sos
216220 target_tokens = self .tokenizer .encode (target )[1 :]
221+ target_tokens = torch .tensor (target_tokens )
217222 elif type (target ) is torch .Tensor :
218223 target_tokens = target
219224
@@ -249,3 +254,195 @@ def attribute(
249254 inp .values ,
250255 self .tokenizer .convert_ids_to_tokens (target_tokens ),
251256 )
257+
258+
259+ class LLMGradientAttribution (Attribution ):
260+ """
261+ Attribution class for large language models. It wraps a gradient-based
262+ attribution algorthm to produce commonly interested attribution
263+ results for the use case of text generation.
264+ The wrapped instance will calculate attribution in the
265+ same way as configured in the original attribution algorthm,
266+ with respect to the log probabilities of each
267+ generated token and the whole sequence. It will provide a
268+ new "attribute" function which accepts text-based inputs
269+ and returns LLMAttributionResult
270+ """
271+
272+ SUPPORTED_METHODS = (LayerIntegratedGradients ,)
273+ SUPPORTED_INPUTS = (TextTokenInput ,)
274+
275+ def __init__ (
276+ self ,
277+ attr_method ,
278+ tokenizer ,
279+ ):
280+ """
281+ Args:
282+ attr_method (Attribution): instance of a supported perturbation attribution
283+ class created with the llm model that follows huggingface style
284+ interface convention
285+ tokenizer (Tokenizer): tokenizer of the llm model used in the attr_method
286+ """
287+ assert isinstance (
288+ attr_method , self .SUPPORTED_METHODS
289+ ), f"LLMGradientAttribution does not support { type (attr_method )} "
290+
291+ super ().__init__ (attr_method .forward_func )
292+
293+ # shallow copy is enough to avoid modifying original instance
294+ self .attr_method = copy (attr_method )
295+ self .attr_method .forward_func = self ._forward_func
296+
297+ # alias, we really need a model and don't support wrapper functions
298+ # coz we need call model.forward, model.generate, etc.
299+ self .model = cast (nn .Module , self .forward_func )
300+
301+ self .tokenizer = tokenizer
302+ self .device = (
303+ cast (torch .device , self .model .device )
304+ if hasattr (self .model , "device" )
305+ else next (self .model .parameters ()).device
306+ )
307+
308+ def _forward_func (
309+ self ,
310+ perturbed_tensor : Tensor ,
311+ inp : InterpretableInput ,
312+ target_tokens : Tensor , # 1D tensor of target token ids
313+ cur_target_idx : int , # current target index
314+ ):
315+ perturbed_input = self ._format_model_input (inp .to_model_input (perturbed_tensor ))
316+
317+ if cur_target_idx :
318+ # the input batch size can be expanded by attr method
319+ output_token_tensor = (
320+ target_tokens [:cur_target_idx ]
321+ .unsqueeze (0 )
322+ .expand (perturbed_input .size (0 ), - 1 )
323+ .to (self .device )
324+ )
325+ new_input_tensor = torch .cat ([perturbed_input , output_token_tensor ], dim = 1 )
326+ else :
327+ new_input_tensor = perturbed_input
328+
329+ output_logits = self .model (new_input_tensor )
330+
331+ new_token_logits = output_logits .logits [:, - 1 ]
332+ log_probs = torch .nn .functional .log_softmax (new_token_logits , dim = 1 )
333+
334+ target_token = target_tokens [cur_target_idx ]
335+ token_log_probs = log_probs [..., target_token ]
336+
337+ # the attribution target is limited to the log probability
338+ return token_log_probs
339+
340+ def _format_model_input (self , model_input ):
341+ """
342+ Convert str to tokenized tensor
343+ """
344+ return model_input .to (self .device )
345+
346+ def attribute (
347+ self ,
348+ inp : InterpretableInput ,
349+ target : Union [str , torch .Tensor , None ] = None ,
350+ gen_args : Optional [Dict ] = None ,
351+ ** kwargs ,
352+ ):
353+ """
354+ Args:
355+ inp (InterpretableInput): input prompt for which attributions are computed
356+ target (str or Tensor, optional): target response with respect to
357+ which attributions are computed. If None, it uses the model
358+ to generate the target based on the input and gen_args.
359+ Default: None
360+ gen_args (dict, optional): arguments for generating the target. Only used if
361+ target is not given. When None, the default arguments are used,
362+ {"max_length": 25, "do_sample": False}
363+ Defaults: None
364+ **kwargs (Any): any extra keyword arguments passed to the call of the
365+ underlying attribute function of the given attribution instance
366+
367+ Returns:
368+
369+ attr (LLMAttributionResult): attribution result
370+ """
371+
372+ assert isinstance (
373+ inp , self .SUPPORTED_INPUTS
374+ ), f"LLMGradAttribution does not support input type { type (inp )} "
375+
376+ if target is None :
377+ # generate when None
378+ assert hasattr (self .model , "generate" ) and callable (self .model .generate ), (
379+ "The model does not have recognizable generate function."
380+ "Target must be given for attribution"
381+ )
382+
383+ if not gen_args :
384+ gen_args = DEFAULT_GEN_ARGS
385+
386+ model_inp = self ._format_model_input (inp .to_model_input ())
387+ output_tokens = self .model .generate (model_inp , ** gen_args )
388+ target_tokens = output_tokens [0 ][model_inp .size (1 ) :]
389+ else :
390+ assert gen_args is None , "gen_args must be None when target is given"
391+
392+ if type (target ) is str :
393+ # exclude sos
394+ target_tokens = self .tokenizer .encode (target )[1 :]
395+ target_tokens = torch .tensor (target_tokens )
396+ elif type (target ) is torch .Tensor :
397+ target_tokens = target
398+
399+ attr_inp = inp .to_tensor ().to (self .device )
400+
401+ attr_list = []
402+ for cur_target_idx , _ in enumerate (target_tokens ):
403+ # attr in shape(batch_size, input+output_len, emb_dim)
404+ attr = self .attr_method .attribute (
405+ attr_inp ,
406+ additional_forward_args = (
407+ inp ,
408+ target_tokens ,
409+ cur_target_idx ,
410+ ),
411+ ** kwargs ,
412+ )
413+ attr = cast (Tensor , attr )
414+
415+ # will have the attr for previous output tokens
416+ # cut to shape(batch_size, inp_len, emb_dim)
417+ if cur_target_idx :
418+ attr = attr [:, :- cur_target_idx ]
419+
420+ # the author of IG uses sum
421+ # https://github.com/ankurtaly/Integrated-Gradients/blob/master/BertModel/bert_model_utils.py#L350
422+ attr = attr .sum (- 1 )
423+
424+ attr_list .append (attr )
425+
426+ # assume inp batch only has one instance
427+ # to shape(n_output_token, ...)
428+ attr = torch .cat (attr_list , dim = 0 )
429+
430+ # grad attr method do not care the length of features in interpretable format
431+ # it attributes to all the elements of the output of the specified layer
432+ # so we need special handling for the inp type which don't care all the elements
433+ if isinstance (inp , TextTokenInput ) and inp .itp_mask is not None :
434+ itp_mask = inp .itp_mask .to (self .device )
435+ itp_mask = itp_mask .expand_as (attr )
436+ attr = attr [itp_mask ].view (attr .size (0 ), - 1 )
437+
438+ # for all the gradient methods we support in this class
439+ # the seq attr is the sum of all the token attr if the attr_target is log_prob,
440+ # shape(n_input_features)
441+ seq_attr = attr .sum (0 )
442+
443+ return LLMAttributionResult (
444+ seq_attr ,
445+ attr , # shape(n_output_token, n_input_features)
446+ inp .values ,
447+ self .tokenizer .convert_ids_to_tokens (target_tokens ),
448+ )
0 commit comments