Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions captum/_utils/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,18 +849,21 @@ def _compute_jacobian_wrt_params_with_sample_wise_trick(
if labels is not None and loss_fn is not None:
loss = loss_fn(out, labels)
# TODO: allow loss_fn to be Callable
if isinstance(loss_fn, Module) and hasattr(loss_fn, "reduction"):
if (isinstance(loss_fn, Module) or callable(loss_fn)) and hasattr(
loss_fn, "reduction"
):
reduction = loss_fn.reduction # type: ignore
msg0 = (
"Please ensure that loss_fn.reduction is set to `sum` or `mean`"
)

assert loss_fn.reduction != "none", msg0
assert reduction != "none", msg0
msg1 = (
f"loss_fn.reduction ({loss_fn.reduction}) does not match"
f"loss_fn.reduction ({reduction}) does not match"
f"reduction type ({reduction_type}). Please ensure they are"
" matching."
)
assert loss_fn.reduction == reduction_type, msg1
assert reduction == reduction_type, msg1
msg2 = (
"Please ensure custom loss function is applying either a "
"sum or mean reduction."
Expand Down
118 changes: 71 additions & 47 deletions captum/influence/_core/tracincp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from captum._utils.progress import NullProgress, progress
from captum.influence._core.influence import DataInfluence
from captum.influence._utils.common import (
_check_loss_fn,
_format_inputs_dataset,
_get_k_most_influential_helper,
_gradient_dot_product,
Expand Down Expand Up @@ -102,6 +103,7 @@ def __init__(
checkpoints_load_func: Callable = _load_flexible_state_dict,
loss_fn: Optional[Union[Module, Callable]] = None,
batch_size: Union[int, None] = 1,
test_loss_fn: Optional[Union[Module, Callable]] = None,
) -> None:
r"""
Args:
Expand Down Expand Up @@ -152,6 +154,19 @@ def __init__(
`train_dataset` is a Dataset. If `train_dataset`
is a DataLoader, then `batch_size` is ignored as an argument.
Default: 1
test_loss_fn (Callable, optional): In some cases, one may want to use a
separate loss functions for training examples, i.e. those in
`train_dataset`, and for test examples, i.e. those
represented by the `inputs` and `targets` arguments to the
`influence` method. For example, if one wants to calculate the
influence score of a training example on a test example's
prediction for a fixed class, `test_loss_fn` could map from the
logits for all classes to the logits for a fixed class.
`test_loss_fn` needs to satisfy the same constraints as `loss_fn`.
If not provided, the loss function for test examples is assumed to
be the same as the loss function for training examples, i.e.
`loss_fn`.
Default: None
"""

self.model = model
Expand All @@ -167,6 +182,8 @@ def __init__(

self.checkpoints_load_func = checkpoints_load_func
self.loss_fn = loss_fn
# If test_loss_fn not provided, it's assumed to be same as loss_fn
self.test_loss_fn = loss_fn if test_loss_fn is None else test_loss_fn
self.batch_size = batch_size

if not isinstance(train_dataset, DataLoader):
Expand Down Expand Up @@ -489,6 +506,7 @@ def __init__(
layers: Optional[List[str]] = None,
loss_fn: Optional[Union[Module, Callable]] = None,
batch_size: Union[int, None] = 1,
test_loss_fn: Optional[Union[Module, Callable]] = None,
sample_wise_grads_per_batch: bool = False,
) -> None:
r"""
Expand Down Expand Up @@ -561,6 +579,24 @@ def __init__(
`train_dataset` is a Dataset. If `train_dataset`
is a DataLoader, then `batch_size` is ignored as an argument.
Default: 1
test_loss_fn (Callable, optional): In some cases, one may want to use a
separate loss functions for training examples, i.e. those in
`train_dataset`, and for test examples, i.e. those
represented by the `inputs` and `targets` arguments to the
`influence` method. For example, if one wants to calculate the
influence score of a training example on a test example's
prediction for a fixed class, `test_loss_fn` could map from the
logits for all classes to the logits for a fixed class.
`test_loss_fn` needs satisfy the same constraints as `loss_fn`.
Thus, the same checks that we apply to `loss_fn` are also applied
to `test_loss_fn`, if the latter is provided. Note that the
constraints on both `loss_fn` and `test_loss_fn` both depend on
`sample_wise_grads_per_batch`. This means `loss_fn` and
`test_loss_fn` must either both be "per-example" loss functions,
or both be "reduction" loss functions. If not provided, the loss
function for test examples is assumed to be the same as the loss
function for training examples, i.e. `loss_fn`.
Default: None
sample_wise_grads_per_batch (bool, optional): PyTorch's native gradient
computations w.r.t. model parameters aggregates the results for a
batch and does not allow to access sample-wise gradients w.r.t.
Expand Down Expand Up @@ -590,51 +626,23 @@ def __init__(
checkpoints_load_func,
loss_fn,
batch_size,
test_loss_fn,
)

self.sample_wise_grads_per_batch = sample_wise_grads_per_batch

# If we are able to access the reduction used by `loss_fn`, we check whether
# the reduction is compatible with `sample_wise_grads_per_batch`
if isinstance(loss_fn, Module) and hasattr(
loss_fn, "reduction"
): # TODO: allow loss_fn to be Callable
if self.sample_wise_grads_per_batch:
assert loss_fn.reduction in ["sum", "mean"], (
'reduction for `loss_fn` must be "sum" or "mean" when '
"`sample_wise_grads_per_batch` is True"
)
self.reduction_type = str(loss_fn.reduction)
else:
assert loss_fn.reduction == "none", (
'reduction for `loss_fn` must be "none" when '
"`sample_wise_grads_per_batch` is False"
)
else:
# if we are unable to access the reduction used by `loss_fn`, we warn
# the user about the assumptions we are making regarding the reduction
# used by `loss_fn`
if self.sample_wise_grads_per_batch:
warnings.warn(
'Since `loss_fn` has no "reduction" attribute, and '
"`sample_wise_grads_per_batch` is True, the implementation assumes "
'that `loss_fn` is a "reduction" loss function that reduces the '
"per-example losses by taking their *sum*. If `loss_fn` "
"instead reduces the per-example losses by taking their mean, "
'please set the reduction attribute of `loss_fn` to "mean", i.e. '
'`loss_fn.reduction = "mean"`. Note that if '
"`sample_wise_grads_per_batch` is True, the implementation "
"assumes the reduction is either a sum or mean reduction."
)
self.reduction_type = "sum"
else:
warnings.warn(
'Since `loss_fn` has no "reduction" attribute, and '
"`sample_wise_grads_per_batch` is False, the implementation "
'assumes that `loss_fn` is a "per-example" loss function (see '
"documentation for `loss_fn` for details). Please ensure that "
"this is the case."
)
# check `loss_fn`
self.reduction_type = _check_loss_fn(
self, loss_fn, "loss_fn", sample_wise_grads_per_batch
)
# check `test_loss_fn` if it was provided
self.test_reduction_type = (
self.reduction_type
if test_loss_fn is None
else _check_loss_fn(
self, test_loss_fn, "test_loss_fn", sample_wise_grads_per_batch
)
)

r"""
TODO: Either restore model state after done (would have to place functionality
Expand Down Expand Up @@ -790,11 +798,15 @@ def get_checkpoint_contribution(checkpoint):
input_jacobians = self._basic_computation_tracincp(
inputs,
targets,
self.test_loss_fn,
self.test_reduction_type,
)
return (
_gradient_dot_product(
input_jacobians,
self._basic_computation_tracincp(batch[0:-1], batch[-1]),
self._basic_computation_tracincp(
batch[0:-1], batch[-1], self.loss_fn, self.reduction_type
),
)
* learning_rate
)
Expand Down Expand Up @@ -1055,7 +1067,10 @@ def get_checkpoint_contribution(checkpoint):
for batch in _inputs_dataset:

layer_jacobians = self._basic_computation_tracincp(
batch[0:-1], batch[-1]
batch[0:-1],
batch[-1],
self.loss_fn,
self.reduction_type,
)

# Note that all variables in this function are for an entire batch.
Expand Down Expand Up @@ -1196,11 +1211,14 @@ def _basic_computation_tracincp(
self,
inputs: Tuple[Any, ...],
targets: Optional[Tensor] = None,
loss_fn: Optional[Union[Module, Callable]] = None,
reduction_type: Optional[str] = None,
) -> Tuple[Tensor, ...]:
"""
For instances of TracInCP, computation of influence scores or self influence
scores repeatedly calls this function for different checkpoints
and batches.
and batches. In particular, this function computes the jacobian of a loss
function w.r.t. parameters in the `layers` initialization argument.

Args:

Expand All @@ -1210,20 +1228,26 @@ def _basic_computation_tracincp(
that `model(*inputs)` produces the predictions for the batch.
targets (tensor or None): If computing influence scores on a loss function,
these are the labels corresponding to the batch `inputs`.
Default: none
loss_fn (Callable, optional): The loss function to use when computing the
jacobian.
reduction_type (str, optional): The reduction type of `loss_fn`. This
argument is only used if `sample_wise_grads_per_batch` was true in
initialization.
"""
if self.sample_wise_grads_per_batch:
return _compute_jacobian_wrt_params_with_sample_wise_trick(
self.model,
inputs,
targets,
self.loss_fn,
self.reduction_type,
loss_fn,
reduction_type,
self.layer_modules,
)
return _compute_jacobian_wrt_params(
self.model,
inputs,
targets,
self.loss_fn,
loss_fn,
self.layer_modules,
)
Loading