-
Notifications
You must be signed in to change notification settings - Fork 540
Optim-wip: Fix duplicated target bug #919
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The |
@@ -129,6 +129,9 @@ def loss_fn(module: ModuleOutputMapping) -> torch.Tensor: | |||
target = (self.target if isinstance(self.target, list) else [self.target]) + ( | |||
other.target if isinstance(other.target, list) else [other.target] | |||
) | |||
|
|||
# Filter out duplicate targets | |||
target = list(dict.fromkeys(target)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ProGamerGov, why would someone pass duplicated target here ? Shouldn't we set an assert here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NarineK There are a few reason why there can be duplicates here.
For example, optimization with transparency will be using NaturalImage
or a transform as the target for one or more alpha channel related objectives. If the user is working with a CLIP model, an L2 penalty objective will also be using one of the same targets.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using multiple different penalties on the same target will also create duplicates without that line. In Optimizing with Transparency Notebook, a duplicate would be created in the final section when both a blurring penalty and an l2 penalty are using the same target layer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ProGamerGov, are you saying that in this line the targets will be duplicated (example from Optimizing with Transparency Notebook) ?
loss_fn = loss_fn - MeanAlphaChannelPenalty(transforms[0])
loss_fn = loss_fn - (9 * BlurActivations(transforms[0], channel_index=3))
Is it because we are concatenating the other target to the current target here ?
target = (self.target if isinstance(self.target, list) else [self.target]) + (
other.target if isinstance(other.target, list) else [other.target]
)
I was thinking why are we concatenating the targets in the above line ?
It looks like we are not concatenating if self.target
is a list but otherwise we concatenate it with other.target
. I was wondering if we could elaborate this logic a bit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NarineK Yes, that those lines will result in a duplicated target because we concatenate the target lists for every operation involving multiple loss objectives. I'll about doing a more detailed write-up of how it works in another PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ProGamerGov, but we don't concatenate them if self.target
is a list ? You can perhaps rework this PR since it is very small or we need to document that the logic requires refinement in this PR before we merge it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NarineK It seems like it could be a bit complicated to change at the moment. Most loss objectives store a self.target
value that is then called to collect the target activations:
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
activations = targets_to_values[self.target]
return activations
The self.target
value them becomes a list when combined with another loss objective in a resulting CompositeLoss
instance. The two original objectives can still call and use their own self.target
value. The InputOptimization
module also uses a list of targets, but it does not overwrite the original self.target
to make it's list.
Currently a hook is created in
ModuleOutputsHook
for ever instance of a target in the target list. Each captured set of activations for the same hook overwrites the previous set of activations, potentially leading to negative performance impacts as only one set of activations for each target is returned. This bug also causes the warning messages inModuleOutputsHook
to repeat every iteration.This PR solves the issue by ensuring that duplicate target values are removed.