diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td index 6150f55796d0e..4aec85fb24212 100644 --- a/llvm/include/llvm/IR/IntrinsicsDirectX.td +++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td @@ -31,8 +31,9 @@ def int_dx_handle_fromBinding [IntrNoMem]>; def int_dx_typedBufferLoad - : DefaultAttrsIntrinsic<[llvm_anyvector_ty], - [llvm_any_ty, llvm_i32_ty]>; + : DefaultAttrsIntrinsic< + [llvm_any_ty, LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], + [llvm_any_ty, llvm_i32_ty]>; // Cast between target extension handle types and dxil-style opaque handles def int_dx_cast_handle : Intrinsic<[llvm_any_ty], [llvm_any_ty]>; diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp index 1594fa533379b..acfefe8d2b875 100644 --- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp +++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp @@ -124,10 +124,7 @@ static OverloadKind getOverloadKind(Type *Ty) { // TODO: This is a hack. As described in DXILEmitter.cpp, we need to rework // how we're handling overloads and remove the `OverloadKind` proxy enum. StructType *ST = cast(Ty); - if (ST->hasName() && ST->getName().starts_with("dx.types.ResRet")) - return getOverloadKind(ST->getElementType(0)); - - return OverloadKind::ObjectType; + return getOverloadKind(ST->getElementType(0)); } default: llvm_unreachable("invalid overload type"); diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp index f31ec85641471..c0f3133f415ac 100644 --- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp +++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp @@ -270,7 +270,8 @@ class OpLowerer { createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType()); Value *Index0 = CI->getArgOperand(1); Value *Index1 = UndefValue::get(Int32Ty); - Type *RetTy = OpBuilder.getResRetType(CI->getType()->getScalarType()); + Type *ElTy = cast(CI->getType())->getElementType(0); + Type *RetTy = OpBuilder.getResRetType(ElTy); std::array Args{Handle, Index0, Index1}; Expected OpCall = @@ -278,33 +279,18 @@ class OpLowerer { if (Error E = OpCall.takeError()) return E; - std::array Extracts = {}; - - // We've switched the return type from a vector to a struct, but at this - // point most vectors have probably already been scalarized. Try to - // forward arguments directly rather than inserting into and immediately - // extracting from a vector. - for (Use &U : make_early_inc_range(CI->uses())) - if (auto *EEI = dyn_cast(U.getUser())) - if (auto *Index = dyn_cast(EEI->getIndexOperand())) { - size_t IndexVal = Index->getZExtValue(); - assert(IndexVal < 4 && "Index into buffer load out of range"); - if (!Extracts[IndexVal]) - Extracts[IndexVal] = IRB.CreateExtractValue(*OpCall, IndexVal); - EEI->replaceAllUsesWith(Extracts[IndexVal]); - EEI->eraseFromParent(); - } - - // If there are still uses then we need to create a vector. - if (!CI->use_empty()) { - for (int I = 0, E = 4; I != E; ++I) - if (!Extracts[I]) - Extracts[I] = IRB.CreateExtractValue(*OpCall, I); - - Value *Vec = UndefValue::get(CI->getType()); - for (int I = 0, E = 4; I != E; ++I) - Vec = IRB.CreateInsertElement(Vec, Extracts[I], I); - CI->replaceAllUsesWith(Vec); + // We've switched the return type from an anonymous struct to a named one, + // so we need to update the types in the uses. + for (Use &U : make_early_inc_range(CI->uses())) { + // Uses other than extract value should be impossible at this point, as + // we shouldn't be able to call functions with the anonymous struct or + // store these directly. + auto *EVI = cast(U.getUser()); + IRB.SetInsertPoint(EVI); + auto *NewEVI = IRB.CreateExtractValue(*OpCall, EVI->getIndices()); + + EVI->replaceAllUsesWith(NewEVI); + EVI->eraseFromParent(); } CI->eraseFromParent(); diff --git a/llvm/test/CodeGen/DirectX/BufferLoad.ll b/llvm/test/CodeGen/DirectX/BufferLoad.ll index c3bb96dbdf909..7e42084dfce78 100644 --- a/llvm/test/CodeGen/DirectX/BufferLoad.ll +++ b/llvm/test/CodeGen/DirectX/BufferLoad.ll @@ -3,7 +3,6 @@ target triple = "dxil-pc-shadermodel6.6-compute" declare void @scalar_user(float) -declare void @vector_user(<4 x float>) define void @loadfloats() { ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding @@ -16,46 +15,17 @@ define void @loadfloats() { ; CHECK-NOT: %dx.cast_handle ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef) - %data0 = call <4 x float> @llvm.dx.typedBufferLoad( + %data0 = call { float, float, float, float } @llvm.dx.typedBufferLoad.f32( target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 0) - ; The extract order depends on the users, so don't enforce that here. - ; CHECK-DAG: extractvalue %dx.types.ResRet.f32 [[DATA0]], 0 - %data0_0 = extractelement <4 x float> %data0, i32 0 - ; CHECK-DAG: extractvalue %dx.types.ResRet.f32 [[DATA0]], 2 - %data0_2 = extractelement <4 x float> %data0, i32 2 + ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA0]], 0 + %data0_0 = extractvalue {float, float, float, float} %data0, 0 + ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA0]], 2 + %data0_2 = extractvalue {float, float, float, float} %data0, 2 - ; If all of the uses are extracts, we skip creating a vector - ; CHECK-NOT: insertelement call void @scalar_user(float %data0_0) call void @scalar_user(float %data0_2) - ; CHECK: [[DATA4:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 4, i32 undef) - %data4 = call <4 x float> @llvm.dx.typedBufferLoad( - target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 4) - - ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 0 - ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 1 - ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 2 - ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 3 - ; CHECK: insertelement <4 x float> undef - ; CHECK: insertelement <4 x float> - ; CHECK: insertelement <4 x float> - ; CHECK: insertelement <4 x float> - call void @vector_user(<4 x float> %data4) - - ; CHECK: [[DATA12:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 12, i32 undef) - %data12 = call <4 x float> @llvm.dx.typedBufferLoad( - target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 12) - - ; CHECK: [[DATA12_3:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA12]], 3 - %data12_3 = extractelement <4 x float> %data12, i32 3 - - ; If there are a mix of users we need the vector, but extracts are direct - ; CHECK: call void @scalar_user(float [[DATA12_3]]) - call void @scalar_user(float %data12_3) - call void @vector_user(<4 x float> %data12) - ret void } @@ -67,7 +37,7 @@ define void @loadint() { i32 0, i32 0, i32 1, i32 0, i1 false) ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.i32 @dx.op.bufferLoad.i32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef) - %data0 = call <4 x i32> @llvm.dx.typedBufferLoad( + %data0 = call {i32, i32, i32, i32} @llvm.dx.typedBufferLoad.i32( target("dx.TypedBuffer", <4 x i32>, 0, 0, 0) %buffer, i32 0) ret void @@ -81,7 +51,7 @@ define void @loadhalf() { i32 0, i32 0, i32 1, i32 0, i1 false) ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f16 @dx.op.bufferLoad.f16(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef) - %data0 = call <4 x half> @llvm.dx.typedBufferLoad( + %data0 = call {half, half, half, half} @llvm.dx.typedBufferLoad.f16( target("dx.TypedBuffer", <4 x half>, 0, 0, 0) %buffer, i32 0) ret void @@ -95,7 +65,7 @@ define void @loadi16() { i32 0, i32 0, i32 1, i32 0, i1 false) ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.i16 @dx.op.bufferLoad.i16(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef) - %data0 = call <4 x i16> @llvm.dx.typedBufferLoad( + %data0 = call {i16, i16, i16, i16} @llvm.dx.typedBufferLoad.i16( target("dx.TypedBuffer", <4 x i16>, 0, 0, 0) %buffer, i32 0) ret void