@@ -2896,10 +2896,17 @@ class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
28962896 linearize (completePositions, computeStrides (destTy.getShape ()));
28972897
28982898 SmallVector<Attribute> insertedValues;
2899- if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst))
2900- llvm::append_range (insertedValues, denseSource.getValues <Attribute>());
2901- else
2902- insertedValues.push_back (sourceCst);
2899+ Type destEltType = destTy.getElementType ();
2900+
2901+ // The `convertIntegerAttr` method specifically handles the case
2902+ // for `llvm.mlir.constant` which can hold an attribute with a
2903+ // different type than the return type.
2904+ if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst)) {
2905+ for (auto value : denseSource.getValues <Attribute>())
2906+ insertedValues.push_back (convertIntegerAttr (value, destEltType));
2907+ } else {
2908+ insertedValues.push_back (convertIntegerAttr (sourceCst, destEltType));
2909+ }
29032910
29042911 auto allValues = llvm::to_vector (denseDest.getValues <Attribute>());
29052912 copy (insertedValues, allValues.begin () + insertBeginPosition);
@@ -2908,6 +2915,17 @@ class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
29082915 rewriter.replaceOpWithNewOp <arith::ConstantOp>(op, newAttr);
29092916 return success ();
29102917 }
2918+
2919+ private:
2920+ // / Converts the expected type to an IntegerAttr if there's
2921+ // / a mismatch.
2922+ Attribute convertIntegerAttr (Attribute attr, Type expectedType) const {
2923+ if (auto intAttr = attr.dyn_cast <IntegerAttr>()) {
2924+ if (intAttr.getType () != expectedType)
2925+ return IntegerAttr::get (expectedType, intAttr.getInt ());
2926+ }
2927+ return attr;
2928+ }
29112929};
29122930
29132931} // namespace
0 commit comments