Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 43 additions & 14 deletions mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,8 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
TransformMapKeyTy key = {m, r};
int64_t retRows = 1;
Value matmulRetValue = extractFilter;
Value zero = builder.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(elementType));
if (leftTransform) {
// Get constant transform matrix G.
auto it = GMatrices.find(key);
Expand All @@ -399,8 +401,11 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,

retRows = GMatrix.rows;
auto matmulType = RankedTensorType::get({retRows, filterW}, elementType);
auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
elementType);
auto empty =
builder
.create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
.getResult();
auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);

Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType);
// Multiply G x g.
Expand All @@ -418,8 +423,11 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,

auto matmulType =
RankedTensorType::get({retRows, GTMatrix.cols}, elementType);
auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
elementType);
auto empty =
builder
.create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
.getResult();
auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);

Value GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType);
// Multiply u = (G x g) x GT.
Expand Down Expand Up @@ -523,6 +531,8 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
int64_t retRows = 1;
int64_t retCols = 1;
Value matmulRetValue = extractInput;
Value zero = builder.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(elementType));
if (leftTransform) {
// Get constant transform matrix BT.
auto it = BTMatrices.find(key);
Expand All @@ -532,8 +542,11 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,

retRows = BTMatrix.rows;
auto matmulType = RankedTensorType::get({retRows, alphaW}, elementType);
auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
elementType);
auto empty =
builder
.create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
.getResult();
auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);

Value BT =
create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
Expand All @@ -552,8 +565,11 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,

retCols = BMatrix.cols;
auto matmulType = RankedTensorType::get({retRows, retCols}, elementType);
auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
elementType);
auto empty =
builder
.create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
.getResult();
auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);
Value B =
create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
// Multiply v = (BT x d) x B.
Expand Down Expand Up @@ -636,8 +652,13 @@ static Value matrixMultiply(RewriterBase &rewriter, Location loc,
{inputShape[0] * inputShape[1],
inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]},
outputElementType);
Value init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
outputElementType);
Value empty = rewriter
.create<tensor::EmptyOp>(loc, matmulType.getShape(),
outputElementType)
.getResult();
Value zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(outputElementType));
Value init = rewriter.create<linalg::FillOp>(loc, zero, empty).getResult(0);

auto matmulOp = rewriter.create<linalg::BatchMatmulOp>(
loc, matmulType, ValueRange({collapseInput, collapseFilter}),
Expand Down Expand Up @@ -725,6 +746,8 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
int64_t leftScalarFactor = 1;
int64_t rightScalarFactor = 1;
Value matmulRetValue = extractValue;
Value zero = builder.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(elementType));
if (leftTransform) {
// Get constant transform matrix AT.
auto it = ATMatrices.find(key);
Expand All @@ -735,8 +758,11 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
leftScalarFactor = ATMatrix.scalarFactor;
retRows = ATMatrix.rows;
auto matmulType = RankedTensorType::get({retRows, valueW}, elementType);
auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
elementType);
auto empty =
builder
.create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
.getResult();
auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);

Value AT = create2DTransformMatrix(builder, loc, ATMatrix, elementType);
// Multiply AT x m.
Expand All @@ -756,8 +782,11 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
auto matmulType =
RankedTensorType::get({retRows, AMatrix.cols}, elementType);
retCols = AMatrix.cols;
auto init = builder.create<tensor::EmptyOp>(loc, matmulType.getShape(),
elementType);
auto empty =
builder
.create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType)
.getResult();
auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0);

Value A = create2DTransformMatrix(builder, loc, AMatrix, elementType);
// Multiply y = (AT x m) x A.
Expand Down
Loading