Skip to content

Commit ba9720a

Browse files
committed
[𝘀𝗽𝗿] initial version
Created using spr 1.3.5-bogner
2 parents c490658 + 8f26a1c commit ba9720a

File tree

7 files changed

+165
-9
lines changed

7 files changed

+165
-9
lines changed

llvm/include/llvm/IR/IntrinsicsDirectX.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ def int_dx_handle_fromBinding
3030
[llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i1_ty],
3131
[IntrNoMem]>;
3232

33+
def int_dx_typedBufferLoad
34+
: DefaultAttrsIntrinsic<
35+
[llvm_any_ty, LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>],
36+
[llvm_any_ty, llvm_i32_ty]>;
37+
3338
// Cast between target extension handle types and dxil-style opaque handles
3439
def int_dx_cast_handle : Intrinsic<[llvm_any_ty], [llvm_any_ty]>;
3540

llvm/lib/Target/DirectX/DXIL.td

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@ def Int64Ty : DXILOpParamType;
4040
def HalfTy : DXILOpParamType;
4141
def FloatTy : DXILOpParamType;
4242
def DoubleTy : DXILOpParamType;
43-
def ResRetTy : DXILOpParamType;
43+
def ResRetHalfTy : DXILOpParamType;
44+
def ResRetFloatTy : DXILOpParamType;
45+
def ResRetInt16Ty : DXILOpParamType;
46+
def ResRetInt32Ty : DXILOpParamType;
4447
def HandleTy : DXILOpParamType;
4548
def ResBindTy : DXILOpParamType;
4649
def ResPropsTy : DXILOpParamType;
@@ -693,6 +696,17 @@ def CreateHandle : DXILOp<57, createHandle> {
693696
let stages = [Stages<DXIL1_0, [all_stages]>, Stages<DXIL1_6, [removed]>];
694697
}
695698

