5
5
from itertools import chain
6
6
from typing import TYPE_CHECKING
7
7
8
- from .kast .inner import KApply , KInner , KRewrite , KVariable , Subst
8
+ from .kast .inner import KApply , KInner , KRewrite , KToken , KVariable , Subst , bottom_up
9
9
from .kast .kast import KAtt
10
10
from .kast .manip import (
11
+ abstract_term_safely ,
11
12
apply_existential_substitutions ,
12
13
count_vars ,
13
14
flatten_label ,
22
23
)
23
24
from .kast .outer import KClaim , KRule
24
25
from .prelude .k import GENERATED_TOP_CELL
25
- from .prelude .ml import is_top , mlAnd , mlImplies , mlTop
26
- from .utils import unique
26
+ from .prelude .kbool import orBool
27
+ from .prelude .ml import is_bottom , is_top , mlAnd , mlBottom , mlEqualsTrue , mlImplies , mlTop
28
+ from .utils import single , unique
27
29
28
30
if TYPE_CHECKING :
29
31
from collections .abc import Iterable , Iterator
30
32
from typing import Any
31
33
34
+ from .kast .outer import KDefinition
35
+
32
36
33
37
@dataclass (frozen = True , order = True )
34
38
class CTerm :
35
39
config : KInner # TODO Optional?
36
40
constraints : tuple [KInner , ...]
37
41
38
42
def __init__ (self , config : KInner , constraints : Iterable [KInner ] = ()) -> None :
39
- self ._check_config (config )
40
- constraints = self ._normalize_constraints (constraints )
43
+ if CTerm ._is_top (config ):
44
+ config = mlTop ()
45
+ constraints = ()
46
+ elif CTerm ._is_bottom (config ):
47
+ config = mlBottom ()
48
+ constraints = ()
49
+ else :
50
+ self ._check_config (config )
51
+ constraints = self ._normalize_constraints (constraints )
41
52
object .__setattr__ (self , 'config' , config )
42
53
object .__setattr__ (self , 'constraints' , constraints )
43
54
44
55
@staticmethod
45
56
def from_kast (kast : KInner ) -> CTerm :
46
- config , constraint = split_config_and_constraints (kast )
47
- constraints = flatten_label ('#And' , constraint )
48
- return CTerm (config , constraints )
57
+ if CTerm ._is_top (kast ):
58
+ return CTerm .top ()
59
+ elif CTerm ._is_bottom (kast ):
60
+ return CTerm .bottom ()
61
+ else :
62
+ config , constraint = split_config_and_constraints (kast )
63
+ constraints = flatten_label ('#And' , constraint )
64
+ return CTerm (config , constraints )
49
65
50
66
@staticmethod
51
67
def from_dict (dct : dict [str , Any ]) -> CTerm :
52
68
config = KInner .from_dict (dct ['config' ])
53
69
constraints = [KInner .from_dict (c ) for c in dct ['constraints' ]]
54
70
return CTerm (config , constraints )
55
71
72
+ @staticmethod
73
+ def top () -> CTerm :
74
+ return CTerm (mlTop (), ())
75
+
76
+ @staticmethod
77
+ def bottom () -> CTerm :
78
+ return CTerm (mlBottom (), ())
79
+
56
80
@staticmethod
57
81
def _check_config (config : KInner ) -> None :
58
82
if not isinstance (config , KApply ) or not config .is_cell :
59
- raise ValueError ('Expected cell label, found: {config.label.name }' )
83
+ raise ValueError (f 'Expected cell label, found: { config } ' )
60
84
61
85
@staticmethod
62
86
def _normalize_constraints (constraints : Iterable [KInner ]) -> tuple [KInner , ...]:
@@ -74,6 +98,20 @@ def _is_spurious_constraint(term: KInner) -> bool:
74
98
return True
75
99
return False
76
100
101
+ @staticmethod
102
+ def _is_top (kast : KInner ) -> bool :
103
+ flat = flatten_label ('#And' , kast )
104
+ if len (flat ) == 1 :
105
+ return is_top (single (flat ))
106
+ return all (CTerm ._is_top (term ) for term in flat )
107
+
108
+ @staticmethod
109
+ def _is_bottom (kast : KInner ) -> bool :
110
+ flat = flatten_label ('#And' , kast )
111
+ if len (flat ) == 1 :
112
+ return is_bottom (single (flat ))
113
+ return all (CTerm ._is_bottom (term ) for term in flat )
114
+
77
115
@staticmethod
78
116
def _constraint_sort_key (term : KInner ) -> tuple [int , str ]:
79
117
term_str = str (term )
@@ -104,6 +142,9 @@ def cells(self) -> Subst:
104
142
def cell (self , cell : str ) -> KInner :
105
143
return self .cells [cell ]
106
144
145
+ def try_cell (self , cell : str ) -> KInner | None :
146
+ return self .cells .get (cell )
147
+
107
148
def match (self , cterm : CTerm ) -> Subst | None :
108
149
csubst = self .match_with_constraint (cterm )
109
150
@@ -138,6 +179,59 @@ def _ml_impl(antecedents: Iterable[KInner], consequents: Iterable[KInner]) -> KI
138
179
def add_constraint (self , new_constraint : KInner ) -> CTerm :
139
180
return CTerm (self .config , [new_constraint ] + list (self .constraints ))
140
181
182
+ def anti_unify (
183
+ self , other : CTerm , keep_values : bool = False , kdef : KDefinition | None = None
184
+ ) -> tuple [CTerm , CSubst , CSubst ]:
185
+ def disjunction_from_substs (subst1 : Subst , subst2 : Subst ) -> KInner :
186
+ if KToken ('true' , 'Bool' ) in [subst1 .pred , subst2 .pred ]:
187
+ return mlTop ()
188
+ return mlEqualsTrue (orBool ([subst1 .pred , subst2 .pred ]))
189
+
190
+ new_config , self_subst , other_subst = anti_unify (self .config , other .config , kdef = kdef )
191
+ common_constraints = [constraint for constraint in self .constraints if constraint in other .constraints ]
192
+
193
+ new_cterm = CTerm (
194
+ config = new_config , constraints = ([disjunction_from_substs (self_subst , other_subst )] if keep_values else [])
195
+ )
196
+
197
+ new_constraints = []
198
+ fvs = free_vars (new_cterm .kast )
199
+ len_fvs = 0
200
+ while len_fvs < len (fvs ):
201
+ len_fvs = len (fvs )
202
+ for constraint in common_constraints :
203
+ if constraint not in new_constraints :
204
+ constraint_fvs = free_vars (constraint )
205
+ if any (fv in fvs for fv in constraint_fvs ):
206
+ new_constraints .append (constraint )
207
+ fvs .extend (constraint_fvs )
208
+
209
+ for constraint in new_constraints :
210
+ new_cterm = new_cterm .add_constraint (constraint )
211
+ self_csubst = new_cterm .match_with_constraint (self )
212
+ other_csubst = new_cterm .match_with_constraint (other )
213
+ if self_csubst is None or other_csubst is None :
214
+ raise ValueError (
215
+ f'Anti-unification failed to produce a more general state: { (new_cterm , (self , self_csubst ), (other , other_csubst ))} '
216
+ )
217
+ return (new_cterm , self_csubst , other_csubst )
218
+
219
+
220
+ def anti_unify (state1 : KInner , state2 : KInner , kdef : KDefinition | None = None ) -> tuple [KInner , Subst , Subst ]:
221
+ def _rewrites_to_abstractions (_kast : KInner ) -> KInner :
222
+ if type (_kast ) is KRewrite :
223
+ sort = kdef .sort (_kast ) if kdef else None
224
+ return abstract_term_safely (_kast , sort = sort )
225
+ return _kast
226
+
227
+ minimized_rewrite = push_down_rewrites (KRewrite (state1 , state2 ))
228
+ abstracted_state = bottom_up (_rewrites_to_abstractions , minimized_rewrite )
229
+ subst1 = abstracted_state .match (state1 )
230
+ subst2 = abstracted_state .match (state2 )
231
+ if subst1 is None or subst2 is None :
232
+ raise ValueError ('Anti-unification failed to produce a more general state!' )
233
+ return (abstracted_state , subst1 , subst2 )
234
+
141
235
142
236
@dataclass (frozen = True , order = True )
143
237
class CSubst :
0 commit comments