Skip to content
Merged
278 changes: 181 additions & 97 deletions mlir/lib/Bindings/Python/IRAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "IRModule.h"

#include "PybindUtils.h"
#include <pybind11/numpy.h>

#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/raw_ostream.h"
Expand Down Expand Up @@ -757,103 +758,10 @@ class PyDenseElementsAttribute
throw py::error_already_set();
}
auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
SmallVector<int64_t> shape;
if (explicitShape) {
shape.append(explicitShape->begin(), explicitShape->end());
} else {
shape.append(view.shape, view.shape + view.ndim);
}

MlirAttribute encodingAttr = mlirAttributeGetNull();
MlirContext context = contextWrapper->get();

// Detect format codes that are suitable for bulk loading. This includes
// all byte aligned integer and floating point types up to 8 bytes.
// Notably, this excludes, bool (which needs to be bit-packed) and
// other exotics which do not have a direct representation in the buffer
// protocol (i.e. complex, etc).
std::optional<MlirType> bulkLoadElementType;
if (explicitType) {
bulkLoadElementType = *explicitType;
} else {
std::string_view format(view.format);
if (format == "f") {
// f32
assert(view.itemsize == 4 && "mismatched array itemsize");
bulkLoadElementType = mlirF32TypeGet(context);
} else if (format == "d") {
// f64
assert(view.itemsize == 8 && "mismatched array itemsize");
bulkLoadElementType = mlirF64TypeGet(context);
} else if (format == "e") {
// f16
assert(view.itemsize == 2 && "mismatched array itemsize");
bulkLoadElementType = mlirF16TypeGet(context);
} else if (isSignedIntegerFormat(format)) {
if (view.itemsize == 4) {
// i32
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 32)
: mlirIntegerTypeSignedGet(context, 32);
} else if (view.itemsize == 8) {
// i64
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 64)
: mlirIntegerTypeSignedGet(context, 64);
} else if (view.itemsize == 1) {
// i8
bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
: mlirIntegerTypeSignedGet(context, 8);
} else if (view.itemsize == 2) {
// i16
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 16)
: mlirIntegerTypeSignedGet(context, 16);
}
} else if (isUnsignedIntegerFormat(format)) {
if (view.itemsize == 4) {
// unsigned i32
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 32)
: mlirIntegerTypeUnsignedGet(context, 32);
} else if (view.itemsize == 8) {
// unsigned i64
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 64)
: mlirIntegerTypeUnsignedGet(context, 64);
} else if (view.itemsize == 1) {
// i8
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 8)
: mlirIntegerTypeUnsignedGet(context, 8);
} else if (view.itemsize == 2) {
// i16
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 16)
: mlirIntegerTypeUnsignedGet(context, 16);
}
}
if (!bulkLoadElementType) {
throw std::invalid_argument(
std::string("unimplemented array format conversion from format: ") +
std::string(format));
}
}

MlirType shapedType;
if (mlirTypeIsAShaped(*bulkLoadElementType)) {
if (explicitShape) {
throw std::invalid_argument("Shape can only be specified explicitly "
"when the type is not a shaped type.");
}
shapedType = *bulkLoadElementType;
} else {
shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(),
*bulkLoadElementType, encodingAttr);
}
size_t rawBufferSize = view.len;
MlirAttribute attr =
mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf);
MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to move the if-statements to the getAttributeFromBuffer method, as now the i1 case will not follow the usual flow, but instead call getBitpackedAttributeFromBooleanBuffer to construct the MlirAttribute.

explicitShape, context);
if (mlirAttributeIsNull(attr)) {
throw std::invalid_argument(
"DenseElementsAttr could not be constructed from the given buffer. "
Expand Down Expand Up @@ -963,6 +871,13 @@ class PyDenseElementsAttribute
// unsigned i16
return bufferInfo<uint16_t>(shapedType);
}
} else if (mlirTypeIsAInteger(elementType) &&
mlirIntegerTypeGetWidth(elementType) == 1) {
// i1 / bool
// We can not send the buffer directly back to Python, because the i1
// values are bitpacked within MLIR. We call numpy's unpackbits function
// to convert the bytes.
return getBooleanBufferFromBitpackedAttribute();
}

// TODO: Currently crashes the program.
Expand Down Expand Up @@ -1016,14 +931,183 @@ class PyDenseElementsAttribute
code == 'q';
}