699+
def BufferLoad : DXILOp<68, bufferLoad> {
700+
let Doc = "reads from a TypedBuffer";
701+
// Handle, Coord0, Coord1
702+
let arguments = [HandleTy, Int32Ty, Int32Ty];
703+
let result = OverloadTy;
704+
let overloads =
705+
[Overloads<DXIL1_0,
706+
[ResRetHalfTy, ResRetFloatTy, ResRetInt16Ty, ResRetInt32Ty]>];
707+
let stages = [Stages<DXIL1_0, [all_stages]>];
708+
}
709+
696710
def ThreadId : DXILOp<93, threadId> {
697711
let Doc = "Reads the thread ID";
698712
let LLVMIntrinsic = int_dx_thread_id;

llvm/lib/Target/DirectX/DXILOpBuilder.cpp

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,12 @@ static OverloadKind getOverloadKind(Type *Ty) {
120120
}
121121
case Type::PointerTyID:
122122
return OverloadKind::UserDefineType;
123-
case Type::StructTyID:
124-
return OverloadKind::ObjectType;
123+
case Type::StructTyID: {
124+
// TODO: This is a hack. As described in DXILEmitter.cpp, we need to rework
125+
// how we're handling overloads and remove the `OverloadKind` proxy enum.
126+
StructType *ST = cast<StructType>(Ty);
127+
return getOverloadKind(ST->getElementType(0));
128+
}
125129
default:
126130
llvm_unreachable("invalid overload type");
127131
return OverloadKind::VOID;
@@ -195,10 +199,11 @@ static StructType *getOrCreateStructType(StringRef Name,
195199
return StructType::create(Ctx, EltTys, Name);
196200
}
197201

198-
static StructType *getResRetType(Type *OverloadTy, LLVMContext &Ctx) {
199-
OverloadKind Kind = getOverloadKind(OverloadTy);
202+
static StructType *getResRetType(Type *ElementTy) {
203+
LLVMContext &Ctx = ElementTy->getContext();
204+
OverloadKind Kind = getOverloadKind(ElementTy);
200205
std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet.");
201-
Type *FieldTypes[5] = {OverloadTy, OverloadTy, OverloadTy, OverloadTy,
206+
Type *FieldTypes[5] = {ElementTy, ElementTy, ElementTy, ElementTy,
202207
Type::getInt32Ty(Ctx)};
203208
return getOrCreateStructType(TypeName, FieldTypes, Ctx);
204209
}
@@ -248,8 +253,14 @@ static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx,
248253
return Type::getInt64Ty(Ctx);
249254
case OpParamType::OverloadTy:
250255
return OverloadTy;
251-
case OpParamType::ResRetTy:
252-
return getResRetType(OverloadTy, Ctx);
256+
case OpParamType::ResRetHalfTy:
257+
return getResRetType(Type::getHalfTy(Ctx));
258+
case OpParamType::ResRetFloatTy:
259+
return getResRetType(Type::getFloatTy(Ctx));
260+
case OpParamType::ResRetInt16Ty:
261+
return getResRetType(Type::getInt16Ty(Ctx));
262+
case OpParamType::ResRetInt32Ty:
263+
return getResRetType(Type::getInt32Ty(Ctx));
253264
case OpParamType::HandleTy:
254265
return getHandleType(Ctx);
255266
case OpParamType::ResBindTy:
@@ -391,6 +402,7 @@ Expected<CallInst *> DXILOpBuilder::tryCreateOp(dxil::OpCode OpCode,
391402
return makeOpError(OpCode, "Wrong number of arguments");
392403
OverloadTy = Args[ArgIndex]->getType();
393404
}
405+
394406
FunctionType *DXILOpFT =
395407
getDXILOpFunctionType(OpCode, M.getContext(), OverloadTy);
396408

@@ -451,6 +463,10 @@ CallInst *DXILOpBuilder::createOp(dxil::OpCode OpCode, ArrayRef<Value *> Args,
451463
return *Result;
452464
}
453465

466+
StructType *DXILOpBuilder::getResRetType(Type *ElementTy) {
467+
return ::getResRetType(ElementTy);
468+
}
469+
454470
StructType *DXILOpBuilder::getHandleType() {
455471
return ::getHandleType(IRB.getContext());
456472
}

llvm/lib/Target/DirectX/DXILOpBuilder.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ class DXILOpBuilder {
4646
Expected<CallInst *> tryCreateOp(dxil::OpCode Op, ArrayRef<Value *> Args,
4747
Type *RetTy = nullptr);
4848

49+
/// Get a `%dx.types.ResRet` type with the given element type.
50+
StructType *getResRetType(Type *ElementTy);
4951
/// Get the `%dx.types.Handle` type.
5052
StructType *getHandleType();
5153

llvm/lib/Target/DirectX/DXILOpLowering.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,45 @@ class OpLowerer {
259259
lowerToBindAndAnnotateHandle(F);
260260
}
261261

262+
void lowerTypedBufferLoad(Function &F) {
263+
IRBuilder<> &IRB = OpBuilder.getIRB();
264+
Type *Int32Ty = IRB.getInt32Ty();
265+
266+
replaceFunction(F, [&](CallInst *CI) -> Error {
267+
IRB.SetInsertPoint(CI);
268+
269+
Value *Handle =
270+
createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
271+
Value *Index0 = CI->getArgOperand(1);
272+
Value *Index1 = UndefValue::get(Int32Ty);
273+
Type *ElTy = cast<StructType>(CI->getType())->getElementType(0);
274+
Type *RetTy = OpBuilder.getResRetType(ElTy);
275+
276+
std::array<Value *, 3> Args{Handle, Index0, Index1};
277+
Expected<CallInst *> OpCall =
278+
OpBuilder.tryCreateOp(OpCode::BufferLoad, Args, RetTy);
279+
if (Error E = OpCall.takeError())
280+
return E;
281+
282+
// We've switched the return type from an anonymous struct to a named one,
283+
// so we need to update the types in the uses.
284+
for (Use &U : make_early_inc_range(CI->uses())) {
285+
// Uses other than extract value should be impossible at this point, as
286+
// we shouldn't be able to call functions with the anonymous struct or
287+
// store these directly.
288+
auto *EVI = cast<ExtractValueInst>(U.getUser());
289+
IRB.SetInsertPoint(EVI);
290+
auto *NewEVI = IRB.CreateExtractValue(*OpCall, EVI->getIndices());
291+
292+
EVI->replaceAllUsesWith(NewEVI);
293+
EVI->eraseFromParent();
294+
}
295+
296+
CI->eraseFromParent();
297+
return Error::success();
298+
});
299+
}
300+
262301
bool lowerIntrinsics() {
263302
bool Updated = false;
264303

@@ -276,6 +315,10 @@ class OpLowerer {
276315
#include "DXILOperation.inc"
277316
case Intrinsic::dx_handle_fromBinding:
278317
lowerHandleFromBinding(F);
318+
break;
319+
case Intrinsic::dx_typedBufferLoad:
320+
lowerTypedBufferLoad(F);
321+
break;
279322
}
280323
Updated = true;
281324
}
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
; RUN: opt -S -dxil-op-lower %s | FileCheck %s
2+
3+
target triple = "dxil-pc-shadermodel6.6-compute"
4+
5+
declare void @scalar_user(float)
6+
7+
define void @loadfloats() {
8+
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
9+
; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
10+
%buffer = call target("dx.TypedBuffer", <4 x float>, 0, 0, 0)
11+
@llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_0_0_0(
12+
i32 0, i32 0, i32 1, i32 0, i1 false)
13+
14+
; The temporary casts should all have been cleaned up
15+
; CHECK-NOT: %dx.cast_handle
16+
17+
; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
18+
%data0 = call { float, float, float, float } @llvm.dx.typedBufferLoad.f32(
19+
target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 0)
20+
21+
; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA0]], 0
22+
%data0_0 = extractvalue {float, float, float, float} %data0, 0
23+
; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA0]], 2
24+
%data0_2 = extractvalue {float, float, float, float} %data0, 2
25+
26+
call void @scalar_user(float %data0_0)
27+
call void @scalar_user(float %data0_2)
28+
29+
ret void
30+
}
31+
32+
define void @loadint() {
33+
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
34+
; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
35+
%buffer = call target("dx.TypedBuffer", <4 x i32>, 0, 0, 0)
36+
@llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4i32_0_0_0(
37+
i32 0, i32 0, i32 1, i32 0, i1 false)
38+
39+
; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.i32 @dx.op.bufferLoad.i32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
40+
%data0 = call {i32, i32, i32, i32} @llvm.dx.typedBufferLoad.i32(
41+
target("dx.TypedBuffer", <4 x i32>, 0, 0, 0) %buffer, i32 0)
42+
43+
ret void
44+
}
45+
46+
define void @loadhalf() {
47+
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
48+
; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
49+
%buffer = call target("dx.TypedBuffer", <4 x half>, 0, 0, 0)
50+
@llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f16_0_0_0(
51+
i32 0, i32 0, i32 1, i32 0, i1 false)
52+
53+
; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f16 @dx.op.bufferLoad.f16(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
54+
%data0 = call {half, half, half, half} @llvm.dx.typedBufferLoad.f16(
55+
target("dx.TypedBuffer", <4 x half>, 0, 0, 0) %buffer, i32 0)
56+
57+
ret void
58+
}
59+
60+
define void @loadi16() {
61+
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
62+
; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]]
63+
%buffer = call target("dx.TypedBuffer", <4 x i16>, 0, 0, 0)
64+
@llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4i16_0_0_0(
65+
i32 0, i32 0, i32 1, i32 0, i1 false)
66+
67+
; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.i16 @dx.op.bufferLoad.i16(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
68+
%data0 = call {i16, i16, i16, i16} @llvm.dx.typedBufferLoad.i16(
69+
target("dx.TypedBuffer", <4 x i16>, 0, 0, 0) %buffer, i32 0)
70+
71+
ret void
72+
}

llvm/utils/TableGen/DXILEmitter.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,11 @@ static StringRef getOverloadKindStr(const Record *R) {
187187
.Case("Int8Ty", "OverloadKind::I8")
188188
.Case("Int16Ty", "OverloadKind::I16")
189189
.Case("Int32Ty", "OverloadKind::I32")
190-
.Case("Int64Ty", "OverloadKind::I64");
190+
.Case("Int64Ty", "OverloadKind::I64")
191+
.Case("ResRetHalfTy", "OverloadKind::HALF")
192+
.Case("ResRetFloatTy", "OverloadKind::FLOAT")
193+
.Case("ResRetInt16Ty", "OverloadKind::I16")
194+
.Case("ResRetInt32Ty", "OverloadKind::I32");
191195
}
192196

193197
/// Return a string representation of valid overload information denoted

0 commit comments

Comments
 (0)