From f2263063f56a26cdffdc6a06274e7ff9cd01f4b9 Mon Sep 17 00:00:00 2001 From: Aobo Yang Date: Thu, 13 Oct 2022 23:32:00 -0700 Subject: [PATCH 1/4] refactor feature ablation --- captum/attr/_core/feature_ablation.py | 54 +++++++++++++++++---------- 1 file changed, 34 insertions(+), 20 deletions(-) diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index 0a68de505d..0f238e8e53 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -286,21 +286,22 @@ def attribute( if show_progress: attr_progress.update() + # number of elements in the output of forward_func + n_outputs = initial_eval.numel() if isinstance(initial_eval, Tensor) else 1 + + # flattent eval outputs into 1D (n_outputs) + # add the leading dim for n_feature_perturbed + if isinstance(initial_eval, Tensor): + initial_eval = initial_eval.reshape(1, -1) + agg_output_mode = FeatureAblation._find_output_mode( perturbations_per_eval, feature_mask ) - # get as a 2D tensor (if it is not a scalar) - if isinstance(initial_eval, torch.Tensor): - initial_eval = initial_eval.reshape(1, -1) - num_outputs = initial_eval.shape[1] - else: - num_outputs = 1 - if not agg_output_mode: assert ( - isinstance(initial_eval, torch.Tensor) - and num_outputs == num_examples + isinstance(initial_eval, Tensor) + and n_outputs == num_examples ), ( "expected output of `forward_func` to have " + "`batch_size` elements for perturbations_per_eval > 1 " @@ -316,8 +317,9 @@ def attribute( ) total_attrib = [ + # attribute w.r.t each output element torch.zeros( - (num_outputs,) + input.shape[1:], + (n_outputs, *input.shape[1:]), dtype=attrib_type, device=input.device, ) @@ -328,7 +330,7 @@ def attribute( if self.use_weights: weights = [ torch.zeros( - (num_outputs,) + input.shape[1:], device=input.device + (n_outputs, *input.shape[1:]), device=input.device ).float() for input in inputs ] @@ -354,8 +356,11 @@ def attribute( perturbations_per_eval, **kwargs, ): - # modified_eval dimensions: 1D tensor with length - # equal to #num_examples * #features in batch + # modified_eval has (n_feature_perturbed * n_outputs) elements + # shape: + # agg mode: (*initial_eval.shape) + # non-agg mode: + # (feature_perturbed * batch_size, *initial_eval.shape[1:]) modified_eval = _run_forward( self.forward_func, current_inputs, @@ -366,25 +371,34 @@ def attribute( if show_progress: attr_progress.update() - # (contains 1 more dimension than inputs). This adds extra - # dimensions of 1 to make the tensor broadcastable with the inputs - # tensor. if not isinstance(modified_eval, torch.Tensor): eval_diff = initial_eval - modified_eval else: if not agg_output_mode: + # current_batch_size is not n_examples + # it may get expanded by n_feature_perturbed + current_batch_size = current_inputs[0].shape[0] assert ( - modified_eval.numel() == current_inputs[0].shape[0] + modified_eval.numel() == current_batch_size ), """expected output of forward_func to grow with batch_size. If this is not the case for your model please set perturbations_per_eval = 1""" - eval_diff = ( - initial_eval - modified_eval.reshape((-1, num_outputs)) - ).reshape((-1, num_outputs) + (len(inputs[i].shape) - 1) * (1,)) + # reshape the leading dim for n_feature_perturbed + # flatten each feature's eval outputs into 1D of (n_outputs) + modified_eval = modified_eval.reshape(-1, n_outputs) + # eval_diff in shape (n_feature_perturbed, n_outputs) + eval_diff = initial_eval - modified_eval + + # append the shape of one input example + # to make it broadcastable to mask + eval_diff = eval_diff.reshape( + eval_diff.shape + (inputs[i].dim() - 1) * (1,) + ) eval_diff = eval_diff.to(total_attrib[i].device) if self.use_weights: weights[i] += current_mask.float().sum(dim=0) + total_attrib[i] += (eval_diff * current_mask.to(attrib_type)).sum( dim=0 ) From 081c81fdedf853152384f38fd5d8927bc391bbc0 Mon Sep 17 00:00:00 2001 From: Aobo Yang Date: Fri, 14 Oct 2022 00:38:28 -0700 Subject: [PATCH 2/4] fix mypy --- captum/attr/_core/feature_ablation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index 0f238e8e53..e5692fb79e 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -319,7 +319,7 @@ def attribute( total_attrib = [ # attribute w.r.t each output element torch.zeros( - (n_outputs, *input.shape[1:]), + (n_outputs,) + input.shape[1:], dtype=attrib_type, device=input.device, ) @@ -330,7 +330,7 @@ def attribute( if self.use_weights: weights = [ torch.zeros( - (n_outputs, *input.shape[1:]), device=input.device + (n_outputs,) + input.shape[1:], device=input.device ).float() for input in inputs ] From fb45bb73fb4e03202e3ba1f519c37e79a1aae837 Mon Sep 17 00:00:00 2001 From: Aobo Yang Date: Fri, 14 Oct 2022 00:45:25 -0700 Subject: [PATCH 3/4] ufmt --- captum/attr/_core/feature_ablation.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index e5692fb79e..23db4f47ff 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -299,10 +299,7 @@ def attribute( ) if not agg_output_mode: - assert ( - isinstance(initial_eval, Tensor) - and n_outputs == num_examples - ), ( + assert isinstance(initial_eval, Tensor) and n_outputs == num_examples, ( "expected output of `forward_func` to have " + "`batch_size` elements for perturbations_per_eval > 1 " + "and all feature_mask.shape[0] > 1" From 965071e9a3965f0e1792483ee8f96bfa95b947e4 Mon Sep 17 00:00:00 2001 From: Aobo Yang Date: Tue, 18 Oct 2022 11:30:45 -0700 Subject: [PATCH 4/4] typo --- captum/attr/_core/feature_ablation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index 23db4f47ff..28882c7811 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -289,7 +289,7 @@ def attribute( # number of elements in the output of forward_func n_outputs = initial_eval.numel() if isinstance(initial_eval, Tensor) else 1 - # flattent eval outputs into 1D (n_outputs) + # flatten eval outputs into 1D (n_outputs) # add the leading dim for n_feature_perturbed if isinstance(initial_eval, Tensor): initial_eval = initial_eval.reshape(1, -1)