Skip to content

Commit 9aa5d5a

Browse files
[MLIR] Add sincos intrinsic to LLVM dialect (#160561)
Adds llvm.intr.sincos operation using LLVM_TwoResultIntrOp in the mold of the frexp intrinsic.
1 parent 832a342 commit 9aa5d5a

File tree

4 files changed

+56
-0
lines changed

4 files changed

+56
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,15 @@ def LLVM_UMinOp : LLVM_BinarySameArgsIntrOpI<"umin">;
184184
def LLVM_SinOp : LLVM_UnaryIntrOpF<"sin">;
185185
def LLVM_CosOp : LLVM_UnaryIntrOpF<"cos">;
186186
def LLVM_TanOp : LLVM_UnaryIntrOpF<"tan">;
187+
def LLVM_SincosOp : LLVM_TwoResultIntrOp<"sincos", [], [0],
188+
[Pure], /*requiresFastmath=*/1> {
189+
let arguments =
190+
(ins LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$val,
191+
DefaultValuedAttr<LLVM_FastmathFlagsAttr, "{}">:$fastmathFlags);
192+
let assemblyFormat = "`(` operands `)` attr-dict `:` "
193+
"functional-type(operands, results)";
194+
let hasVerifier = 1;
195+
}
187196

188197
def LLVM_ASinOp : LLVM_UnaryIntrOpF<"asin">;
189198
def LLVM_ACosOp : LLVM_UnaryIntrOpF<"acos">;

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4085,6 +4085,25 @@ printIndirectBrOpSucessors(OpAsmPrinter &p, IndirectBrOp op, Type flagType,
40854085
p << "]";
40864086
}
40874087

4088+
//===----------------------------------------------------------------------===//
4089+
// SincosOp (intrinsic)
4090+
//===----------------------------------------------------------------------===//
4091+
4092+
LogicalResult LLVM::SincosOp::verify() {
4093+
auto operandType = getOperand().getType();
4094+
auto resultType = getResult().getType();
4095+
auto resultStructType =
4096+
mlir::dyn_cast<mlir::LLVM::LLVMStructType>(resultType);
4097+
if (!resultStructType || resultStructType.getBody().size() != 2 ||
4098+
resultStructType.getBody()[0] != operandType ||
4099+
resultStructType.getBody()[1] != operandType) {
4100+
return emitOpError("expected result type to be an homogeneous struct with "
4101+
"two elements matching the operand type, but got ")
4102+
<< resultType;
4103+
}
4104+
return success();
4105+
}
4106+
40884107
//===----------------------------------------------------------------------===//
40894108
// AssumeOp (intrinsic)
40904109
//===----------------------------------------------------------------------===//

mlir/test/Dialect/LLVMIR/invalid.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2014,3 +2014,24 @@ llvm.mlir.alias external @alias_resolver : !llvm.ptr {
20142014
}
20152015
// expected-error@+1 {{'llvm.mlir.ifunc' op must have a function resolver}}
20162016
llvm.mlir.ifunc external @foo : !llvm.func<void (ptr, i32)>, !llvm.ptr @alias_resolver {dso_local}
2017+
2018+
// -----
2019+
2020+
llvm.func @invalid_sincos_nonhomogeneous_return_type(%f: f32) -> () {
2021+
// expected-error@+1 {{op expected result type to be an homogeneous struct with two elements matching the operand type}}
2022+
llvm.intr.sincos(%f) : (f32) -> !llvm.struct<(f32, f64)>
2023+
}
2024+
2025+
// -----
2026+
2027+
llvm.func @invalid_sincos_non_struct_return_type(%f: f32) -> () {
2028+
// expected-error@+1 {{op expected result type to be an homogeneous struct with two elements matching the operand type}}
2029+
llvm.intr.sincos(%f) : (f32) -> f32
2030+
}
2031+
2032+
// -----
2033+
2034+
llvm.func @invalid_sincos_gt_2_element_struct_return_type(%f: f32) -> () {
2035+
// expected-error@+1 {{op expected result type to be an homogeneous struct with two elements matching the operand type}}
2036+
llvm.intr.sincos(%f) : (f32) -> !llvm.struct<(f32, f32, f32)>
2037+
}

mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,11 @@ llvm.func @trig_test(%arg0: f32, %arg1: vector<8xf32>) {
146146
llvm.intr.tan(%arg0) : (f32) -> f32
147147
// CHECK: call <8 x float> @llvm.tan.v8f32
148148
llvm.intr.tan(%arg1) : (vector<8xf32>) -> vector<8xf32>
149+
150+
// CHECK: call { float, float } @llvm.sincos.f32
151+
llvm.intr.sincos(%arg0) : (f32) -> !llvm.struct<(f32, f32)>
152+
// CHECK: call { <8 x float>, <8 x float> } @llvm.sincos.v8f32
153+
llvm.intr.sincos(%arg1) : (vector<8xf32>) -> !llvm.struct<(vector<8xf32>, vector<8xf32>)>
149154
llvm.return
150155
}
151156

@@ -1302,6 +1307,8 @@ llvm.func @experimental_constrained_fpext(%s: f32, %v: vector<4xf32>) {
13021307
// CHECK-DAG: declare <8 x float> @llvm.ceil.v8f32(<8 x float>) #0
13031308
// CHECK-DAG: declare float @llvm.cos.f32(float)
13041309
// CHECK-DAG: declare <8 x float> @llvm.cos.v8f32(<8 x float>) #0
1310+
// CHECK-DAG: declare { float, float } @llvm.sincos.f32(float)
1311+
// CHECK-DAG: declare { <8 x float>, <8 x float> } @llvm.sincos.v8f32(<8 x float>) #0
13051312
// CHECK-DAG: declare float @llvm.copysign.f32(float, float)
13061313
// CHECK-DAG: declare float @llvm.rint.f32(float)
13071314
// CHECK-DAG: declare double @llvm.rint.f64(double)

0 commit comments

Comments
 (0)