Skip to content

Commit 226ea65

Browse files
committed
implement select intrinsic
1 parent e454d31 commit 226ea65

File tree

4 files changed

+155
-0
lines changed

4 files changed

+155
-0
lines changed

clang/include/clang/Basic/Builtins.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4751,6 +4751,12 @@ def HLSLSaturate : LangBuiltin<"HLSL_LANG"> {
47514751
let Prototype = "void(...)";
47524752
}
47534753

4754+
def HLSLSelect : LangBuiltin<"HLSL_LANG"> {
4755+
let Spellings = ["__builtin_hlsl_select"];
4756+
let Attributes = [NoThrow, Const];
4757+
let Prototype = "void(...)";
4758+
}
4759+
47544760
// Builtins for XRay.
47554761
def XRayCustomEvent : Builtin {
47564762
let Spellings = ["__xray_customevent"];

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18695,6 +18695,47 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
1869518695
CGM.getHLSLRuntime().getSaturateIntrinsic(), ArrayRef<Value *>{Op0},
1869618696
nullptr, "hlsl.saturate");
1869718697
}
18698+
case Builtin::BI__builtin_hlsl_select: {
18699+
Value *OpCond = EmitScalarExpr(E->getArg(0));
18700+
Value *OpTrue = EmitScalarExpr(E->getArg(1));
18701+
Value *OpFalse = EmitScalarExpr(E->getArg(2));
18702+
llvm::Type *TCond = OpCond->getType();
18703+
18704+
// if cond is a bool emit a select instruction
18705+
if (TCond->isIntegerTy(1))
18706+
return Builder.CreateSelect(OpCond, OpTrue, OpFalse);
18707+
18708+
// if cond is a vector of bools lower to a shufflevector
18709+
// todo check if that true and false are vectors
18710+
// todo check that the size of true and false and cond are the same
18711+
if (TCond->isVectorTy() &&
18712+
E->getArg(0)->getType()->getAs<VectorType>()->isBooleanType()) {
18713+
assert(OpTrue->getType()->isVectorTy() && OpFalse->getType()->isVectorTy() &&
18714+
"Select's second and third operands must be vectors if first operand is a vector.");
18715+
18716+
auto *VecTyTrue = E->getArg(1)->getType()->getAs<VectorType>();
18717+
auto *VecTyFalse = E->getArg(2)->getType()->getAs<VectorType>();
18718+
18719+
assert(VecTyTrue->getElementType() == VecTyFalse->getElementType() &&
18720+
"Select's second and third vectors need the same element types.");
18721+
18722+
const unsigned N = VecTyTrue->getNumElements();
18723+
assert(N == VecTyFalse->getNumElements() &&
18724+
N == E->getArg(0)->getType()->getAs<VectorType>()->getNumElements() &&
18725+
"Select requires vectors to be of the same size.");
18726+
18727+
llvm::SmallVector<Value *> Mask;
18728+
for (unsigned I = 0; I < N; I++) {
18729+
Value *Index = ConstantInt::get(IntTy, I);
18730+
Value *IndexBool = Builder.CreateExtractElement(OpCond, Index);
18731+
Mask.push_back(Builder.CreateSelect(IndexBool, Index, ConstantInt::get(IntTy, I + N)));
18732+
}
18733+
18734+
return Builder.CreateShuffleVector(OpTrue, OpFalse, BuildVector(Mask));
18735+
}
18736+
18737+
llvm_unreachable("Select requires a bool or vector of bools as its first operand.");
18738+
}
1869818739
case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
1869918740
return EmitRuntimeCall(CGM.CreateRuntimeFunction(
1870018741
llvm::FunctionType::get(IntTy, {}, false), "__hlsl_wave_get_lane_index",

clang/lib/Headers/hlsl/hlsl_intrinsics.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1603,6 +1603,30 @@ double3 saturate(double3);
16031603
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_saturate)
16041604
double4 saturate(double4);
16051605

1606+
//===----------------------------------------------------------------------===//
1607+
// select builtins
1608+
//===----------------------------------------------------------------------===//
1609+
1610+
/// \fn T select(bool Cond, T TrueVal, T FalseVal)
1611+
/// \brief ternary operator.
1612+
/// \param Cond The Condition input value.
1613+
/// \param TrueVal The Value returned if Cond is true.
1614+
/// \param FalseVal The Value returned if Cond is false.
1615+
1616+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
1617+
template<typename T>
1618+
T select(bool, T, T);
1619+
1620+
/// \fn vector<T,Sz> select(vector<bool,Sz> Conds, vector<T,Sz> TrueVals, vector<T,Sz>, FalseVals)
1621+
/// \brief ternary operator for vectors. All vectors must be the same size.
1622+
/// \param Conds The Condition input values.
1623+
/// \param TrueVals The vector values are chosen from when conditions are true.
1624+
/// \param FalseVals The vector values are chosen from when conditions are false.
1625+
1626+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_select)
1627+
template<typename T, int Sz>
1628+
vector<T,Sz> select(vector<bool,Sz>, vector<T,Sz>, vector<T,Sz>);
1629+
16061630
//===----------------------------------------------------------------------===//
16071631
// sin builtins
16081632
//===----------------------------------------------------------------------===//

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,6 +1013,66 @@ void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
10131013
TheCall->setType(ReturnType);
10141014
}
10151015

