@@ -206,49 +206,28 @@ void RewriteSystem::processMergedAssociatedTypes() {
206206
207207 unsigned i = 0 ;
208208
209- // Chase the end of the vector; calls to RewriteSystem::addRule()
210- // can theoretically add new elements below.
209+ // Chase the end of the vector, since addRule() might add new elements below.
211210 while (i < MergedAssociatedTypes.size ()) {
212- auto pair = MergedAssociatedTypes[i++];
213- const auto &lhs = pair.first ;
214- const auto &rhs = pair.second ;
211+ // Copy the entry out, since addRule() might add new elements below.
212+ auto entry = MergedAssociatedTypes[i++];
215213
216- // If we have X.[P2:T] => Y.[P1:T], add a new pair of rules:
217- // X.[P1:T] => X.[P1&P2:T]
218- // X.[P2:T] => X.[P1&P2:T]
219- if (Debug.contains (DebugFlags::Merge)) {
220- llvm::dbgs () << " ## Processing associated type merge candidate " ;
221- llvm::dbgs () << lhs << " => " << rhs << " \n " ;
222- }
223-
224- auto mergedSymbol = Context.mergeAssociatedTypes (lhs.back (), rhs.back (),
225- Protos);
226214 if (Debug.contains (DebugFlags::Merge)) {
227- llvm::dbgs () << " ### Merged symbol " << mergedSymbol << " \n " ;
215+ llvm::dbgs () << " ## Processing associated type merge with " ;
216+ llvm::dbgs () << entry.rhs << " , " ;
217+ llvm::dbgs () << entry.lhsSymbol << " , " ;
218+ llvm::dbgs () << entry.mergedSymbol << " \n " ;
228219 }
229220
230- // We must have mergedSymbol <= rhs < lhs, therefore mergedSymbol != lhs.
231- assert (lhs.back () != mergedSymbol &&
232- " Left hand side should not already end with merged symbol?" );
233- assert (mergedSymbol.compare (rhs.back (), Protos) <= 0 );
234- assert (rhs.back ().compare (lhs.back (), Protos) < 0 );
235-
236- // If the merge didn't actually produce a new symbol, there is nothing else
237- // to do.
238- if (rhs.back () == mergedSymbol) {
239- if (Debug.contains (DebugFlags::Merge)) {
240- llvm::dbgs () << " ### Skipping\n " ;
241- }
242-
243- continue ;
244- }
221+ // If we have X.[P2:T] => Y.[P1:T], add a new rule:
222+ // X.[P1:T] => X.[P1&P2:T]
223+ MutableTerm lhs (entry.rhs );
245224
246225 // Build the term X.[P1&P2:T].
247- MutableTerm mergedTerm = lhs ;
248- mergedTerm .back () = mergedSymbol;
226+ MutableTerm rhs (entry. rhs ) ;
227+ rhs .back () = entry. mergedSymbol ;
249228
250229 // Add the rule X.[P1:T] => X.[P1&P2:T].
251- addRule (rhs, mergedTerm );
230+ addRule (lhs, rhs );
252231
253232 // Collect new rules here so that we're not adding rules while traversing
254233 // the trie.
@@ -260,8 +239,8 @@ void RewriteSystem::processMergedAssociatedTypes() {
260239 const auto &otherLHS = otherRule.getLHS ();
261240 if (otherLHS.size () == 2 &&
262241 otherLHS[1 ].getKind () == Symbol::Kind::Protocol) {
263- if (otherLHS[0 ] == lhs. back () ||
264- otherLHS[0 ] == rhs.back ()) {
242+ if (otherLHS[0 ] == entry. lhsSymbol ||
243+ otherLHS[0 ] == entry. rhs .back ()) {
265244 // We have a rule of the form
266245 //
267246 // [P1:T].[Q] => [P1:T]
@@ -280,11 +259,11 @@ void RewriteSystem::processMergedAssociatedTypes() {
280259 // [P1&P2:T].[Q] => [P1&P2:T]
281260 //
282261 MutableTerm newLHS;
283- newLHS.add (mergedSymbol);
262+ newLHS.add (entry. mergedSymbol );
284263 newLHS.add (otherLHS[1 ]);
285264
286265 MutableTerm newRHS;
287- newRHS.add (mergedSymbol);
266+ newRHS.add (entry. mergedSymbol );
288267
289268 inducedRules.emplace_back (newLHS, newRHS);
290269 }
@@ -294,8 +273,8 @@ void RewriteSystem::processMergedAssociatedTypes() {
294273 // Visit rhs first to preserve the ordering of protocol requirements in the
295274 // the property map. This is just for aesthetic purposes in the debug dump,
296275 // it doesn't change behavior.
297- Trie.findAll (rhs.back (), visitRule);
298- Trie.findAll (lhs. back () , visitRule);
276+ Trie.findAll (entry. rhs .back (), visitRule);
277+ Trie.findAll (entry. lhsSymbol , visitRule);
299278
300279 // Now add the new rules.
301280 for (const auto &pair : inducedRules)
@@ -305,10 +284,58 @@ void RewriteSystem::processMergedAssociatedTypes() {
305284 MergedAssociatedTypes.clear ();
306285}
307286
287+ // / Check if we have a rule of the form
288+ // /
289+ // / X.[P1:T] => X.[P2:T]
290+ // /
291+ // / If so, record this rule for later. We'll try to merge the associated
292+ // / types in RewriteSystem::processMergedAssociatedTypes().
293+ void RewriteSystem::checkMergedAssociatedType (Term lhs, Term rhs) {
294+ if (lhs.size () == rhs.size () &&
295+ std::equal (lhs.begin (), lhs.end () - 1 , rhs.begin ()) &&
296+ lhs.back ().getKind () == Symbol::Kind::AssociatedType &&
297+ rhs.back ().getKind () == Symbol::Kind::AssociatedType &&
298+ lhs.back ().getName () == rhs.back ().getName ()) {
299+ if (Debug.contains (DebugFlags::Merge)) {
300+ llvm::dbgs () << " ## Associated type merge candidate " ;
301+ llvm::dbgs () << lhs << " => " << rhs << " \n\n " ;
302+ }
303+
304+ auto mergedSymbol = Context.mergeAssociatedTypes (lhs.back (), rhs.back (),
305+ Protos);
306+ if (Debug.contains (DebugFlags::Merge)) {
307+ llvm::dbgs () << " ### Merged symbol " << mergedSymbol << " \n " ;
308+ }
309+
310+ // We must have mergedSymbol <= rhs < lhs, therefore mergedSymbol != lhs.
311+ assert (lhs.back () != mergedSymbol &&
312+ " Left hand side should not already end with merged symbol?" );
313+ assert (mergedSymbol.compare (rhs.back (), Protos) <= 0 );
314+ assert (rhs.back ().compare (lhs.back (), Protos) < 0 );
315+
316+ // If the merge didn't actually produce a new symbol, there is nothing else
317+ // to do.
318+ if (rhs.back () == mergedSymbol) {
319+ if (Debug.contains (DebugFlags::Merge)) {
320+ llvm::dbgs () << " ### Skipping\n " ;
321+ }
322+
323+ return ;
324+ }
325+
326+ MergedAssociatedTypes.push_back ({rhs, lhs.back (), mergedSymbol});
327+ }
328+ }
329+
308330// / Compute a critical pair from the left hand sides of two rewrite rules,
309331// / where \p rhs begins at \p from, which must be an iterator pointing
310332// / into \p lhs.
311333// /
334+ // / The resulting pair is pushed onto \p result only if it is non-trivial,
335+ // / that is, the left hand side and right hand side are not equal.
336+ // /
337+ // / Returns true if the pair was non-trivial, false if it was trivial.
338+ // /
312339// / There are two cases:
313340// /
314341// / 1) lhs == TUV -> X, rhs == U -> Y. The overlapped term is TUV;
@@ -336,9 +363,11 @@ void RewriteSystem::processMergedAssociatedTypes() {
336363// / concrete substitution 'X' to get 'A.X'; the new concrete term
337364// / is now rooted at the same level as A.B in the rewrite system,
338365// / not just B.
339- std::pair<MutableTerm, MutableTerm>
366+ bool
340367RewriteSystem::computeCriticalPair (ArrayRef<Symbol>::const_iterator from,
341- const Rule &lhs, const Rule &rhs) const {
368+ const Rule &lhs, const Rule &rhs,
369+ std::vector<std::pair<MutableTerm,
370+ MutableTerm>> &result) const {
342371 auto end = lhs.getLHS ().end ();
343372 if (from + rhs.getLHS ().size () < end) {
344373 // lhs == TUV -> X, rhs == U -> Y.
@@ -352,7 +381,14 @@ RewriteSystem::computeCriticalPair(ArrayRef<Symbol>::const_iterator from,
352381 MutableTerm t (lhs.getLHS ().begin (), from);
353382 t.append (rhs.getRHS ());
354383 t.append (from + rhs.getLHS ().size (), lhs.getLHS ().end ());
355- return std::make_pair (MutableTerm (lhs.getRHS ()), t);
384+
385+ if (lhs.getRHS ().size () == t.size () &&
386+ std::equal (lhs.getRHS ().begin (), lhs.getRHS ().end (),
387+ t.begin ())) {
388+ return false ;
389+ }
390+
391+ result.emplace_back (MutableTerm (lhs.getRHS ()), t);
356392 } else {
357393 // lhs == TU -> X, rhs == UV -> Y.
358394
@@ -372,8 +408,13 @@ RewriteSystem::computeCriticalPair(ArrayRef<Symbol>::const_iterator from,
372408 // Compute the term TY.
373409 t.append (rhs.getRHS ());
374410
375- return std::make_pair (xv, t);
411+ if (xv == t)
412+ return false ;
413+
414+ result.emplace_back (xv, t);
376415 }
416+
417+ return true ;
377418}
378419
379420// / Computes the confluent completion using the Knuth-Bendix algorithm.
@@ -439,19 +480,26 @@ RewriteSystem::computeConfluentCompletion(unsigned maxIterations,
439480 }
440481
441482 // Try to repair the confluence violation by adding a new rule.
442- resolvedCriticalPairs.push_back (computeCriticalPair (from, lhs, rhs));
443-
444- if (Debug.contains (DebugFlags::Completion)) {
445- const auto &pair = resolvedCriticalPairs.back ();
446-
447- llvm::dbgs () << " $ Overlapping rules: (#" << i << " ) " ;
448- llvm::dbgs () << lhs << " \n " ;
449- llvm::dbgs () << " -vs- (#" << j << " ) " ;
450- llvm::dbgs () << rhs << " :\n " ;
451- llvm::dbgs () << " $$ First term of critical pair is "
452- << pair.first << " \n " ;
453- llvm::dbgs () << " $$ Second term of critical pair is "
454- << pair.second << " \n\n " ;
483+ if (computeCriticalPair (from, lhs, rhs, resolvedCriticalPairs)) {
484+ if (Debug.contains (DebugFlags::Completion)) {
485+ const auto &pair = resolvedCriticalPairs.back ();
486+
487+ llvm::dbgs () << " $ Overlapping rules: (#" << i << " ) " ;
488+ llvm::dbgs () << lhs << " \n " ;
489+ llvm::dbgs () << " -vs- (#" << j << " ) " ;
490+ llvm::dbgs () << rhs << " :\n " ;
491+ llvm::dbgs () << " $$ First term of critical pair is "
492+ << pair.first << " \n " ;
493+ llvm::dbgs () << " $$ Second term of critical pair is "
494+ << pair.second << " \n\n " ;
495+ }
496+ } else {
497+ if (Debug.contains (DebugFlags::Completion)) {
498+ llvm::dbgs () << " $ Trivially overlapping rules: (#" << i << " ) " ;
499+ llvm::dbgs () << lhs << " \n " ;
500+ llvm::dbgs () << " -vs- (#" << j << " ) " ;
501+ llvm::dbgs () << rhs << " :\n " ;
502+ }
455503 }
456504 });
457505
0 commit comments