From 5bc444bb9611c08e97e1d1bb64ca75840ccc935f Mon Sep 17 00:00:00 2001 From: Marcio Porto Date: Mon, 18 Mar 2024 14:04:45 -0700 Subject: [PATCH] Move curent_mask to perturbed tensor device (#1245) Summary: Currently `FeaturePermutation` and `FeatureAblation` both throw a device mismatch issue in https://fburl.com/code/9mfuidf4 because the `current_mask` is always created on CPU and never moved to the same device as `expanded_input` when CUDA is available. Reviewed By: cyrjano, vivekmig Differential Revision: D54969675 --- captum/attr/_core/feature_ablation.py | 1 + captum/attr/_core/feature_permutation.py | 1 + 2 files changed, 2 insertions(+) diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index 725b4fa9f3..e4e9719a4b 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -559,6 +559,7 @@ def _construct_ablated_input( current_mask = torch.stack( [input_mask == j for j in range(start_feature, end_feature)], dim=0 ).long() + current_mask = current_mask.to(expanded_input.device) ablated_tensor = ( expanded_input * (1 - current_mask).to(expanded_input.dtype) ) + (baseline * current_mask.to(expanded_input.dtype)) diff --git a/captum/attr/_core/feature_permutation.py b/captum/attr/_core/feature_permutation.py index 557c1787e9..ba23ad4ec6 100644 --- a/captum/attr/_core/feature_permutation.py +++ b/captum/attr/_core/feature_permutation.py @@ -301,6 +301,7 @@ def _construct_ablated_input( current_mask = torch.stack( [input_mask == j for j in range(start_feature, end_feature)], dim=0 ).bool() + current_mask = current_mask.to(expanded_input.device) output = torch.stack( [