diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index 28882c7811..fab34221ba 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -279,7 +279,7 @@ def attribute( # Computes initial evaluation with all features, which is compared # to each ablated result. - initial_eval = _run_forward( + initial_eval = self._strict_run_forward( self.forward_func, inputs, target, additional_forward_args ) @@ -291,27 +291,21 @@ def attribute( # flatten eval outputs into 1D (n_outputs) # add the leading dim for n_feature_perturbed - if isinstance(initial_eval, Tensor): - initial_eval = initial_eval.reshape(1, -1) + initial_eval = initial_eval.reshape(1, -1) agg_output_mode = FeatureAblation._find_output_mode( perturbations_per_eval, feature_mask ) if not agg_output_mode: - assert isinstance(initial_eval, Tensor) and n_outputs == num_examples, ( + assert n_outputs == num_examples, ( "expected output of `forward_func` to have " + "`batch_size` elements for perturbations_per_eval > 1 " + "and all feature_mask.shape[0] > 1" ) # Initialize attribution totals and counts - attrib_type = cast( - dtype, - initial_eval.dtype - if isinstance(initial_eval, Tensor) - else type(initial_eval), - ) + attrib_type = cast(dtype, initial_eval.dtype) total_attrib = [ # attribute w.r.t each output element @@ -358,7 +352,7 @@ def attribute( # agg mode: (*initial_eval.shape) # non-agg mode: # (feature_perturbed * batch_size, *initial_eval.shape[1:]) - modified_eval = _run_forward( + modified_eval = self._strict_run_forward( self.forward_func, current_inputs, current_target, @@ -368,31 +362,29 @@ def attribute( if show_progress: attr_progress.update() - if not isinstance(modified_eval, torch.Tensor): - eval_diff = initial_eval - modified_eval - else: - if not agg_output_mode: - # current_batch_size is not n_examples - # it may get expanded by n_feature_perturbed - current_batch_size = current_inputs[0].shape[0] - assert ( - modified_eval.numel() == current_batch_size - ), """expected output of forward_func to grow with - batch_size. If this is not the case for your model - please set perturbations_per_eval = 1""" - - # reshape the leading dim for n_feature_perturbed - # flatten each feature's eval outputs into 1D of (n_outputs) - modified_eval = modified_eval.reshape(-1, n_outputs) - # eval_diff in shape (n_feature_perturbed, n_outputs) - eval_diff = initial_eval - modified_eval - - # append the shape of one input example - # to make it broadcastable to mask - eval_diff = eval_diff.reshape( - eval_diff.shape + (inputs[i].dim() - 1) * (1,) - ) - eval_diff = eval_diff.to(total_attrib[i].device) + if not agg_output_mode: + # current_batch_size is not n_examples + # it may get expanded by n_feature_perturbed + current_batch_size = current_inputs[0].shape[0] + assert ( + modified_eval.numel() == current_batch_size + ), """expected output of forward_func to grow with + batch_size. If this is not the case for your model + please set perturbations_per_eval = 1""" + + # reshape the leading dim for n_feature_perturbed + # flatten each feature's eval outputs into 1D of (n_outputs) + modified_eval = modified_eval.reshape(-1, n_outputs) + # eval_diff in shape (n_feature_perturbed, n_outputs) + eval_diff = initial_eval - modified_eval + + # append the shape of one input example + # to make it broadcastable to mask + eval_diff = eval_diff.reshape( + eval_diff.shape + (inputs[i].dim() - 1) * (1,) + ) + eval_diff = eval_diff.to(total_attrib[i].device) + if self.use_weights: weights[i] += current_mask.float().sum(dim=0) @@ -601,3 +593,24 @@ def _find_output_mode( feature_mask is None or all(len(sm.shape) == 0 or sm.shape[0] == 1 for sm in feature_mask) ) + + def _strict_run_forward(self, *args, **kwargs) -> Tensor: + """ + A temp wrapper for global _run_forward util to force forward output + type assertion & conversion. + Remove after the strict logic is supported by all attr classes + """ + forward_output = _run_forward(*args, **kwargs) + if isinstance(forward_output, Tensor): + return forward_output + + output_type = type(forward_output) + assert output_type is int or output_type is float, ( + "the return of forward_func must be a tensor, int, or float," + f" received: {forward_output}" + ) + + # using python built-in type as torch dtype + # int -> torch.int64, float -> torch.float64 + # ref: https://github.com/pytorch/pytorch/pull/21215 + return torch.tensor(forward_output, dtype=output_type)