From 4abb2ac407ff9c89d951def478349f924fb7139d Mon Sep 17 00:00:00 2001 From: Kohei Yamaguchi Date: Tue, 30 Jan 2024 15:43:57 +0000 Subject: [PATCH 1/3] [mlir][spirv] Fix a crash of typeConverter with non supported type --- mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp index c62e676efc159..c7c67c04c8919 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -367,12 +367,13 @@ class AccessChainPattern : public SPIRVToLLVMConversion { Value zero = rewriter.create( op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0)); indices.insert(indices.begin(), zero); - rewriter.replaceOpWithNewOp( - op, dstType, - typeConverter.convertType( - cast(op.getBasePtr().getType()) - .getPointeeType()), - adaptor.getBasePtr(), indices); + + auto elementType = typeConverter.convertType( + cast(op.getBasePtr().getType()).getPointeeType()); + if (!elementType) + return failure(); + rewriter.replaceOpWithNewOp(op, dstType, elementType, + adaptor.getBasePtr(), indices); return success(); } }; From c0beb3feae88e878fb7db9647563430ea05f353f Mon Sep 17 00:00:00 2001 From: Kohei Yamaguchi Date: Wed, 31 Jan 2024 09:40:39 +0000 Subject: [PATCH 2/3] addressed comment --- mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp index c7c67c04c8919..607f4c595169f 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -371,7 +371,7 @@ class AccessChainPattern : public SPIRVToLLVMConversion { auto elementType = typeConverter.convertType( cast(op.getBasePtr().getType()).getPointeeType()); if (!elementType) - return failure(); + return rewriter.notifyMatchFailure(op, "type conversion failed"); rewriter.replaceOpWithNewOp(op, dstType, elementType, adaptor.getBasePtr(), indices); return success(); From ea81352d770087f789d610fc5ad1f45d7147def5 Mon Sep 17 00:00:00 2001 From: Kohei Yamaguchi Date: Wed, 31 Jan 2024 10:54:24 +0000 Subject: [PATCH 3/3] more support --- .../Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp | 106 ++++++++++-------- 1 file changed, 57 insertions(+), 49 deletions(-) diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp index 607f4c595169f..11d2312b9492f 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -240,7 +240,7 @@ static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, if (auto loadOp = dyn_cast(op)) { auto dstType = typeConverter.convertType(loadOp.getType()); if (!dstType) - return failure(); + return rewriter.notifyMatchFailure(op, "type conversion failed"); rewriter.replaceOpWithNewOp( loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment, isVolatile, isNonTemporal); @@ -357,13 +357,13 @@ class AccessChainPattern : public SPIRVToLLVMConversion { ConversionPatternRewriter &rewriter) const override { auto dstType = typeConverter.convertType(op.getComponentPtr().getType()); if (!dstType) - return failure(); + return rewriter.notifyMatchFailure(op, "type conversion failed"); // To use GEP we need to add a first 0 index to go through the pointer. auto indices = llvm::to_vector<4>(adaptor.getIndices()); Type indexType = op.getIndices().front().getType(); auto llvmIndexType = typeConverter.convertType(indexType); if (!llvmIndexType) - return failure(); + return rewriter.notifyMatchFailure(op, "type conversion failed"); Value zero = rewriter.create( op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0)); indices.insert(indices.begin(), zero); @@ -387,7 +387,7 @@ class AddressOfPattern : public SPIRVToLLVMConversion { ConversionPatternRewriter &rewriter) const override { auto dstType = typeConverter.convertType(op.getPointer().getType()); if (!dstType) - return failure(); + return rewriter.notifyMatchFailure(op, "type conversion failed"); rewriter.replaceOpWithNewOp(op, dstType, op.getVariable()); return success(); @@ -405,7 +405,7 @@ class BitFieldInsertPattern auto srcType = op.getType(); auto dstType = typeConverter.convertType(srcType); if (!dstType) - return failure(); + return rewriter.notifyMatchFailure(op, "type conversion failed"); Location loc = op.getLoc(); // Process `Offset` and `Count`: broadcast and extend/truncate if needed. @@ -452,7 +452,7 @@ class ConstantScalarAndVectorPattern auto dstType = typeConverter.convertType(srcType); if (!dstType) - return failure(); + return rewriter.notifyMatchFailure(constOp, "type conversion failed"); // SPIR-V constant can be a signed/unsigned integer, which has to be // casted to signless integer when converting to LLVM dialect. Removing the @@ -493,7 +493,7 @@ class BitFieldSExtractPattern auto srcType = op.getType(); auto dstType = typeConverter.convertType(srcType); if (!dstType) - return failure(); + return rewriter.notifyMatchFailure(op, "type conversion failed"); Location loc = op.getLoc(); // Process `Offset` and `Count`: broadcast and extend/truncate if needed. @@ -546,7 +546,7 @@ class BitFieldUExtractPattern auto srcType = op.getType(); auto dstType = typeConverter.convertType(srcType); if (!dstType) - return failure(); + return rewriter.notifyMatchFailure(op, "type conversion failed"); Location loc = op.getLoc(); // Process `Offset` and `Count`: broadcast and extend/truncate if needed. @@ -622,7 +622,7 @@ class CompositeExtractPattern ConversionPatternRewriter &rewriter) const override { auto dstType = this->typeConverter.convertType(op.getType()); if (!dstType) - return failure(); + return rewriter.notifyMatchFailure(op, "type conversion failed"); Type containerType = op.getComposite().getType(); if (isa(containerType)) { @@ -654,7 +654,7 @@ class CompositeInsertPattern ConversionPatternRewriter &rewriter) const override { auto dstType = this->typeConverter.convertType(op.getType()); if (!dstType) - return failure(); + return rewriter.notifyMatchFailure(op, "type conversion failed"); Type containerType = op.getComposite().getType(); if (isa(containerType)) { @@ -681,13 +681,13 @@ class DirectConversionPattern : public SPIRVToLLVMConversion { using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor, + matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto dstType = this->typeConverter.convertType(operation.getType()); + auto dstType = this->typeConverter.convertType(op.getType()); if (!dstType) - return failure(); + return rewriter.notifyMatchFailure(op, "type conversion failed"); rewriter.template replaceOpWithNewOp( - operation, dstType, adaptor.getOperands(), operation->getAttrs()); + op, dstType, adaptor.getOperands(), op->getAttrs()); return success(); } }; @@ -791,7 +791,7 @@ class GlobalVariablePattern auto srcType = cast(op.getType()); auto dstType = typeConverter.convertType(srcType.getPointeeType()); if (!dstType) - return failure(); + return rewriter.notifyMatchFailure(op, "type conversion failed"); // Limit conversion to the current invocation only or `StorageBuffer` // required by SPIR-V runner. @@ -844,23 +844,23 @@ class IndirectCastPattern : public SPIRVToLLVMConversion { using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor, + matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Type fromType = operation.getOperand().getType(); - Type toType = operation.getType(); + Type fromType = op.getOperand().getType(); + Type toType = op.getType(); auto dstType = this->typeConverter.convertType(toType); if (!dstType) - return failure(); + return rewriter.notifyMatchFailure(op, "type conversion failed"); if (getBitWidth(fromType) < getBitWidth(toType)) { - rewriter.template replaceOpWithNewOp(operation, dstType, + rewriter.template replaceOpWithNewOp(op, dstType, adaptor.getOperands()); return success(); } if (getBitWidth(fromType) > getBitWidth(toType)) { - rewriter.template replaceOpWithNewOp(operation, dstType, + rewriter.template replaceOpWithNewOp(op, dstType, adaptor.getOperands()); return success(); } @@ -884,6 +884,8 @@ class FunctionCallPattern // Function returns a single result. auto dstType = typeConverter.convertType(callOp.getType(0)); + if (!dstType) + return rewriter.notifyMatchFailure(callOp, "type conversion failed"); rewriter.replaceOpWithNewOp( callOp, dstType, adaptor.getOperands(), callOp->getAttrs()); return success(); @@ -897,16 +899,15 @@ class FComparePattern : public SPIRVToLLVMConversion { using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor, + matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto dstType = this->typeConverter.convertType(operation.getType()); + auto dstType = this->typeConverter.convertType(op.getType()); if (!dstType) - return failure(); + return rewriter.notifyMatchFailure(op, "type conversion failed"); rewriter.template replaceOpWithNewOp( - operation, dstType, predicate, operation.getOperand1(), - operation.getOperand2()); + op, dstType, predicate, op.getOperand1(), op.getOperand2()); return success(); } }; @@ -918,16 +919,15 @@ class IComparePattern : public SPIRVToLLVMConversion { using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor, + matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto dstType = this->typeConverter.convertType(operation.getType()); + auto dstType = this->typeConverter.convertType(op.getType()); if (!dstType) - return failure(); + return rewriter.notifyMatchFailure(op, "type conversion failed"); rewriter.template replaceOpWithNewOp( - operation, dstType, predicate, operation.getOperand1(), - operation.getOperand2()); + op, dstType, predicate, op.getOperand1(), op.getOperand2()); return success(); } }; @@ -943,7 +943,7 @@ class InverseSqrtPattern auto srcType = op.getType(); auto dstType = typeConverter.convertType(srcType); if (!dstType) - return failure(); + return rewriter.notifyMatchFailure(op, "type conversion failed"); Location loc = op.getLoc(); Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); @@ -1001,7 +1001,7 @@ class NotPattern : public SPIRVToLLVMConversion { auto srcType = notOp.getType(); auto dstType = this->typeConverter.convertType(srcType); if (!dstType) - return failure(); + return rewriter.notifyMatchFailure(notOp, "type conversion failed"); Location loc = notOp.getLoc(); IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter); @@ -1227,18 +1227,18 @@ class ShiftPattern : public SPIRVToLLVMConversion { using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor, + matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto dstType = this->typeConverter.convertType(operation.getType()); + auto dstType = this->typeConverter.convertType(op.getType()); if (!dstType) - return failure(); + return rewriter.notifyMatchFailure(op, "type conversion failed"); - Type op1Type = operation.getOperand1().getType(); - Type op2Type = operation.getOperand2().getType(); + Type op1Type = op.getOperand1().getType(); + Type op2Type = op.getOperand2().getType(); if (op1Type == op2Type) { - rewriter.template replaceOpWithNewOp(operation, dstType, + rewriter.template replaceOpWithNewOp(op, dstType, adaptor.getOperands()); return success(); } @@ -1251,7 +1251,7 @@ class ShiftPattern : public SPIRVToLLVMConversion { if (!dstTypeWidth || !op2TypeWidth) return failure(); - Location loc = operation.getLoc(); + Location loc = op.getLoc(); Value extended; if (op2TypeWidth < dstTypeWidth) { if (isUnsignedIntegerOrVector(op2Type)) { @@ -1269,7 +1269,7 @@ class ShiftPattern : public SPIRVToLLVMConversion { Value result = rewriter.template create( loc, dstType, adaptor.getOperand1(), extended); - rewriter.replaceOp(operation, result); + rewriter.replaceOp(op, result); return success(); } }; @@ -1283,7 +1283,7 @@ class TanPattern : public SPIRVToLLVMConversion { ConversionPatternRewriter &rewriter) const override { auto dstType = typeConverter.convertType(tanOp.getType()); if (!dstType) - return failure(); + return rewriter.notifyMatchFailure(tanOp, "type conversion failed"); Location loc = tanOp.getLoc(); Value sin = rewriter.create(loc, dstType, tanOp.getOperand()); @@ -1309,7 +1309,7 @@ class TanhPattern : public SPIRVToLLVMConversion { auto srcType = tanhOp.getType(); auto dstType = typeConverter.convertType(srcType); if (!dstType) - return failure(); + return rewriter.notifyMatchFailure(tanhOp, "type conversion failed"); Location loc = tanhOp.getLoc(); Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0); @@ -1343,17 +1343,23 @@ class VariablePattern : public SPIRVToLLVMConversion { auto dstType = typeConverter.convertType(srcType); if (!dstType) - return failure(); + return rewriter.notifyMatchFailure(varOp, "type conversion failed"); Location loc = varOp.getLoc(); Value size = createI32ConstantOf(loc, rewriter, 1); if (!init) { - rewriter.replaceOpWithNewOp( - varOp, dstType, typeConverter.convertType(pointerTo), size); + auto elementType = typeConverter.convertType(pointerTo); + if (!elementType) + return rewriter.notifyMatchFailure(varOp, "type conversion failed"); + rewriter.replaceOpWithNewOp(varOp, dstType, elementType, + size); return success(); } - Value allocated = rewriter.create( - loc, dstType, typeConverter.convertType(pointerTo), size); + auto elementType = typeConverter.convertType(pointerTo); + if (!elementType) + return rewriter.notifyMatchFailure(varOp, "type conversion failed"); + Value allocated = + rewriter.create(loc, dstType, elementType, size); rewriter.create(loc, adaptor.getInitializer(), allocated); rewriter.replaceOp(varOp, allocated); return success(); @@ -1374,7 +1380,7 @@ class BitcastConversionPattern ConversionPatternRewriter &rewriter) const override { auto dstType = typeConverter.convertType(bitcastOp.getType()); if (!dstType) - return failure(); + return rewriter.notifyMatchFailure(bitcastOp, "type conversion failed"); // LLVM's opaque pointers do not require bitcasts. if (isa(dstType)) { @@ -1500,6 +1506,8 @@ class VectorShufflePattern } auto dstType = typeConverter.convertType(op.getType()); + if (!dstType) + return rewriter.notifyMatchFailure(op, "type conversion failed"); auto scalarType = cast(dstType).getElementType(); auto componentsArray = components.getValue(); auto *context = rewriter.getContext();