Skip to content

Optim-wip: Fix duplicated target bug #919

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions captum/optim/_core/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,14 @@ def loss_fn(module: ModuleOutputMapping) -> torch.Tensor:
return math_op(torch.mean(self(module)), torch.mean(other(module)))

name = f"Compose({', '.join([self.__name__, other.__name__])})"

# ToDo: Refine logic for self.target handling
target = (self.target if isinstance(self.target, list) else [self.target]) + (
other.target if isinstance(other.target, list) else [other.target]
)

# Filter out duplicate targets
target = list(dict.fromkeys(target))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProGamerGov, why would someone pass duplicated target here ? Shouldn't we set an assert here ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NarineK There are a few reason why there can be duplicates here.

For example, optimization with transparency will be using NaturalImage or a transform as the target for one or more alpha channel related objectives. If the user is working with a CLIP model, an L2 penalty objective will also be using one of the same targets.

Copy link
Contributor Author

@ProGamerGov ProGamerGov May 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using multiple different penalties on the same target will also create duplicates without that line. In Optimizing with Transparency Notebook, a duplicate would be created in the final section when both a blurring penalty and an l2 penalty are using the same target layer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProGamerGov, are you saying that in this line the targets will be duplicated (example from Optimizing with Transparency Notebook) ?

loss_fn = loss_fn - MeanAlphaChannelPenalty(transforms[0])
loss_fn = loss_fn - (9 * BlurActivations(transforms[0], channel_index=3))

Is it because we are concatenating the other target to the current target here ?

        target = (self.target if isinstance(self.target, list) else [self.target]) + (
            other.target if isinstance(other.target, list) else [other.target]
        )

I was thinking why are we concatenating the targets in the above line ?
It looks like we are not concatenating if self.target is a list but otherwise we concatenate it with other.target . I was wondering if we could elaborate this logic a bit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NarineK Yes, that those lines will result in a duplicated target because we concatenate the target lists for every operation involving multiple loss objectives. I'll about doing a more detailed write-up of how it works in another PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProGamerGov, but we don't concatenate them if self.target is a list ? You can perhaps rework this PR since it is very small or we need to document that the logic requires refinement in this PR before we merge it.

Copy link
Contributor Author

@ProGamerGov ProGamerGov May 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NarineK It seems like it could be a bit complicated to change at the moment. Most loss objectives store a self.target value that is then called to collect the target activations:

    def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
        activations = targets_to_values[self.target]
        return activations

The self.target value them becomes a list when combined with another loss objective in a resulting CompositeLoss instance. The two original objectives can still call and use their own self.target value. The InputOptimization module also uses a list of targets, but it does not overwrite the original self.target to make it's list.

else:
raise TypeError(
"Can only apply math operations with int, float or Loss. Received type "
Expand Down Expand Up @@ -720,6 +725,9 @@ def loss_fn(module: ModuleOutputMapping) -> torch.Tensor:
]
for target in targets
]

# Filter out duplicate targets
target = list(dict.fromkeys(target))
return CompositeLoss(loss_fn, name=name, target=target)


Expand Down
1 change: 1 addition & 0 deletions captum/optim/models/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def collect_activations(
"""
if not isinstance(targets, list):
targets = [targets]
targets = list(dict.fromkeys(targets))
catch_activ = ActivationFetcher(model, targets)
activ_out = catch_activ(model_input)
return activ_out
Expand Down