@@ -829,6 +829,144 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
829829 return success ();
830830}
831831
832+ // Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
833+ // operation into the corresponding ROCDL instructions.
834+ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern <DPPOp> {
835+ AMDGPUDPPLowering (LLVMTypeConverter &converter, Chipset chipset)
836+ : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
837+ Chipset chipset;
838+
839+ LogicalResult
840+ matchAndRewrite (DPPOp DppOp, DPPOp::Adaptor adaptor,
841+ ConversionPatternRewriter &rewriter) const override {
842+
843+ // Convert the source operand to the corresponding LLVM type
844+ Location loc = DppOp.getLoc ();
845+ Value src = adaptor.getSrc ();
846+ Type srcType = src.getType ();
847+ auto llvmI32Type = typeConverter->convertType (rewriter.getI32Type ());
848+ auto llvmSrcIntType = typeConverter->convertType (
849+ rewriter.getIntegerType (srcType.getIntOrFloatBitWidth ()));
850+
851+ // If the source type is less or equal to i32 or f32, use bitcast to convert
852+ // it to i32.
853+ if (llvm::isa<FloatType>(srcType)) {
854+ src = rewriter.create <LLVM::BitcastOp>(loc, llvmSrcIntType, src);
855+ }
856+
857+ if (srcType.getIntOrFloatBitWidth () < 32 ) {
858+ auto llvmVecType = typeConverter->convertType (mlir::VectorType::get (
859+ 32 / srcType.getIntOrFloatBitWidth (), llvmSrcIntType));
860+ Value undefVec = rewriter.create <LLVM::UndefOp>(loc, llvmVecType);
861+ src = rewriter.create <LLVM::InsertElementOp>(
862+ loc, undefVec, src, createI32Constant (rewriter, loc, 0 ));
863+ src = rewriter.create <LLVM::BitcastOp>(loc, llvmI32Type, src);
864+ }
865+
866+ // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
867+ enum DppCtrl : unsigned {
868+ ROW_SHL0 = 0x100 ,
869+ ROW_SHR0 = 0x110 ,
870+ ROW_ROR0 = 0x120 ,
871+ WAVE_SHL1 = 0x130 ,
872+ WAVE_ROL1 = 0x134 ,
873+ WAVE_SHR1 = 0x138 ,
874+ WAVE_ROR1 = 0x13C ,
875+ ROW_MIRROR = 0x140 ,
876+ ROW_HALF_MIRROR = 0x141 ,
877+ BCAST15 = 0x142 ,
878+ BCAST31 = 0x143 ,
879+ };
880+
881+ auto kind = DppOp.getKind ();
882+ auto permArgument = DppOp.getPermArgument ();
883+ uint32_t DppCtrl = 0 ;
884+
885+ switch (kind) {
886+
887+ case DPPPerm::quad_perm:
888+ if (auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
889+ int32_t i = 0 ;
890+ for (auto elem : quadPermAttr.getAsRange <IntegerAttr>()) {
891+ uint32_t num = elem.getInt ();
892+ DppCtrl |= num << (i * 2 );
893+ i++;
894+ }
895+ }
896+ break ;
897+ case DPPPerm::row_shl:
898+ if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
899+ DppCtrl = intAttr.getInt () + DppCtrl::ROW_SHL0;
900+ }
901+ break ;
902+ case DPPPerm::row_shr:
903+ if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
904+ DppCtrl = intAttr.getInt () + DppCtrl::ROW_SHR0;
905+ }
906+ break ;
907+ case DPPPerm::row_ror:
908+ if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
909+ DppCtrl = intAttr.getInt () + DppCtrl::ROW_ROR0;
910+ }
911+ break ;
912+ case DPPPerm::wave_shl:
913+ DppCtrl = DppCtrl::WAVE_SHL1;
914+ break ;
915+ case DPPPerm::wave_shr:
916+ DppCtrl = DppCtrl::WAVE_SHR1;
917+ break ;
918+ case DPPPerm::wave_rol:
919+ DppCtrl = DppCtrl::WAVE_ROL1;
920+ break ;
921+ case DPPPerm::wave_ror:
922+ DppCtrl = DppCtrl::WAVE_ROR1;
923+ break ;
924+ case DPPPerm::row_mirror:
925+ DppCtrl = DppCtrl::ROW_MIRROR;
926+ break ;
927+ case DPPPerm::row_half_mirror:
928+ DppCtrl = DppCtrl::ROW_HALF_MIRROR;
929+ break ;
930+ case DPPPerm::row_bcast_15:
931+ DppCtrl = DppCtrl::BCAST15;
932+ break ;
933+ case DPPPerm::row_bcast_31:
934+ DppCtrl = DppCtrl::BCAST31;
935+ break ;
936+ }
937+
938+ // Check for row_mask, bank_mask, bound_ctrl if they exist and create
939+ // constants
940+ auto rowMask = DppOp->getAttrOfType <IntegerAttr>(" row_mask" )
941+ .dyn_cast <IntegerAttr>()
942+ .getInt ();
943+ auto bankMask = DppOp->getAttrOfType <IntegerAttr>(" bank_mask" )
944+ .dyn_cast <IntegerAttr>()
945+ .getInt ();
946+ bool boundCtrl = DppOp->getAttrOfType <IntegerAttr>(" bound_ctrl" )
947+ .dyn_cast <BoolAttr>()
948+ .getValue ();
949+
950+ // create a ROCDL_DPPMovOp instruction with the appropriate attributes
951+ auto dppMovOp = rewriter.create <ROCDL::DPPMovOp>(
952+ loc, llvmI32Type, src, DppCtrl, rowMask, bankMask, boundCtrl);
953+
954+ Value result = dppMovOp.getRes ();
955+ if (srcType.getIntOrFloatBitWidth () < 32 ) {
956+ result = rewriter.create <LLVM::TruncOp>(loc, llvmSrcIntType, result);
957+ }
958+
959+ if (!llvm::isa<IntegerType>(srcType)) {
960+ result = rewriter.create <LLVM::BitcastOp>(loc, srcType, result);
961+ }
962+
963+ // We are replacing the AMDGPU_DPPOp instruction with the new
964+ // ROCDL_DPPMovOp instruction
965+ rewriter.replaceOp (DppOp, ValueRange (result));
966+ return success ();
967+ }
968+ };
969+
832970struct ConvertAMDGPUToROCDLPass
833971 : public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
834972 ConvertAMDGPUToROCDLPass () = default ;
@@ -880,8 +1018,8 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
8801018 ROCDL::RawPtrBufferAtomicUminOp>,
8811019 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
8821020 ROCDL::RawPtrBufferAtomicCmpSwap>,
883- LDSBarrierOpLowering, MFMAOpLowering, WMMAOpLowering ,
884- ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
1021+ AMDGPUDPPLowering, LDSBarrierOpLowering, MFMAOpLowering ,
1022+ WMMAOpLowering, ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
8851023 PackedStochRoundFp8OpLowering>(converter, chipset);
8861024}
8871025
0 commit comments