diff --git a/cpmpy/solvers/exact.py b/cpmpy/solvers/exact.py index 96e4441ab..cb974dda0 100644 --- a/cpmpy/solvers/exact.py +++ b/cpmpy/solvers/exact.py @@ -203,6 +203,7 @@ def solve(self, time_limit=None, assumptions=None, **kwargs): # set assumptions if assumptions is not None: + assumptions = flatlist(assumptions) assert all(v.is_bool() for v in assumptions), "Non-Boolean assumptions given to Exact: " + str([v for v in assumptions if not v.is_bool()]) assump_vals = [int(not isinstance(v, NegBoolView)) for v in assumptions] assump_vars = [self.solver_var(v._bv if isinstance(v, NegBoolView) else v) for v in assumptions] diff --git a/cpmpy/solvers/ortools.py b/cpmpy/solvers/ortools.py index 2029110f5..920402d38 100644 --- a/cpmpy/solvers/ortools.py +++ b/cpmpy/solvers/ortools.py @@ -193,7 +193,7 @@ def solve(self, time_limit=None, assumptions=None, solution_callback=None, **kwa self.ort_solver.parameters.max_time_in_seconds = float(time_limit) if assumptions is not None: - ort_assum_vars = self.solver_vars(assumptions) + ort_assum_vars = self.solver_vars(flatlist(assumptions)) # dict mapping ortools vars to CPMpy vars self.assumption_dict = {ort_var.Index(): cpm_var for (cpm_var, ort_var) in zip(assumptions, ort_assum_vars)} self.ort_model.ClearAssumptions() # because add just appends diff --git a/cpmpy/solvers/pumpkin.py b/cpmpy/solvers/pumpkin.py index 1cbbb7d96..57d2bf306 100644 --- a/cpmpy/solvers/pumpkin.py +++ b/cpmpy/solvers/pumpkin.py @@ -47,7 +47,7 @@ from ..expressions.core import Expression, Comparison, Operator, BoolVal from ..expressions.globalconstraints import GlobalConstraint from ..expressions.variables import _BoolVarImpl, NegBoolView, _IntVarImpl, _NumVarImpl, intvar, boolvar -from ..expressions.utils import is_num, is_any_list, get_bounds +from ..expressions.utils import flatlist, is_num, is_any_list, get_bounds from ..transformations.get_variables import get_variables from ..transformations.linearize import canonical_comparison from ..transformations.normalize import toplevel_list @@ -151,6 +151,7 @@ def solve(self, time_limit=None, prove=False, proof_name="proof.drcp", proof_loc elif assumptions is not None: assert not prove, "Proof-logging under assumptions is not supported" + assumptions = flatlist(assumptions) pum_assumptions = [self.to_predicate(a) for a in assumptions] self.assump_map = dict(zip(pum_assumptions, assumptions)) solve_func = self.pum_solver.satisfy_under_assumptions diff --git a/cpmpy/solvers/pysat.py b/cpmpy/solvers/pysat.py index 72c920e00..e5b681152 100644 --- a/cpmpy/solvers/pysat.py +++ b/cpmpy/solvers/pysat.py @@ -251,6 +251,7 @@ def solve(self, time_limit=None, assumptions=None): if assumptions is None: pysat_assum_vars = [] # default if no assumptions else: + assumptions = flatlist(assumptions) pysat_assum_vars = self.solver_vars(assumptions) self.assumption_vars = assumptions diff --git a/cpmpy/solvers/z3.py b/cpmpy/solvers/z3.py index 26b84622f..9c40eee7a 100644 --- a/cpmpy/solvers/z3.py +++ b/cpmpy/solvers/z3.py @@ -55,7 +55,7 @@ from ..expressions.globalconstraints import GlobalConstraint, DirectConstraint from ..expressions.globalfunctions import GlobalFunction from ..expressions.variables import _BoolVarImpl, NegBoolView, _NumVarImpl, _IntVarImpl, intvar -from ..expressions.utils import is_num, is_any_list, is_bool, is_int, is_boolexpr, eval_comparison +from ..expressions.utils import flatlist, is_num, is_any_list, is_bool, is_int, is_boolexpr, eval_comparison from ..transformations.decompose_global import decompose_in_tree from ..transformations.normalize import toplevel_list from ..transformations.safening import no_partial_functions @@ -139,7 +139,7 @@ def native_model(self): return self.z3_solver - def solve(self, time_limit=None, assumptions=[], **kwargs): + def solve(self, time_limit=None, assumptions=None, **kwargs): """ Call the z3 solver @@ -184,8 +184,12 @@ def solve(self, time_limit=None, assumptions=[], **kwargs): self.z3_solver.set(timeout=int(time_limit*1000)) - z3_assum_vars = self.solver_vars(assumptions) - self.assumption_dict = {z3_var : cpm_var for (cpm_var, z3_var) in zip(assumptions, z3_assum_vars)} + if assumptions is not None: + assumptions = flatlist(assumptions) + z3_assum_vars = self.solver_vars(assumptions) + self.assumption_dict = {z3_var : cpm_var for (cpm_var, z3_var) in zip(assumptions, z3_assum_vars)} + else: + z3_assum_vars = [] # call the solver, with parameters diff --git a/tests/test_solvers.py b/tests/test_solvers.py index fac298f39..88f5d20c8 100644 --- a/tests/test_solvers.py +++ b/tests/test_solvers.py @@ -854,6 +854,9 @@ def test_incremental_assumptions(self, solver): assert s.solve(assumptions=[]) + # better for user experience: allow to use set of assumptions too + assert s.solve(assumptions={x,y}) + def test_vars_not_removed(self, solver): bvs = cp.boolvar(shape=3) m = cp.Model([cp.any(bvs) <= 2])