static MlirType
getShapedType(std::optional<MlirType> bulkLoadElementType,
std::optional<std::vector<int64_t>> explicitShape,
Py_buffer &view) {
SmallVector<int64_t> shape;
if (explicitShape) {
shape.append(explicitShape->begin(), explicitShape->end());
} else {
shape.append(view.shape, view.shape + view.ndim);
}

if (mlirTypeIsAShaped(*bulkLoadElementType)) {
if (explicitShape) {
throw std::invalid_argument("Shape can only be specified explicitly "
"when the type is not a shaped type.");
}
return *bulkLoadElementType;
} else {
MlirAttribute encodingAttr = mlirAttributeGetNull();
return mlirRankedTensorTypeGet(shape.size(), shape.data(),
*bulkLoadElementType, encodingAttr);
}
}

static MlirAttribute getAttributeFromBuffer(
Py_buffer &view, bool signless, std::optional<PyType> explicitType,
std::optional<std::vector<int64_t>> explicitShape, MlirContext &context) {
// Detect format codes that are suitable for bulk loading. This includes
// all byte aligned integer and floating point types up to 8 bytes.
// Notably, this excludes exotics types which do not have a direct
// representation in the buffer protocol (i.e. complex, etc).
std::optional<MlirType> bulkLoadElementType;
if (explicitType) {
bulkLoadElementType = *explicitType;
} else {
std::string_view format(view.format);
if (format == "f") {
// f32
assert(view.itemsize == 4 && "mismatched array itemsize");
bulkLoadElementType = mlirF32TypeGet(context);
} else if (format == "d") {
// f64
assert(view.itemsize == 8 && "mismatched array itemsize");
bulkLoadElementType = mlirF64TypeGet(context);
} else if (format == "e") {
// f16
assert(view.itemsize == 2 && "mismatched array itemsize");
bulkLoadElementType = mlirF16TypeGet(context);
} else if (format == "?") {
// i1
// The i1 type needs to be bit-packed, so we will handle it seperately
return getBitpackedAttributeFromBooleanBuffer(view, explicitShape,
context);
} else if (isSignedIntegerFormat(format)) {
if (view.itemsize == 4) {
// i32
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 32)
: mlirIntegerTypeSignedGet(context, 32);
} else if (view.itemsize == 8) {
// i64
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 64)
: mlirIntegerTypeSignedGet(context, 64);
} else if (view.itemsize == 1) {
// i8
bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
: mlirIntegerTypeSignedGet(context, 8);
} else if (view.itemsize == 2) {
// i16
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 16)
: mlirIntegerTypeSignedGet(context, 16);
}
} else if (isUnsignedIntegerFormat(format)) {
if (view.itemsize == 4) {
// unsigned i32
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 32)
: mlirIntegerTypeUnsignedGet(context, 32);
} else if (view.itemsize == 8) {
// unsigned i64
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 64)
: mlirIntegerTypeUnsignedGet(context, 64);
} else if (view.itemsize == 1) {
// i8
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 8)
: mlirIntegerTypeUnsignedGet(context, 8);
} else if (view.itemsize == 2) {
// i16
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 16)
: mlirIntegerTypeUnsignedGet(context, 16);
}
}
if (!bulkLoadElementType) {
throw std::invalid_argument(
std::string("unimplemented array format conversion from format: ") +
std::string(format));
}
}

MlirType type = getShapedType(bulkLoadElementType, explicitShape, view);
return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf);
}

// There is a complication for boolean numpy arrays, as numpy represents them
// as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 booleans
// per byte.
static MlirAttribute getBitpackedAttributeFromBooleanBuffer(
Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
MlirContext &context) {
if (llvm::endianness::native != llvm::endianness::little) {
// Given we have no good way of testing the behavior on big-endian systems
// we will throw
throw py::type_error("Constructing a bit-packed MLIR attribute is "
"unsupported on big-endian systems");
}

py::array_t<uint8_t> unpackedArray(view.len,
static_cast<uint8_t *>(view.buf));

py::module numpy = py::module::import("numpy");
py::object packbits_func = numpy.attr("packbits");
py::object packed_booleans =
packbits_func(unpackedArray, "bitorder"_a = "little");
Comment on lines +1060 to +1061
Copy link
Contributor

@makslevental makslevental Oct 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

................ you can do this ..................... wow color me shocked I never noticed/knew you could do "kwargs" like this on cpp side.

py::buffer_info pythonBuffer = packed_booleans.cast<py::buffer>().request();

MlirType bitpackedType =
getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view);
return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size,
pythonBuffer.ptr);
}

