Skip to content

Commit 288cd3a

Browse files
vivekmigfacebook-github-bot
authored andcommitted
Switch from register_full_backward_hooks to tensor hooks (#979)
Summary: This switches usage of full backward hooks to instead apply forward hooks which then add tensor backward hooks, as suggested in #914 . We initially did not choose this approach since it may have limitations with backward hooks on modules with multiple tensors as inputs / outputs (each tensor must be called independently in the hook), but all current use-cases within Captum only require a single tensor input / output. This change allows us to enable in-place modules as well as remove the limitation on neuron input attribution. DeepLift also no longer needs valid module checks, as these are no longer applicable with usage of tensor hooks. Pull Request resolved: #979 Reviewed By: NarineK Differential Revision: D41687791 Pulled By: vivekmig fbshipit-source-id: 2ddc5aac7b9bf70a56ffb3ace3dc026fca7d4bfa
1 parent cb44edd commit 288cd3a

21 files changed

+184
-266
lines changed

captum/_utils/common.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -684,20 +684,48 @@ def _get_module_from_name(model: Module, layer_name: str) -> Any:
684684

685685
def _register_backward_hook(
686686
module: Module, hook: Callable, attr_obj: Any
687-
) -> torch.utils.hooks.RemovableHandle:
688-
# Special case for supporting output attributions for neuron methods
689-
# This can be removed after deprecation of neuron output attributions
690-
# for NeuronDeepLift, NeuronDeconvolution, and NeuronGuidedBackprop
691-
# in v0.6.0
692-
if (
693-
hasattr(attr_obj, "skip_new_hook_layer")
694-
and attr_obj.skip_new_hook_layer == module
695-
):
696-
return module.register_backward_hook(hook)
687+
) -> List[torch.utils.hooks.RemovableHandle]:
688+
grad_out: Dict[device, Tensor] = {}
697689

698-
if _parse_version(torch.__version__) >= (1, 9, 0):
699-
# Only supported for torch >= 1.9
700-
return module.register_full_backward_hook(hook)
701-
else:
702-
# Fallback for previous versions of PyTorch
703-
return module.register_backward_hook(hook)
690+
def forward_hook(
691+
module: Module,
692+
inp: Union[Tensor, Tuple[Tensor, ...]],
693+
out: Union[Tensor, Tuple[Tensor, ...]],
694+
) -> None:
695+
nonlocal grad_out
696+
grad_out = {}
697+
698+
def output_tensor_hook(output_grad: Tensor) -> None:
699+
grad_out[output_grad.device] = output_grad
700+
701+
if isinstance(out, tuple):
702+
assert (
703+
len(out) == 1
704+
), "Backward hooks not supported for module with >1 output"
705+
out[0].register_hook(output_tensor_hook)
706+
else:
707+
out.register_hook(output_tensor_hook)
708+
709+
def pre_hook(module, inp):
710+
def input_tensor_hook(input_grad: Tensor):
711+
if len(grad_out) == 0:
712+
return
713+
hook_out = hook(module, input_grad, grad_out[input_grad.device])
714+
715+
if hook_out is not None:
716+
return hook_out[0] if isinstance(hook_out, tuple) else hook_out
717+
718+
if isinstance(inp, tuple):
719+
assert (
720+
len(inp) == 1
721+
), "Backward hooks not supported for module with >1 input"
722+
inp[0].register_hook(input_tensor_hook)
723+
return inp[0].clone()
724+
else:
725+
inp.register_hook(input_tensor_hook)
726+
return inp.clone()
727+
728+
return [
729+
module.register_forward_pre_hook(pre_hook),
730+
module.register_forward_hook(forward_hook),
731+
]

captum/_utils/gradient.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -837,6 +837,8 @@ def _compute_jacobian_wrt_params_with_sample_wise_trick(
837837
parameters of the i-th layer, for the j-th member of the minibatch.
838838
"""
839839
with torch.autograd.set_grad_enabled(True):
840+
inputs = tuple(inp.clone() for inp in inputs)
841+
apply_gradient_requirements(inputs)
840842
sample_grad_wrapper = SampleGradientWrapper(model, layer_modules)
841843
try:
842844
sample_grad_wrapper.add_hooks()

captum/_utils/sample_gradient.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def _register_module_hooks(self, module: torch.nn.Module) -> None:
120120
self.forward_hooks.append(
121121
module.register_forward_hook(self._forward_hook_fn)
122122
)
123-
self.backward_hooks.append(
123+
self.backward_hooks.extend(
124124
_register_backward_hook(module, self._backward_hook_fn, None)
125125
)
126126

captum/attr/_core/deep_lift.py

Lines changed: 24 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -43,34 +43,6 @@
4343
from torch.utils.hooks import RemovableHandle
4444

4545

46-
# Check if module backward hook can safely be used for the module that produced
47-
# this inputs / outputs mapping
48-
def _check_valid_module(inputs_grad_fn, outputs) -> bool:
49-
def is_output_cloned(output_fn, input_grad_fn) -> bool:
50-
"""
51-
Checks if the output has been cloned. This happens especially in case of
52-
layer deeplift.
53-
"""
54-
return (
55-
output_fn[0].next_functions is not None
56-
and output_fn[0].next_functions[0][0] == input_grad_fn
57-
)
58-
59-
curr_fn = outputs.grad_fn
60-
first_next = curr_fn.next_functions[0]
61-
try:
62-
# if `inputs` in the input to the network then the grad_fn is None and
63-
# for that input backward_hook isn't computed. That's the reason why we
64-
# need to check on `inputs_grad_fns[first_next[1]]` being None.
65-
return (
66-
inputs_grad_fn is None
67-
or first_next[0] == inputs_grad_fn
68-
or is_output_cloned(first_next, inputs_grad_fn)
69-
)
70-
except IndexError:
71-
return False
72-
73-
7446
class DeepLift(GradientAttribution):
7547
r"""
7648
Implements DeepLIFT algorithm based on the following paper:
@@ -112,10 +84,7 @@ def __init__(
11284
r"""
11385
Args:
11486
115-
model (nn.Module): The reference to PyTorch model instance. Model cannot
116-
contain any in-place nonlinear submodules; these are not
117-
supported by the register_full_backward_hook PyTorch API
118-
starting from PyTorch v1.9.
87+
model (nn.Module): The reference to PyTorch model instance.
11988
multiply_by_inputs (bool, optional): Indicates whether to factor
12089
model inputs' multiplier in the final attribution scores.
12190
In the literature this is also known as local vs global
@@ -430,25 +399,6 @@ def _forward_pre_hook(
430399
"""
431400
inputs = _format_tensor_into_tuples(inputs)
432401
module.input = inputs[0].clone().detach()
433-
module.input_grad_fns = inputs[0].grad_fn # type: ignore
434-
435-
def tensor_backward_hook(grad):
436-
if module.saved_grad is None:
437-
raise RuntimeError(
438-
"""Module {} was detected as not supporting correctly module
439-
backward hook. You should modify your hook to ignore the given
440-
grad_inputs (recompute them by hand if needed) and save the
441-
newly computed grad_inputs in module.saved_grad. See MaxPool1d
442-
as an example.""".format(
443-
module
444-
)
445-
)
446-
return module.saved_grad
447-
448-
# the hook is set by default but it will be used only for
449-
# failure cases and will be removed otherwise
450-
handle = inputs[0].register_hook(tensor_backward_hook)
451-
module.input_hook = handle
452402

453403
def _forward_hook(
454404
self,
@@ -462,30 +412,13 @@ def _forward_hook(
462412
"""
463413
outputs = _format_tensor_into_tuples(outputs)
464414
module.output = outputs[0].clone().detach()
465-
if not _check_valid_module(module.input_grad_fns, outputs[0]):
466-
warnings.warn(
467-
"""An invalid module {} is detected. Saved gradients will
468-
be used as the gradients of the module's input tensor.
469-
See MaxPool1d as an example.""".format(
470-
module
471-
)
472-
)
473-
module.is_invalid = True # type: ignore
474-
module.saved_grad = None # type: ignore
475-
self.forward_handles.append(cast(RemovableHandle, module.input_hook))
476-
else:
477-
module.is_invalid = False # type: ignore
478-
# removing the hook if there is no failure case
479-
cast(RemovableHandle, module.input_hook).remove()
480-
del module.input_hook
481-
del module.input_grad_fns
482415

483416
def _backward_hook(
484417
self,
485418
module: Module,
486-
grad_input: Union[Tensor, Tuple[Tensor, ...]],
487-
grad_output: Union[Tensor, Tuple[Tensor, ...]],
488-
):
419+
grad_input: Tensor,
420+
grad_output: Tensor,
421+
) -> Tensor:
489422
r"""
490423
`grad_input` is the gradient of the neuron with respect to its input
491424
`grad_output` is the gradient of the neuron with respect to its output
@@ -506,15 +439,14 @@ def _backward_hook(
506439
"Please, ensure that module is being used only once in the "
507440
"network.".format(module)
508441
)
509-
multipliers = tuple(
510-
SUPPORTED_NON_LINEAR[type(module)](
511-
module,
512-
module.input,
513-
module.output,
514-
grad_input,
515-
grad_output,
516-
eps=self.eps,
517-
)
442+
443+
multipliers = SUPPORTED_NON_LINEAR[type(module)](
444+
module,
445+
module.input,
446+
module.output,
447+
grad_input,
448+
grad_output,
449+
eps=self.eps,
518450
)
519451
# remove all the properies that we set for the inputs and output
520452
del module.input
@@ -545,10 +477,10 @@ def _register_hooks(
545477
# adds forward hook to leaf nodes that are non-linear
546478
forward_handle = module.register_forward_hook(self._forward_hook)
547479
pre_forward_handle = module.register_forward_pre_hook(self._forward_pre_hook)
548-
backward_handle = _register_backward_hook(module, self._backward_hook, self)
480+
backward_handles = _register_backward_hook(module, self._backward_hook, self)
549481
self.forward_handles.append(forward_handle)
550482
self.forward_handles.append(pre_forward_handle)
551-
self.backward_handles.append(backward_handle)
483+
self.backward_handles.extend(backward_handles)
552484

553485
def _remove_hooks(self, extra_hooks_to_remove: List[RemovableHandle]) -> None:
554486
for handle in extra_hooks_to_remove:
@@ -627,9 +559,7 @@ def __init__(self, model: Module, multiply_by_inputs: bool = True) -> None:
627559
r"""
628560
Args:
629561
630-
model (nn.Module): The reference to PyTorch model instance. Model cannot
631-
contain any in-place nonlinear submodules; these are not
632-
supported by the register_full_backward_hook PyTorch API.
562+
model (nn.Module): The reference to PyTorch model instance.
633563
multiply_by_inputs (bool, optional): Indicates whether to factor
634564
model inputs' multiplier in the final attribution scores.
635565
In the literature this is also known as local vs global
@@ -941,26 +871,18 @@ def nonlinear(
941871
grad_input: Tensor,
942872
grad_output: Tensor,
943873
eps: float = 1e-10,
944-
):
874+
) -> Tensor:
945875
r"""
946876
grad_input: (dLoss / dprev_layer_out, dLoss / wij, dLoss / bij)
947877
grad_output: (dLoss / dlayer_out)
948878
https://github.com/pytorch/pytorch/issues/12331
949879
"""
950880
delta_in, delta_out = _compute_diffs(inputs, outputs)
951881

952-
new_grad_inp = list(grad_input)
953-
954-
# supported non-linear modules take only single tensor as input hence accessing
955-
# only the first element in `grad_input` and `grad_output`
956-
new_grad_inp[0] = torch.where(
957-
abs(delta_in) < eps, new_grad_inp[0], grad_output[0] * delta_out / delta_in
882+
new_grad_inp = torch.where(
883+
abs(delta_in) < eps, grad_input, grad_output * delta_out / delta_in
958884
)
959885

960-
# If the module is invalid, save the newly computed gradients
961-
# The original_grad_input will be overridden later in the Tensor hook
962-
if module.is_invalid:
963-
module.saved_grad = new_grad_inp[0]
964886
return new_grad_inp
965887

966888

@@ -974,15 +896,14 @@ def softmax(
974896
):
975897
delta_in, delta_out = _compute_diffs(inputs, outputs)
976898

977-
new_grad_inp = list(grad_input)
978899
grad_input_unnorm = torch.where(
979-
abs(delta_in) < eps, new_grad_inp[0], grad_output[0] * delta_out / delta_in
900+
abs(delta_in) < eps, grad_input, grad_output * delta_out / delta_in
980901
)
981902
# normalizing
982-
n = grad_input[0].numel()
903+
n = grad_input.numel()
983904

984905
# updating only the first half
985-
new_grad_inp[0] = grad_input_unnorm - grad_input_unnorm.sum() * 1 / n
906+
new_grad_inp = grad_input_unnorm - grad_input_unnorm.sum() * 1 / n
986907
return new_grad_inp
987908

988909

@@ -1073,7 +994,7 @@ def maxpool(
1073994
module.ceil_mode,
1074995
True,
1075996
)
1076-
grad_output_updated = grad_output[0]
997+
grad_output_updated = grad_output
1077998
unpool_grad_out_delta, unpool_grad_out_ref_delta = torch.chunk(
1078999
unpool_func(
10791000
grad_output_updated * delta_out,
@@ -1089,20 +1010,7 @@ def maxpool(
10891010
unpool_grad_out_delta = unpool_grad_out_delta + unpool_grad_out_ref_delta
10901011
unpool_grad_out_delta = torch.cat(2 * [unpool_grad_out_delta])
10911012

1092-
# If the module is invalid, we need to recompute the grad_input
1093-
if module.is_invalid:
1094-
original_grad_input = grad_input
1095-
grad_input = (
1096-
unpool_func(
1097-
grad_output_updated,
1098-
indices,
1099-
module.kernel_size,
1100-
module.stride,
1101-
module.padding,
1102-
list(cast(torch.Size, module.input.shape)),
1103-
),
1104-
)
1105-
if grad_input[0].shape != inputs.shape:
1013+
if grad_input.shape != inputs.shape:
11061014
raise AssertionError(
11071015
"A problem occurred during maxpool modul's backward pass. "
11081016
"The gradients with respect to inputs include only a "
@@ -1118,13 +1026,7 @@ def maxpool(
11181026
new_grad_inp = torch.where(
11191027
abs(delta_in) < eps, grad_input[0], unpool_grad_out_delta / delta_in
11201028
)
1121-
# If the module is invalid, save the newly computed gradients
1122-
# The original_grad_input will be overridden later in the Tensor hook
1123-
if module.is_invalid:
1124-
module.saved_grad = new_grad_inp
1125-
return original_grad_input
1126-
else:
1127-
return (new_grad_inp,)
1029+
return new_grad_inp
11281030

11291031

11301032
def _compute_diffs(inputs: Tensor, outputs: Tensor) -> Tuple[Tensor, Tensor]:

captum/attr/_core/guided_backprop_deconvnet.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ def attribute(
7979

8080
def _register_hooks(self, module: Module):
8181
if isinstance(module, torch.nn.ReLU):
82-
hook = _register_backward_hook(module, self._backward_hook, self)
83-
self.backward_hooks.append(hook)
82+
hooks = _register_backward_hook(module, self._backward_hook, self)
83+
self.backward_hooks.extend(hooks)
8484

8585
def _backward_hook(
8686
self,
@@ -121,9 +121,7 @@ def __init__(self, model: Module) -> None:
121121
r"""
122122
Args:
123123
124-
model (nn.Module): The reference to PyTorch model instance. Model cannot
125-
contain any in-place ReLU submodules; these are not
126-
supported by the register_full_backward_hook PyTorch API.
124+
model (nn.Module): The reference to PyTorch model instance.
127125
"""
128126
ModifiedReluGradientAttribution.__init__(
129127
self, model, use_relu_grad_output=False
@@ -234,9 +232,7 @@ def __init__(self, model: Module) -> None:
234232
r"""
235233
Args:
236234
237-
model (nn.Module): The reference to PyTorch model instance. Model cannot
238-
contain any in-place ReLU submodules; these are not
239-
supported by the register_full_backward_hook PyTorch API.
235+
model (nn.Module): The reference to PyTorch model instance.
240236
"""
241237
ModifiedReluGradientAttribution.__init__(self, model, use_relu_grad_output=True)
242238

captum/attr/_core/guided_grad_cam.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,7 @@ def __init__(
5151
r"""
5252
Args:
5353
54-
model (nn.Module): The reference to PyTorch model instance. Model cannot
55-
contain any in-place ReLU submodules; these are not
56-
supported by the register_full_backward_hook PyTorch API
57-
starting from PyTorch v1.9.
54+
model (nn.Module): The reference to PyTorch model instance.
5855
layer (torch.nn.Module): Layer for which GradCAM attributions are computed.
5956
Currently, only layers with a single tensor output are
6057
supported.

0 commit comments

Comments
 (0)