Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpmpy/solvers/exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def solve(self, time_limit=None, assumptions=None, **kwargs):

# set assumptions
if assumptions is not None:
assumptions = flatlist(assumptions)
assert all(v.is_bool() for v in assumptions), "Non-Boolean assumptions given to Exact: " + str([v for v in assumptions if not v.is_bool()])
assump_vals = [int(not isinstance(v, NegBoolView)) for v in assumptions]
assump_vars = [self.solver_var(v._bv if isinstance(v, NegBoolView) else v) for v in assumptions]
Expand Down
2 changes: 1 addition & 1 deletion cpmpy/solvers/ortools.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def solve(self, time_limit=None, assumptions=None, solution_callback=None, **kwa
self.ort_solver.parameters.max_time_in_seconds = float(time_limit)

if assumptions is not None:
ort_assum_vars = self.solver_vars(assumptions)
ort_assum_vars = self.solver_vars(flatlist(assumptions))
# dict mapping ortools vars to CPMpy vars
self.assumption_dict = {ort_var.Index(): cpm_var for (cpm_var, ort_var) in zip(assumptions, ort_assum_vars)}
self.ort_model.ClearAssumptions() # because add just appends
Expand Down
3 changes: 2 additions & 1 deletion cpmpy/solvers/pumpkin.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from ..expressions.core import Expression, Comparison, Operator, BoolVal
from ..expressions.globalconstraints import GlobalConstraint
from ..expressions.variables import _BoolVarImpl, NegBoolView, _IntVarImpl, _NumVarImpl, intvar, boolvar
from ..expressions.utils import is_num, is_any_list, get_bounds
from ..expressions.utils import flatlist, is_num, is_any_list, get_bounds
from ..transformations.get_variables import get_variables
from ..transformations.linearize import canonical_comparison
from ..transformations.normalize import toplevel_list
Expand Down Expand Up @@ -151,6 +151,7 @@ def solve(self, time_limit=None, prove=False, proof_name="proof.drcp", proof_loc

elif assumptions is not None:
assert not prove, "Proof-logging under assumptions is not supported"
assumptions = flatlist(assumptions)
pum_assumptions = [self.to_predicate(a) for a in assumptions]
self.assump_map = dict(zip(pum_assumptions, assumptions))
solve_func = self.pum_solver.satisfy_under_assumptions
Expand Down
1 change: 1 addition & 0 deletions cpmpy/solvers/pysat.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def solve(self, time_limit=None, assumptions=None):
if assumptions is None:
pysat_assum_vars = [] # default if no assumptions
else:
assumptions = flatlist(assumptions)
pysat_assum_vars = self.solver_vars(assumptions)
self.assumption_vars = assumptions

Expand Down
12 changes: 8 additions & 4 deletions cpmpy/solvers/z3.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from ..expressions.globalconstraints import GlobalConstraint, DirectConstraint
from ..expressions.globalfunctions import GlobalFunction
from ..expressions.variables import _BoolVarImpl, NegBoolView, _NumVarImpl, _IntVarImpl, intvar
from ..expressions.utils import is_num, is_any_list, is_bool, is_int, is_boolexpr, eval_comparison
from ..expressions.utils import flatlist, is_num, is_any_list, is_bool, is_int, is_boolexpr, eval_comparison
from ..transformations.decompose_global import decompose_in_tree
from ..transformations.normalize import toplevel_list
from ..transformations.safening import no_partial_functions
Expand Down Expand Up @@ -139,7 +139,7 @@ def native_model(self):
return self.z3_solver


def solve(self, time_limit=None, assumptions=[], **kwargs):
def solve(self, time_limit=None, assumptions=None, **kwargs):
"""
Call the z3 solver

Expand Down Expand Up @@ -184,8 +184,12 @@ def solve(self, time_limit=None, assumptions=[], **kwargs):
self.z3_solver.set(timeout=int(time_limit*1000))


z3_assum_vars = self.solver_vars(assumptions)
self.assumption_dict = {z3_var : cpm_var for (cpm_var, z3_var) in zip(assumptions, z3_assum_vars)}
if assumptions is not None:
assumptions = flatlist(assumptions)
z3_assum_vars = self.solver_vars(assumptions)
self.assumption_dict = {z3_var : cpm_var for (cpm_var, z3_var) in zip(assumptions, z3_assum_vars)}
else:
z3_assum_vars = []


# call the solver, with parameters
Expand Down
3 changes: 3 additions & 0 deletions tests/test_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,9 @@ def test_incremental_assumptions(self, solver):

assert s.solve(assumptions=[])

# better for user experience: allow to use set of assumptions too
assert s.solve(assumptions={x,y})

def test_vars_not_removed(self, solver):
bvs = cp.boolvar(shape=3)
m = cp.Model([cp.any(bvs) <= 2])
Expand Down
Loading