// This does the opposite transformation of
// `getBitpackedAttributeFromBooleanBuffer`
py::buffer_info getBooleanBufferFromBitpackedAttribute() {
if (llvm::endianness::native != llvm::endianness::little) {
// Given we have no good way of testing the behavior on big-endian systems
// we will throw
throw py::type_error("Constructing a numpy array from a MLIR attribute "
"is unsupported on big-endian systems");
}

int64_t numBooleans = mlirElementsAttrGetNumElements(*this);
int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8);
uint8_t *bitpackedData = static_cast<uint8_t *>(
const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
py::array_t<uint8_t> packedArray(numBitpackedBytes, bitpackedData);

py::module numpy = py::module::import("numpy");
py::object unpackbits_func = numpy.attr("unpackbits");
py::object unpacked_booleans =
unpackbits_func(packedArray, "bitorder"_a = "little");
py::buffer_info pythonBuffer =
unpacked_booleans.cast<py::buffer>().request();

MlirType shapedType = mlirAttributeGetType(*this);
return bufferInfo<bool>(shapedType, (bool *)pythonBuffer.ptr, "?");
}

template <typename Type>
py::buffer_info bufferInfo(MlirType shapedType,
const char *explicitFormat = nullptr) {
intptr_t rank = mlirShapedTypeGetRank(shapedType);
// Prepare the data for the buffer_info.
// Buffer is configured for read-only access below.
// Buffer is configured for read-only access inside the `bufferInfo` call.
Type *data = static_cast<Type *>(
const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
return bufferInfo<Type>(shapedType, data, explicitFormat);
}

template <typename Type>
py::buffer_info bufferInfo(MlirType shapedType, Type *data,
const char *explicitFormat = nullptr) {
intptr_t rank = mlirShapedTypeGetRank(shapedType);
// Prepare the shape for the buffer_info.
SmallVector<intptr_t, 4> shape;
for (intptr_t i = 0; i < rank; ++i)
Expand Down
72 changes: 72 additions & 0 deletions mlir/test/python/ir/array_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,78 @@ def testGetDenseElementsF64():
print(np.array(attr))


### 1 bit/boolean integer arrays
# CHECK-LABEL: TEST: testGetDenseElementsI1Signless
@run
def testGetDenseElementsI1Signless():
with Context():
array = np.array([True], dtype=np.bool_)
attr = DenseElementsAttr.get(array)
# CHECK: dense<true> : tensor<1xi1>
print(attr)
# CHECK{LITERAL}: [ True]
print(np.array(attr))

array = np.array([[True, False, True], [True, True, False]], dtype=np.bool_)
attr = DenseElementsAttr.get(array)
# CHECK{LITERAL}: dense<[[true, false, true], [true, true, false]]> : tensor<2x3xi1>
print(attr)
# CHECK{LITERAL}: [[ True False True]
# CHECK{LITERAL}: [ True True False]]
print(np.array(attr))

array = np.array(
[[True, True, False, False], [True, False, True, False]], dtype=np.bool_
)
attr = DenseElementsAttr.get(array)
# CHECK{LITERAL}: dense<[[true, true, false, false], [true, false, true, false]]> : tensor<2x4xi1>
print(attr)
# CHECK{LITERAL}: [[ True True False False]
# CHECK{LITERAL}: [ True False True False]]
print(np.array(attr))

array = np.array(
[
[True, True, False, False],
[True, False, True, False],
[False, False, False, False],
[True, True, True, True],
[True, False, False, True],
],
dtype=np.bool_,
)
attr = DenseElementsAttr.get(array)
# CHECK{LITERAL}: dense<[[true, true, false, false], [true, false, true, false], [false, false, false, false], [true, true, true, true], [true, false, false, true]]> : tensor<5x4xi1>
print(attr)
# CHECK{LITERAL}: [[ True True False False]
# CHECK{LITERAL}: [ True False True False]
# CHECK{LITERAL}: [False False False False]
# CHECK{LITERAL}: [ True True True True]
# CHECK{LITERAL}: [ True False False True]]
print(np.array(attr))

array = np.array(
[
[True, True, False, False, True, True, False, False, False],
[False, False, False, True, False, True, True, False, True],
],
dtype=np.bool_,
)
attr = DenseElementsAttr.get(array)
# CHECK{LITERAL}: dense<[[true, true, false, false, true, true, false, false, false], [false, false, false, true, false, true, true, false, true]]> : tensor<2x9xi1>
print(attr)
# CHECK{LITERAL}: [[ True True False False True True False False False]
# CHECK{LITERAL}: [False False False True False True True False True]]
print(np.array(attr))

array = np.array([], dtype=np.bool_)
attr = DenseElementsAttr.get(array)
# CHECK: dense<> : tensor<0xi1>
print(attr)
# CHECK{LITERAL}: []
print(np.array(attr))


### 16 bit integer arrays
# CHECK-LABEL: TEST: testGetDenseElementsI16Signless
@run
Expand Down