Skip to content
Closed
Empty file modified captum/optim/__init__.py
100755 → 100644
Empty file.
20 changes: 16 additions & 4 deletions captum/optim/_core/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
) -> None:
r"""
Args:

model (nn.Module): The reference to PyTorch model instance.
input_param (nn.Module, optional): A module that generates an input,
consumed by the model.
Expand All @@ -71,6 +72,7 @@ def __init__(

def loss(self) -> torch.Tensor:
r"""Compute loss value for current iteration.

Returns:
*tensor* representing **loss**:
- **loss** (*tensor*):
Expand Down Expand Up @@ -115,18 +117,26 @@ def optimize(
lr: float = 0.025,
) -> torch.Tensor:
r"""Optimize input based on loss function and objectives.

Args:

stop_criteria (StopCriteria, optional): A function that is called
every iteration and returns a bool that determines whether
to stop the optimization.
See captum.optim.typing.StopCriteria for details.
optimizer (Optimizer, optional): An torch.optim.Optimizer used to
optimize the input based on the loss function.
loss_summarize_fn (Callable, optional): The function to use for summarizing
tensor outputs from loss functions.
Default: default_loss_summarize
lr: (float, optional): If no optimizer is given, then lr is used as the
learning rate for the Adam optimizer.
Default: 0.025

Returns:
*list* of *np.arrays* representing the **history**:
- **history** (*list*):
A list of loss values per iteration.
Length of the list corresponds to the number of iterations
history (torch.Tensor): A stack of loss values per iteration. The size
of the dimension on which loss values are stacked corresponds to
the number of iterations.
"""
stop_criteria = stop_criteria or n_steps(512)
optimizer = optimizer or optim.Adam(self.parameters(), lr=lr)
Expand All @@ -150,10 +160,12 @@ def optimize(

def n_steps(n: int, show_progress: bool = True) -> StopCriteria:
"""StopCriteria generator that uses number of steps as a stop criteria.

Args:
n (int): Number of steps to run optimization.
show_progress (bool, optional): Whether or not to show progress bar.
Default: True

Returns:
*StopCriteria* callable
"""
Expand Down
54 changes: 47 additions & 7 deletions captum/optim/_core/output_hook.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,37 @@
from captum.optim._utils.typing import ModuleOutputMapping, TupleOfTensorsOrTensorType


class ModuleReuseException(Exception):
pass


class ModuleOutputsHook:
def __init__(self, target_modules: Iterable[nn.Module]) -> None:
"""
Args:

target_modules (Iterable of nn.Module): A list of nn.Module targets.
"""
self.outputs: ModuleOutputMapping = dict.fromkeys(target_modules, None)
self.hooks = [
module.register_forward_hook(self._forward_hook())
for module in target_modules
]

def _reset_outputs(self) -> None:
"""
Delete captured activations.
"""
self.outputs = dict.fromkeys(self.outputs.keys(), None)

@property
def is_ready(self) -> bool:
return all(value is not None for value in self.outputs.values())

def _forward_hook(self) -> Callable:
"""
Return the forward_hook function.

Returns:
forward_hook (Callable): The forward_hook function.
"""

def forward_hook(
module: nn.Module, input: Tuple[torch.Tensor], output: torch.Tensor
) -> None:
Expand All @@ -49,6 +60,12 @@ def forward_hook(
return forward_hook

def consume_outputs(self) -> ModuleOutputMapping:
"""
Collect target activations and return them.

Returns:
outputs (ModuleOutputMapping): The captured outputs.
"""
if not self.is_ready:
warn(
"Consume captured outputs, but not all requested target outputs "
Expand All @@ -63,11 +80,16 @@ def targets(self) -> Iterable[nn.Module]:
return self.outputs.keys()

def remove_hooks(self) -> None:
"""
Remove hooks.
"""
for hook in self.hooks:
hook.remove()

def __del__(self) -> None:
# print(f"DEL HOOKS!: {list(self.outputs.keys())}")
"""
Ensure that using 'del' properly deletes hooks.
"""
self.remove_hooks()


Expand All @@ -77,16 +99,34 @@ class ActivationFetcher:
"""

def __init__(self, model: nn.Module, targets: Iterable[nn.Module]) -> None:
"""
Args:

model (nn.Module): The reference to PyTorch model instance.
targets (nn.Module or list of nn.Module): The target layers to
collect activations from.
"""
super(ActivationFetcher, self).__init__()
self.model = model
self.layers = ModuleOutputsHook(targets)

def __call__(self, input_t: TupleOfTensorsOrTensorType) -> ModuleOutputMapping:
"""
Args:

input_t (tensor or tuple of tensors, optional): The input to use
with the specified model.

Returns:
activations_dict: An dict containing the collected activations. The keys
for the returned dictionary are the target layers.
"""

try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
self.model(input_t)
activations = self.layers.consume_outputs()
activations_dict = self.layers.consume_outputs()
finally:
self.layers.remove_hooks()
return activations
return activations_dict
Loading