Skip to content

Commit 993a38f

Browse files
authored
[MLIR][Affine] Extend getVectorReductionOp to support xor/maxnumf/minnumf (#163310)
This PR extends the `getVectorReductionOp` function, which is used by the affine vectorizer, to also recognize and support `xor/maxnumf/minnumf` reduction operations.
1 parent 95d6caa commit 993a38f

File tree

4 files changed

+148
-3
lines changed

4 files changed

+148
-3
lines changed

mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,10 @@ static Value getSupportedReduction(AffineForOp forOp, unsigned pos,
6666
.Case([](arith::MaxSIOp) { return arith::AtomicRMWKind::maxs; })
6767
.Case([](arith::MinUIOp) { return arith::AtomicRMWKind::minu; })
6868
.Case([](arith::MaxUIOp) { return arith::AtomicRMWKind::maxu; })
69+
.Case([](arith::XOrIOp) { return arith::AtomicRMWKind::xori; })
70+
.Case([](arith::MaxNumFOp) { return arith::AtomicRMWKind::maxnumf; })
71+
.Case([](arith::MinNumFOp) { return arith::AtomicRMWKind::minnumf; })
6972
.Default([](Operation *) -> std::optional<arith::AtomicRMWKind> {
70-
// TODO: AtomicRMW supports other kinds of reductions this is
71-
// currently not detecting, add those when the need arises.
7273
return std::nullopt;
7374
});
7475
if (!maybeKind)

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -717,7 +717,15 @@ Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
717717
case arith::AtomicRMWKind::ori:
718718
return vector::ReductionOp::create(builder, vector.getLoc(),
719719
CombiningKind::OR, vector);
720-
// TODO: Add remaining reduction operations.
720+
case arith::AtomicRMWKind::minnumf:
721+
return vector::ReductionOp::create(builder, vector.getLoc(),
722+
CombiningKind::MINNUMF, vector);
723+
case arith::AtomicRMWKind::maxnumf:
724+
return vector::ReductionOp::create(builder, vector.getLoc(),
725+
CombiningKind::MAXNUMF, vector);
726+
case arith::AtomicRMWKind::xori:
727+
return vector::ReductionOp::create(builder, vector.getLoc(),
728+
CombiningKind::XOR, vector);
721729
default:
722730
(void)emitOptionalError(loc, "Reduction operation type not supported");
723731
break;

mlir/test/Conversion/ConvertToSPIRV/vector.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,42 @@ func.func @reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 {
275275

276276
// -----
277277

278+
// CHECK-LABEL: spirv.func @reduction_minnumf(
279+
// CHECK-SAME: %[[V:.*]]: vector<3xf32>,
280+
// CHECK-SAME: %[[S:.*]]: f32) -> f32 "None" {
281+
// CHECK: %[[S0:.*]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
282+
// CHECK: %[[S1:.*]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
283+
// CHECK: %[[S2:.*]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
284+
// CHECK: %[[MIN0:.*]] = spirv.GL.FMin %[[S0]], %[[S1]] : f32
285+
// CHECK: %[[MIN1:.*]] = spirv.GL.FMin %[[MIN0]], %[[S2]] : f32
286+
// CHECK: %[[MIN2:.*]] = spirv.GL.FMin %[[MIN1]], %[[S]] : f32
287+
// CHECK: spirv.ReturnValue %[[MIN2]] : f32
288+
// CHECK: }
289+
func.func @reduction_minnumf(%v : vector<3xf32>, %s: f32) -> f32 {
290+
%reduce = vector.reduction <minnumf>, %v, %s : vector<3xf32> into f32
291+
return %reduce : f32
292+
}
293+
294+
// -----
295+
296+
// CHECK-LABEL: spirv.func @reduction_maxnumf(
297+
// CHECK-SAME: %[[V:.*]]: vector<3xf32>,
298+
// CHECK-SAME: %[[S:.*]]: f32) -> f32 "None" {
299+
// CHECK: %[[S0:.*]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
300+
// CHECK: %[[S1:.*]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
301+
// CHECK: %[[S2:.*]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
302+
// CHECK: %[[MAX0:.*]] = spirv.GL.FMax %[[S0]], %[[S1]] : f32
303+
// CHECK: %[[MAX1:.*]] = spirv.GL.FMax %[[MAX0]], %[[S2]] : f32
304+
// CHECK: %[[MAX2:.*]] = spirv.GL.FMax %[[MAX1]], %[[S]] : f32
305+
// CHECK: spirv.ReturnValue %[[MAX2]] : f32
306+
// CHECK: }
307+
func.func @reduction_maxnumf(%v : vector<3xf32>, %s: f32) -> f32 {
308+
%reduce = vector.reduction <maxnumf>, %v, %s : vector<3xf32> into f32
309+
return %reduce : f32
310+
}
311+
312+
// -----
313+
278314
// CHECK-LABEL: func @reduction_maxsi
279315
// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
280316
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>

mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,106 @@ func.func @vecdim_reduction_ori(%in: memref<256x512xi32>, %out: memref<256xi32>)
243243
// CHECK: affine.store %[[final_red]], %{{.*}} : memref<256xi32>
244244
// CHECK: }
245245

246+
// -----
247+
248+
func.func @vecdim_reduction_xori(%in: memref<256x512xi32>, %out: memref<256xi32>) {
249+
%cst = arith.constant 0 : i32
250+
affine.for %i = 0 to 256 {
251+
%final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (i32) {
252+
%ld = affine.load %in[%i, %j] : memref<256x512xi32>
253+
%xor = arith.xori %red_iter, %ld : i32
254+
affine.yield %xor : i32
255+
}
256+
affine.store %final_red, %out[%i] : memref<256xi32>
257+
}
258+
return
259+
}
260+
261+
// CHECK-LABEL: func.func @vecdim_reduction_xori(
262+
// CHECK-SAME: %[[input:.*]]: memref<256x512xi32>,
263+
// CHECK-SAME: %[[output:.*]]: memref<256xi32>) {
264+
// CHECK: %[[cst:.*]] = arith.constant 0 : i32
265+
// CHECK: affine.for %{{.*}} = 0 to 256 {
266+
// CHECK: %[[vzero:.*]] = arith.constant dense<0> : vector<128xi32>
267+
// CHECK: %[[vred:.*]] = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[red_iter:.*]] = %[[vzero]]) -> (vector<128xi32>) {
268+
// CHECK: %[[poison:.*]] = ub.poison : i32
269+
// CHECK: %[[ld:.*]] = vector.transfer_read %[[input]]{{\[}}%{{.*}}, %{{.*}}], %[[poison]] : memref<256x512xi32>, vector<128xi32>
270+
// CHECK: %[[xor:.*]] = arith.xori %[[red_iter]], %[[ld]] : vector<128xi32>
271+
// CHECK: affine.yield %[[xor]] : vector<128xi32>
272+
// CHECK: }
273+
// CHECK: %[[final_red:.*]] = vector.reduction <xor>, %[[vred]] : vector<128xi32> into i32
274+
// CHECK: affine.store %[[final_red]], %[[output]]{{\[}}%{{.*}}] : memref<256xi32>
275+
// CHECK: }
276+
// CHECK: return
277+
// CHECK: }
278+
279+
// -----
280+
281+
func.func @vecdim_reduction_minnumf(%in: memref<256x512xf32>, %out: memref<256xf32>) {
282+
%cst = arith.constant 0xFF800000 : f32
283+
affine.for %i = 0 to 256 {
284+
%final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (f32) {
285+
%ld = affine.load %in[%i, %j] : memref<256x512xf32>
286+
%min = arith.minnumf %red_iter, %ld : f32
287+
affine.yield %min : f32
288+
}
289+
affine.store %final_red, %out[%i] : memref<256xf32>
290+
}
291+
return
292+
}
293+
294+
// CHECK-LABEL: func.func @vecdim_reduction_minnumf(
295+
// CHECK-SAME: %[[input:.*]]: memref<256x512xf32>,
296+
// CHECK-SAME: %[[output:.*]]: memref<256xf32>) {
297+
// CHECK: %[[cst:.*]] = arith.constant 0xFF800000 : f32
298+
// CHECK: affine.for %{{.*}} = 0 to 256 {
299+
// CHECK: %[[vzero:.*]] = arith.constant dense<0x7FC00000> : vector<128xf32>
300+
// CHECK: %[[vred:.*]] = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[red_iter:.*]] = %[[vzero]]) -> (vector<128xf32>) {
301+
// CHECK: %[[poison:.*]] = ub.poison : f32
302+
// CHECK: %[[ld:.*]] = vector.transfer_read %[[input]]{{\[}}%{{.*}}, %{{.*}}], %[[poison]] : memref<256x512xf32>, vector<128xf32>
303+
// CHECK: %[[min:.*]] = arith.minnumf %[[red_iter]], %[[ld]] : vector<128xf32>
304+
// CHECK: affine.yield %[[min]] : vector<128xf32>
305+
// CHECK: }
306+
// CHECK: %[[red_scalar:.*]] = vector.reduction <minnumf>, %[[vred]] : vector<128xf32> into f32
307+
// CHECK: %[[final_red:.*]] = arith.minnumf %[[red_scalar]], %[[cst]] : f32
308+
// CHECK: affine.store %[[final_red]], %[[output]]{{\[}}%{{.*}}] : memref<256xf32>
309+
// CHECK: }
310+
// CHECK: return
311+
// CHECK: }
312+
313+
// -----
314+
315+
func.func @vecdim_reduction_maxnumf(%in: memref<256x512xf32>, %out: memref<256xf32>) {
316+
%cst = arith.constant 0xFF800000 : f32
317+
affine.for %i = 0 to 256 {
318+
%final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (f32) {
319+
%ld = affine.load %in[%i, %j] : memref<256x512xf32>
320+
%max = arith.maxnumf %red_iter, %ld : f32
321+
affine.yield %max : f32
322+
}
323+
affine.store %final_red, %out[%i] : memref<256xf32>
324+
}
325+
return
326+
}
327+
328+
// CHECK-LABEL: func.func @vecdim_reduction_maxnumf(
329+
// CHECK-SAME: %[[input:.*]]: memref<256x512xf32>,
330+
// CHECK-SAME: %[[output:.*]]: memref<256xf32>) {
331+
// CHECK: %[[cst:.*]] = arith.constant 0xFF800000 : f32
332+
// CHECK: affine.for %{{.*}} = 0 to 256 {
333+
// CHECK: %[[vzero:.*]] = arith.constant dense<0xFFC00000> : vector<128xf32>
334+
// CHECK: %[[vred:.*]] = affine.for %{{.*}} = 0 to 512 step 128 iter_args(%[[red_iter:.*]] = %[[vzero]]) -> (vector<128xf32>) {
335+
// CHECK: %[[poison:.*]] = ub.poison : f32
336+
// CHECK: %[[ld:.*]] = vector.transfer_read %[[input]]{{\[}}%{{.*}}, %{{.*}}], %[[poison]] : memref<256x512xf32>, vector<128xf32>
337+
// CHECK: %[[max:.*]] = arith.maxnumf %[[red_iter]], %[[ld]] : vector<128xf32>
338+
// CHECK: affine.yield %[[max]] : vector<128xf32>
339+
// CHECK: }
340+
// CHECK: %[[red_scalar:.*]] = vector.reduction <maxnumf>, %[[vred]] : vector<128xf32> into f32
341+
// CHECK: %[[final_red:.*]] = arith.maxnumf %[[red_scalar]], %[[cst]] : f32
342+
// CHECK: affine.store %[[final_red]], %[[output]]{{\[}}%{{.*}}] : memref<256xf32>
343+
// CHECK: }
344+
// CHECK: return
345+
// CHECK: }
246346

247347
// -----
248348

0 commit comments

Comments
 (0)