From b76e4d55bc15dbbf81286c7edc0e12da7c8b57d5 Mon Sep 17 00:00:00 2001 From: "Chen Chen (AI Infra)" Date: Thu, 14 Mar 2024 13:52:28 -0700 Subject: [PATCH] Simplify the _check_loss_fn() logic (#1243) Summary: The _check_loss_fn() has exact same logic when sample_wise_grads_per_batch is None and True cases. Thus simplify the logic. Differential Revision: D54883319 --- captum/influence/_utils/common.py | 24 +++++------------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/captum/influence/_utils/common.py b/captum/influence/_utils/common.py index 4300f0c0e1..c214ecbdf1 100644 --- a/captum/influence/_utils/common.py +++ b/captum/influence/_utils/common.py @@ -444,7 +444,7 @@ def _check_loss_fn( influence_instance: Union["TracInCPBase", "InfluenceFunctionBase"], loss_fn: Optional[Union[Module, Callable]], loss_fn_name: str, - sample_wise_grads_per_batch: Optional[bool] = None, + sample_wise_grads_per_batch: bool = True, ) -> str: """ This checks whether `loss_fn` satisfies the requirements assumed of all @@ -469,16 +469,13 @@ def _check_loss_fn( # attribute. if hasattr(loss_fn, "reduction"): reduction = loss_fn.reduction # type: ignore - if sample_wise_grads_per_batch is None: + if sample_wise_grads_per_batch: assert reduction in [ "sum", "mean", - ], 'reduction for `loss_fn` must be "sum" or "mean"' - reduction_type = str(reduction) - elif sample_wise_grads_per_batch: - assert reduction in ["sum", "mean"], ( + ], ( 'reduction for `loss_fn` must be "sum" or "mean" when ' - "`sample_wise_grads_per_batch` is True" + "`sample_wise_grads_per_batch` is True (i.e. the default value) " ) reduction_type = str(reduction) else: @@ -490,18 +487,7 @@ def _check_loss_fn( # if we are unable to access the reduction used by `loss_fn`, we warn # the user about the assumptions we are making regarding the reduction # used by `loss_fn` - if sample_wise_grads_per_batch is None: - warnings.warn( - f'Since `{loss_fn_name}` has no "reduction" attribute, the ' - f'implementation assumes that `{loss_fn_name}` is a "reduction" loss ' - "function that reduces the per-example losses by taking their *sum*. " - f"If `{loss_fn_name}` instead reduces the per-example losses by " - f"taking their mean, please set the reduction attribute of " - f'`{loss_fn_name}` to "mean", i.e. ' - f'`{loss_fn_name}.reduction = "mean"`.' - ) - reduction_type = "sum" - elif sample_wise_grads_per_batch: + if sample_wise_grads_per_batch: warnings.warn( f"Since `{loss_fn_name}`` has no 'reduction' attribute, and " "`sample_wise_grads_per_batch` is True, the implementation assumes "