Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 88 additions & 47 deletions quantus/metrics/axiomatic/non_sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,9 @@ def custom_preprocess(
features_in_step=self.features_in_step,
input_shape=x_batch.shape[2:],
)




def evaluate_batch(
self,
Expand All @@ -292,65 +295,103 @@ def evaluate_batch(
**kwargs,
) -> List[int]:
"""
This method performs XAI evaluation on a single batch of explanations.
For more information on the specific logic, we refer the metric’s initialisation docstring.

Parameters
----------
model: ModelInterface
A ModelInterface that is subject to explanation.
x_batch: np.ndarray
The input to be evaluated on a batch-basis.
y_batch: np.ndarray
The output to be evaluated on a batch-basis.
a_batch: np.ndarray
The explanation to be evaluated on a batch-basis.
kwargs:
Unused.

Returns
-------
scores_batch:
The evaluation results.
"""

This method performs XAI evaluation on a single batch of explanations.
For more information on the specific logic, we refer the metric’s initialisation docstring.

Parameters
----------
model: ModelInterface
A ModelInterface that is subject to explanation.
x_batch: np.ndarray
The input to be evaluated on a batch-basis.
y_batch: np.ndarray
The output to be evaluated on a batch-basis.
a_batch: np.ndarray
The explanation to be evaluated on a batch-basis.
kwargs:
Unused.

Returns
-------
np.ndarray
Array of shape (batch_size,), per sample score.
Lower values are better.
"""

# Prepare shapes. Expand a_batch if not the same shape
if x_batch.shape != a_batch.shape:
a_batch = np.broadcast_to(a_batch, x_batch.shape)

# Flatten the attributions.
batch_size = a_batch.shape[0]
x_shape = x_batch.shape
x_batch = x_batch.reshape(batch_size, -1)
a_batch = a_batch.reshape(batch_size, -1)
n_features = a_batch.shape[-1]

non_features = a_batch < self.eps
features = ~non_features

x_input = model.shape_input(x_batch, x_batch.shape, channel_first=True, batched=True)
x_input = model.shape_input(x_batch, x_shape, channel_first=True, batched=True)
y_pred = model.predict(x_input)[np.arange(batch_size), y_batch]

# Prepare lists.
n_perturbations = math.ceil(n_features / self.features_in_step)
preds = []
x_perturbed = x_batch.copy()
x_batch_shape = x_batch.shape
a_indices = np.stack([np.arange(n_features) for _ in x_batch])
for perturbation_step_index in range(n_perturbations):
# Perturb input by indices of attributions.
a_ix = a_indices[
:,
perturbation_step_index * self.features_in_step : (perturbation_step_index + 1) * self.features_in_step,
]
x_perturbed = self.perturb_func(
arr=x_batch.reshape(batch_size, -1),
indices=a_ix,
)
x_perturbed = x_perturbed.reshape(*x_batch_shape)
pixel_scores_non = self._process_mask(
model, x_batch, y_batch, non_features, x_shape
)
pixel_scores_feat = self._process_mask(
model, x_batch, y_batch, features, x_shape
)

# Predict on perturbed input x.
x_input = model.shape_input(x_perturbed, x_batch.shape, channel_first=True, batched=True)
y_pred_perturb = model.predict(x_input)[np.arange(batch_size), y_batch]
preds.append(y_pred_perturb)
preds = np.stack(preds, axis=1)
preds_differences = abs(preds - y_pred[:, None]) < self.eps
preds_differences = np.abs(
y_pred[:, None] - (pixel_scores_non + pixel_scores_feat)
) < self.eps

return (preds_differences ^ non_features).sum(-1)

def _create_index_groups(self, mask: np.ndarray) -> List[np.ndarray]:
"""Divide mask indices into perturbation groups."""
indices = np.where(mask)[0]
return [
indices[i:i + self.features_in_step]
for i in range(0, len(indices), self.features_in_step)
]

def _perturb_sample_batch(
self, x_batch: np.ndarray, b: int, indices: np.ndarray
) -> np.ndarray:
"""Perturb a single sample within the batch."""
perturbed_flat = x_batch.copy()
indices_2d = np.expand_dims(indices, axis=0)
perturbed_flat[b] = self.perturb_func(
arr=perturbed_flat[b:b + 1, :],
indices=indices_2d,
)
return perturbed_flat

def _predict_scores(
self, model, x_batch: np.ndarray, y_batch: np.ndarray, x_shape: tuple
) -> np.ndarray:
"""Predict scores for the true labels of the given batch."""
x_input = model.shape_input(
x_batch.reshape(x_shape), x_shape, channel_first=True, batched=True
)
return model.predict(x_input)[np.arange(x_batch.shape[0]), y_batch]

def _process_mask(
self,
model,
x_batch: np.ndarray,
y_batch: np.ndarray,
mask: np.ndarray,
x_shape: tuple,
) -> np.ndarray:
"""Handle perturbation and prediction workflow for a single mask type."""
batch_size = x_batch.shape[0]
pixel_scores = np.zeros_like(x_batch, dtype=float)

for b in range(batch_size):
for indices in self._create_index_groups(mask[b]):
perturbed_batch = self._perturb_sample_batch(x_batch, b, indices)
preds = self._predict_scores(model, perturbed_batch, y_batch, x_shape)
pixel_scores[b, indices] = preds[b]

return pixel_scores
132 changes: 132 additions & 0 deletions tests/metrics/test_axiomatic_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,65 @@
import pytest
from pytest_lazyfixture import lazy_fixture
import numpy as np
import torch
import torch.nn as nn

from quantus.functions.explanation_func import explain
from quantus.metrics.axiomatic import Completeness, InputInvariance, NonSensitivity

# test_axiomatic_metrics.py (or similar)

def _ensure_4d(x):
"""Make sure x is (B, C, H, W), even if passed as (B, N)."""
x = np.array(x)
if x.ndim == 2:
B, N = x.shape
side = int(np.sqrt(N))
x = x.reshape(B, 1, side, side)
elif x.ndim == 3:
x = x[:, None, :, :]
return x

class SensitiveModel(nn.Module):
def shape_input(self, x, shape, channel_first=True, batched=True):
return x

def forward(self, x):
x = _ensure_4d(x)
return x.sum(axis=(1, 2, 3), keepdims=True)

def predict(self, x):
x = _ensure_4d(x)
return self.forward(x)

class InsensitiveModel(nn.Module):
def shape_input(self, x, shape, channel_first=True, batched=True):
return x
def forward(self, x):
x = _ensure_4d(x)
B = x.shape[0]
return np.ones((B, 1), dtype=float) * 100.0
def predict(self, x):
return self.forward(x)

class SemiSensitiveModel(nn.Module):
def shape_input(self, x, shape, channel_first=True, batched=True):
return x
def forward(self, x):
x = _ensure_4d(x)
top_sum = x[:, :, 0, :].sum(axis=(1, 2))
return top_sum[:, None]
def predict(self, x):
x = _ensure_4d(x)
return self.forward(x)

class TrickModel(nn.Module):
def shape_input(self, x, shape, channel_first=True, batched=True):
return x
def predict(self, x):
x = _ensure_4d(x)
bottom_sum = x[:, :, 1, :].sum(axis=(1, 2))
return bottom_sum[:, None]

@pytest.mark.axiomatic
@pytest.mark.parametrize(
Expand Down Expand Up @@ -231,6 +286,7 @@ def test_completeness(
"normalise": True,
"disable_warnings": False,
"display_progressbar": False,
"features_in_step": 2,
},
"call": {
"explain_func": explain,
Expand Down Expand Up @@ -417,6 +473,82 @@ def test_non_sensitivity(
)
assert scores is not None, "Test failed."

@pytest.mark.axiomatic
@pytest.mark.parametrize(
"scenario,model_factory,x_batch,y_batch,a_batch,expected_violations,kwargs",
[

(
"zero_violations",
lambda: SemiSensitiveModel(),
np.array([[[[1.0, 2.0], [3.0, 4.0]]]], dtype=float),
np.array([0]),
np.array([[[[10.0, 10.0], [0.0, 0.0]]]], dtype=float),
0,
{"features_in_step": 2, "eps": 1e-5},
),
(
"low_attr_high_change",
lambda: SensitiveModel(),
np.array([[[[5.0, 5.0], [5.0, 5.0]]]], dtype=float),
np.array([0]),
np.random.uniform(1e-6, 2e-6, size=(1, 1, 2, 2)),
4,
{"features_in_step": 2, "eps": 1e-5},
),
(
"high_attr_low_change",
lambda: InsensitiveModel(),
np.random.rand(1, 1, 4, 4),
np.array([0]),
np.ones((1, 1, 4, 4)),
16,
{"features_in_step": 2, "eps": 1e-5},
),
(
"half_good_half_bad",
lambda: TrickModel(),
np.array([[[[1.0, 2.0], [3.0, 4.0]]]], dtype=float),
np.array([0]),
np.array([[[[10.0, 10.0], [0.0, 0.0]]]], dtype=float),
4,
{"features_in_step": 1, "eps": 1e-5},
),
],
)
def test_my_non_sensitivity_logics(
scenario,
model_factory,
x_batch,
y_batch,
a_batch,
expected_violations,
kwargs,
):
"""
Parametrized logic-based tests for NonSensitivity.
Each scenario defines a different consistency pattern between attribution and model behavior.
"""
model = model_factory()
model.eval()
metric = NonSensitivity(
disable_warnings=True,
perturb_baseline="uniform",
normalise=False,
**kwargs,
)

scores = metric.evaluate_batch(model, x_batch, y_batch, a_batch)

# --- Assertions ---
assert isinstance(scores, np.ndarray), f"[{scenario}] Output must be np.ndarray"
assert scores.shape[0] == x_batch.shape[0], f"[{scenario}] Wrong batch size"
assert np.all(np.isfinite(scores)), f"[{scenario}] Scores contain NaN/Inf"

if expected_violations is not None:
assert scores[0] == expected_violations, (
f"[{scenario}] expected {expected_violations}, got {scores[0]}"
)

@pytest.mark.axiomatic
@pytest.mark.parametrize(
Expand Down