diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index fab34221ba..6355b62bd3 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -53,6 +53,16 @@ def __init__(self, forward_func: Callable) -> None: PerturbationAttribution.__init__(self, forward_func) self.use_weights = False + # only used when perturbations_per_eval > 1, where the 1st dim of forward_func's + # output must grow as the input batch size. If forward's output is aggregated, + # we cannot expand the input to include more perturbations in one call. + # If it's False, we will force the validation by comparing the outpus of + # the original input and the modified input whose batch size expanded based on + # perturbations_per_eval. Set the flag to True if the output of the modified + # input grow as expected. Once it turns to True, we will assume the model's + # behavior stays consistent and no longer check again + self._is_output_shape_valid = False + @log_usage() def attribute( self, @@ -291,21 +301,10 @@ def attribute( # flatten eval outputs into 1D (n_outputs) # add the leading dim for n_feature_perturbed - initial_eval = initial_eval.reshape(1, -1) - - agg_output_mode = FeatureAblation._find_output_mode( - perturbations_per_eval, feature_mask - ) - - if not agg_output_mode: - assert n_outputs == num_examples, ( - "expected output of `forward_func` to have " - + "`batch_size` elements for perturbations_per_eval > 1 " - + "and all feature_mask.shape[0] > 1" - ) + flattened_initial_eval = initial_eval.reshape(1, -1) # Initialize attribution totals and counts - attrib_type = cast(dtype, initial_eval.dtype) + attrib_type = cast(dtype, flattened_initial_eval.dtype) total_attrib = [ # attribute w.r.t each output element @@ -362,21 +361,43 @@ def attribute( if show_progress: attr_progress.update() - if not agg_output_mode: - # current_batch_size is not n_examples - # it may get expanded by n_feature_perturbed + # if perturbations_per_eval > 1, the output shape must grow with + # input and not be aggregated + if perturbations_per_eval > 1 and not self._is_output_shape_valid: current_batch_size = current_inputs[0].shape[0] + + # number of perturbation, which is not the same as + # perturbations_per_eval when not enough features to perturb + n_perturb = current_batch_size / num_examples + + current_output_shape = modified_eval.shape + + # use initial_eval as the forward of perturbations_per_eval = 1 + initial_output_shape = initial_eval.shape + assert ( - modified_eval.numel() == current_batch_size - ), """expected output of forward_func to grow with - batch_size. If this is not the case for your model - please set perturbations_per_eval = 1""" + # check if the output is not a scalar + current_output_shape + and initial_output_shape + # check if the output grow in same ratio, i.e., not agg + and current_output_shape[0] + == n_perturb * initial_output_shape[0] + ), ( + "When perturbations_per_eval > 1, forward_func's output " + "should be a tensor whose 1st dim grow with the input " + f"batch size: when input batch size is {num_examples}, " + f"the output shape is {initial_output_shape}; " + f"when input batch size is {current_batch_size}, " + f"the output shape is {current_output_shape}" + ) + + self._is_output_shape_valid = True # reshape the leading dim for n_feature_perturbed # flatten each feature's eval outputs into 1D of (n_outputs) modified_eval = modified_eval.reshape(-1, n_outputs) # eval_diff in shape (n_feature_perturbed, n_outputs) - eval_diff = initial_eval - modified_eval + eval_diff = flattened_initial_eval - modified_eval # append the shape of one input example # to make it broadcastable to mask @@ -572,28 +593,6 @@ def _get_feature_counts(self, inputs, feature_mask, **kwargs): for inp, mask in zip(inputs, feature_mask) ) - @staticmethod - def _find_output_mode( - perturbations_per_eval: int, - feature_mask: Union[None, TensorOrTupleOfTensorsGeneric], - ) -> bool: - """ - Returns True if the output mode is "aggregation output mode" - - Aggregation output mode is defined as: when there is no 1:1 correspondence - with the `num_examples` (`batch_size`) and the amount of outputs your model - produces, i.e. the model output does not grow in size as the input becomes - larger. - - We assume this is the case if `perturbations_per_eval == 1` - and your feature mask is None or is associated to all - examples in a batch (fm.shape[0] == 1 for all fm in feature_mask). - """ - return perturbations_per_eval == 1 and ( - feature_mask is None - or all(len(sm.shape) == 0 or sm.shape[0] == 1 for sm in feature_mask) - ) - def _strict_run_forward(self, *args, **kwargs) -> Tensor: """ A temp wrapper for global _run_forward util to force forward output diff --git a/tests/attr/test_feature_ablation.py b/tests/attr/test_feature_ablation.py index 290b9bf265..91ff63d259 100644 --- a/tests/attr/test_feature_ablation.py +++ b/tests/attr/test_feature_ablation.py @@ -345,17 +345,6 @@ def forward_func(inp): with self.assertRaises(AssertionError): _ = ablation.attribute(inp, perturbations_per_eval=2) - def test_error_agg_mode_incorrect_fm(self) -> None: - def forward_func(inp): - return inp[0].unsqueeze(0) - - inp = torch.tensor([[1, 2, 3], [4, 5, 6]]) - mask = torch.tensor([[0, 1, 2], [0, 0, 1]]) - - ablation = FeatureAblation(forward_func) - with self.assertRaises(AssertionError): - _ = ablation.attribute(inp, perturbations_per_eval=1, feature_mask=mask) - def test_empty_sparse_features(self) -> None: ablation_algo = FeatureAblation(BasicModelWithSparseInputs()) inp1 = torch.tensor([[1.0, -2.0, 3.0], [2.0, -1.0, 3.0]])