File tree Expand file tree Collapse file tree 2 files changed +9
-0
lines changed Expand file tree Collapse file tree 2 files changed +9
-0
lines changed Original file line number Diff line number Diff line change @@ -126,9 +126,14 @@ def loss_fn(module: ModuleOutputMapping) -> torch.Tensor:
126126 return math_op (torch .mean (self (module )), torch .mean (other (module )))
127127
128128 name = f"Compose({ ', ' .join ([self .__name__ , other .__name__ ])} )"
129+
130+ # ToDo: Refine logic for self.target handling
129131 target = (self .target if isinstance (self .target , list ) else [self .target ]) + (
130132 other .target if isinstance (other .target , list ) else [other .target ]
131133 )
134+
135+ # Filter out duplicate targets
136+ target = list (dict .fromkeys (target ))
132137 else :
133138 raise TypeError (
134139 "Can only apply math operations with int, float or Loss. Received type "
@@ -875,6 +880,9 @@ def loss_fn(module: ModuleOutputMapping) -> torch.Tensor:
875880 ]
876881 for target in targets
877882 ]
883+
884+ # Filter out duplicate targets
885+ target = list (dict .fromkeys (target ))
878886 return CompositeLoss (loss_fn , name = name , target = target )
879887
880888
Original file line number Diff line number Diff line change @@ -255,6 +255,7 @@ def collect_activations(
255255 """
256256 if not isinstance (targets , list ):
257257 targets = [targets ]
258+ targets = list (dict .fromkeys (targets ))
258259 catch_activ = ActivationFetcher (model , targets )
259260 activ_dict = catch_activ (model_input )
260261 return activ_dict
You can’t perform that action at this time.
0 commit comments