Skip to content

Commit b221b97

Browse files
Add support for SPIR-V extension: SPV_INTEL_subgroups (#81023)
The goal of this PR is to implement SPV_INTEL_subgroups extension in SPIR-V Backend.
1 parent 1e36d92 commit b221b97

File tree

7 files changed

+373
-4
lines changed

7 files changed

+373
-4
lines changed

llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
#include "SPIRVBuiltins.h"
1515
#include "SPIRV.h"
16+
#include "SPIRVSubtarget.h"
1617
#include "SPIRVUtils.h"
1718
#include "llvm/ADT/StringExtras.h"
1819
#include "llvm/Analysis/ValueTracking.h"
@@ -82,6 +83,16 @@ struct GroupBuiltin {
8283
#define GET_GroupBuiltins_DECL
8384
#define GET_GroupBuiltins_IMPL
8485

86+
struct IntelSubgroupsBuiltin {
87+
StringRef Name;
88+
uint32_t Opcode;
89+
bool IsBlock;
90+
bool IsWrite;
91+
};
92+
93+
#define GET_IntelSubgroupsBuiltins_DECL
94+
#define GET_IntelSubgroupsBuiltins_IMPL
95+
8596
struct GetBuiltin {
8697
StringRef Name;
8798
InstructionSet::InstructionSet Set;
@@ -549,6 +560,7 @@ static bool buildAtomicCompareExchangeInst(const SPIRV::IncomingCall *Call,
549560
assert(GR->getSPIRVTypeForVReg(ObjectPtr)->getOpcode() ==
550561
SPIRV::OpTypePointer);
551562
unsigned ExpectedType = GR->getSPIRVTypeForVReg(ExpectedArg)->getOpcode();
563+
(void)ExpectedType;
552564
assert(IsCmpxchg ? ExpectedType == SPIRV::OpTypeInt
553565
: ExpectedType == SPIRV::OpTypePointer);
554566
assert(GR->isScalarOfType(Desired, SPIRV::OpTypeInt));
@@ -849,6 +861,7 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
849861
if (GroupBuiltin->HasBoolArg) {
850862
Register ConstRegister = Call->Arguments[0];
851863
auto ArgInstruction = getDefInstrMaybeConstant(ConstRegister, MRI);
864+
(void)ArgInstruction;
852865
// TODO: support non-constant bool values.
853866
assert(ArgInstruction->getOpcode() == TargetOpcode::G_CONSTANT &&
854867
"Only constant bool value args are supported");
@@ -900,6 +913,67 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
900913
return true;
901914
}
902915

916+
static bool generateIntelSubgroupsInst(const SPIRV::IncomingCall *Call,
917+
MachineIRBuilder &MIRBuilder,
918+
SPIRVGlobalRegistry *GR) {
919+
const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
920+
MachineFunction &MF = MIRBuilder.getMF();
921+
const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
922+
if (!ST->canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
923+
std::string DiagMsg = std::string(Builtin->Name) +
924+
": the builtin requires the following SPIR-V "
925+
"extension: SPV_INTEL_subgroups";
926+
report_fatal_error(DiagMsg.c_str(), false);
927+
}
928+
const SPIRV::IntelSubgroupsBuiltin *IntelSubgroups =
929+
SPIRV::lookupIntelSubgroupsBuiltin(Builtin->Name);
930+
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
931+
932+
uint32_t OpCode = IntelSubgroups->Opcode;
933+
if (IntelSubgroups->IsBlock) {
934+
// Minimal number or arguments set in TableGen records is 1
935+
if (SPIRVType *Arg0Type = GR->getSPIRVTypeForVReg(Call->Arguments[0])) {
936+
if (Arg0Type->getOpcode() == SPIRV::OpTypeImage) {
937+
// TODO: add required validation from the specification:
938+
// "'Image' must be an object whose type is OpTypeImage with a 'Sampled'
939+
// operand of 0 or 2. If the 'Sampled' operand is 2, then some
940+
// dimensions require a capability."
941+
switch (OpCode) {
942+
case SPIRV::OpSubgroupBlockReadINTEL:
943+
OpCode = SPIRV::OpSubgroupImageBlockReadINTEL;
944+
break;
945+
case SPIRV::OpSubgroupBlockWriteINTEL:
946+
OpCode = SPIRV::OpSubgroupImageBlockWriteINTEL;
947+
break;
948+
}
949+
}
950+
}
951+
}
952+
953+
// TODO: opaque pointers types should be eventually resolved in such a way
954+
// that validation of block read is enabled with respect to the following
955+
// specification requirement:
956+
// "'Result Type' may be a scalar or vector type, and its component type must
957+
// be equal to the type pointed to by 'Ptr'."
958+
// For example, function parameter type should not be default i8 pointer, but
959+
// depend on the result type of the instruction where it is used as a pointer
960+
// argument of OpSubgroupBlockReadINTEL
961+
962+
// Build Intel subgroups instruction
963+
MachineInstrBuilder MIB =
964+
IntelSubgroups->IsWrite
965+
? MIRBuilder.buildInstr(OpCode)
966+
: MIRBuilder.buildInstr(OpCode)
967+
.addDef(Call->ReturnRegister)
968+
.addUse(GR->getSPIRVTypeID(Call->ReturnType));
969+
for (size_t i = 0; i < Call->Arguments.size(); ++i) {
970+
MIB.addUse(Call->Arguments[i]);
971+
MRI->setRegClass(Call->Arguments[i], &SPIRV::IDRegClass);
972+
}
973+
974+
return true;
975+
}
976+
903977
// These queries ask for a single size_t result for a given dimension index, e.g
904978
// size_t get_global_id(uint dimindex). In SPIR-V, the builtins corresonding to
905979
// these values are all vec3 types, so we need to extract the correct index or
@@ -1199,6 +1273,7 @@ static bool generateImageMiscQueryInst(const SPIRV::IncomingCall *Call,
11991273
MIRBuilder.getMRI()->setRegClass(Image, &SPIRV::IDRegClass);
12001274
SPIRV::Dim::Dim ImageDimensionality = static_cast<SPIRV::Dim::Dim>(
12011275
GR->getSPIRVTypeForVReg(Image)->getOperand(2).getImm());
1276+
(void)ImageDimensionality;
12021277

12031278
switch (Opcode) {
12041279
case SPIRV::OpImageQuerySamples:
@@ -1976,6 +2051,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
19762051
return generateVectorLoadStoreInst(Call.get(), MIRBuilder, GR);
19772052
case SPIRV::LoadStore:
19782053
return generateLoadStoreInst(Call.get(), MIRBuilder, GR);
2054+
case SPIRV::IntelSubgroups:
2055+
return generateIntelSubgroupsInst(Call.get(), MIRBuilder, GR);
19792056
}
19802057
return false;
19812058
}
@@ -2119,6 +2196,7 @@ parseBuiltinTypeNameToTargetExtType(std::string TypeName,
21192196
for (unsigned i = HasTypeParameter ? 1 : 0; i < Parameters.size(); i++) {
21202197
unsigned IntParameter = 0;
21212198
bool ValidLiteral = !Parameters[i].getAsInteger(10, IntParameter);
2199+
(void)ValidLiteral;
21222200
assert(ValidLiteral &&
21232201
"Invalid format of SPIR-V builtin parameter literal!");
21242202
IntParameters.push_back(IntParameter);

llvm/lib/Target/SPIRV/SPIRVBuiltins.td

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def Enqueue : BuiltinGroup;
5454
def AsyncCopy : BuiltinGroup;
5555
def VectorLoadStore : BuiltinGroup;
5656
def LoadStore : BuiltinGroup;
57+
def IntelSubgroups : BuiltinGroup;
5758

5859
//===----------------------------------------------------------------------===//
5960
// Class defining a demangled builtin record. The information in the record
@@ -625,7 +626,7 @@ def GroupBuiltins : GenericTable {
625626
"IsBallotFindBit", "IsLogical", "NoGroupOperation", "HasBoolArg"];
626627
}
627628

628-
// Function to lookup native builtins by their name and set.
629+
// Function to lookup group builtins by their name and set.
629630
def lookupGroupBuiltin : SearchIndex {
630631
let Table = GroupBuiltins;
631632
let Key = ["Name"];
@@ -871,6 +872,61 @@ defm : DemangledGroupBuiltin<"group_non_uniform_scan_inclusive_logical_xors", Wo
871872
defm : DemangledGroupBuiltin<"group_non_uniform_scan_exclusive_logical_xors", WorkOrSub, OpGroupNonUniformLogicalXor>;
872873
defm : DemangledGroupBuiltin<"group_clustered_reduce_logical_xor", WorkOrSub, OpGroupNonUniformLogicalXor>;
873874

875+
//===----------------------------------------------------------------------===//
876+
// Class defining a sub group builtin that should be translated into a
877+
// SPIR-V instruction using the SPV_INTEL_subgroups extension.
878+
//
879+
// name is the demangled name of the given builtin.
880+
// opcode specifies the SPIR-V operation code of the generated instruction.
881+
//===----------------------------------------------------------------------===//
882+
class IntelSubgroupsBuiltin<string name, Op operation> {
883+
string Name = name;
884+
Op Opcode = operation;
885+
bit IsBlock = !or(!eq(operation, OpSubgroupBlockReadINTEL),
886+
!eq(operation, OpSubgroupBlockWriteINTEL));
887+
bit IsWrite = !eq(operation, OpSubgroupBlockWriteINTEL);
888+
}
889+
890+
// Table gathering all the Intel sub group builtins.
891+
def IntelSubgroupsBuiltins : GenericTable {
892+
let FilterClass = "IntelSubgroupsBuiltin";
893+
let Fields = ["Name", "Opcode", "IsBlock", "IsWrite"];
894+
}
895+
896+
// Function to lookup group builtins by their name and set.
897+
def lookupIntelSubgroupsBuiltin : SearchIndex {
898+
let Table = IntelSubgroupsBuiltins;
899+
let Key = ["Name"];
900+
}
901+
902+
// Multiclass used to define incoming builtin records for the SPV_INTEL_subgroups extension
903+
// and corresponding work/sub group builtin records.
904+
multiclass DemangledIntelSubgroupsBuiltin<string name, bits<8> minNumArgs, bits<8> maxNumArgs, Op operation> {
905+
def : DemangledBuiltin<!strconcat("intel_sub_group_", name), OpenCL_std, IntelSubgroups, minNumArgs, maxNumArgs>;
906+
def : IntelSubgroupsBuiltin<!strconcat("intel_sub_group_", name), operation>;
907+
}
908+
909+
// cl_intel_subgroups
910+
defm : DemangledIntelSubgroupsBuiltin<"shuffle", 2, 2, OpSubgroupShuffleINTEL>;
911+
defm : DemangledIntelSubgroupsBuiltin<"shuffle_down", 3, 3, OpSubgroupShuffleDownINTEL>;
912+
defm : DemangledIntelSubgroupsBuiltin<"shuffle_up", 3, 3, OpSubgroupShuffleUpINTEL>;
913+
defm : DemangledIntelSubgroupsBuiltin<"shuffle_xor", 2, 2, OpSubgroupShuffleXorINTEL>;
914+
foreach i = ["", "2", "4", "8"] in {
915+
// cl_intel_subgroups
916+
defm : DemangledIntelSubgroupsBuiltin<!strconcat("block_read", i), 1, 2, OpSubgroupBlockReadINTEL>;
917+
defm : DemangledIntelSubgroupsBuiltin<!strconcat("block_write", i), 2, 3, OpSubgroupBlockWriteINTEL>;
918+
// cl_intel_subgroups_short
919+
defm : DemangledIntelSubgroupsBuiltin<!strconcat("block_read_ui", i), 1, 2, OpSubgroupBlockReadINTEL>;
920+
defm : DemangledIntelSubgroupsBuiltin<!strconcat("block_write_ui", i), 2, 3, OpSubgroupBlockWriteINTEL>;
921+
}
922+
// cl_intel_subgroups_char, cl_intel_subgroups_short, cl_intel_subgroups_long
923+
foreach i = ["", "2", "4", "8", "16"] in {
924+
foreach j = ["c", "s", "l"] in {
925+
defm : DemangledIntelSubgroupsBuiltin<!strconcat("block_read_u", j, i), 1, 2, OpSubgroupBlockReadINTEL>;
926+
defm : DemangledIntelSubgroupsBuiltin<!strconcat("block_write_u", j, i), 2, 3, OpSubgroupBlockWriteINTEL>;
927+
}
928+
}
929+
// OpSubgroupImageBlockReadINTEL and OpSubgroupImageBlockWriteINTEL are to be resolved later on (in code)
874930

875931
//===----------------------------------------------------------------------===//
876932
// Class defining a get builtin record used for lowering builtin calls such as

llvm/lib/Target/SPIRV/SPIRVInstrInfo.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,3 +761,21 @@ def OpGroupNonUniformBitwiseXor: OpGroupNUGroup<"BitwiseXor", 361>;
761761
def OpGroupNonUniformLogicalAnd: OpGroupNUGroup<"LogicalAnd", 362>;
762762
def OpGroupNonUniformLogicalOr: OpGroupNUGroup<"LogicalOr", 363>;
763763
def OpGroupNonUniformLogicalXor: OpGroupNUGroup<"LogicalXor", 364>;
764+
765+
// 3.49.21. Group and Subgroup Instructions
766+
def OpSubgroupShuffleINTEL: Op<5571, (outs ID:$res), (ins TYPE:$type, ID:$data, ID:$invocationId),
767+
"$res = OpSubgroupShuffleINTEL $type $data $invocationId">;
768+
def OpSubgroupShuffleDownINTEL: Op<5572, (outs ID:$res), (ins TYPE:$type, ID:$current, ID:$next, ID:$delta),
769+
"$res = OpSubgroupShuffleDownINTEL $type $current $next $delta">;
770+
def OpSubgroupShuffleUpINTEL: Op<5573, (outs ID:$res), (ins TYPE:$type, ID:$previous, ID:$current, ID:$delta),
771+
"$res = OpSubgroupShuffleUpINTEL $type $previous $current $delta">;
772+
def OpSubgroupShuffleXorINTEL: Op<5574, (outs ID:$res), (ins TYPE:$type, ID:$data, ID:$value),
773+
"$res = OpSubgroupShuffleXorINTEL $type $data $value">;
774+
def OpSubgroupBlockReadINTEL: Op<5575, (outs ID:$res), (ins TYPE:$type, ID:$ptr),
775+
"$res = OpSubgroupBlockReadINTEL $type $ptr">;
776+
def OpSubgroupBlockWriteINTEL: Op<5576, (outs), (ins ID:$ptr, ID:$data),
777+
"OpSubgroupBlockWriteINTEL $ptr $data">;
778+
def OpSubgroupImageBlockReadINTEL: Op<5577, (outs ID:$res), (ins TYPE:$type, ID:$image, ID:$coordinate),
779+
"$res = OpSubgroupImageBlockReadINTEL $type $image $coordinate">;
780+
def OpSubgroupImageBlockWriteINTEL: Op<5578, (outs), (ins ID:$image, ID:$coordinate, ID:$data),
781+
"OpSubgroupImageBlockWriteINTEL $image $coordinate $data">;

llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,29 @@ void addInstrRequirements(const MachineInstr &MI,
908908
case SPIRV::OpGroupNonUniformBallotFindMSB:
909909
Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot);
910910
break;
911+
case SPIRV::OpSubgroupShuffleINTEL:
912+
case SPIRV::OpSubgroupShuffleDownINTEL:
913+
case SPIRV::OpSubgroupShuffleUpINTEL:
914+
case SPIRV::OpSubgroupShuffleXorINTEL:
915+
if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
916+
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
917+
Reqs.addCapability(SPIRV::Capability::SubgroupShuffleINTEL);
918+
}
919+
break;
920+
case SPIRV::OpSubgroupBlockReadINTEL:
921+
case SPIRV::OpSubgroupBlockWriteINTEL:
922+
if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
923+
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
924+
Reqs.addCapability(SPIRV::Capability::SubgroupBufferBlockIOINTEL);
925+
}
926+
break;
927+
case SPIRV::OpSubgroupImageBlockReadINTEL:
928+
case SPIRV::OpSubgroupImageBlockWriteINTEL:
929+
if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_subgroups)) {
930+
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_subgroups);
931+
Reqs.addCapability(SPIRV::Capability::SubgroupImageBlockIOINTEL);
932+
}
933+
break;
911934
case SPIRV::OpAssumeTrueKHR:
912935
case SPIRV::OpExpectKHR:
913936
if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume)) {

llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ cl::list<SPIRV::Extension::Extension> Extensions(
3737
clEnumValN(SPIRV::Extension::SPV_INTEL_optnone, "SPV_INTEL_optnone",
3838
"Adds OptNoneINTEL value for Function Control mask that "
3939
"indicates a request to not optimize the function"),
40+
clEnumValN(SPIRV::Extension::SPV_INTEL_subgroups, "SPV_INTEL_subgroups",
41+
"Allows work items in a subgroup to share data without the "
42+
"use of local memory and work group barriers, and to "
43+
"utilize specialized hardware to load and store blocks of "
44+
"data from images or buffers."),
4045
clEnumValN(SPIRV::Extension::SPV_KHR_no_integer_wrap_decoration,
4146
"SPV_KHR_no_integer_wrap_decoration",
4247
"Adds decorations to indicate that a given instruction does "

llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -431,9 +431,9 @@ defm InputAttachmentArrayNonUniformIndexingEXT : CapabilityOperand<5310, 0, 0, [
431431
defm UniformTexelBufferArrayNonUniformIndexingEXT : CapabilityOperand<5311, 0, 0, [], [SampledBuffer, ShaderNonUniformEXT]>;
432432
defm StorageTexelBufferArrayNonUniformIndexingEXT : CapabilityOperand<5312, 0, 0, [], [ImageBuffer, ShaderNonUniformEXT]>;
433433
defm RayTracingNV : CapabilityOperand<5340, 0, 0, [], [Shader]>;
434-
defm SubgroupShuffleINTEL : CapabilityOperand<5568, 0, 0, [], []>;
435-
defm SubgroupBufferBlockIOINTEL : CapabilityOperand<5569, 0, 0, [], []>;
436-
defm SubgroupImageBlockIOINTEL : CapabilityOperand<5570, 0, 0, [], []>;
434+
defm SubgroupShuffleINTEL : CapabilityOperand<5568, 0, 0, [SPV_INTEL_subgroups], []>;
435+
defm SubgroupBufferBlockIOINTEL : CapabilityOperand<5569, 0, 0, [SPV_INTEL_subgroups], []>;
436+
defm SubgroupImageBlockIOINTEL : CapabilityOperand<5570, 0, 0, [SPV_INTEL_subgroups], []>;
437437
defm SubgroupImageMediaBlockIOINTEL : CapabilityOperand<5579, 0, 0, [], []>;
438438
defm SubgroupAvcMotionEstimationINTEL : CapabilityOperand<5696, 0, 0, [], []>;
439439
defm SubgroupAvcMotionEstimationIntraINTEL : CapabilityOperand<5697, 0, 0, [], []>;

0 commit comments

Comments
 (0)