diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index 5375dbb638..d7f2570c9b 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -3,7 +3,18 @@ # pyre-strict import math -from typing import Any, Callable, cast, Generator, List, Optional, Tuple, TypeVar, Union +from typing import ( + Any, + Callable, + cast, + Dict, + Generator, + List, + Optional, + Tuple, + TypeVar, + Union, +) import torch from captum._utils.common import ( @@ -465,6 +476,13 @@ def _attribute_with_cross_tensor_feature_masks( attrib_type: dtype, **kwargs: Any, ) -> Tuple[List[Tensor], List[Tensor]]: + feature_idx_to_tensor_idx: Dict[int, List[int]] = {} + for i, mask in enumerate(formatted_feature_mask): + for feature_idx in torch.unique(mask): + if feature_idx.item() not in feature_idx_to_tensor_idx: + feature_idx_to_tensor_idx[feature_idx.item()] = [] + feature_idx_to_tensor_idx[feature_idx.item()].append(i) + for ( current_inputs, current_mask, @@ -472,6 +490,7 @@ def _attribute_with_cross_tensor_feature_masks( formatted_inputs, baselines, formatted_feature_mask, + feature_idx_to_tensor_idx, **kwargs, ): # modified_eval has (n_feature_perturbed * n_outputs) elements @@ -511,27 +530,28 @@ def _ablation_generator( inputs: Tuple[Tensor, ...], baselines: BaselineType, input_mask: Tuple[Tensor, ...], + feature_idx_to_tensor_idx: Dict[int, List[int]], **kwargs: Any, ) -> Generator[ Tuple[ Tuple[Tensor, ...], - Tuple[Tensor, ...], + Tuple[Optional[Tensor], ...], ], None, None, ]: - unique_feature_ids = torch.unique( - torch.cat([mask.flatten() for mask in input_mask]) - ).tolist() - if isinstance(baselines, torch.Tensor): baselines = baselines.reshape((1,) + tuple(baselines.shape)) # Process one feature per time, rather than processing every input tensor - for feature_idx in unique_feature_ids: + for feature_idx in feature_idx_to_tensor_idx.keys(): ablated_inputs, current_masks = ( self._construct_ablated_input_across_tensors( - inputs, input_mask, baselines, feature_idx + inputs, + input_mask, + baselines, + feature_idx, + feature_idx_to_tensor_idx[feature_idx], ) ) yield ablated_inputs, current_masks @@ -542,18 +562,17 @@ def _construct_ablated_input_across_tensors( input_mask: Tuple[Tensor, ...], baselines: BaselineType, feature_idx: int, - ) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]: + tensor_idxs: List[int], + ) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]: ablated_inputs = [] - current_masks = [] + current_masks: List[Optional[Tensor]] = [] for i, input_tensor in enumerate(inputs): - mask = input_mask[i] - tensor_mask = mask == feature_idx - if not tensor_mask.any(): + if i not in tensor_idxs: ablated_inputs.append(input_tensor) - current_masks.append(torch.zeros_like(tensor_mask)) + current_masks.append(None) continue - tensor_mask = tensor_mask.to(input_tensor.device).long() + tensor_mask = (input_mask[i] == feature_idx).to(input_tensor.device).long() baseline = baselines[i] if isinstance(baselines, tuple) else baselines if isinstance(baseline, torch.Tensor): baseline = baseline.reshape( @@ -1173,7 +1192,7 @@ def _process_ablated_out( def _process_ablated_out_full( self, modified_eval: Tensor, - current_mask: Tuple[Tensor, ...], + current_mask: Tuple[Optional[Tensor], ...], flattened_initial_eval: Tensor, inputs: TensorOrTupleOfTensorsGeneric, n_outputs: int, @@ -1195,9 +1214,10 @@ def _process_ablated_out_full( if self.use_weights: for weight, mask in zip(weights, current_mask): - weight += mask.float().sum(dim=0) + if mask is not None: + weight += mask.float().sum(dim=0) for i, mask in enumerate(current_mask): - if inputs[i].numel() == 0: + if mask is None or inputs[i].numel() == 0: continue eval_diff = eval_diff.reshape( eval_diff_shape + (inputs[i].dim() - 1) * (1,) diff --git a/captum/attr/_core/feature_permutation.py b/captum/attr/_core/feature_permutation.py index 1fc85d16fe..3657c00fc2 100644 --- a/captum/attr/_core/feature_permutation.py +++ b/captum/attr/_core/feature_permutation.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # pyre-strict -from typing import Any, Callable, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union import torch from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric @@ -26,7 +26,7 @@ def _permute_feature(x: Tensor, feature_mask: Tensor) -> Tensor: def _permute_features_across_tensors( - inputs: Tuple[Tensor, ...], feature_masks: Tuple[Tensor, ...] + inputs: Tuple[Tensor, ...], feature_masks: Tuple[Optional[Tensor], ...] ) -> Tuple[Tensor, ...]: """ Permutes features across multiple input tensors using the corresponding @@ -34,7 +34,7 @@ def _permute_features_across_tensors( """ permuted_outputs = [] for input_tensor, feature_mask in zip(inputs, feature_masks): - if not feature_mask.any(): + if feature_mask is None or not feature_mask.any(): permuted_outputs.append(input_tensor) continue n = input_tensor.size(0) @@ -103,7 +103,7 @@ def __init__( forward_func: Callable[..., Union[int, float, Tensor, Future[Tensor]]], perm_func: Callable[[Tensor, Tensor], Tensor] = _permute_feature, perm_func_cross_tensor: Callable[ - [Tuple[Tensor, ...], Tuple[Tensor, ...]], Tuple[Tensor, ...] + [Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]], Tuple[Tensor, ...] ] = _permute_features_across_tensors, ) -> None: r""" @@ -392,9 +392,14 @@ def _construct_ablated_input_across_tensors( input_mask: Tuple[Tensor, ...], baselines: BaselineType, feature_idx: int, - ) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]: - feature_masks = tuple( - (mask == feature_idx).to(inputs[0].device) for mask in input_mask - ) + tensor_idxs: List[int], + ) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]: + current_masks: List[Optional[Tensor]] = [] + for i, mask in enumerate(input_mask): + if i in tensor_idxs: + current_masks.append((mask == feature_idx).to(inputs[0].device)) + else: + current_masks.append(None) + feature_masks = tuple(current_masks) permuted_outputs = self.perm_func_cross_tensor(inputs, feature_masks) return permuted_outputs, feature_masks