@@ -844,6 +844,155 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
844844 return success ();
845845}
846846
847+ // Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
848+ // operation into the corresponding ROCDL instructions.
849+ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern <DPPOp> {
850+ AMDGPUDPPLowering (LLVMTypeConverter &converter, Chipset chipset)
851+ : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
852+ Chipset chipset;
853+
854+ LogicalResult
855+ matchAndRewrite (DPPOp DppOp, DPPOp::Adaptor adaptor,
856+ ConversionPatternRewriter &rewriter) const override {
857+
858+ // Convert the source operand to the corresponding LLVM type
859+ Location loc = DppOp.getLoc ();
860+ Value src = adaptor.getSrc ();
861+ Value old = adaptor.getOld ();
862+ Type srcType = src.getType ();
863+ Type oldType = old.getType ();
864+ Type llvmType = nullptr ;
865+ if (srcType.getIntOrFloatBitWidth () < 32 ) {
866+ llvmType = rewriter.getI32Type ();
867+ } else if (isa<FloatType>(srcType)) {
868+ llvmType = (srcType.getIntOrFloatBitWidth () == 32 )
869+ ? rewriter.getF32Type ()
870+ : rewriter.getF64Type ();
871+ } else if (isa<IntegerType>(srcType)) {
872+ llvmType = (srcType.getIntOrFloatBitWidth () == 32 )
873+ ? rewriter.getI32Type ()
874+ : rewriter.getI64Type ();
875+ }
876+ auto llvmSrcIntType = typeConverter->convertType (
877+ rewriter.getIntegerType (srcType.getIntOrFloatBitWidth ()));
878+
879+ // If the source type is less of 32, use bitcast to convert it to i32.
880+ auto convertOperand = [&](Value operand, Type operandType) {
881+ if (operandType.getIntOrFloatBitWidth () <= 16 ) {
882+ if (llvm::isa<FloatType>(operandType)) {
883+ operand =
884+ rewriter.create <LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
885+ }
886+ auto llvmVecType = typeConverter->convertType (mlir::VectorType::get (
887+ 32 / operandType.getIntOrFloatBitWidth (), llvmSrcIntType));
888+ Value undefVec = rewriter.create <LLVM::UndefOp>(loc, llvmVecType);
889+ operand = rewriter.create <LLVM::InsertElementOp>(
890+ loc, undefVec, operand, createI32Constant (rewriter, loc, 0 ));
891+ operand = rewriter.create <LLVM::BitcastOp>(loc, llvmType, operand);
892+ }
893+ return operand;
894+ };
895+
896+ src = convertOperand (src, srcType);
897+ old = convertOperand (old, oldType);
898+
899+ // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
900+ enum DppCtrl : unsigned {
901+ ROW_SHL0 = 0x100 ,
902+ ROW_SHR0 = 0x110 ,
903+ ROW_ROR0 = 0x120 ,
904+ WAVE_SHL1 = 0x130 ,
905+ WAVE_ROL1 = 0x134 ,
906+ WAVE_SHR1 = 0x138 ,
907+ WAVE_ROR1 = 0x13C ,
908+ ROW_MIRROR = 0x140 ,
909+ ROW_HALF_MIRROR = 0x141 ,
910+ BCAST15 = 0x142 ,
911+ BCAST31 = 0x143 ,
912+ };
913+
914+ auto kind = DppOp.getKind ();
915+ auto permArgument = DppOp.getPermArgument ();
916+ uint32_t DppCtrl = 0 ;
917+
918+ switch (kind) {
919+
920+ case DPPPerm::quad_perm:
921+ if (auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
922+ int32_t i = 0 ;
923+ for (auto elem : quadPermAttr.getAsRange <IntegerAttr>()) {
924+ uint32_t num = elem.getInt ();
925+ DppCtrl |= num << (i * 2 );
926+ i++;
927+ }
928+ }
929+ break ;
930+ case DPPPerm::row_shl:
931+ if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
932+ DppCtrl = intAttr.getInt () + DppCtrl::ROW_SHL0;
933+ }
934+ break ;
935+ case DPPPerm::row_shr:
936+ if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
937+ DppCtrl = intAttr.getInt () + DppCtrl::ROW_SHR0;
938+ }
939+ break ;
940+ case DPPPerm::row_ror:
941+ if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
942+ DppCtrl = intAttr.getInt () + DppCtrl::ROW_ROR0;
943+ }
944+ break ;
945+ case DPPPerm::wave_shl:
946+ DppCtrl = DppCtrl::WAVE_SHL1;
947+ break ;
948+ case DPPPerm::wave_shr:
949+ DppCtrl = DppCtrl::WAVE_SHR1;
950+ break ;
951+ case DPPPerm::wave_rol:
952+ DppCtrl = DppCtrl::WAVE_ROL1;
953+ break ;
954+ case DPPPerm::wave_ror:
955+ DppCtrl = DppCtrl::WAVE_ROR1;
956+ break ;
957+ case DPPPerm::row_mirror:
958+ DppCtrl = DppCtrl::ROW_MIRROR;
959+ break ;
960+ case DPPPerm::row_half_mirror:
961+ DppCtrl = DppCtrl::ROW_HALF_MIRROR;
962+ break ;
963+ case DPPPerm::row_bcast_15:
964+ DppCtrl = DppCtrl::BCAST15;
965+ break ;
966+ case DPPPerm::row_bcast_31:
967+ DppCtrl = DppCtrl::BCAST31;
968+ break ;
969+ }
970+
971+ // Check for row_mask, bank_mask, bound_ctrl if they exist and create
972+ // constants
973+ auto rowMask = DppOp->getAttrOfType <IntegerAttr>(" row_mask" ).getInt ();
974+ auto bankMask = DppOp->getAttrOfType <IntegerAttr>(" bank_mask" ).getInt ();
975+ bool boundCtrl = DppOp->getAttrOfType <BoolAttr>(" bound_ctrl" ).getValue ();
976+
977+ // create a ROCDL_DPPMovOp instruction with the appropriate attributes
978+ auto dppMovOp = rewriter.create <ROCDL::DPPUpdateOp>(
979+ loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
980+
981+ Value result = dppMovOp.getRes ();
982+ if (srcType.getIntOrFloatBitWidth () < 32 ) {
983+ result = rewriter.create <LLVM::TruncOp>(loc, llvmSrcIntType, result);
984+ if (!llvm::isa<IntegerType>(srcType)) {
985+ result = rewriter.create <LLVM::BitcastOp>(loc, srcType, result);
986+ }
987+ }
988+
989+ // We are replacing the AMDGPU_DPPOp instruction with the new
990+ // ROCDL_DPPMovOp instruction
991+ rewriter.replaceOp (DppOp, ValueRange (result));
992+ return success ();
993+ }
994+ };
995+
847996struct ConvertAMDGPUToROCDLPass
848997 : public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
849998 ConvertAMDGPUToROCDLPass () = default ;
@@ -895,9 +1044,10 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
8951044 ROCDL::RawPtrBufferAtomicUminOp>,
8961045 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
8971046 ROCDL::RawPtrBufferAtomicCmpSwap>,
898- LDSBarrierOpLowering, SchedBarrierOpLowering, MFMAOpLowering,
899- WMMAOpLowering, ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
900- PackedStochRoundFp8OpLowering>(converter, chipset);
1047+ AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
1048+ MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
1049+ PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter,
1050+ chipset);
9011051}
9021052
9031053std::unique_ptr<Pass> mlir::createConvertAMDGPUToROCDLPass () {
0 commit comments