diff --git a/captum/_utils/common.py b/captum/_utils/common.py index bba0ea293b..e5373e9c83 100644 --- a/captum/_utils/common.py +++ b/captum/_utils/common.py @@ -174,6 +174,8 @@ def _format_tensor_into_tuples( ) -> Union[None, Tuple[Tensor, ...]]: if inputs is None: return None + if isinstance(inputs, list): + inputs = tuple(inputs) if not isinstance(inputs, tuple): assert isinstance( inputs, torch.Tensor