diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 40d874dc99dd9..8e333def33869 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -1626,7 +1626,8 @@ def Vector_LoadOp : Vector_Op<"load"> { let arguments = (ins Arg:$base, - Variadic:$indices); + Variadic:$indices, + DefaultValuedOptionalAttr:$nontemporal); let results = (outs AnyVectorOfAnyRank:$result); let extraClassDeclaration = [{ @@ -1710,7 +1711,8 @@ def Vector_StoreOp : Vector_Op<"store"> { AnyVectorOfAnyRank:$valueToStore, Arg:$base, - Variadic:$indices + Variadic:$indices, + DefaultValuedOptionalAttr:$nontemporal ); let extraClassDeclaration = [{ diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index a24fb6f839153..b66b55ae8d57f 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -192,7 +192,9 @@ static void replaceLoadOrStoreOp(vector::LoadOp loadOp, vector::LoadOpAdaptor adaptor, VectorType vectorTy, Value ptr, unsigned align, ConversionPatternRewriter &rewriter) { - rewriter.replaceOpWithNewOp(loadOp, vectorTy, ptr, align); + rewriter.replaceOpWithNewOp(loadOp, vectorTy, ptr, align, + /*volatile_=*/false, + loadOp.getNontemporal()); } static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp, @@ -208,7 +210,8 @@ static void replaceLoadOrStoreOp(vector::StoreOp storeOp, VectorType vectorTy, Value ptr, unsigned align, ConversionPatternRewriter &rewriter) { rewriter.replaceOpWithNewOp(storeOp, adaptor.getValueToStore(), - ptr, align); + ptr, align, /*volatile_=*/false, + storeOp.getNontemporal()); } static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp, diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 7ea0197bdecb3..0ae191dd58aea 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -2023,6 +2023,20 @@ func.func @vector_load_op(%memref : memref<200x100xf32>, %i : index, %j : index) // ----- +func.func @vector_load_op_nontemporal(%memref : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> { + %0 = vector.load %memref[%i, %j] {nontemporal = true} : memref<200x100xf32>, vector<8xf32> + return %0 : vector<8xf32> +} + +// CHECK-LABEL: func @vector_load_op_nontemporal +// CHECK: %[[c100:.*]] = llvm.mlir.constant(100 : index) : i64 +// CHECK: %[[mul:.*]] = llvm.mul %{{.*}}, %[[c100]] : i64 +// CHECK: %[[add:.*]] = llvm.add %[[mul]], %{{.*}} : i64 +// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}}[%[[add]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 +// CHECK: llvm.load %[[gep]] {alignment = 4 : i64, nontemporal} : !llvm.ptr -> vector<8xf32> + +// ----- + func.func @vector_load_op_index(%memref : memref<200x100xindex>, %i : index, %j : index) -> vector<8xindex> { %0 = vector.load %memref[%i, %j] : memref<200x100xindex>, vector<8xindex> return %0 : vector<8xindex> @@ -2049,6 +2063,21 @@ func.func @vector_store_op(%memref : memref<200x100xf32>, %i : index, %j : index // ----- +func.func @vector_store_op_nontemporal(%memref : memref<200x100xf32>, %i : index, %j : index) { + %val = arith.constant dense<11.0> : vector<4xf32> + vector.store %val, %memref[%i, %j] {nontemporal = true} : memref<200x100xf32>, vector<4xf32> + return +} + +// CHECK-LABEL: func @vector_store_op_nontemporal +// CHECK: %[[c100:.*]] = llvm.mlir.constant(100 : index) : i64 +// CHECK: %[[mul:.*]] = llvm.mul %{{.*}}, %[[c100]] : i64 +// CHECK: %[[add:.*]] = llvm.add %[[mul]], %{{.*}} : i64 +// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}}[%[[add]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 +// CHECK: llvm.store %{{.*}}, %[[gep]] {alignment = 4 : i64, nontemporal} : vector<4xf32>, !llvm.ptr + +// ----- + func.func @vector_store_op_index(%memref : memref<200x100xindex>, %i : index, %j : index) { %val = arith.constant dense<11> : vector<4xindex> vector.store %val, %memref[%i, %j] : memref<200x100xindex>, vector<4xindex>