From 1aab81ad0bed68c6b9db438cd74a0f84ff18dff0 Mon Sep 17 00:00:00 2001 From: Sarah Tran Date: Thu, 13 Mar 2025 09:02:17 -0700 Subject: [PATCH] Avoid unnecessary tensor construction when creating input masks for permutation/ablation (#1527) Summary: Pull Request resolved: https://github.com/pytorch/captum/pull/1527 Study: https://docs.google.com/spreadsheets/d/1GyNJJBrNkazGOyJQLv00QV4phX2R3488oNgVPT17qzU/edit?gid=0#gid=0 Saw a regression in the new logic introduced in D69531512 with one of the models for both permutation and ablation methods, potentially due to large sparse features. vivekmig suggested we can avoid creating all these zero tensors Reviewed By: craymichael Differential Revision: D71057703 --- captum/attr/_core/feature_ablation.py | 56 ++++++++++++++++-------- captum/attr/_core/feature_permutation.py | 21 +++++---- 2 files changed, 51 insertions(+), 26 deletions(-) diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index 5375dbb638..d7f2570c9b 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -3,7 +3,18 @@ # pyre-strict import math -from typing import Any, Callable, cast, Generator, List, Optional, Tuple, TypeVar, Union +from typing import ( + Any, + Callable, + cast, + Dict, + Generator, + List, + Optional, + Tuple, + TypeVar, + Union, +) import torch from captum._utils.common import ( @@ -465,6 +476,13 @@ def _attribute_with_cross_tensor_feature_masks( attrib_type: dtype, **kwargs: Any, ) -> Tuple[List[Tensor], List[Tensor]]: + feature_idx_to_tensor_idx: Dict[int, List[int]] = {} + for i, mask in enumerate(formatted_feature_mask): + for feature_idx in torch.unique(mask): + if feature_idx.item() not in feature_idx_to_tensor_idx: + feature_idx_to_tensor_idx[feature_idx.item()] = [] + feature_idx_to_tensor_idx[feature_idx.item()].append(i) + for ( current_inputs, current_mask, @@ -472,6 +490,7 @@ def _attribute_with_cross_tensor_feature_masks( formatted_inputs, baselines, formatted_feature_mask, + feature_idx_to_tensor_idx, **kwargs, ): # modified_eval has (n_feature_perturbed * n_outputs) elements @@ -511,27 +530,28 @@ def _ablation_generator( inputs: Tuple[Tensor, ...], baselines: BaselineType, input_mask: Tuple[Tensor, ...], + feature_idx_to_tensor_idx: Dict[int, List[int]], **kwargs: Any, ) -> Generator[ Tuple[ Tuple[Tensor, ...], - Tuple[Tensor, ...], + Tuple[Optional[Tensor], ...], ], None, None, ]: - unique_feature_ids = torch.unique( - torch.cat([mask.flatten() for mask in input_mask]) - ).tolist() - if isinstance(baselines, torch.Tensor): baselines = baselines.reshape((1,) + tuple(baselines.shape)) # Process one feature per time, rather than processing every input tensor - for feature_idx in unique_feature_ids: + for feature_idx in feature_idx_to_tensor_idx.keys(): ablated_inputs, current_masks = ( self._construct_ablated_input_across_tensors( - inputs, input_mask, baselines, feature_idx + inputs, + input_mask, + baselines, + feature_idx, + feature_idx_to_tensor_idx[feature_idx], ) ) yield ablated_inputs, current_masks @@ -542,18 +562,17 @@ def _construct_ablated_input_across_tensors( input_mask: Tuple[Tensor, ...], baselines: BaselineType, feature_idx: int, - ) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]: + tensor_idxs: List[int], + ) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]: ablated_inputs = [] - current_masks = [] + current_masks: List[Optional[Tensor]] = [] for i, input_tensor in enumerate(inputs): - mask = input_mask[i] - tensor_mask = mask == feature_idx - if not tensor_mask.any(): + if i not in tensor_idxs: ablated_inputs.append(input_tensor) - current_masks.append(torch.zeros_like(tensor_mask)) + current_masks.append(None) continue - tensor_mask = tensor_mask.to(input_tensor.device).long() + tensor_mask = (input_mask[i] == feature_idx).to(input_tensor.device).long() baseline = baselines[i] if isinstance(baselines, tuple) else baselines if isinstance(baseline, torch.Tensor): baseline = baseline.reshape( @@ -1173,7 +1192,7 @@ def _process_ablated_out( def _process_ablated_out_full( self, modified_eval: Tensor, - current_mask: Tuple[Tensor, ...], + current_mask: Tuple[Optional[Tensor], ...], flattened_initial_eval: Tensor, inputs: TensorOrTupleOfTensorsGeneric, n_outputs: int, @@ -1195,9 +1214,10 @@ def _process_ablated_out_full( if self.use_weights: for weight, mask in zip(weights, current_mask): - weight += mask.float().sum(dim=0) + if mask is not None: + weight += mask.float().sum(dim=0) for i, mask in enumerate(current_mask): - if inputs[i].numel() == 0: + if mask is None or inputs[i].numel() == 0: continue eval_diff = eval_diff.reshape( eval_diff_shape + (inputs[i].dim() - 1) * (1,) diff --git a/captum/attr/_core/feature_permutation.py b/captum/attr/_core/feature_permutation.py index 1fc85d16fe..3657c00fc2 100644 --- a/captum/attr/_core/feature_permutation.py +++ b/captum/attr/_core/feature_permutation.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # pyre-strict -from typing import Any, Callable, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union import torch from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric @@ -26,7 +26,7 @@ def _permute_feature(x: Tensor, feature_mask: Tensor) -> Tensor: def _permute_features_across_tensors( - inputs: Tuple[Tensor, ...], feature_masks: Tuple[Tensor, ...] + inputs: Tuple[Tensor, ...], feature_masks: Tuple[Optional[Tensor], ...] ) -> Tuple[Tensor, ...]: """ Permutes features across multiple input tensors using the corresponding @@ -34,7 +34,7 @@ def _permute_features_across_tensors( """ permuted_outputs = [] for input_tensor, feature_mask in zip(inputs, feature_masks): - if not feature_mask.any(): + if feature_mask is None or not feature_mask.any(): permuted_outputs.append(input_tensor) continue n = input_tensor.size(0) @@ -103,7 +103,7 @@ def __init__( forward_func: Callable[..., Union[int, float, Tensor, Future[Tensor]]], perm_func: Callable[[Tensor, Tensor], Tensor] = _permute_feature, perm_func_cross_tensor: Callable[ - [Tuple[Tensor, ...], Tuple[Tensor, ...]], Tuple[Tensor, ...] + [Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]], Tuple[Tensor, ...] ] = _permute_features_across_tensors, ) -> None: r""" @@ -392,9 +392,14 @@ def _construct_ablated_input_across_tensors( input_mask: Tuple[Tensor, ...], baselines: BaselineType, feature_idx: int, - ) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]: - feature_masks = tuple( - (mask == feature_idx).to(inputs[0].device) for mask in input_mask - ) + tensor_idxs: List[int], + ) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]: + current_masks: List[Optional[Tensor]] = [] + for i, mask in enumerate(input_mask): + if i in tensor_idxs: + current_masks.append((mask == feature_idx).to(inputs[0].device)) + else: + current_masks.append(None) + feature_masks = tuple(current_masks) permuted_outputs = self.perm_func_cross_tensor(inputs, feature_masks) return permuted_outputs, feature_masks