diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 030160821bd82..88e82ce48959b 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -1737,7 +1737,8 @@ def LLVM_ConstantOp // Atomic operations. // -def LLVM_AtomicRMWType : AnyTypeOf<[LLVM_AnyFloat, LLVM_AnyPointer, AnySignlessInteger]>; +def LLVM_AtomicRMWType + : AnyTypeOf<[LLVM_AnyFloat, LLVM_AnyPointer, AnySignlessInteger, LLVM_AnyFixedVector]>; def LLVM_AtomicRMWOp : LLVM_MemAccessOpBase<"atomicrmw", [ TypesMatchWith<"result #0 and operand #1 have the same type", diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 0561c364c7d59..fb7024a14f8d4 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -3010,8 +3010,16 @@ LogicalResult AtomicRMWOp::verify() { auto valType = getVal().getType(); if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub || getBinOp() == AtomicBinOp::fmin || getBinOp() == AtomicBinOp::fmax) { - if (!mlir::LLVM::isCompatibleFloatingPointType(valType)) + if (isCompatibleVectorType(valType)) { + if (isScalableVectorType(valType)) + return emitOpError("expected LLVM IR fixed vector type"); + Type elemType = getVectorElementType(valType); + if (!isCompatibleFloatingPointType(elemType)) + return emitOpError( + "expected LLVM IR floating point type for vector element"); + } else if (!isCompatibleFloatingPointType(valType)) { return emitOpError("expected LLVM IR floating point type"); + } } else if (getBinOp() == AtomicBinOp::xchg) { DataLayout dataLayout = DataLayout::closest(*this); if (!isTypeCompatibleWithAtomicOp(valType, dataLayout)) diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index 9388d7ef24936..5677d7ff41202 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -643,6 +643,21 @@ func.func @atomicrmw_expected_float(%i32_ptr : !llvm.ptr, %i32 : i32) { // ----- +func.func @atomicrmw_scalable_vector(%ptr : !llvm.ptr, %f32_vec : vector<[2]xf32>) { + // expected-error@+1 {{'val' must be floating point LLVM type or LLVM pointer type or signless integer or LLVM dialect-compatible fixed-length vector type}} + %0 = llvm.atomicrmw fadd %ptr, %f32_vec unordered : !llvm.ptr, vector<[2]xf32> + llvm.return +} +// ----- + +func.func @atomicrmw_vector_expected_float(%ptr : !llvm.ptr, %i32_vec : vector<3xi32>) { + // expected-error@+1 {{expected LLVM IR floating point type for vector element}} + %0 = llvm.atomicrmw fadd %ptr, %i32_vec unordered : !llvm.ptr, vector<3xi32> + llvm.return +} + +// ----- + func.func @atomicrmw_unexpected_xchg_type(%i1_ptr : !llvm.ptr, %i1 : i1) { // expected-error@+1 {{unexpected LLVM IR type for 'xchg' bin_op}} %0 = llvm.atomicrmw xchg %i1_ptr, %i1 unordered : !llvm.ptr, i1 diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir index 62f1de2b7fe7d..3062cdc38c0ab 100644 --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -420,11 +420,13 @@ func.func @atomic_store(%val : f32, %large_val : i256, %ptr : !llvm.ptr) { } // CHECK-LABEL: @atomicrmw -func.func @atomicrmw(%ptr : !llvm.ptr, %val : f32) { +func.func @atomicrmw(%ptr : !llvm.ptr, %f32 : f32, %f16_vec : vector<2xf16>) { // CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} monotonic : !llvm.ptr, f32 - %0 = llvm.atomicrmw fadd %ptr, %val monotonic : !llvm.ptr, f32 + %0 = llvm.atomicrmw fadd %ptr, %f32 monotonic : !llvm.ptr, f32 // CHECK: llvm.atomicrmw volatile fsub %{{.*}}, %{{.*}} syncscope("singlethread") monotonic {alignment = 16 : i64} : !llvm.ptr, f32 - %1 = llvm.atomicrmw volatile fsub %ptr, %val syncscope("singlethread") monotonic {alignment = 16 : i64} : !llvm.ptr, f32 + %1 = llvm.atomicrmw volatile fsub %ptr, %f32 syncscope("singlethread") monotonic {alignment = 16 : i64} : !llvm.ptr, f32 + // CHECK: llvm.atomicrmw fmin %{{.*}}, %{{.*}} monotonic : !llvm.ptr, vector<2xf16> + %2 = llvm.atomicrmw fmin %ptr, %f16_vec monotonic : !llvm.ptr, vector<2xf16> llvm.return } diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir index 007284d0ca443..327c9f05f4c72 100644 --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -1496,7 +1496,8 @@ llvm.func @elements_constant_3d_array() -> !llvm.array<2 x array<2 x array<2 x i // CHECK-LABEL: @atomicrmw llvm.func @atomicrmw( %f32_ptr : !llvm.ptr, %f32 : f32, - %i32_ptr : !llvm.ptr, %i32 : i32) { + %i32_ptr : !llvm.ptr, %i32 : i32, + %f16_vec_ptr : !llvm.ptr, %f16_vec : vector<2xf16>) { // CHECK: atomicrmw fadd ptr %{{.*}}, float %{{.*}} monotonic %0 = llvm.atomicrmw fadd %f32_ptr, %f32 monotonic : !llvm.ptr, f32 // CHECK: atomicrmw fsub ptr %{{.*}}, float %{{.*}} monotonic @@ -1535,11 +1536,19 @@ llvm.func @atomicrmw( %17 = llvm.atomicrmw usub_cond %i32_ptr, %i32 monotonic : !llvm.ptr, i32 // CHECK: atomicrmw usub_sat ptr %{{.*}}, i32 %{{.*}} monotonic %18 = llvm.atomicrmw usub_sat %i32_ptr, %i32 monotonic : !llvm.ptr, i32 + // CHECK: atomicrmw fadd ptr %{{.*}}, <2 x half> %{{.*}} monotonic + %19 = llvm.atomicrmw fadd %f16_vec_ptr, %f16_vec monotonic : !llvm.ptr, vector<2xf16> + // CHECK: atomicrmw fsub ptr %{{.*}}, <2 x half> %{{.*}} monotonic + %20 = llvm.atomicrmw fsub %f16_vec_ptr, %f16_vec monotonic : !llvm.ptr, vector<2xf16> + // CHECK: atomicrmw fmax ptr %{{.*}}, <2 x half> %{{.*}} monotonic + %21 = llvm.atomicrmw fmax %f16_vec_ptr, %f16_vec monotonic : !llvm.ptr, vector<2xf16> + // CHECK: atomicrmw fmin ptr %{{.*}}, <2 x half> %{{.*}} monotonic + %22 = llvm.atomicrmw fmin %f16_vec_ptr, %f16_vec monotonic : !llvm.ptr, vector<2xf16> // CHECK: atomicrmw volatile // CHECK-SAME: syncscope("singlethread") // CHECK-SAME: align 8 - %19 = llvm.atomicrmw volatile udec_wrap %i32_ptr, %i32 syncscope("singlethread") monotonic {alignment = 8 : i64} : !llvm.ptr, i32 + %23 = llvm.atomicrmw volatile udec_wrap %i32_ptr, %i32 syncscope("singlethread") monotonic {alignment = 8 : i64} : !llvm.ptr, i32 llvm.return }