@@ -359,8 +359,6 @@ class PartitionOp {
359359 }
360360};
361361
362- struct PartitionOpEvaluator ;
363-
364362// / A map from Element -> Region that represents the current partition set.
365363// /
366364// /
@@ -369,7 +367,6 @@ class Partition {
369367 // / A class defined in PartitionUtils unittest used to grab state from
370368 // / Partition without exposing it to other users.
371369 struct PartitionTester ;
372- friend PartitionOpEvaluator;
373370
374371 using Element = PartitionPrimitives::Element;
375372 using Region = PartitionPrimitives::Region;
@@ -451,14 +448,16 @@ class Partition {
451448 return fst.elementToRegionMap == snd.elementToRegionMap ;
452449 }
453450
454- bool isTracked (Element val) const { return elementToRegionMap.count (val); }
451+ bool isTrackingElement (Element val) const {
452+ return elementToRegionMap.count (val);
453+ }
455454
456455 // / Mark val as transferred.
457456 void markTransferred (Element val,
458457 TransferringOperandSet *transferredOperandSet) {
459458 // First see if our val is tracked. If it is not tracked, insert it and mark
460459 // its new region as transferred.
461- if (!isTracked (val)) {
460+ if (!isTrackingElement (val)) {
462461 elementToRegionMap.insert_or_assign (val, fresh_label);
463462 regionToTransferredOpMap.insert ({fresh_label, transferredOperandSet});
464463 fresh_label = Region (fresh_label + 1 );
@@ -485,7 +484,7 @@ class Partition {
485484 // / we found that \p val was transferred. We return false otherwise.
486485 bool undoTransfer (Element val) {
487486 // First see if our val is tracked. If it is not tracked, insert it.
488- if (!isTracked (val)) {
487+ if (!isTrackingElement (val)) {
489488 elementToRegionMap.insert_or_assign (val, fresh_label);
490489 fresh_label = Region (fresh_label + 1 );
491490 canonical = false ;
@@ -499,7 +498,7 @@ class Partition {
499498 return regionToTransferredOpMap.erase (iter1->second );
500499 }
501500
502- void addElement (Element newElt) {
501+ void trackNewElement (Element newElt) {
503502 // Map index newElt to a fresh label.
504503 elementToRegionMap.insert_or_assign (newElt, fresh_label);
505504
@@ -508,6 +507,22 @@ class Partition {
508507 canonical = false ;
509508 }
510509
510+ void assignElement (Element oldElt, Element newElt) {
511+ elementToRegionMap.insert_or_assign (oldElt, elementToRegionMap.at (newElt));
512+ canonical = false ;
513+ }
514+
515+ bool areElementsInSameRegion (Element firstElt, Element secondElt) const {
516+ return elementToRegionMap.at (firstElt) == elementToRegionMap.at (secondElt);
517+ }
518+
519+ Region getRegion (Element elt) const { return elementToRegionMap.at (elt); }
520+
521+ using iterator = std::map<Element, Region>::iterator;
522+ iterator begin () { return elementToRegionMap.begin (); }
523+ iterator end () { return elementToRegionMap.end (); }
524+ llvm::iterator_range<iterator> range () { return {begin (), end ()}; }
525+
511526 // / Construct the partition corresponding to the union of the two passed
512527 // / partitions.
513528 // /
@@ -754,10 +769,12 @@ class Partition {
754769 return set;
755770 }
756771
757- private:
758772 // / Used only in assertions, check that Partitions promised to be canonical
759773 // / are actually canonical
760774 bool is_canonical_correct () {
775+ #ifdef NDEBUG
776+ return true ;
777+ #else
761778 if (!canonical)
762779 return true ; // vacuously correct
763780
@@ -796,8 +813,54 @@ class Partition {
796813 }
797814
798815 return true ;
816+ #endif
799817 }
800818
819+ // / Merge the regions of two indices while maintaining canonicality. Returns
820+ // / the final region used.
821+ // /
822+ // / This runs in linear time.
823+ Region merge (Element fst, Element snd) {
824+ assert (elementToRegionMap.count (fst) && elementToRegionMap.count (snd));
825+
826+ auto fstRegion = elementToRegionMap.at (fst);
827+ auto sndRegion = elementToRegionMap.at (snd);
828+
829+ if (fstRegion == sndRegion)
830+ return fstRegion;
831+
832+ // Maintain canonicality by renaming the greater-numbered region to the
833+ // smaller region.
834+ std::optional<Region> result;
835+ if (fstRegion < sndRegion) {
836+ result = fstRegion;
837+
838+ // Rename snd to use first region.
839+ horizontalUpdate (elementToRegionMap, snd, fstRegion);
840+ auto iter = regionToTransferredOpMap.find (sndRegion);
841+ if (iter != regionToTransferredOpMap.end ()) {
842+ auto operand = iter->second ;
843+ regionToTransferredOpMap.erase (iter);
844+ regionToTransferredOpMap.try_emplace (fstRegion, operand);
845+ }
846+ } else {
847+ result = sndRegion;
848+
849+ horizontalUpdate (elementToRegionMap, fst, sndRegion);
850+ auto iter = regionToTransferredOpMap.find (fstRegion);
851+ if (iter != regionToTransferredOpMap.end ()) {
852+ auto operand = iter->second ;
853+ regionToTransferredOpMap.erase (iter);
854+ regionToTransferredOpMap.try_emplace (sndRegion, operand);
855+ }
856+ }
857+
858+ assert (is_canonical_correct ());
859+ assert (elementToRegionMap.at (fst) == elementToRegionMap.at (snd));
860+ return *result;
861+ }
862+
863+ private:
801864 // / For each region label that occurs, find the first index at which it occurs
802865 // / and relabel all instances of it to that index. This excludes the -1 label
803866 // / for transferred regions.
@@ -845,51 +908,6 @@ class Partition {
845908 assert (is_canonical_correct ());
846909 }
847910
848- // / Merge the regions of two indices while maintaining canonicality. Returns
849- // / the final region used.
850- // /
851- // / This runs in linear time.
852- Region merge (Element fst, Element snd) {
853- assert (elementToRegionMap.count (fst) && elementToRegionMap.count (snd));
854-
855- auto fstRegion = elementToRegionMap.at (fst);
856- auto sndRegion = elementToRegionMap.at (snd);
857-
858- if (fstRegion == sndRegion)
859- return fstRegion;
860-
861- // Maintain canonicality by renaming the greater-numbered region to the
862- // smaller region.
863- std::optional<Region> result;
864- if (fstRegion < sndRegion) {
865- result = fstRegion;
866-
867- // Rename snd to use first region.
868- horizontalUpdate (elementToRegionMap, snd, fstRegion);
869- auto iter = regionToTransferredOpMap.find (sndRegion);
870- if (iter != regionToTransferredOpMap.end ()) {
871- auto operand = iter->second ;
872- regionToTransferredOpMap.erase (iter);
873- regionToTransferredOpMap.try_emplace (fstRegion, operand);
874- }
875- } else {
876- result = sndRegion;
877-
878- horizontalUpdate (elementToRegionMap, fst, sndRegion);
879- auto iter = regionToTransferredOpMap.find (fstRegion);
880- if (iter != regionToTransferredOpMap.end ()) {
881- auto operand = iter->second ;
882- regionToTransferredOpMap.erase (iter);
883- regionToTransferredOpMap.try_emplace (sndRegion, operand);
884- }
885- }
886-
887- assert (is_canonical_correct ());
888- assert (elementToRegionMap.at (fst) == elementToRegionMap.at (snd));
889- return *result;
890- }
891-
892- private:
893911 // / For the passed `map`, ensure that `key` maps to `val`. If `key` already
894912 // / mapped to a different value, ensure that all other keys mapped to that
895913 // / value also now map to `val`. This is a relatively expensive (linear time)
@@ -1035,7 +1053,7 @@ struct PartitionOpEvaluator {
10351053 case PartitionOpKind::Assign:
10361054 assert (op.getOpArgs ().size () == 2 &&
10371055 " Assign PartitionOp should be passed 2 arguments" );
1038- assert (p.elementToRegionMap . count (op.getOpArgs ()[1 ]) &&
1056+ assert (p.isTrackingElement (op.getOpArgs ()[1 ]) &&
10391057 " Assign PartitionOp's source argument should be already tracked" );
10401058 // If we are using a region that was transferred as our assignment source
10411059 // value... emit an error.
@@ -1044,37 +1062,30 @@ struct PartitionOpEvaluator {
10441062 handleFailure (op, op.getOpArgs ()[1 ], transferredOperand);
10451063 }
10461064 }
1047-
1048- p.elementToRegionMap .insert_or_assign (
1049- op.getOpArgs ()[0 ], p.elementToRegionMap .at (op.getOpArgs ()[1 ]));
1050-
1051- // assignment could have invalidated canonicality of either the old region
1052- // of op.getOpArgs()[0] or the region of op.getOpArgs()[1], or both
1053- p.canonical = false ;
1065+ p.assignElement (op.getOpArgs ()[0 ], op.getOpArgs ()[1 ]);
10541066 return ;
10551067 case PartitionOpKind::AssignFresh:
10561068 assert (op.getOpArgs ().size () == 1 &&
10571069 " AssignFresh PartitionOp should be passed 1 argument" );
10581070
1059- p.addElement (op.getOpArgs ()[0 ]);
1071+ p.trackNewElement (op.getOpArgs ()[0 ]);
10601072 return ;
10611073 case PartitionOpKind::Transfer: {
10621074 assert (op.getOpArgs ().size () == 1 &&
10631075 " Transfer PartitionOp should be passed 1 argument" );
1064- assert (p.elementToRegionMap . count (op.getOpArgs ()[0 ]) &&
1076+ assert (p.isTrackingElement (op.getOpArgs ()[0 ]) &&
10651077 " Transfer PartitionOp's argument should already be tracked" );
10661078
10671079 // check if any nontransferrables are transferred here, and handle the
10681080 // failure if so
10691081 for (Element nonTransferrable : nonTransferrableElements) {
10701082 assert (
1071- p.elementToRegionMap . count (nonTransferrable) &&
1083+ p.isTrackingElement (nonTransferrable) &&
10721084 " nontransferrables should be function args and self, and therefore"
10731085 " always present in the label map because of initialization at "
10741086 " entry" );
10751087 if (!p.isTransferred (nonTransferrable) &&
1076- p.elementToRegionMap .at (nonTransferrable) ==
1077- p.elementToRegionMap .at (op.getOpArgs ()[0 ])) {
1088+ p.areElementsInSameRegion (nonTransferrable, op.getOpArgs ()[0 ])) {
10781089 return handleTransferNonTransferrable (op, nonTransferrable);
10791090 }
10801091 }
@@ -1090,8 +1101,8 @@ struct PartitionOpEvaluator {
10901101 bool isClosureCapturedElt =
10911102 isClosureCaptured (op.getOpArgs ()[0 ], op.getSourceOp ());
10921103
1093- Region elementRegion = p.elementToRegionMap . at (op.getOpArgs ()[0 ]);
1094- for (const auto &pair : p.elementToRegionMap ) {
1104+ Region elementRegion = p.getRegion (op.getOpArgs ()[0 ]);
1105+ for (const auto &pair : p.range () ) {
10951106 if (pair.second == elementRegion && isActorDerived (pair.first ))
10961107 return handleTransferNonTransferrable (op, op.getOpArgs ()[0 ]);
10971108 isClosureCapturedElt |= isClosureCaptured (pair.first , op.getSourceOp ());
@@ -1106,7 +1117,7 @@ struct PartitionOpEvaluator {
11061117 case PartitionOpKind::UndoTransfer: {
11071118 assert (op.getOpArgs ().size () == 1 &&
11081119 " UndoTransfer PartitionOp should be passed 1 argument" );
1109- assert (p.elementToRegionMap . count (op.getOpArgs ()[0 ]) &&
1120+ assert (p.isTrackingElement (op.getOpArgs ()[0 ]) &&
11101121 " UndoTransfer PartitionOp's argument should already be tracked" );
11111122
11121123 // Mark op.getOpArgs()[0] as not transferred.
@@ -1116,8 +1127,8 @@ struct PartitionOpEvaluator {
11161127 case PartitionOpKind::Merge:
11171128 assert (op.getOpArgs ().size () == 2 &&
11181129 " Merge PartitionOp should be passed 2 arguments" );
1119- assert (p.elementToRegionMap . count (op.getOpArgs ()[0 ]) &&
1120- p.elementToRegionMap . count (op.getOpArgs ()[1 ]) &&
1130+ assert (p.isTrackingElement (op.getOpArgs ()[0 ]) &&
1131+ p.isTrackingElement (op.getOpArgs ()[1 ]) &&
11211132 " Merge PartitionOp's arguments should already be tracked" );
11221133
11231134 // if attempting to merge a transferred region, handle the failure
@@ -1137,7 +1148,7 @@ struct PartitionOpEvaluator {
11371148 case PartitionOpKind::Require:
11381149 assert (op.getOpArgs ().size () == 1 &&
11391150 " Require PartitionOp should be passed 1 argument" );
1140- assert (p.elementToRegionMap . count (op.getOpArgs ()[0 ]) &&
1151+ assert (p.isTrackingElement (op.getOpArgs ()[0 ]) &&
11411152 " Require PartitionOp's argument should already be tracked" );
11421153 if (auto *transferredOperandSet = p.getTransferred (op.getOpArgs ()[0 ])) {
11431154 for (auto transferredOperand : transferredOperandSet->data ()) {
0 commit comments