Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions llvm/include/llvm/IR/IntrinsicsDirectX.td
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ def int_dx_handle_fromBinding
[IntrNoMem]>;

def int_dx_typedBufferLoad
: DefaultAttrsIntrinsic<[llvm_anyvector_ty],
[llvm_any_ty, llvm_i32_ty]>;
: DefaultAttrsIntrinsic<
[llvm_any_ty, LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>],
[llvm_any_ty, llvm_i32_ty]>;

// Cast between target extension handle types and dxil-style opaque handles
def int_dx_cast_handle : Intrinsic<[llvm_any_ty], [llvm_any_ty]>;
Expand Down
5 changes: 1 addition & 4 deletions llvm/lib/Target/DirectX/DXILOpBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,7 @@ static OverloadKind getOverloadKind(Type *Ty) {
// TODO: This is a hack. As described in DXILEmitter.cpp, we need to rework
// how we're handling overloads and remove the `OverloadKind` proxy enum.
StructType *ST = cast<StructType>(Ty);
if (ST->hasName() && ST->getName().starts_with("dx.types.ResRet"))
return getOverloadKind(ST->getElementType(0));

return OverloadKind::ObjectType;
return getOverloadKind(ST->getElementType(0));
}
default:
llvm_unreachable("invalid overload type");
Expand Down
42 changes: 14 additions & 28 deletions llvm/lib/Target/DirectX/DXILOpLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,41 +270,27 @@ class OpLowerer {
createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
Value *Index0 = CI->getArgOperand(1);
Value *Index1 = UndefValue::get(Int32Ty);
Type *RetTy = OpBuilder.getResRetType(CI->getType()->getScalarType());
Type *ElTy = cast<StructType>(CI->getType())->getElementType(0);
Type *RetTy = OpBuilder.getResRetType(ElTy);

std::array<Value *, 3> Args{Handle, Index0, Index1};
Expected<CallInst *> OpCall =
OpBuilder.tryCreateOp(OpCode::BufferLoad, Args, RetTy);
if (Error E = OpCall.takeError())
return E;

std::array<Value *, 4> Extracts = {};

// We've switched the return type from a vector to a struct, but at this
// point most vectors have probably already been scalarized. Try to
// forward arguments directly rather than inserting into and immediately
// extracting from a vector.
for (Use &U : make_early_inc_range(CI->uses()))
if (auto *EEI = dyn_cast<ExtractElementInst>(U.getUser()))
if (auto *Index = dyn_cast<ConstantInt>(EEI->getIndexOperand())) {
size_t IndexVal = Index->getZExtValue();
assert(IndexVal < 4 && "Index into buffer load out of range");
if (!Extracts[IndexVal])
Extracts[IndexVal] = IRB.CreateExtractValue(*OpCall, IndexVal);
EEI->replaceAllUsesWith(Extracts[IndexVal]);
EEI->eraseFromParent();
}

// If there are still uses then we need to create a vector.
if (!CI->use_empty()) {
for (int I = 0, E = 4; I != E; ++I)
if (!Extracts[I])
Extracts[I] = IRB.CreateExtractValue(*OpCall, I);

Value *Vec = UndefValue::get(CI->getType());
for (int I = 0, E = 4; I != E; ++I)
Vec = IRB.CreateInsertElement(Vec, Extracts[I], I);
CI->replaceAllUsesWith(Vec);
// We've switched the return type from an anonymous struct to a named one,
// so we need to update the types in the uses.
for (Use &U : make_early_inc_range(CI->uses())) {
// Uses other than extract value should be impossible at this point, as
// we shouldn't be able to call functions with the anonymous struct or
// store these directly.
auto *EVI = cast<ExtractValueInst>(U.getUser());
IRB.SetInsertPoint(EVI);
auto *NewEVI = IRB.CreateExtractValue(*OpCall, EVI->getIndices());

EVI->replaceAllUsesWith(NewEVI);
EVI->eraseFromParent();
}

CI->eraseFromParent();
Expand Down
46 changes: 8 additions & 38 deletions llvm/test/CodeGen/DirectX/BufferLoad.ll
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
target triple = "dxil-pc-shadermodel6.6-compute"

declare void @scalar_user(float)
declare void @vector_user(<4 x float>)

define void @loadfloats() {
; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding
Expand All @@ -16,46 +15,17 @@ define void @loadfloats() {
; CHECK-NOT: %dx.cast_handle

; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
%data0 = call <4 x float> @llvm.dx.typedBufferLoad(
%data0 = call { float, float, float, float } @llvm.dx.typedBufferLoad.f32(
target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 0)

; The extract order depends on the users, so don't enforce that here.
; CHECK-DAG: extractvalue %dx.types.ResRet.f32 [[DATA0]], 0
%data0_0 = extractelement <4 x float> %data0, i32 0
; CHECK-DAG: extractvalue %dx.types.ResRet.f32 [[DATA0]], 2
%data0_2 = extractelement <4 x float> %data0, i32 2
; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA0]], 0
%data0_0 = extractvalue {float, float, float, float} %data0, 0
; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA0]], 2
%data0_2 = extractvalue {float, float, float, float} %data0, 2

; If all of the uses are extracts, we skip creating a vector
; CHECK-NOT: insertelement
call void @scalar_user(float %data0_0)
call void @scalar_user(float %data0_2)

; CHECK: [[DATA4:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 4, i32 undef)
%data4 = call <4 x float> @llvm.dx.typedBufferLoad(
target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 4)

; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 0
; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 1
; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 2
; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 3
; CHECK: insertelement <4 x float> undef
; CHECK: insertelement <4 x float>
; CHECK: insertelement <4 x float>
; CHECK: insertelement <4 x float>
call void @vector_user(<4 x float> %data4)

; CHECK: [[DATA12:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 12, i32 undef)
%data12 = call <4 x float> @llvm.dx.typedBufferLoad(
target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 12)

; CHECK: [[DATA12_3:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA12]], 3
%data12_3 = extractelement <4 x float> %data12, i32 3

; If there are a mix of users we need the vector, but extracts are direct
; CHECK: call void @scalar_user(float [[DATA12_3]])
call void @scalar_user(float %data12_3)
call void @vector_user(<4 x float> %data12)

ret void
}

Expand All @@ -67,7 +37,7 @@ define void @loadint() {
i32 0, i32 0, i32 1, i32 0, i1 false)

; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.i32 @dx.op.bufferLoad.i32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
%data0 = call <4 x i32> @llvm.dx.typedBufferLoad(
%data0 = call {i32, i32, i32, i32} @llvm.dx.typedBufferLoad.i32(
target("dx.TypedBuffer", <4 x i32>, 0, 0, 0) %buffer, i32 0)

ret void
Expand All @@ -81,7 +51,7 @@ define void @loadhalf() {
i32 0, i32 0, i32 1, i32 0, i1 false)

; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f16 @dx.op.bufferLoad.f16(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
%data0 = call <4 x half> @llvm.dx.typedBufferLoad(
%data0 = call {half, half, half, half} @llvm.dx.typedBufferLoad.f16(
target("dx.TypedBuffer", <4 x half>, 0, 0, 0) %buffer, i32 0)

ret void
Expand All @@ -95,7 +65,7 @@ define void @loadi16() {
i32 0, i32 0, i32 1, i32 0, i1 false)

; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.i16 @dx.op.bufferLoad.i16(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef)
%data0 = call <4 x i16> @llvm.dx.typedBufferLoad(
%data0 = call {i16, i16, i16, i16} @llvm.dx.typedBufferLoad.i16(
target("dx.TypedBuffer", <4 x i16>, 0, 0, 0) %buffer, i32 0)

ret void
Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.