Skip to content

[SandboxIR] Implement ScalableVectorType #108124

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 68 additions & 15 deletions llvm/include/llvm/SandboxIR/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class Context;
class PointerType;
class VectorType;
class FixedVectorType;
class ScalableVectorType;
class IntegerType;
class FunctionType;
class ArrayType;
Expand All @@ -39,21 +40,22 @@ class StructType;
class Type {
protected:
llvm::Type *LLVMTy;
friend class ArrayType; // For LLVMTy.
friend class StructType; // For LLVMTy.
friend class VectorType; // For LLVMTy.
friend class FixedVectorType; // For LLVMTy.
friend class PointerType; // For LLVMTy.
friend class FunctionType; // For LLVMTy.
friend class IntegerType; // For LLVMTy.
friend class Function; // For LLVMTy.
friend class CallBase; // For LLVMTy.
friend class ConstantInt; // For LLVMTy.
friend class ConstantArray; // For LLVMTy.
friend class ConstantStruct; // For LLVMTy.
friend class ConstantVector; // For LLVMTy.
friend class CmpInst; // For LLVMTy. TODO: Cleanup after
// sandboxir::VectorType is more complete.
friend class ArrayType; // For LLVMTy.
friend class StructType; // For LLVMTy.
friend class VectorType; // For LLVMTy.
friend class FixedVectorType; // For LLVMTy.
friend class ScalableVectorType; // For LLVMTy.
friend class PointerType; // For LLVMTy.
friend class FunctionType; // For LLVMTy.
friend class IntegerType; // For LLVMTy.
friend class Function; // For LLVMTy.
friend class CallBase; // For LLVMTy.
friend class ConstantInt; // For LLVMTy.
friend class ConstantArray; // For LLVMTy.
friend class ConstantStruct; // For LLVMTy.
friend class ConstantVector; // For LLVMTy.
friend class CmpInst; // For LLVMTy. TODO: Cleanup after
// sandboxir::VectorType is more complete.

// Friend all instruction classes because `create()` functions use LLVMTy.
#define DEF_INSTR(ID, OPCODE, CLASS) friend class CLASS;
Expand Down Expand Up @@ -390,6 +392,57 @@ class FixedVectorType : public VectorType {
}
};

class ScalableVectorType : public VectorType {
public:
static ScalableVectorType *get(Type *ElementType, unsigned MinNumElts);

static ScalableVectorType *get(Type *ElementType,
const ScalableVectorType *SVTy) {
return get(ElementType, SVTy->getMinNumElements());
}

static ScalableVectorType *getInteger(ScalableVectorType *VTy) {
return cast<ScalableVectorType>(VectorType::getInteger(VTy));
}

static ScalableVectorType *
getExtendedElementVectorType(ScalableVectorType *VTy) {
return cast<ScalableVectorType>(
VectorType::getExtendedElementVectorType(VTy));
}

static ScalableVectorType *
getTruncatedElementVectorType(ScalableVectorType *VTy) {
return cast<ScalableVectorType>(
VectorType::getTruncatedElementVectorType(VTy));
}

static ScalableVectorType *getSubdividedVectorType(ScalableVectorType *VTy,
int NumSubdivs) {
return cast<ScalableVectorType>(
VectorType::getSubdividedVectorType(VTy, NumSubdivs));
}

static ScalableVectorType *
getHalfElementsVectorType(ScalableVectorType *VTy) {
return cast<ScalableVectorType>(VectorType::getHalfElementsVectorType(VTy));
}

static ScalableVectorType *
getDoubleElementsVectorType(ScalableVectorType *VTy) {
return cast<ScalableVectorType>(
VectorType::getDoubleElementsVectorType(VTy));
}

unsigned getMinNumElements() const {
return cast<llvm::ScalableVectorType>(LLVMTy)->getMinNumElements();
}

static bool classof(const Type *T) {
return isa<llvm::ScalableVectorType>(T->LLVMTy);
}
};

class FunctionType : public Type {
public:
// TODO: add missing functions
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/SandboxIR/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ FixedVectorType *FixedVectorType::get(Type *ElementType, unsigned NumElts) {
llvm::FixedVectorType::get(ElementType->LLVMTy, NumElts)));
}

ScalableVectorType *ScalableVectorType::get(Type *ElementType,
unsigned NumElts) {
return cast<ScalableVectorType>(ElementType->getContext().getType(
llvm::ScalableVectorType::get(ElementType->LLVMTy, NumElts)));
}

