@@ -371,6 +371,38 @@ static VectorType getDistributedType(VectorType originalType, AffineMap map,
371371 return targetType;
372372}
373373
374+ // / Given a warpOp that contains ops with regions, the corresponding op's
375+ // / "inner" region and the distributionMapFn, get all values used by the op's
376+ // / region that are defined within the warpOp, but outside the inner region.
377+ // / Return the set of values, their types and their distributed types.
378+ std::tuple<llvm::SmallSetVector<Value, 32 >, SmallVector<Type>,
379+ SmallVector<Type>>
380+ getInnerRegionEscapingValues (WarpExecuteOnLane0Op warpOp, Region &innerRegion,
381+ DistributionMapFn distributionMapFn) {
382+ llvm::SmallSetVector<Value, 32 > escapingValues;
383+ SmallVector<Type> escapingValueTypes;
384+ SmallVector<Type> escapingValueDistTypes; // to yield from the new warpOp
385+ if (innerRegion.empty ())
386+ return {std::move (escapingValues), std::move (escapingValueTypes),
387+ std::move (escapingValueDistTypes)};
388+ mlir::visitUsedValuesDefinedAbove (innerRegion, [&](OpOperand *operand) {
389+ Operation *parent = operand->get ().getParentRegion ()->getParentOp ();
390+ if (warpOp->isAncestor (parent)) {
391+ if (!escapingValues.insert (operand->get ()))
392+ return ;
393+ Type distType = operand->get ().getType ();
394+ if (auto vecType = dyn_cast<VectorType>(distType)) {
395+ AffineMap map = distributionMapFn (operand->get ());
396+ distType = getDistributedType (vecType, map, warpOp.getWarpSize ());
397+ }
398+ escapingValueTypes.push_back (operand->get ().getType ());
399+ escapingValueDistTypes.push_back (distType);
400+ }
401+ });
402+ return {std::move (escapingValues), std::move (escapingValueTypes),
403+ std::move (escapingValueDistTypes)};
404+ }
405+
374406// / Distribute transfer_write ops based on the affine map returned by
375407// / `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
376408// / will not be distributed (it should be less than the warp size).
@@ -1713,6 +1745,215 @@ struct WarpOpInsert : public WarpDistributionPattern {
17131745 }
17141746};
17151747
1748+ // / Sink scf.if out of WarpExecuteOnLane0Op. This can be done only if
1749+ // / the scf.if is the last operation in the region so that it doesn't
1750+ // / change the order of execution. This creates a new scf.if after the
1751+ // / WarpExecuteOnLane0Op. Each branch of the new scf.if is enclosed in
1752+ // / the "inner" WarpExecuteOnLane0Op. Example:
1753+ // / ```
1754+ // / gpu.warp_execute_on_lane_0(%laneid)[32] {
1755+ // / %payload = ... : vector<32xindex>
1756+ // / scf.if %pred {
1757+ // / vector.store %payload, %buffer[%idx] : memref<128xindex>,
1758+ // / vector<32xindex>
1759+ // / }
1760+ // / gpu.yield
1761+ // / }
1762+ // / ```
1763+ // / %r = gpu.warp_execute_on_lane_0(%laneid)[32] {
1764+ // / %payload = ... : vector<32xindex>
1765+ // / gpu.yield %payload : vector<32xindex>
1766+ // / }
1767+ // / scf.if %pred {
1768+ // / gpu.warp_execute_on_lane_0(%laneid)[32] args(%r : vector<1xindex>) {
1769+ // / ^bb0(%arg1: vector<32xindex>):
1770+ // / vector.store %arg1, %buffer[%idx] : memref<128xindex>, vector<32xindex>
1771+ // / }
1772+ // / }
1773+ // / ```
1774+ struct WarpOpScfIfOp : public WarpDistributionPattern {
1775+ WarpOpScfIfOp (MLIRContext *ctx, DistributionMapFn fn, PatternBenefit b = 1 )
1776+ : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {}
1777+ LogicalResult matchAndRewrite (WarpExecuteOnLane0Op warpOp,
1778+ PatternRewriter &rewriter) const override {
1779+ gpu::YieldOp warpOpYield = warpOp.getTerminator ();
1780+ // Only pick up `IfOp` if it is the last op in the region.
1781+ Operation *lastNode = warpOpYield->getPrevNode ();
1782+ auto ifOp = dyn_cast_or_null<scf::IfOp>(lastNode);
1783+ if (!ifOp)
1784+ return failure ();
1785+
1786+ // The current `WarpOp` can yield two types of values:
1787+ // 1. Not results of `IfOp`:
1788+ // Preserve them in the new `WarpOp`.
1789+ // Collect their yield index to remap the usages.
1790+ // 2. Results of `IfOp`:
1791+ // They are not part of the new `WarpOp` results.
1792+ // Map current warp's yield operand index to `IfOp` result idx.
1793+ SmallVector<Value> nonIfYieldValues;
1794+ SmallVector<unsigned > nonIfYieldIndices;
1795+ llvm::SmallDenseMap<unsigned , unsigned > ifResultMapping;
1796+ llvm::SmallDenseMap<unsigned , VectorType> ifResultDistTypes;
1797+ for (OpOperand &yieldOperand : warpOpYield->getOpOperands ()) {
1798+ const unsigned yieldOperandIdx = yieldOperand.getOperandNumber ();
1799+ if (yieldOperand.get ().getDefiningOp () != ifOp.getOperation ()) {
1800+ nonIfYieldValues.push_back (yieldOperand.get ());
1801+ nonIfYieldIndices.push_back (yieldOperandIdx);
1802+ continue ;
1803+ }
1804+ OpResult ifResult = cast<OpResult>(yieldOperand.get ());
1805+ const unsigned ifResultIdx = ifResult.getResultNumber ();
1806+ ifResultMapping[yieldOperandIdx] = ifResultIdx;
1807+ // If this `ifOp` result is vector type and it is yielded by the
1808+ // `WarpOp`, we keep track the distributed type for this result.
1809+ if (!isa<VectorType>(ifResult.getType ()))
1810+ continue ;
1811+ VectorType distType =
1812+ cast<VectorType>(warpOp.getResult (yieldOperandIdx).getType ());
1813+ ifResultDistTypes[ifResultIdx] = distType;
1814+ }
1815+
1816+ // Collect `WarpOp`-defined values used in `ifOp`, the new warp op returns
1817+ // them
1818+ auto [escapingValuesThen, escapingValueInputTypesThen,
1819+ escapingValueDistTypesThen] =
1820+ getInnerRegionEscapingValues (warpOp, ifOp.getThenRegion (),
1821+ distributionMapFn);
1822+ auto [escapingValuesElse, escapingValueInputTypesElse,
1823+ escapingValueDistTypesElse] =
1824+ getInnerRegionEscapingValues (warpOp, ifOp.getElseRegion (),
1825+ distributionMapFn);
1826+ if (llvm::is_contained (escapingValueDistTypesThen, Type{}) ||
1827+ llvm::is_contained (escapingValueDistTypesElse, Type{}))
1828+ return failure ();
1829+
1830+ // The new `WarpOp` groups yields values in following order:
1831+ // 1. Branch condition
1832+ // 2. Escaping values then branch
1833+ // 3. Escaping values else branch
1834+ // 4. All non-`ifOp` yielded values.
1835+ SmallVector<Value> newWarpOpYieldValues{ifOp.getCondition ()};
1836+ newWarpOpYieldValues.append (escapingValuesThen.begin (),
1837+ escapingValuesThen.end ());
1838+ newWarpOpYieldValues.append (escapingValuesElse.begin (),
1839+ escapingValuesElse.end ());
1840+ SmallVector<Type> newWarpOpDistTypes{ifOp.getCondition ().getType ()};
1841+ newWarpOpDistTypes.append (escapingValueDistTypesThen.begin (),
1842+ escapingValueDistTypesThen.end ());
1843+ newWarpOpDistTypes.append (escapingValueDistTypesElse.begin (),
1844+ escapingValueDistTypesElse.end ());
1845+
1846+ llvm::SmallDenseMap<unsigned , unsigned > origToNewYieldIdx;
1847+ for (auto [idx, val] :
1848+ llvm::zip_equal (nonIfYieldIndices, nonIfYieldValues)) {
1849+ origToNewYieldIdx[idx] = newWarpOpYieldValues.size ();
1850+ newWarpOpYieldValues.push_back (val);
1851+ newWarpOpDistTypes.push_back (warpOp.getResult (idx).getType ());
1852+ }
1853+ // Create the new `WarpOp` with the updated yield values and types.
1854+ WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns (
1855+ rewriter, warpOp, newWarpOpYieldValues, newWarpOpDistTypes);
1856+ // `ifOp` returns the result of the inner warp op.
1857+ SmallVector<Type> newIfOpDistResTypes;
1858+ for (auto [i, res] : llvm::enumerate (ifOp.getResults ())) {
1859+ Type distType = cast<Value>(res).getType ();
1860+ if (auto vecType = dyn_cast<VectorType>(distType)) {
1861+ AffineMap map = distributionMapFn (cast<Value>(res));
1862+ // Fallback to affine map if the dist result was not previously recorded
1863+ distType = ifResultDistTypes.count (i)
1864+ ? ifResultDistTypes[i]
1865+ : getDistributedType (vecType, map, warpOp.getWarpSize ());
1866+ }
1867+ newIfOpDistResTypes.push_back (distType);
1868+ }
1869+ // Create a new `IfOp` outside the new `WarpOp` region.
1870+ OpBuilder::InsertionGuard g (rewriter);
1871+ rewriter.setInsertionPointAfter (newWarpOp);
1872+ auto newIfOp = scf::IfOp::create (
1873+ rewriter, ifOp.getLoc (), newIfOpDistResTypes, newWarpOp.getResult (0 ),
1874+ static_cast <bool >(ifOp.thenBlock ()),
1875+ static_cast <bool >(ifOp.elseBlock ()));
1876+ auto encloseRegionInWarpOp =
1877+ [&](Block *oldIfBranch, Block *newIfBranch,
1878+ llvm::SmallSetVector<Value, 32 > &escapingValues,
1879+ SmallVector<Type> &escapingValueInputTypes,
1880+ size_t warpResRangeStart) {
1881+ OpBuilder::InsertionGuard g (rewriter);
1882+ if (!newIfBranch)
1883+ return ;
1884+ rewriter.setInsertionPointToStart (newIfBranch);
1885+ llvm::SmallDenseMap<Value, int64_t > escapeValToBlockArgIndex;
1886+ SmallVector<Value> innerWarpInputVals;
1887+ SmallVector<Type> innerWarpInputTypes;
1888+ for (size_t i = 0 ; i < escapingValues.size ();
1889+ ++i, ++warpResRangeStart) {
1890+ innerWarpInputVals.push_back (
1891+ newWarpOp.getResult (warpResRangeStart));
1892+ escapeValToBlockArgIndex[escapingValues[i]] =
1893+ innerWarpInputTypes.size ();
1894+ innerWarpInputTypes.push_back (escapingValueInputTypes[i]);
1895+ }
1896+ auto innerWarp = WarpExecuteOnLane0Op::create (
1897+ rewriter, newWarpOp.getLoc (), newIfOp.getResultTypes (),
1898+ newWarpOp.getLaneid (), newWarpOp.getWarpSize (),
1899+ innerWarpInputVals, innerWarpInputTypes);
1900+
1901+ innerWarp.getWarpRegion ().takeBody (*oldIfBranch->getParent ());
1902+ innerWarp.getWarpRegion ().addArguments (
1903+ innerWarpInputTypes,
1904+ SmallVector<Location>(innerWarpInputTypes.size (), ifOp.getLoc ()));
1905+
1906+ SmallVector<Value> yieldOperands;
1907+ for (Value operand : oldIfBranch->getTerminator ()->getOperands ())
1908+ yieldOperands.push_back (operand);
1909+ rewriter.eraseOp (oldIfBranch->getTerminator ());
1910+
1911+ rewriter.setInsertionPointToEnd (innerWarp.getBody ());
1912+ gpu::YieldOp::create (rewriter, innerWarp.getLoc (), yieldOperands);
1913+ rewriter.setInsertionPointAfter (innerWarp);
1914+ scf::YieldOp::create (rewriter, ifOp.getLoc (), innerWarp.getResults ());
1915+
1916+ // Update any users of escaping values that were forwarded to the
1917+ // inner `WarpOp`. These values are arguments of the inner `WarpOp`.
1918+ innerWarp.walk ([&](Operation *op) {
1919+ for (OpOperand &operand : op->getOpOperands ()) {
1920+ auto it = escapeValToBlockArgIndex.find (operand.get ());
1921+ if (it == escapeValToBlockArgIndex.end ())
1922+ continue ;
1923+ operand.set (innerWarp.getBodyRegion ().getArgument (it->second ));
1924+ }
1925+ });
1926+ mlir::vector::moveScalarUniformCode (innerWarp);
1927+ };
1928+ encloseRegionInWarpOp (&ifOp.getThenRegion ().front (),
1929+ &newIfOp.getThenRegion ().front (), escapingValuesThen,
1930+ escapingValueInputTypesThen, 1 );
1931+ if (!ifOp.getElseRegion ().empty ())
1932+ encloseRegionInWarpOp (&ifOp.getElseRegion ().front (),
1933+ &newIfOp.getElseRegion ().front (),
1934+ escapingValuesElse, escapingValueInputTypesElse,
1935+ 1 + escapingValuesThen.size ());
1936+ // Update the users of `<- WarpOp.yield <- IfOp.yield` to use the new `IfOp`
1937+ // result.
1938+ for (auto [origIdx, newIdx] : ifResultMapping)
1939+ rewriter.replaceAllUsesExcept (warpOp.getResult (origIdx),
1940+ newIfOp.getResult (newIdx), newIfOp);
1941+ // Similarly, update any users of the `WarpOp` results that were not
1942+ // results of the `IfOp`.
1943+ for (auto [origIdx, newIdx] : origToNewYieldIdx)
1944+ rewriter.replaceAllUsesWith (warpOp.getResult (origIdx),
1945+ newWarpOp.getResult (newIdx));
1946+ // Remove the original `WarpOp` and `IfOp`, they should not have any uses
1947+ // at this point.
1948+ rewriter.eraseOp (ifOp);
1949+ rewriter.eraseOp (warpOp);
1950+ return success ();
1951+ }
1952+
1953+ private:
1954+ DistributionMapFn distributionMapFn;
1955+ };
1956+
17161957// / Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
17171958// / the scf.ForOp is the last operation in the region so that it doesn't
17181959// / change the order of execution. This creates a new scf.for region after the
@@ -1759,25 +2000,9 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
17592000 return failure ();
17602001 // Collect Values that come from the `WarpOp` but are outside the `ForOp`.
17612002 // Those Values need to be returned by the new warp op.
1762- llvm::SmallSetVector<Value, 32 > escapingValues;
1763- SmallVector<Type> escapingValueInputTypes;
1764- SmallVector<Type> escapingValueDistTypes;
1765- mlir::visitUsedValuesDefinedAbove (
1766- forOp.getBodyRegion (), [&](OpOperand *operand) {
1767- Operation *parent = operand->get ().getParentRegion ()->getParentOp ();
1768- if (warpOp->isAncestor (parent)) {
1769- if (!escapingValues.insert (operand->get ()))
1770- return ;
1771- Type distType = operand->get ().getType ();
1772- if (auto vecType = dyn_cast<VectorType>(distType)) {
1773- AffineMap map = distributionMapFn (operand->get ());
1774- distType = getDistributedType (vecType, map, warpOp.getWarpSize ());
1775- }
1776- escapingValueInputTypes.push_back (operand->get ().getType ());
1777- escapingValueDistTypes.push_back (distType);
1778- }
1779- });
1780-
2003+ auto [escapingValues, escapingValueInputTypes, escapingValueDistTypes] =
2004+ getInnerRegionEscapingValues (warpOp, forOp.getBodyRegion (),
2005+ distributionMapFn);
17812006 if (llvm::is_contained (escapingValueDistTypes, Type{}))
17822007 return failure ();
17832008 // `WarpOp` can yield two types of values:
@@ -2068,6 +2293,8 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
20682293 benefit);
20692294 patterns.add <WarpOpScfForOp>(patterns.getContext (), distributionMapFn,
20702295 benefit);
2296+ patterns.add <WarpOpScfIfOp>(patterns.getContext (), distributionMapFn,
2297+ benefit);
20712298}
20722299
20732300void mlir::vector::populateDistributeReduction (
0 commit comments