@@ -653,26 +653,24 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
653653 switch (op.getOperandNumber ()) {
654654 case 3 :
655655 return smfma.getScalesIdxA () != 0 ;
656- break ;
657656 case 4 :
658657 return smfma.getScalesIdxB () != 0 ;
659- break ;
660658 default :
661- return true ;
662659 break ;
663660 }
664661 }
662+ return true ;
665663 };
666664
667665 auto setOpsel = [&](unsigned idx, int64_t val) {
668666 switch (idx) {
669667 case 3 :
670- return op.setScalesIdxA (val);
668+ op.setScalesIdxA (val);
671669 break ;
672670 case 4 :
673- return op.setScalesIdxB (val);
671+ op.setScalesIdxB (val);
674672 break ;
675- default :
673+ default :
676674 break ;
677675 }
678676 };
@@ -695,7 +693,7 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
695693 SmallVector<int64_t > res;
696694 ShapedType shapedty = static_cast <ShapedType>(ty);
697695 int64_t numElements = shapedty.getNumElements ();
698- for (auto size : shapedty.getShape ()) {
696+ for (unsigned size : shapedty.getShape ()) {
699697 numElements /= size;
700698 res.push_back (idx / numElements);
701699 idx -= (idx / numElements) * size;
@@ -706,17 +704,19 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
706704 // For every scale operand of this ScaledMFMAOp, if the scale follows the
707705 // following pattern:
708706 //
709- // %unit = vector.extract %ScaleSrc[offsets] : f8E8M0FNU from vector<?x?x?xf8E8M0FNU>
710- // %scale = vector.insert %unit, ... : f8E8M0FNU into vector<4xf8E8M0FNU>
711- // amdgpu.scaled_mfma(%scale[0] * ...
707+ // %unit = vector.extract %ScaleSrc[offsets] : f8E8M0FNU from
708+ // vector<?x?x?xf8E8M0FNU> %scale = vector.insert %unit, ... : f8E8M0FNU
709+ // into vector<4xf8E8M0FNU> amdgpu.scaled_mfma(%scale[0] * ...
712710 //
713711 // rewrite to:
714712 //
715- // %reshaped = vector.shape_cast %ScaleSrc : vector<?x?x?xf8E8M0FNU> to vector<?x4xf8E8M0FNU>
716- // %scale = vector.extract %reshaped[?] : vector<4xf8E8M0FNU> from vector<?x4xf8E8M0FNU>
713+ // %reshaped = vector.shape_cast %ScaleSrc : vector<?x?x?xf8E8M0FNU> to
714+ // vector<?x4xf8E8M0FNU> %scale = vector.extract %reshaped[?] :
715+ // vector<4xf8E8M0FNU> from vector<?x4xf8E8M0FNU>
717716 // amdgpu.scaled_mfma(%scale[0-3] * ...
718717 //
719- // This creates duplicate shape_casts for every use but these will be removed in CSE.
718+ // This creates duplicate shape_casts for every use but these will be
719+ // removed in CSE.
720720 for (auto opIdx : SmallVector<int64_t >({3 , 4 })) {
721721 auto insertOp = op.getOperand (opIdx).getDefiningOp <vector::InsertOp>();
722722 if (!insertOp) {
0 commit comments