Skip to content

Conversation

@the-moliver
Copy link

Small generalization of your code to allow the functions to which constraints are applied to be able to take arguments. This allows for using multiple loss functions, constraining a model's output to be within a certain range, etc, etc.
As a proof of concept, you can append the following to the constraints in the MNIST example to constrain the output to sum to 50:
constraints.append(EqConstraint(torch.sum, 50, scale=scale, damping=damping))
and then change the mdmm_module call to:
mdmm_return = mdmm_module(loss, [None, None, None, outputs])
The first 3 Nones are for the three constraints that take no argument, while the last constraint we added will take outputs. That's all that's needed to make it work.
The change is backward compatible and no list is required if no constraint functions need arguments. Also functions that require multiple arguments, like a secondary loss, can have their arguments provided as a tuple or list within the argument list.

This allows for multiple cost functions on the output of a model
Allow functions to take arguments
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant