diff --git a/tf_explain/core/integrated_gradients.py b/tf_explain/core/integrated_gradients.py index 9308b30..ade4408 100644 --- a/tf_explain/core/integrated_gradients.py +++ b/tf_explain/core/integrated_gradients.py @@ -42,7 +42,7 @@ def explain(self, validation_data, model, class_index, n_steps=10): ) grayscale_integrated_gradients = transform_to_normalized_grayscale( - tf.abs(integrated_gradients) + tf.abs(images*integrated_gradients) ).numpy() grid = grid_display(grayscale_integrated_gradients)