@@ -844,6 +844,147 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
844844 return success ();
845845}
846846
847+ // Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
848+ // operation into the corresponding ROCDL instructions.
849+ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern <DPPOp> {
850+ AMDGPUDPPLowering (LLVMTypeConverter &converter, Chipset chipset)
851+ : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
852+ Chipset chipset;
853+
854+ LogicalResult
855+ matchAndRewrite (DPPOp DppOp, DPPOp::Adaptor adaptor,
856+ ConversionPatternRewriter &rewriter) const override {
857+
858+ // Convert the source operand to the corresponding LLVM type
859+ Location loc = DppOp.getLoc ();
860+ Value src = adaptor.getSrc ();
861+ Value old = adaptor.getOld ();
862+ Type srcType = src.getType ();
863+ Type oldType = old.getType ();
864+ auto llvmI32Type = typeConverter->convertType (rewriter.getI32Type ());
865+ auto llvmSrcIntType = typeConverter->convertType (
866+ rewriter.getIntegerType (srcType.getIntOrFloatBitWidth ()));
867+
868+ // If the source type is less or equal to i32 or f32, use bitcast to convert
869+ // it to i32.
870+ auto convertOperand = [&](Value operand, Type operandType) {
871+ if (llvm::isa<FloatType>(operandType)) {
872+ operand =
873+ rewriter.create <LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
874+ }
875+
876+ if (operandType.getIntOrFloatBitWidth () < 32 ) {
877+ auto llvmVecType = typeConverter->convertType (mlir::VectorType::get (
878+ 32 / operandType.getIntOrFloatBitWidth (), llvmSrcIntType));
879+ Value undefVec = rewriter.create <LLVM::UndefOp>(loc, llvmVecType);
880+ operand = rewriter.create <LLVM::InsertElementOp>(
881+ loc, undefVec, operand, createI32Constant (rewriter, loc, 0 ));
882+ operand = rewriter.create <LLVM::BitcastOp>(loc, llvmI32Type, operand);
883+ }
884+ return operand;
885+ };
886+
887+ src = convertOperand (src, srcType);
888+ old = convertOperand (old, oldType);
889+
890+ // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
891+ enum DppCtrl : unsigned {
892+ ROW_SHL0 = 0x100 ,
893+ ROW_SHR0 = 0x110 ,
894+ ROW_ROR0 = 0x120 ,
895+ WAVE_SHL1 = 0x130 ,
896+ WAVE_ROL1 = 0x134 ,
897+ WAVE_SHR1 = 0x138 ,
898+ WAVE_ROR1 = 0x13C ,
899+ ROW_MIRROR = 0x140 ,
900+ ROW_HALF_MIRROR = 0x141 ,
901+ BCAST15 = 0x142 ,
902+ BCAST31 = 0x143 ,
903+ };
904+
905+ auto kind = DppOp.getKind ();
906+ auto permArgument = DppOp.getPermArgument ();
907+ uint32_t DppCtrl = 0 ;
908+
909+ switch (kind) {
910+
911+ case DPPPerm::quad_perm:
912+ if (auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
913+ int32_t i = 0 ;
914+ for (auto elem : quadPermAttr.getAsRange <IntegerAttr>()) {
915+ uint32_t num = elem.getInt ();
916+ DppCtrl |= num << (i * 2 );
917+ i++;
918+ }
919+ }
920+ break ;
921+ case DPPPerm::row_shl:
922+ if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
923+ DppCtrl = intAttr.getInt () + DppCtrl::ROW_SHL0;
924+ }
925+ break ;
926+ case DPPPerm::row_shr:
927+ if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
928+ DppCtrl = intAttr.getInt () + DppCtrl::ROW_SHR0;
929+ }
930+ break ;
931+ case DPPPerm::row_ror:
932+ if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
933+ DppCtrl = intAttr.getInt () + DppCtrl::ROW_ROR0;
934+ }
935+ break ;
936+ case DPPPerm::wave_shl:
937+ DppCtrl = DppCtrl::WAVE_SHL1;
938+ break ;
939+ case DPPPerm::wave_shr:
940+ DppCtrl = DppCtrl::WAVE_SHR1;
941+ break ;
942+ case DPPPerm::wave_rol:
943+ DppCtrl = DppCtrl::WAVE_ROL1;
944+ break ;
945+ case DPPPerm::wave_ror:
946+ DppCtrl = DppCtrl::WAVE_ROR1;
947+ break ;
948+ case DPPPerm::row_mirror:
949+ DppCtrl = DppCtrl::ROW_MIRROR;
950+ break ;
951+ case DPPPerm::row_half_mirror:
952+ DppCtrl = DppCtrl::ROW_HALF_MIRROR;
953+ break ;
954+ case DPPPerm::row_bcast_15:
955+ DppCtrl = DppCtrl::BCAST15;
956+ break ;
957+ case DPPPerm::row_bcast_31:
958+ DppCtrl = DppCtrl::BCAST31;
959+ break ;
960+ }
961+
962+ // Check for row_mask, bank_mask, bound_ctrl if they exist and create
963+ // constants
964+ auto rowMask = DppOp->getAttrOfType <IntegerAttr>(" row_mask" ).getInt ();
965+ auto bankMask = DppOp->getAttrOfType <IntegerAttr>(" bank_mask" ).getInt ();
966+ bool boundCtrl = DppOp->getAttrOfType <BoolAttr>(" bound_ctrl" ).getValue ();
967+
968+ // create a ROCDL_DPPMovOp instruction with the appropriate attributes
969+ auto dppMovOp = rewriter.create <ROCDL::DPPUpdateOp>(
970+ loc, llvmI32Type, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
971+
972+ Value result = dppMovOp.getRes ();
973+ if (srcType.getIntOrFloatBitWidth () < 32 ) {
974+ result = rewriter.create <LLVM::TruncOp>(loc, llvmSrcIntType, result);
975+ }
976+
977+ if (!llvm::isa<IntegerType>(srcType)) {
978+ result = rewriter.create <LLVM::BitcastOp>(loc, srcType, result);
979+ }
980+
981+ // We are replacing the AMDGPU_DPPOp instruction with the new
982+ // ROCDL_DPPMovOp instruction
983+ rewriter.replaceOp (DppOp, ValueRange (result));
984+ return success ();
985+ }
986+ };
987+
847988struct ConvertAMDGPUToROCDLPass
848989 : public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
849990 ConvertAMDGPUToROCDLPass () = default ;
@@ -895,9 +1036,10 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
8951036 ROCDL::RawPtrBufferAtomicUminOp>,
8961037 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
8971038 ROCDL::RawPtrBufferAtomicCmpSwap>,
898- LDSBarrierOpLowering, SchedBarrierOpLowering, MFMAOpLowering,
899- WMMAOpLowering, ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
900- PackedStochRoundFp8OpLowering>(converter, chipset);
1039+ AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
1040+ MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
1041+ PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter,
1042+ chipset);
9011043}
9021044
9031045std::unique_ptr<Pass> mlir::createConvertAMDGPUToROCDLPass () {
0 commit comments