Skip to content

Commit 3873eda

Browse files
PR review round 0
Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent 2b6d917 commit 3873eda

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -653,26 +653,24 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
653653
switch (op.getOperandNumber()) {
654654
case 3:
655655
return smfma.getScalesIdxA() != 0;
656-
break;
657656
case 4:
658657
return smfma.getScalesIdxB() != 0;
659-
break;
660658
default:
661-
return true;
662659
break;
663660
}
664661
}
662+
return true;
665663
};
666664

667665
auto setOpsel = [&](unsigned idx, int64_t val) {
668666
switch (idx) {
669667
case 3:
670-
return op.setScalesIdxA(val);
668+
op.setScalesIdxA(val);
671669
break;
672670
case 4:
673-
return op.setScalesIdxB(val);
671+
op.setScalesIdxB(val);
674672
break;
675-
default:
673+
default:
676674
break;
677675
}
678676
};
@@ -695,7 +693,7 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
695693
SmallVector<int64_t> res;
696694
ShapedType shapedty = static_cast<ShapedType>(ty);
697695
int64_t numElements = shapedty.getNumElements();
698-
for (auto size : shapedty.getShape()) {
696+
for (unsigned size : shapedty.getShape()) {
699697
numElements /= size;
700698
res.push_back(idx / numElements);
701699
idx -= (idx / numElements) * size;
@@ -706,17 +704,19 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
706704
// For every scale operand of this ScaledMFMAOp, if the scale follows the
707705
// following pattern:
708706
//
709-
// %unit = vector.extract %ScaleSrc[offsets] : f8E8M0FNU from vector<?x?x?xf8E8M0FNU>
710-
// %scale = vector.insert %unit, ... : f8E8M0FNU into vector<4xf8E8M0FNU>
711-
// amdgpu.scaled_mfma(%scale[0] * ...
707+
// %unit = vector.extract %ScaleSrc[offsets] : f8E8M0FNU from
708+
// vector<?x?x?xf8E8M0FNU> %scale = vector.insert %unit, ... : f8E8M0FNU
709+
// into vector<4xf8E8M0FNU> amdgpu.scaled_mfma(%scale[0] * ...
712710
//
713711
// rewrite to:
714712
//
715-
// %reshaped = vector.shape_cast %ScaleSrc : vector<?x?x?xf8E8M0FNU> to vector<?x4xf8E8M0FNU>
716-
// %scale = vector.extract %reshaped[?] : vector<4xf8E8M0FNU> from vector<?x4xf8E8M0FNU>
713+
// %reshaped = vector.shape_cast %ScaleSrc : vector<?x?x?xf8E8M0FNU> to
714+
// vector<?x4xf8E8M0FNU> %scale = vector.extract %reshaped[?] :
715+
// vector<4xf8E8M0FNU> from vector<?x4xf8E8M0FNU>
717716
// amdgpu.scaled_mfma(%scale[0-3] * ...
718717
//
719-
// This creates duplicate shape_casts for every use but these will be removed in CSE.
718+
// This creates duplicate shape_casts for every use but these will be
719+
// removed in CSE.
720720
for (auto opIdx : SmallVector<int64_t>({3, 4})) {
721721
auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
722722
if (!insertOp) {

0 commit comments

Comments
 (0)