@@ -828,6 +828,147 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
828828 return success ();
829829}
830830
831+ // Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
832+ // operation into the corresponding ROCDL instructions.
833+ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern <DPPOp> {
834+ AMDGPUDPPLowering (LLVMTypeConverter &converter, Chipset chipset)
835+ : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
836+ Chipset chipset;
837+
838+ LogicalResult
839+ matchAndRewrite (DPPOp DppOp, DPPOp::Adaptor adaptor,
840+ ConversionPatternRewriter &rewriter) const override {
841+
842+ // Convert the source operand to the corresponding LLVM type
843+ Location loc = DppOp.getLoc ();
844+ Value src = adaptor.getSrc ();
845+ Value old = adaptor.getOld ();
846+ Type srcType = src.getType ();
847+ Type oldType = old.getType ();
848+ auto llvmI32Type = typeConverter->convertType (rewriter.getI32Type ());
849+ auto llvmSrcIntType = typeConverter->convertType (
850+ rewriter.getIntegerType (srcType.getIntOrFloatBitWidth ()));
851+
852+ // If the source type is less or equal to i32 or f32, use bitcast to convert
853+ // it to i32.
854+ auto convertOperand = [&](Value operand, Type operandType) {
855+ if (llvm::isa<FloatType>(operandType)) {
856+ operand =
857+ rewriter.create <LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
858+ }
859+
860+ if (operandType.getIntOrFloatBitWidth () < 32 ) {
861+ auto llvmVecType = typeConverter->convertType (mlir::VectorType::get (
862+ 32 / operandType.getIntOrFloatBitWidth (), llvmSrcIntType));
863+ Value undefVec = rewriter.create <LLVM::UndefOp>(loc, llvmVecType);
864+ operand = rewriter.create <LLVM::InsertElementOp>(
865+ loc, undefVec, operand, createI32Constant (rewriter, loc, 0 ));
866+ operand = rewriter.create <LLVM::BitcastOp>(loc, llvmI32Type, operand);
867+ }
868+ return operand;
869+ };
870+
871+ src = convertOperand (src, srcType);
872+ old = convertOperand (old, oldType);
873+
874+ // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
875+ enum DppCtrl : unsigned {
876+ ROW_SHL0 = 0x100 ,
877+ ROW_SHR0 = 0x110 ,
878+ ROW_ROR0 = 0x120 ,
879+ WAVE_SHL1 = 0x130 ,
880+ WAVE_ROL1 = 0x134 ,
881+ WAVE_SHR1 = 0x138 ,
882+ WAVE_ROR1 = 0x13C ,
883+ ROW_MIRROR = 0x140 ,
884+ ROW_HALF_MIRROR = 0x141 ,
885+ BCAST15 = 0x142 ,
886+ BCAST31 = 0x143 ,
887+ };
888+
889+ auto kind = DppOp.getKind ();
890+ auto permArgument = DppOp.getPermArgument ();
891+ uint32_t DppCtrl = 0 ;
892+
893+ switch (kind) {
894+
895+ case DPPPerm::quad_perm:
896+ if (auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
897+ int32_t i = 0 ;
898+ for (auto elem : quadPermAttr.getAsRange <IntegerAttr>()) {
899+ uint32_t num = elem.getInt ();
900+ DppCtrl |= num << (i * 2 );
901+ i++;
902+ }
903+ }
904+ break ;
905+ case DPPPerm::row_shl:
906+ if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
907+ DppCtrl = intAttr.getInt () + DppCtrl::ROW_SHL0;
908+ }
909+ break ;
910+ case DPPPerm::row_shr:
911+ if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
912+ DppCtrl = intAttr.getInt () + DppCtrl::ROW_SHR0;
913+ }
914+ break ;
915+ case DPPPerm::row_ror:
916+ if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
917+ DppCtrl = intAttr.getInt () + DppCtrl::ROW_ROR0;
918+ }
919+ break ;
920+ case DPPPerm::wave_shl:
921+ DppCtrl = DppCtrl::WAVE_SHL1;
922+ break ;
923+ case DPPPerm::wave_shr:
924+ DppCtrl = DppCtrl::WAVE_SHR1;
925+ break ;
926+ case DPPPerm::wave_rol:
927+ DppCtrl = DppCtrl::WAVE_ROL1;
928+ break ;
929+ case DPPPerm::wave_ror:
930+ DppCtrl = DppCtrl::WAVE_ROR1;
931+ break ;
932+ case DPPPerm::row_mirror:
933+ DppCtrl = DppCtrl::ROW_MIRROR;
934+ break ;
935+ case DPPPerm::row_half_mirror:
936+ DppCtrl = DppCtrl::ROW_HALF_MIRROR;
937+ break ;
938+ case DPPPerm::row_bcast_15:
939+ DppCtrl = DppCtrl::BCAST15;
940+ break ;
941+ case DPPPerm::row_bcast_31:
942+ DppCtrl = DppCtrl::BCAST31;
943+ break ;
944+ }
945+
946+ // Check for row_mask, bank_mask, bound_ctrl if they exist and create
947+ // constants
948+ auto rowMask = DppOp->getAttrOfType <IntegerAttr>(" row_mask" ).getInt ();
949+ auto bankMask = DppOp->getAttrOfType <IntegerAttr>(" bank_mask" ).getInt ();
950+ bool boundCtrl = DppOp->getAttrOfType <BoolAttr>(" bound_ctrl" ).getValue ();
951+
952+ // create a ROCDL_DPPMovOp instruction with the appropriate attributes
953+ auto dppMovOp = rewriter.create <ROCDL::DPPUpdateOp>(
954+ loc, llvmI32Type, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
955+
956+ Value result = dppMovOp.getRes ();
957+ if (srcType.getIntOrFloatBitWidth () < 32 ) {
958+ result = rewriter.create <LLVM::TruncOp>(loc, llvmSrcIntType, result);
959+ }
960+
961+ if (!llvm::isa<IntegerType>(srcType)) {
962+ result = rewriter.create <LLVM::BitcastOp>(loc, srcType, result);
963+ }
964+
965+ // We are replacing the AMDGPU_DPPOp instruction with the new
966+ // ROCDL_DPPMovOp instruction
967+ rewriter.replaceOp (DppOp, ValueRange (result));
968+ return success ();
969+ }
970+ };
971+
831972struct ConvertAMDGPUToROCDLPass
832973 : public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
833974 ConvertAMDGPUToROCDLPass () = default ;
@@ -879,8 +1020,8 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
8791020 ROCDL::RawPtrBufferAtomicUminOp>,
8801021 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
8811022 ROCDL::RawPtrBufferAtomicCmpSwap>,
882- LDSBarrierOpLowering, MFMAOpLowering, WMMAOpLowering ,
883- ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
1023+ AMDGPUDPPLowering, LDSBarrierOpLowering, MFMAOpLowering ,
1024+ WMMAOpLowering, ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
8841025 PackedStochRoundFp8OpLowering>(converter, chipset);
8851026}
8861027
0 commit comments