1016+
bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
1017+
assert(TheCall->getNumArgs() == 3);
1018+
Expr *Arg1 = TheCall->getArg(1);
1019+
Expr *Arg2 = TheCall->getArg(2);
1020+
if(!S->Context.hasSameUnqualifiedType(Arg1->getType(),
1021+
Arg2->getType())) {
1022+
S->Diag(TheCall->getBeginLoc(),
1023+
diag::err_typecheck_call_different_arg_types)
1024+
<< Arg1->getType() << Arg2->getType()
1025+
<< Arg1->getSourceRange() << Arg2->getSourceRange();
1026+
return true;
1027+
}
1028+
1029+
TheCall->setType(Arg1->getType());
1030+
return false;
1031+
}
1032+
1033+
bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
1034+
assert(TheCall->getNumArgs() == 3);
1035+
Expr *Arg1 = TheCall->getArg(1);
1036+
Expr *Arg2 = TheCall->getArg(2);
1037+
if(!Arg1->getType()->isVectorType()) {
1038+
S->Diag(Arg1->getBeginLoc(),
1039+
diag::err_builtin_non_vector_type)
1040+
<< "Second" << "__builtin_hlsl_select" << Arg1->getType()
1041+
<< Arg1->getSourceRange();
1042+
return true;
1043+
}
1044+
1045+
if(!Arg2->getType()->isVectorType()) {
1046+
S->Diag(Arg2->getBeginLoc(),
1047+
diag::err_builtin_non_vector_type)
1048+
<< "Third" << "__builtin_hlsl_select" << Arg2->getType()
1049+
<< Arg2->getSourceRange();
1050+
return true;
1051+
}
1052+
1053+
if(!S->Context.hasSameUnqualifiedType(Arg1->getType(),
1054+
Arg2->getType())) {
1055+
S->Diag(TheCall->getBeginLoc(),
1056+
diag::err_typecheck_call_different_arg_types)
1057+
<< Arg1->getType() << Arg2->getType()
1058+
<< Arg1->getSourceRange() << Arg2->getSourceRange();
1059+
return true;
1060+
}
1061+
1062+
// caller has checked that Arg0 is a vector.
1063+
// check all three args have the same length.
1064+
if(TheCall->getArg(0)->getType()->getAs<VectorType>()->getNumElements() !=
1065+
Arg1->getType()->getAs<VectorType>()->getNumElements()) {
1066+
S->Diag(TheCall->getBeginLoc(),
1067+
diag::err_typecheck_vector_lengths_not_equal)
1068+
<< TheCall->getArg(0)->getType() << Arg1->getType()
1069+
<< TheCall->getArg(0)->getSourceRange() << Arg1->getSourceRange();
1070+
return true;
1071+
}
1072+
1073+
return false;
1074+
}
1075+
10161076
// Note: returning true in this case results in CheckBuiltinFunctionCall
10171077
// returning an ExprError
10181078
bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
@@ -1046,6 +1106,30 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
10461106
break;
10471107
}
10481108
case Builtin::BI__builtin_hlsl_elementwise_saturate:
1109+
case Builtin::BI__builtin_hlsl_select: {
1110+
if (SemaRef.checkArgCount(TheCall, 3))
1111+
return true;
1112+
QualType ArgTy = TheCall->getArg(0)->getType();
1113+
if (ArgTy->isBooleanType()) {
1114+
if (CheckBoolSelect(&SemaRef, TheCall))
1115+
return true;
1116+
} else if (ArgTy->isVectorType() &&
1117+
ArgTy->getAs<VectorType>()->getElementType()->isBooleanType()) {
1118+
if (CheckVectorSelect(&SemaRef, TheCall))
1119+
return true;
1120+
} else { // first operand is not a bool or a vector of bools.
1121+
SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
1122+
diag::err_typecheck_convert_incompatible)
1123+
<< TheCall->getArg(0)->getType() << SemaRef.Context.getBOOLType()
1124+
<< 1 << 0 << 0;
1125+
SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
1126+
diag::err_builtin_non_vector_type)
1127+
<< "First" << "__builtin_hlsl_select" << TheCall->getArg(0)->getType()
1128+
<< TheCall->getArg(0)->getSourceRange();
1129+
return true;
1130+
}
1131+
break;
1132+
}
10491133
case Builtin::BI__builtin_hlsl_elementwise_rcp: {
10501134
if (CheckAllArgsHaveFloatRepresentation(&SemaRef, TheCall))
10511135
return true;

0 commit comments

Comments
 (0)