Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions captum/influence/_core/influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,19 @@ class DataInfluence(ABC):
An abstract class to define model data influence skeleton.
"""

def __init_(
self, model: Module, influence_src_dataset: Dataset, **kwargs: Any
) -> None:
def __init_(self, model: Module, train_dataset: Dataset, **kwargs: Any) -> None:
r"""
Args:
model (torch.nn.Module): An instance of pytorch model.
influence_src_dataset (torch.utils.data.Dataset): PyTorch Dataset that is
train_dataset (torch.utils.data.Dataset): PyTorch Dataset that is
used to create a PyTorch Dataloader to iterate over the dataset and
its labels. This is the dataset for which we will be seeking for
influential instances. In most cases this is the training dataset.
**kwargs: Additional key-value arguments that are necessary for specific
implementation of `DataInfluence` abstract class.
"""
self.model = model
self.influence_src_dataset = influence_src_dataset
self.train_dataset = train_dataset

@abstractmethod
def influence(self, inputs: Any = None, **kwargs: Any) -> Any:
Expand Down
461 changes: 211 additions & 250 deletions captum/influence/_core/tracincp.py

Large diffs are not rendered by default.

470 changes: 201 additions & 269 deletions captum/influence/_core/tracincp_fast_rand_proj.py

Large diffs are not rendered by default.

14 changes: 6 additions & 8 deletions captum/influence/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,6 @@ def _get_k_most_influential_helper(
influence_src_dataloader: DataLoader,
influence_batch_fn: Callable,
inputs: Tuple[Any, ...],
targets: Optional[Tensor],
k: int = 5,
proponents: bool = True,
show_progress: bool = False,
Expand All @@ -204,13 +203,12 @@ def _get_k_most_influential_helper(
influence_src_dataloader (DataLoader): The DataLoader, representing training
data, for which we want to compute proponents / opponents.
influence_batch_fn (Callable): A callable that will be called via
`influence_batch_fn(inputs, targets, batch)`, where `batch` is a batch
`influence_batch_fn(inputs, batch)`, where `batch` is a batch
in the `influence_src_dataloader` argument.
inputs (tuple[Any, ...]): A batch of examples. Does not represent labels,
which are passed as `targets`.
targets (Tensor, optional): If computing TracIn scores on a loss function,
these are the labels corresponding to the batch `inputs`.
Default: None
inputs (tuple[Any, ...]): This argument represents the test batch, and is a
single tuple of any, where the last element is assumed to be the labels
for the batch. That is, `model(*batch[0:-1])` produces the output for
`model`, and `batch[-1]` are the labels, if any.
k (int, optional): The number of proponents or opponents to return per test
instance.
Default: 5
Expand Down Expand Up @@ -272,7 +270,7 @@ def _get_k_most_influential_helper(
for batch in influence_src_dataloader:

# calculate tracin_scores for the batch
batch_tracin_scores = influence_batch_fn(inputs, targets, batch)
batch_tracin_scores = influence_batch_fn(inputs, batch)
batch_tracin_scores *= multiplier

# get the top-k indices and tracin_scores for the batch
Expand Down
7 changes: 5 additions & 2 deletions tests/influence/_core/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from parameterized import parameterized
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from tests.influence._utils.common import (
_format_batch_into_tuple,
build_test_name_func,
DataInfluenceConstructor,
get_random_model_and_data,
Expand Down Expand Up @@ -76,7 +77,8 @@ def test_tracin_dataloader(
)

train_scores = tracin.influence(
test_samples, test_labels, k=None, unpack_inputs=unpack_inputs
_format_batch_into_tuple(test_samples, test_labels, unpack_inputs),
k=None,
)

tracin_dataloader = tracin_constructor(
Expand All @@ -88,7 +90,8 @@ def test_tracin_dataloader(
)

train_scores_dataloader = tracin_dataloader.influence(
test_samples, test_labels, k=None, unpack_inputs=unpack_inputs
_format_batch_into_tuple(test_samples, test_labels, unpack_inputs),
k=None,
)

assertTensorAlmostEqual(
Expand Down
21 changes: 5 additions & 16 deletions tests/influence/_core/test_tracin_intermediate_quantities.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from parameterized import parameterized
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from tests.influence._utils.common import (
_format_batch_into_tuple,
build_test_name_func,
DataInfluenceConstructor,
get_random_model_and_data,
Expand Down Expand Up @@ -224,25 +225,13 @@ def test_tracin_intermediate_quantities_consistent(
)

# compute influence scores without using `compute_intermediate_quantities`
test_batch = _format_batch_into_tuple(
test_features, test_labels, unpack_inputs
)
scores = tracin.influence(
test_features, test_labels, unpack_inputs=unpack_inputs
test_batch,
)

# compute influence scores using `compute_intermediate_quantities`
# we combine `test_features` and `test_labels` into a single tuple
# `test_batch` to pass to the model, with the assumption that
# `model(test_batch[0:-1]` produces the predictions, and `test_batch[-1]`
# are the labels. We do this due to the assumptions made by the
# `compute_intermediate_quantities` method. Therefore, how we
# form `test_batch` depends on whether `unpack_inputs` is True or False
if not unpack_inputs:
# `test_features` is a Tensor
test_batch = (test_features, test_labels)
else:
# `test_features` is a tuple, so we unpack it to place in tuple,
# along with `test_labels`
test_batch = (*test_features, test_labels) # type: ignore[assignment]

# the influence score is the dot product of intermediate quantities
intermediate_quantities_scores = torch.matmul(
intermediate_quantities_tracin.compute_intermediate_quantities(
Expand Down
8 changes: 4 additions & 4 deletions tests/influence/_core/test_tracin_k_most_influential.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from parameterized import parameterized
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from tests.influence._utils.common import (
_format_batch_into_tuple,
build_test_name_func,
DataInfluenceConstructor,
get_random_model_and_data,
Expand Down Expand Up @@ -107,15 +108,14 @@ def test_tracin_k_most_influential(
)

train_scores = tracin.influence(
test_samples, test_labels, k=None, unpack_inputs=unpack_inputs
_format_batch_into_tuple(test_samples, test_labels, unpack_inputs),
k=None,
)
sort_idx = torch.argsort(train_scores, dim=1, descending=proponents)[:, 0:k]
idx, _train_scores = tracin.influence(
test_samples,
test_labels,
_format_batch_into_tuple(test_samples, test_labels, unpack_inputs),
k=k,
proponents=proponents,
unpack_inputs=unpack_inputs,
)
for i in range(len(idx)):
# check that idx[i] is correct
Expand Down
28 changes: 14 additions & 14 deletions tests/influence/_core/test_tracin_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,19 +183,19 @@ def test_tracin_regression(
criterion,
)

train_scores = tracin.influence(train_inputs, train_labels)
train_scores = tracin.influence((train_inputs, train_labels))
idx, _ = tracin.influence(
train_inputs, train_labels, k=len(dataset), proponents=True
(train_inputs, train_labels), k=len(dataset), proponents=True
)
# check that top influence is one with maximal value
# (and hence gradient)
for i in range(len(idx)):
self.assertEqual(idx[i][0], 15)

# check influence scores of test data
test_scores = tracin.influence(test_inputs, test_labels)
test_scores = tracin.influence((test_inputs, test_labels))
idx, _ = tracin.influence(
test_inputs, test_labels, k=len(test_inputs), proponents=True
(test_inputs, test_labels), k=len(test_inputs), proponents=True
)
# check that top influence is one with maximal value
# (and hence gradient)
Expand Down Expand Up @@ -226,17 +226,17 @@ def test_tracin_regression(
sample_wise_grads_per_batch=True,
)

train_scores = tracin.influence(train_inputs, train_labels)
train_scores = tracin.influence((train_inputs, train_labels))
train_scores_sample_wise_trick = tracin_sample_wise_trick.influence(
train_inputs, train_labels
(train_inputs, train_labels)
)
assertTensorAlmostEqual(
self, train_scores, train_scores_sample_wise_trick
)

test_scores = tracin.influence(test_inputs, test_labels)
test_scores = tracin.influence((test_inputs, test_labels))
test_scores_sample_wise_trick = tracin_sample_wise_trick.influence(
test_inputs, test_labels
(test_inputs, test_labels)
)
assertTensorAlmostEqual(
self, test_scores, test_scores_sample_wise_trick
Expand Down Expand Up @@ -288,7 +288,7 @@ def test_tracin_regression_1D_numerical(
criterion,
)

train_scores = tracin.influence(train_inputs, train_labels, k=None)
train_scores = tracin.influence((train_inputs, train_labels), k=None)

r"""
Derivation for gradient / resulting TracIn score:
Expand Down Expand Up @@ -382,9 +382,9 @@ def test_tracin_identity_regression(

# check influence scores of training data

train_scores = tracin.influence(train_inputs, train_labels)
train_scores = tracin.influence((train_inputs, train_labels))
idx, _ = tracin.influence(
train_inputs, train_labels, k=len(dataset), proponents=True
(train_inputs, train_labels), k=len(dataset), proponents=True
)

# check that top influence for an instance is itself
Expand Down Expand Up @@ -415,9 +415,9 @@ def test_tracin_identity_regression(
sample_wise_grads_per_batch=True,
)

train_scores = tracin.influence(train_inputs, train_labels)
train_scores = tracin.influence((train_inputs, train_labels))
train_scores_tracin_sample_wise_trick = (
tracin_sample_wise_trick.influence(train_inputs, train_labels)
tracin_sample_wise_trick.influence((train_inputs, train_labels))
)
assertTensorAlmostEqual(
self, train_scores, train_scores_tracin_sample_wise_trick
Expand Down Expand Up @@ -496,5 +496,5 @@ def test_loss_fn(input, target):
)

# check influence scores of training data. they should all be 0
train_scores = tracin.influence(train_inputs, train_labels, k=None)
train_scores = tracin.influence((train_inputs, train_labels), k=None)
assertTensorAlmostEqual(self, train_scores, torch.zeros(train_scores.shape))
7 changes: 4 additions & 3 deletions tests/influence/_core/test_tracin_self_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from parameterized import parameterized
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from tests.influence._utils.common import (
_format_batch_into_tuple,
build_test_name_func,
DataInfluenceConstructor,
get_random_model_and_data,
Expand Down Expand Up @@ -108,10 +109,10 @@ def test_tracin_self_influence(
criterion,
)
train_scores = tracin.influence(
train_dataset.samples,
train_dataset.labels,
_format_batch_into_tuple(
train_dataset.samples, train_dataset.labels, unpack_inputs
),
k=None,
unpack_inputs=unpack_inputs,
)
# calculate self_tracin_scores
self_tracin_scores = tracin.self_influence(
Expand Down
9 changes: 3 additions & 6 deletions tests/influence/_core/test_tracin_show_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,7 @@ def test_tracin_show_progress(
elif mode == "influence":

tracin.influence(
test_samples,
test_labels,
(test_samples, test_labels),
k=None,
show_progress=True,
)
Expand All @@ -196,8 +195,7 @@ def test_tracin_show_progress(
elif mode == "k-most":

tracin.influence(
test_samples,
test_labels,
(test_samples, test_labels),
k=2,
proponents=True,
show_progress=True,
Expand All @@ -218,8 +216,7 @@ def test_tracin_show_progress(
mock_stderr.truncate(0)

tracin.influence(
test_samples,
test_labels,
(test_samples, test_labels),
k=2,
proponents=False,
show_progress=True,
Expand Down
2 changes: 1 addition & 1 deletion tests/influence/_core/test_tracin_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,4 @@ def test_tracin_require_inputs_dataset(
batch_size=1,
)
with self.assertRaisesRegex(AssertionError, "required."):
tracin.influence(None, test_labels, k=None, unpack_inputs=False)
tracin.influence(None, k=None)
6 changes: 3 additions & 3 deletions tests/influence/_core/test_tracin_xor.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def test_tracin_xor(
batch_size,
criterion,
)
test_scores = tracin.influence(testset, testlabels)
test_scores = tracin.influence((testset, testlabels))
idx = torch.argsort(test_scores, dim=1, descending=True)
# check that top 5 influences have matching binary classification
for i in range(len(idx)):
Expand Down Expand Up @@ -288,9 +288,9 @@ def test_tracin_xor(
criterion,
sample_wise_grads_per_batch=True,
)
test_scores = tracin.influence(testset, testlabels)
test_scores = tracin.influence((testset, testlabels))
test_scores_sample_wise_trick = tracin_sample_wise_trick.influence(
testset, testlabels
(testset, testlabels)
)
assertTensorAlmostEqual(
self, test_scores, test_scores_sample_wise_trick
Expand Down
12 changes: 11 additions & 1 deletion tests/influence/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import unittest
from functools import partial
from typing import Callable, Iterator, List, Optional, Union
from typing import Callable, Iterator, List, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand All @@ -14,6 +14,7 @@
)
from parameterized import parameterized
from parameterized.parameterized import param
from torch import Tensor
from torch.nn import Module
from torch.utils.data import DataLoader, Dataset

Expand Down Expand Up @@ -366,3 +367,12 @@ def build_test_name_func(args_to_skip: Optional[List[str]] = None):
"""

return partial(generate_test_name, args_to_skip=args_to_skip)


def _format_batch_into_tuple(
inputs: Union[Tuple, Tensor], targets: Tensor, unpack_inputs: bool
):
if unpack_inputs:
return (*inputs, targets)
else:
return (inputs, targets)
Loading