IntegerType *IntegerType::get(Context &Ctx, unsigned NumBits) {
return cast<IntegerType>(
Ctx.getType(llvm::IntegerType::get(Ctx.LLVMCtx, NumBits)));
Expand Down
59 changes: 59 additions & 0 deletions llvm/unittests/SandboxIR/TypesTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,65 @@ define void @foo(<4 x i16> %vi0, <4 x float> %vf1, i8 %i0) {
EXPECT_EQ(Vec8i16Ty->getElementCount(), ElementCount::getFixed(8));
}

TEST_F(SandboxTypeTest, ScalableVectorType) {
parseIR(C, R"IR(
define void @foo(<vscale x 4 x i16> %vi0, <vscale x 4 x float> %vf1, i8 %i0) {
ret void
}
)IR");
llvm::Function *LLVMF = &*M->getFunction("foo");
sandboxir::Context Ctx(C);
auto *F = Ctx.createFunction(LLVMF);
// Check classof(), creation, accessors
auto *Vec4i16Ty =
cast<sandboxir::ScalableVectorType>(F->getArg(0)->getType());
EXPECT_TRUE(Vec4i16Ty->getElementType()->isIntegerTy(16));
EXPECT_EQ(Vec4i16Ty->getMinNumElements(), 4u);

// get(ElementType, NumElements)
EXPECT_EQ(
sandboxir::ScalableVectorType::get(sandboxir::Type::getInt16Ty(Ctx), 4),
F->getArg(0)->getType());
// get(ElementType, Other)
EXPECT_EQ(sandboxir::ScalableVectorType::get(
sandboxir::Type::getInt16Ty(Ctx),
cast<sandboxir::ScalableVectorType>(F->getArg(0)->getType())),
F->getArg(0)->getType());
auto *Vec4FTy = cast<sandboxir::ScalableVectorType>(F->getArg(1)->getType());
EXPECT_TRUE(Vec4FTy->getElementType()->isFloatTy());
// getInteger
auto *Vec4i32Ty = sandboxir::ScalableVectorType::getInteger(Vec4FTy);
EXPECT_TRUE(Vec4i32Ty->getElementType()->isIntegerTy(32));
EXPECT_EQ(Vec4i32Ty->getMinNumElements(), Vec4FTy->getMinNumElements());
// getExtendedElementCountVectorType
auto *Vec4i64Ty =
sandboxir::ScalableVectorType::getExtendedElementVectorType(Vec4i16Ty);
EXPECT_TRUE(Vec4i64Ty->getElementType()->isIntegerTy(32));
EXPECT_EQ(Vec4i64Ty->getMinNumElements(), Vec4i16Ty->getMinNumElements());
// getTruncatedElementVectorType
auto *Vec4i8Ty =
sandboxir::ScalableVectorType::getTruncatedElementVectorType(Vec4i16Ty);
EXPECT_TRUE(Vec4i8Ty->getElementType()->isIntegerTy(8));
EXPECT_EQ(Vec4i8Ty->getMinNumElements(), Vec4i8Ty->getMinNumElements());
// getSubdividedVectorType
auto *Vec8i8Ty =
sandboxir::ScalableVectorType::getSubdividedVectorType(Vec4i16Ty, 1);
EXPECT_TRUE(Vec8i8Ty->getElementType()->isIntegerTy(8));
EXPECT_EQ(Vec8i8Ty->getMinNumElements(), 8u);
// getMinNumElements
EXPECT_EQ(Vec8i8Ty->getMinNumElements(), 8u);
// getHalfElementsVectorType
auto *Vec2i16Ty =
sandboxir::ScalableVectorType::getHalfElementsVectorType(Vec4i16Ty);
EXPECT_TRUE(Vec2i16Ty->getElementType()->isIntegerTy(16));
EXPECT_EQ(Vec2i16Ty->getMinNumElements(), 2u);
// getDoubleElementsVectorType
auto *Vec8i16Ty =
sandboxir::ScalableVectorType::getDoubleElementsVectorType(Vec4i16Ty);
EXPECT_TRUE(Vec8i16Ty->getElementType()->isIntegerTy(16));
EXPECT_EQ(Vec8i16Ty->getMinNumElements(), 8u);
}

TEST_F(SandboxTypeTest, FunctionType) {
parseIR(C, R"IR(
define void @foo() {
Expand Down
Loading