diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index 8a1ef94c853a5..64db4448bc2f2 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -253,8 +253,8 @@ def AMDGPU_RawBufferAtomicCmpswapOp : // Raw buffer atomic floating point add def AMDGPU_RawBufferAtomicFaddOp : AMDGPU_Op<"raw_buffer_atomic_fadd", [AllElementTypesMatch<["value", "memref"]>, - AttrSizedOperandSegments]>, - Arguments<(ins F32:$value, + AttrSizedOperandSegments]>, + Arguments<(ins AnyTypeOf<[F32, VectorOfLengthAndType<[2], [F16]>]>:$value, Arg:$memref, Variadic:$indices, DefaultValuedAttr:$boundsCheck, diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 96b433294d258..fc5dd7c560212 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -115,15 +115,18 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern { rewriter.getIntegerType(floatType.getWidth())); } if (auto dataVector = dyn_cast(wantedDataType)) { + uint32_t vecLen = dataVector.getNumElements(); uint32_t elemBits = dataVector.getElementTypeBitWidth(); - uint32_t totalBits = elemBits * dataVector.getNumElements(); + uint32_t totalBits = elemBits * vecLen; + bool usePackedFp16 = + dyn_cast_or_null(*gpuOp) && vecLen == 2; if (totalBits > maxVectorOpWidth) return gpuOp.emitOpError( "Total width of loads or stores must be no more than " + Twine(maxVectorOpWidth) + " bits, but we call for " + Twine(totalBits) + " bits. This should've been caught in validation"); - if (elemBits < 32) { + else if (!usePackedFp16 && elemBits < 32) { if (totalBits > 32) { if (totalBits % 32 != 0) return gpuOp.emitOpError("Load or store of more than 32-bits that " diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir index 717667c22af80..cc51a8c40942f 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir @@ -151,6 +151,17 @@ func.func @gpu_gcn_raw_buffer_atomic_fadd_f32(%value: f32, %buf: memref<64xf32>, func.return } +// CHECK-LABEL: func @gpu_gcn_raw_buffer_atomic_fadd_v2f16 +func.func @gpu_gcn_raw_buffer_atomic_fadd_v2f16(%value: vector<2xf16>, %buf: memref<64xf16>, %idx: i32) { + // CHECK: %[[numRecords:.*]] = llvm.mlir.constant(128 : i32) + // GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32) + // RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32) + // CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %[[flags]] + // CHECK: rocdl.raw.ptr.buffer.atomic.fadd %{{.*}}, %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : vector<2xf16> + amdgpu.raw_buffer_atomic_fadd {boundsCheck = true} %value -> %buf[%idx] : vector<2xf16> -> memref<64xf16>, i32 + func.return +} + // CHECK-LABEL: func @gpu_gcn_raw_buffer_atomic_fmax_f32 func.func @gpu_gcn_raw_buffer_atomic_fmax_f32(%value: f32, %buf: memref<64xf32>, %idx: i32) { // CHECK: %[[numRecords:.*]] = llvm.mlir.constant(256 : i32)