@@ -807,200 +807,6 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
807807 }
808808};
809809
810- // / Distribute a scattered store op. The offsets argument is required.
811- // / Both offset and mask vectors must be 1D and have #subgroup_size elements.
812- // / The layouts are fixed and implicit: one offset/mask per lane.
813- // / The pass changes the offset/mask vector shapes to a
814- // / single-element vector, **it is assumed that their producer will also be
815- // / distributed**. The payload vector also has a fixed distribution:
816- // / no chunk size -> vector of one element.
817- // / chunk size -> vector of the innermost dimension of the SG-payload.
818- // / Example 1 (no chunk size):
819- // / %mask = producer_op : vector<16xi1>
820- // / %offset = producer_op : vector<16xindex>
821- // / xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
822- // / memref<256xf16>, vector<16xindex>, vector<16xi1>
823- // / To
824- // / %mask = producer_op : vector<1xi1>
825- // / %offset = producer_op : vector<1xindex>
826- // / xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
827- // / memref<256xf16>, vector<1xindex>, vector<1xi1>
828- // / Example 2 (chunk size, same mask and offsets):
829- // / xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
830- // / vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
831- // / To
832- // / xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
833- // / vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
834- struct StoreDistribution final : public gpu::WarpDistributionPattern {
835- using gpu::WarpDistributionPattern::WarpDistributionPattern;
836- LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp,
837- PatternRewriter &rewriter) const override {
838- Operation *lastNode = warpOp.getTerminator ()->getPrevNode ();
839- auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
840- if (!storeScatterOp)
841- return failure ();
842- auto offsets = storeScatterOp.getOffsets ();
843- if (!offsets || !isa<VectorType>(offsets.getType ()))
844- return rewriter.notifyMatchFailure (
845- storeScatterOp, " Store op must have a vector of offsets argument" );
846- VectorType offsetsTy = cast<VectorType>(offsets.getType ());
847- VectorType maskTy = cast<VectorType>(storeScatterOp.getMask ().getType ());
848- if (offsetsTy.getRank () != 1 || maskTy.getRank () != 1 )
849- return rewriter.notifyMatchFailure (storeScatterOp,
850- " Expected 1D offsets and mask vector" );
851- VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType ());
852- if (storeVecTy.getRank () > 2 )
853- return rewriter.notifyMatchFailure (
854- storeScatterOp, " Expected at most 2D result at SG level" );
855-
856- std::string layoutPayloadName =
857- xegpu::getLayoutName (storeScatterOp->getOpOperand (0 ));
858- std::string layoutOffsetsName =
859- xegpu::getLayoutName (storeScatterOp->getOpOperand (2 ));
860- std::string layoutMaskName =
861- xegpu::getLayoutName (storeScatterOp->getOpOperand (3 ));
862-
863- xegpu::LayoutAttr layoutPayload =
864- storeScatterOp->getAttrOfType <xegpu::LayoutAttr>(layoutPayloadName);
865- xegpu::LayoutAttr layoutOffsets =
866- storeScatterOp->getAttrOfType <xegpu::LayoutAttr>(layoutOffsetsName);
867- xegpu::LayoutAttr layoutMask =
868- storeScatterOp->getAttrOfType <xegpu::LayoutAttr>(layoutMaskName);
869-
870- FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
871- getDistVecTypeBasedOnLaneLayout (layoutPayload, storeVecTy);
872- FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
873- getDistVecTypeBasedOnLaneLayout (layoutOffsets, offsetsTy);
874- FailureOr<VectorType> distMaskByWarpOpOrFailure =
875- getDistVecTypeBasedOnLaneLayout (layoutMask, maskTy);
876- if (failed (distStoreVecByWarpOpOrFailure) ||
877- failed (distOffsetsByWarpOpOrFailure) ||
878- failed (distMaskByWarpOpOrFailure)) {
879- return rewriter.notifyMatchFailure (
880- storeScatterOp,
881- " Some vector operands have no layouts, using defaults instead." );
882- }
883- VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value ();
884- VectorType expectedPayloadTy = VectorType::get (
885- {distPayloadTy.getNumElements ()}, distPayloadTy.getElementType ());
886-
887- SmallVector<size_t > newRetIndices;
888- SmallVector<Value> operands = storeScatterOp->getOperands ();
889- SmallVector<Type> operandTypesToYield = {
890- expectedPayloadTy, operands[1 ].getType (),
891- distOffsetsByWarpOpOrFailure.value (),
892- distMaskByWarpOpOrFailure.value ()};
893-
894- gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
895- rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
896- SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector (
897- newRetIndices, [&](size_t idx) { return newWarpOp.getResult (idx); });
898-
899- rewriter.setInsertionPointAfter (newWarpOp);
900- xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create (
901- rewriter, newWarpOp.getLoc (), TypeRange{}, newStoreScatterOpOperands,
902- storeScatterOp->getAttrs ());
903- xegpu::removeLayoutAttrs (newOp);
904- rewriter.eraseOp (storeScatterOp);
905- return success ();
906- }
907- };
908-
909- // / Distribute a scattered load op. The logic and requirements are the same as
910- // / for the scattered store distribution. The warpOp's payload vector is
911- // / expected to be distributed by the load's result consumer.
912- // / Example 1 (no chunk size):
913- // / %mask = producer_op : vector<16xi1>
914- // / %offset = producer_op : vector<16xindex>
915- // / %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
916- // / vector<16xindex>, vector<16xi1> -> vector<16xf16>
917- // / To
918- // / %mask = producer_op : vector<1xi1>
919- // / %offset = producer_op : vector<1xindex>
920- // / %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
921- // / vector<1xindex>, vector<1xi1> -> vector<1xf16>
922- // / Example 2 (chunk size, same mask and offsets):
923- // / %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
924- // / memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
925- // / To
926- // / %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
927- // / memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
928- struct LoadDistribution final : public gpu::WarpDistributionPattern {
929- using gpu::WarpDistributionPattern::WarpDistributionPattern;
930- LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op warpOp,
931- PatternRewriter &rewriter) const override {
932- OpOperand *producedByLastLoad = getWarpResult (warpOp, [&](Operation *op) {
933- // Check if the yield operand that was produced by the *last* scattered
934- // load op to avoid sinking it before barriers (maintain memory order).
935- return isa<xegpu::LoadGatherOp>(op) &&
936- warpOp.getTerminator ()->getPrevNode () == op;
937- });
938- if (!producedByLastLoad)
939- return rewriter.notifyMatchFailure (
940- warpOp, " The last op is not xegpu::LoadGatherOp" );
941-
942- auto loadGatherOp =
943- producedByLastLoad->get ().getDefiningOp <xegpu::LoadGatherOp>();
944- auto offsets = loadGatherOp.getOffsets ();
945- if (!offsets || !isa<VectorType>(offsets.getType ()) ||
946- !isa<VectorType>(loadGatherOp.getMask ().getType ()))
947- return rewriter.notifyMatchFailure (
948- loadGatherOp,
949- " Load op must have a vector arguments for offsets and mask" );
950- VectorType offsetsTy = cast<VectorType>(offsets.getType ());
951- VectorType maskTy = cast<VectorType>(loadGatherOp.getMask ().getType ());
952- if (offsetsTy.getRank () != 1 || maskTy.getRank () != 1 )
953- return rewriter.notifyMatchFailure (loadGatherOp,
954- " Expected 1D offsets and mask vector" );
955- // Assume offset and mask producers will be distributed as well.
956- std::string layoutOffsetsName =
957- xegpu::getLayoutName (loadGatherOp->getOpOperand (1 ));
958- std::string layoutMaskName =
959- xegpu::getLayoutName (loadGatherOp->getOpOperand (2 ));
960-
961- xegpu::LayoutAttr layoutOffsets =
962- loadGatherOp->getAttrOfType <xegpu::LayoutAttr>(layoutOffsetsName);
963- xegpu::LayoutAttr layoutMask =
964- loadGatherOp->getAttrOfType <xegpu::LayoutAttr>(layoutMaskName);
965-
966- FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
967- getDistVecTypeBasedOnLaneLayout (layoutOffsets, offsetsTy);
968- FailureOr<VectorType> distMaskByWarpOpOrFailure =
969- getDistVecTypeBasedOnLaneLayout (layoutMask, maskTy);
970- if (failed (distOffsetsByWarpOpOrFailure) ||
971- failed (distMaskByWarpOpOrFailure)) {
972- return rewriter.notifyMatchFailure (
973- loadGatherOp,
974- " Some vector operands have no layouts, using defaults instead." );
975- }
976-
977- SmallVector<size_t > newRetIndices;
978- SmallVector<Value> operands = loadGatherOp->getOperands ();
979- SmallVector<Type> operandTypesToYield = {
980- operands[0 ].getType (), distOffsetsByWarpOpOrFailure.value (),
981- distMaskByWarpOpOrFailure.value ()};
982-
983- gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
984- rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
985-
986- SmallVector<Value> newLoadGatherOperands = llvm::map_to_vector (
987- newRetIndices, [&](size_t idx) { return newWarpOp.getResult (idx); });
988-
989- const unsigned operandIdx = producedByLastLoad->getOperandNumber ();
990- VectorType loadVecTy =
991- cast<VectorType>(warpOp.getResult (operandIdx).getType ());
992-
993- rewriter.setInsertionPointAfter (newWarpOp);
994- xegpu::LoadGatherOp newOp = rewriter.create <xegpu::LoadGatherOp>(
995- newWarpOp.getLoc (), loadVecTy, newLoadGatherOperands,
996- loadGatherOp->getAttrs ());
997- xegpu::removeLayoutAttrs (newOp);
998- Value distributedVal = newWarpOp.getResult (operandIdx);
999- rewriter.replaceAllUsesWith (distributedVal, newOp->getResult (0 ));
1000- return success ();
1001- }
1002- };
1003-
1004810} // namespace
1005811
1006812namespace {
@@ -1013,11 +819,10 @@ struct XeGPUSubgroupDistributePass final
1013819
1014820void xegpu::populateXeGPUSubgroupDistributePatterns (
1015821 RewritePatternSet &patterns) {
1016- patterns
1017- .add <CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
1018- DpasDistribution, PrefetchNdDistribution, UpdateNdOffsetDistribution,
1019- GpuBarrierDistribution, LoadDistribution, StoreDistribution>(
1020- patterns.getContext ());
822+ patterns.add <CreateNdDescDistribution, StoreNdDistribution,
823+ LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
824+ UpdateNdOffsetDistribution, GpuBarrierDistribution>(
825+ patterns.getContext ());
1021826}
1022827
1023828void XeGPUSubgroupDistributePass::runOnOperation () {
0 commit comments