diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp index d2a36d4bdcc86..1ae8bbeef8dd8 100644 --- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp @@ -8974,10 +8974,12 @@ IntrinsicLibrary::genSyncThreadsAnd(mlir::Type resultType, llvm::ArrayRef args) { constexpr llvm::StringLiteral funcName = "llvm.nvvm.barrier0.and"; mlir::MLIRContext *context = builder.getContext(); + mlir::Type i32 = builder.getI32Type(); mlir::FunctionType ftype = - mlir::FunctionType::get(context, {resultType}, {args[0].getType()}); + mlir::FunctionType::get(context, {resultType}, {i32}); auto funcOp = builder.createFunction(loc, funcName, ftype); - return fir::CallOp::create(builder, loc, funcOp, args).getResult(0); + mlir::Value arg = builder.createConvert(loc, i32, args[0]); + return fir::CallOp::create(builder, loc, funcOp, {arg}).getResult(0); } // SYNCTHREADS_COUNT @@ -8986,10 +8988,12 @@ IntrinsicLibrary::genSyncThreadsCount(mlir::Type resultType, llvm::ArrayRef args) { constexpr llvm::StringLiteral funcName = "llvm.nvvm.barrier0.popc"; mlir::MLIRContext *context = builder.getContext(); + mlir::Type i32 = builder.getI32Type(); mlir::FunctionType ftype = - mlir::FunctionType::get(context, {resultType}, {args[0].getType()}); + mlir::FunctionType::get(context, {resultType}, {i32}); auto funcOp = builder.createFunction(loc, funcName, ftype); - return fir::CallOp::create(builder, loc, funcOp, args).getResult(0); + mlir::Value arg = builder.createConvert(loc, i32, args[0]); + return fir::CallOp::create(builder, loc, funcOp, {arg}).getResult(0); } // SYNCTHREADS_OR @@ -8998,10 +9002,12 @@ IntrinsicLibrary::genSyncThreadsOr(mlir::Type resultType, llvm::ArrayRef args) { constexpr llvm::StringLiteral funcName = "llvm.nvvm.barrier0.or"; mlir::MLIRContext *context = builder.getContext(); + mlir::Type i32 = builder.getI32Type(); mlir::FunctionType ftype = - mlir::FunctionType::get(context, {resultType}, {args[0].getType()}); + mlir::FunctionType::get(context, {resultType}, {i32}); auto funcOp = builder.createFunction(loc, funcName, ftype); - return fir::CallOp::create(builder, loc, funcOp, args).getResult(0); + mlir::Value arg = builder.createConvert(loc, i32, args[0]); + return fir::CallOp::create(builder, loc, funcOp, {arg}).getResult(0); } // SYNCWARP diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf index 55bb587dcf681..7d6caf58d71b3 100644 --- a/flang/test/Lower/CUDA/cuda-device-proc.cuf +++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf @@ -110,17 +110,20 @@ end ! CHECK: %[[A:.*]] = fir.load %{{.*}} : !fir.ref ! CHECK: %[[B:.*]] = fir.load %{{.*}} : !fir.ref ! CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[A]], %[[B]] : i32 -! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.and(%[[CMP]]) +! CHECK: %[[CONV:.*]] = fir.convert %[[CMP]] : (i1) -> i32 +! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.and(%[[CONV]]) ! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.popc(%c1{{.*}}) fastmath : (i32) -> i32 ! CHECK: %[[A:.*]] = fir.load %{{.*}} : !fir.ref ! CHECK: %[[B:.*]] = fir.load %{{.*}} : !fir.ref ! CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[A]], %[[B]] : i32 -! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.popc(%[[CMP]]) fastmath : (i1) -> i32 +! CHECK: %[[CONV:.*]] = fir.convert %[[CMP]] : (i1) -> i32 +! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.popc(%[[CONV]]) fastmath : (i32) -> i32 ! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.or(%c1{{.*}}) fastmath : (i32) -> i32 ! CHECK: %[[A:.*]] = fir.load %{{.*}} : !fir.ref ! CHECK: %[[B:.*]] = fir.load %{{.*}} : !fir.ref ! CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[A]], %[[B]] : i32 -! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.or(%[[CMP]]) fastmath : (i1) -> i32 +! CHECK: %[[CONV:.*]] = fir.convert %[[CMP]] : (i1) -> i32 +! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.or(%[[CONV]]) fastmath : (i32) -> i32 ! CHECK: %{{.*}} = llvm.atomicrmw add %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i32 ! CHECK: %{{.*}} = llvm.atomicrmw add %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i64 ! CHECK: %{{.*}} = llvm.atomicrmw fadd %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, f32