Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 66 additions & 14 deletions clang/lib/Sema/HLSLExternalSemaSource.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ struct BuiltinTypeDeclBuilder {
BuiltinTypeDeclBuilder &addDecrementCounterMethod();
BuiltinTypeDeclBuilder &addHandleAccessFunction(DeclarationName &Name,
bool IsConst, bool IsRef);
BuiltinTypeDeclBuilder &addAppendMethod();
BuiltinTypeDeclBuilder &addConsumeMethod();
};

struct TemplateParameterListBuilder {
Expand Down Expand Up @@ -428,14 +430,26 @@ struct BuiltinTypeMethodBuilder {
llvm::SmallVector<Stmt *> StmtsList;

// Argument placeholders, inspired by std::placeholder. These are the indices
// of arguments to forward to `callBuiltin`, and additionally `Handle` which
// refers to the resource handle.
enum class PlaceHolder { _0, _1, _2, _3, Handle = 127 };
// of arguments to forward to `callBuiltin` and other method builder methods.
// Additional special values are:
// Handle - refers to the resource handle.
// LastStmt - refers to the last statement in the method body; referencing
// LastStmt will remove the statement from the method body since
// it will be linked from the new expression being constructed.
enum class PlaceHolder { _0, _1, _2, _3, Handle = 128, LastStmt };

Expr *convertPlaceholder(PlaceHolder PH) {
if (PH == PlaceHolder::Handle)
return getResourceHandleExpr();

if (PH == PlaceHolder::LastStmt) {
assert(!StmtsList.empty() && "no statements in the list");
Stmt *LastStmt = StmtsList.pop_back_val();
assert(isa<ValueStmt>(LastStmt) &&
"last statement does not have a value");
return cast<ValueStmt>(LastStmt)->getExprStmt();
}

ASTContext &AST = DeclBuilder.SemaRef.getASTContext();
ParmVarDecl *ParamDecl = Method->getParamDecl(static_cast<unsigned>(PH));
return DeclRefExpr::Create(
Expand Down Expand Up @@ -558,17 +572,25 @@ struct BuiltinTypeMethodBuilder {
return *this;
}

BuiltinTypeMethodBuilder &dereference() {
assert(!StmtsList.empty() && "Nothing to dereference");
ASTContext &AST = DeclBuilder.SemaRef.getASTContext();
template <typename TLHS, typename TRHS>
BuiltinTypeMethodBuilder &assign(TLHS LHS, TRHS RHS) {
Expr *LHSExpr = convertPlaceholder(LHS);
Expr *RHSExpr = convertPlaceholder(RHS);
Stmt *AssignStmt = BinaryOperator::Create(
DeclBuilder.SemaRef.getASTContext(), LHSExpr, RHSExpr, BO_Assign,
LHSExpr->getType(), ExprValueKind::VK_PRValue,
ExprObjectKind::OK_Ordinary, SourceLocation(), FPOptionsOverride());
StmtsList.push_back(AssignStmt);
return *this;
}

Expr *LastExpr = dyn_cast<Expr>(StmtsList.back());
assert(LastExpr && "No expression to dereference");
Expr *Deref = UnaryOperator::Create(
AST, LastExpr, UO_Deref, LastExpr->getType()->getPointeeType(),
VK_PRValue, OK_Ordinary, SourceLocation(),
/*CanOverflow=*/false, FPOptionsOverride());
StmtsList.pop_back();
template <typename T> BuiltinTypeMethodBuilder &dereference(T Ptr) {
Expr *PtrExpr = convertPlaceholder(Ptr);
Expr *Deref =
UnaryOperator::Create(DeclBuilder.SemaRef.getASTContext(), PtrExpr,
UO_Deref, PtrExpr->getType()->getPointeeType(),
VK_PRValue, OK_Ordinary, SourceLocation(),
/*CanOverflow=*/false, FPOptionsOverride());
StmtsList.push_back(Deref);
return *this;
}
Expand Down Expand Up @@ -670,7 +692,35 @@ BuiltinTypeDeclBuilder::addHandleAccessFunction(DeclarationName &Name,
.addParam("Index", AST.UnsignedIntTy)
.callBuiltin("__builtin_hlsl_resource_getpointer", ElemPtrTy, PH::Handle,
PH::_0)
.dereference()
.dereference(PH::LastStmt)
.finalizeMethod();
}

BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addAppendMethod() {
using PH = BuiltinTypeMethodBuilder::PlaceHolder;
ASTContext &AST = SemaRef.getASTContext();
QualType ElemTy = getHandleElementType();
return BuiltinTypeMethodBuilder(*this, "Append", AST.VoidTy)
.addParam("value", ElemTy)
.callBuiltin("__builtin_hlsl_buffer_update_counter", AST.UnsignedIntTy,
PH::Handle, getConstantIntExpr(1))
.callBuiltin("__builtin_hlsl_resource_getpointer",
AST.getPointerType(ElemTy), PH::Handle, PH::LastStmt)
.dereference(PH::LastStmt)
.assign(PH::LastStmt, PH::_0)
.finalizeMethod();
}

BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addConsumeMethod() {
using PH = BuiltinTypeMethodBuilder::PlaceHolder;
ASTContext &AST = SemaRef.getASTContext();
QualType ElemTy = getHandleElementType();
return BuiltinTypeMethodBuilder(*this, "Consume", ElemTy)
.callBuiltin("__builtin_hlsl_buffer_update_counter", AST.UnsignedIntTy,
PH::Handle, getConstantIntExpr(-1))
.callBuiltin("__builtin_hlsl_resource_getpointer",
AST.getPointerType(ElemTy), PH::Handle, PH::LastStmt)
.dereference(PH::LastStmt)
.finalizeMethod();
}

Expand Down Expand Up @@ -900,6 +950,7 @@ void HLSLExternalSemaSource::defineHLSLTypesWithForwardDeclarations() {
onCompletion(Decl, [this](CXXRecordDecl *Decl) {
setupBufferType(Decl, *SemaPtr, ResourceClass::UAV, ResourceKind::RawBuffer,
/*IsROV=*/false, /*RawBuffer=*/true)
.addAppendMethod()
.completeDefinition();
});

Expand All @@ -910,6 +961,7 @@ void HLSLExternalSemaSource::defineHLSLTypesWithForwardDeclarations() {
onCompletion(Decl, [this](CXXRecordDecl *Decl) {
setupBufferType(Decl, *SemaPtr, ResourceClass::UAV, ResourceKind::RawBuffer,
/*IsROV=*/false, /*RawBuffer=*/true)
.addConsumeMethod()
.completeDefinition();
});

Expand Down
46 changes: 44 additions & 2 deletions clang/test/AST/HLSL/StructuredBuffers-AST.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@
//
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -x hlsl -ast-dump \
// RUN: -DRESOURCE=AppendStructuredBuffer %s | FileCheck -DRESOURCE=AppendStructuredBuffer \
// RUN: -check-prefixes=CHECK,CHECK-UAV,CHECK-NOSUBSCRIPT %s
// RUN: -check-prefixes=CHECK,CHECK-UAV,CHECK-NOSUBSCRIPT,CHECK-APPEND %s
//
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -x hlsl -ast-dump -DEMPTY \
// RUN: -DRESOURCE=ConsumeStructuredBuffer %s | FileCheck -DRESOURCE=ConsumeStructuredBuffer \
// RUN: -check-prefix=EMPTY %s
//
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -x hlsl -ast-dump \
// RUN: -DRESOURCE=ConsumeStructuredBuffer %s | FileCheck -DRESOURCE=ConsumeStructuredBuffer \
// RUN: -check-prefixes=CHECK,CHECK-UAV,CHECK-NOSUBSCRIPT %s
// RUN: -check-prefixes=CHECK,CHECK-UAV,CHECK-NOSUBSCRIPT,CHECK-CONSUME %s
//
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-library -x hlsl -ast-dump -DEMPTY \
// RUN: -DRESOURCE=RasterizerOrderedStructuredBuffer %s | FileCheck -DRESOURCE=RasterizerOrderedStructuredBuffer \
Expand Down Expand Up @@ -135,6 +135,48 @@ RESOURCE<float> Buffer;
// CHECK-COUNTER-NEXT: IntegerLiteral 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'int' -1
// CHECK-COUNTER-NEXT: AlwaysInlineAttr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> Implicit always_inline

// CHECK-APPEND: CXXMethodDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc> Append 'void (element_type)'
// CHECK-APPEND-NEXT: ParmVarDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc> value 'element_type'
// CHECK-APPEND-NEXT: CompoundStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>>
// CHECK-APPEND-NEXT: BinaryOperator 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'element_type' '='
// CHECK-APPEND-NEXT: UnaryOperator 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'element_type' prefix '*' cannot overflow
// CHECK-APPEND-NEXT: CallExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'element_type *'
// CHECK-APPEND-NEXT: DeclRefExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '<builtin fn type>' Function 0x{{[0-9A-Fa-f]+}} '__builtin_hlsl_resource_getpointer' 'void (...) noexcept'
// CHECK-APPEND-NEXT: MemberExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '__hlsl_resource_t
// CHECK-APPEND-SAME{LITERAL}: [[hlsl::resource_class(UAV)]]
// CHECK-APPEND-SAME{LITERAL}: [[hlsl::raw_buffer]]
// CHECK-APPEND-SAME{LITERAL}: [[hlsl::contained_type(element_type)]]' lvalue .__handle
// CHECK-APPEND-NEXT: CXXThisExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '[[RESOURCE]]<element_type>' lvalue implicit this
// CHECK-APPEND-NEXT: CallExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'unsigned int'
// CHECK-APPEND-NEXT: DeclRefExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '<builtin fn type>' Function 0x{{[0-9A-Fa-f]+}} '__builtin_hlsl_buffer_update_counter' 'unsigned int (...) noexcept'
// CHECK-APPEND-NEXT: MemberExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '__hlsl_resource_t
// CHECK-APPEND-SAME{LITERAL}: [[hlsl::resource_class(UAV)]]
// CHECK-APPEND-SAME{LITERAL}: [[hlsl::raw_buffer]]
// CHECK-APPEND-SAME{LITERAL}: [[hlsl::contained_type(element_type)]]' lvalue .__handle
// CHECK-APPEND-NEXT: CXXThisExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '[[RESOURCE]]<element_type>' lvalue implicit this
// CHECK-APPEND-NEXT: IntegerLiteral 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'int' 1
// CHECK-APPEND-NEXT: DeclRefExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'element_type' ParmVar 0x{{[0-9A-Fa-f]+}} 'value' 'element_type'

// CHECK-CONSUME: CXXMethodDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc> Consume 'element_type ()'
// CHECK-CONSUME-NEXT: CompoundStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>>
// CHECK-CONSUME-NEXT: ReturnStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>>
// CHECK-CONSUME-NEXT: UnaryOperator 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'element_type' prefix '*' cannot overflow
// CHECK-CONSUME-NEXT: CallExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'element_type *'
// CHECK-CONSUME-NEXT: DeclRefExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '<builtin fn type>' Function 0x{{[0-9A-Fa-f]+}} '__builtin_hlsl_resource_getpointer' 'void (...) noexcept'
// CHECK-CONSUME-NEXT: MemberExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '__hlsl_resource_t
// CHECK-CONSUME-SAME{LITERAL}: [[hlsl::resource_class(UAV)]]
// CHECK-CONSUME-SAME{LITERAL}: [[hlsl::raw_buffer]]
// CHECK-CONSUME-SAME{LITERAL}: [[hlsl::contained_type(element_type)]]' lvalue .__handle
// CHECK-CONSUME-NEXT: CXXThisExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '[[RESOURCE]]<element_type>' lvalue implicit this
// CHECK-CONSUME-NEXT: CallExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'unsigned int'
// CHECK-CONSUME-NEXT: DeclRefExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '<builtin fn type>' Function 0x{{[0-9A-Fa-f]+}} '__builtin_hlsl_buffer_update_counter' 'unsigned int (...) noexcept'
// CHECK-CONSUME-NEXT: MemberExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '__hlsl_resource_t
// CHECK-CONSUME-SAME{LITERAL}: [[hlsl::resource_class(UAV)]]
// CHECK-CONSUME-SAME{LITERAL}: [[hlsl::raw_buffer]]
// CHECK-CONSUME-SAME{LITERAL}: [[hlsl::contained_type(element_type)]]' lvalue .__handle
// CHECK-CONSUME-NEXT: CXXThisExpr 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> '[[RESOURCE]]<element_type>' lvalue implicit this
// CHECK-CONSUME-NEXT: IntegerLiteral 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> 'int' -1

// CHECK: ClassTemplateSpecializationDecl 0x{{[0-9A-Fa-f]+}} <<invalid sloc>> <invalid sloc> class [[RESOURCE]] definition

// CHECK: TemplateArgument type 'float'
Expand Down
40 changes: 32 additions & 8 deletions clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-lib.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,45 @@

RWStructuredBuffer<float> RWSB1 : register(u0);
RWStructuredBuffer<float> RWSB2 : register(u1);
AppendStructuredBuffer<float> ASB : register(u2);
ConsumeStructuredBuffer<float> CSB : register(u3);

// CHECK: %"class.hlsl::RWStructuredBuffer" = type { target("dx.RawBuffer", float, 1, 0) }

export void TestIncrementCounter() {
RWSB1.IncrementCounter();
export int TestIncrementCounter() {
return RWSB1.IncrementCounter();
}

// CHECK: define void @_Z20TestIncrementCounterv()
// CHECK-DXIL: call i32 @llvm.dx.bufferUpdateCounter.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %{{[0-9]+}}, i8 1)
// CHECK: define noundef i32 @_Z20TestIncrementCounterv()
// CHECK-DXIL: %[[INDEX:.*]] = call i32 @llvm.dx.bufferUpdateCounter.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %{{[0-9]+}}, i8 1)
// CHECK-DXIL: ret i32 %[[INDEX]]
export int TestDecrementCounter() {
return RWSB2.DecrementCounter();
}

// CHECK: define noundef i32 @_Z20TestDecrementCounterv()
// CHECK-DXIL: %[[INDEX:.*]] = call i32 @llvm.dx.bufferUpdateCounter.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %{{[0-9]+}}, i8 -1)
// CHECK-DXIL: ret i32 %[[INDEX]]

export void TestAppend(float value) {
ASB.Append(value);
}

// CHECK: define void @_Z10TestAppendf(float noundef %value)
// CHECK-DXIL: %[[VALUE:.*]] = load float, ptr %value.addr, align 4
// CHECK-DXIL: %[[INDEX:.*]] = call i32 @llvm.dx.bufferUpdateCounter.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %{{[0-9]+}}, i8 1)
// CHECK-DXIL: %[[RESPTR:.*]] = call ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %{{[0-9]+}}, i32 %[[INDEX]])
// CHECK-DXIL: store float %[[VALUE]], ptr %[[RESPTR]], align 4

export void TestDecrementCounter() {
RWSB2.DecrementCounter();
export float TestConsume() {
return CSB.Consume();
}

// CHECK: define void @_Z20TestDecrementCounterv()
// CHECK-DXIL: call i32 @llvm.dx.bufferUpdateCounter.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %{{[0-9]+}}, i8 -1)
// CHECK: define noundef float @_Z11TestConsumev()
// CHECK-DXIL: %[[INDEX:.*]] = call i32 @llvm.dx.bufferUpdateCounter.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %1, i8 -1)
// CHECK-DXIL: %[[RESPTR:.*]] = call ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0) %0, i32 %[[INDEX]])
// CHECK-DXIL: %[[VALUE:.*]] = load float, ptr %[[RESPTR]], align 4
// CHECK-DXIL: ret float %[[VALUE]]

// CHECK: declare i32 @llvm.dx.bufferUpdateCounter.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0), i8)
// CHECK: declare ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_f32_1_0t(target("dx.RawBuffer", float, 1, 0), i32)
Loading