@@ -176,6 +176,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
176
176
ImmTyWaitVAVDst,
177
177
ImmTyWaitVMVSrc,
178
178
ImmTyBitOp3,
179
+ ImmTyMatrixAFMT,
180
+ ImmTyMatrixBFMT,
179
181
ImmTyMatrixAReuse,
180
182
ImmTyMatrixBReuse,
181
183
ImmTyByteSel,
@@ -423,6 +425,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
423
425
bool isIndexKey8bit () const { return isImmTy (ImmTyIndexKey8bit); }
424
426
bool isIndexKey16bit () const { return isImmTy (ImmTyIndexKey16bit); }
425
427
bool isIndexKey32bit () const { return isImmTy (ImmTyIndexKey32bit); }
428
+ bool isMatrixAFMT () const { return isImmTy (ImmTyMatrixAFMT); }
429
+ bool isMatrixBFMT () const { return isImmTy (ImmTyMatrixBFMT); }
426
430
bool isMatrixAReuse () const { return isImmTy (ImmTyMatrixAReuse); }
427
431
bool isMatrixBReuse () const { return isImmTy (ImmTyMatrixBReuse); }
428
432
bool isTFE () const { return isImmTy (ImmTyTFE); }
@@ -1174,6 +1178,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
1174
1178
case ImmTyWaitVAVDst: OS << " WaitVAVDst" ; break ;
1175
1179
case ImmTyWaitVMVSrc: OS << " WaitVMVSrc" ; break ;
1176
1180
case ImmTyBitOp3: OS << " BitOp3" ; break ;
1181
+ case ImmTyMatrixAFMT: OS << " ImmTyMatrixAFMT" ; break ;
1182
+ case ImmTyMatrixBFMT: OS << " ImmTyMatrixBFMT" ; break ;
1177
1183
case ImmTyMatrixAReuse: OS << " ImmTyMatrixAReuse" ; break ;
1178
1184
case ImmTyMatrixBReuse: OS << " ImmTyMatrixBReuse" ; break ;
1179
1185
case ImmTyByteSel: OS << " ByteSel" ; break ;
@@ -1714,6 +1720,10 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
1714
1720
ParseStatus parseIndexKey8bit (OperandVector &Operands);
1715
1721
ParseStatus parseIndexKey16bit (OperandVector &Operands);
1716
1722
ParseStatus parseIndexKey32bit (OperandVector &Operands);
1723
+ ParseStatus tryParseMatrixFMT (OperandVector &Operands, StringRef Name,
1724
+ AMDGPUOperand::ImmTy Type);
1725
+ ParseStatus parseMatrixAFMT (OperandVector &Operands);
1726
+ ParseStatus parseMatrixBFMT (OperandVector &Operands);
1717
1727
1718
1728
ParseStatus parseDfmtNfmt (int64_t &Format);
1719
1729
ParseStatus parseUfmt (int64_t &Format);
@@ -1849,6 +1859,7 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
1849
1859
const unsigned CPol);
1850
1860
bool validateTFE (const MCInst &Inst, const OperandVector &Operands);
1851
1861
std::optional<StringRef> validateLdsDirect (const MCInst &Inst);
1862
+ bool validateWMMA (const MCInst &Inst, const OperandVector &Operands);
1852
1863
unsigned getConstantBusLimit (unsigned Opcode) const ;
1853
1864
bool usesConstantBus (const MCInst &Inst, unsigned OpIdx);
1854
1865
bool isInlineConstant (const MCInst &Inst, unsigned OpIdx) const ;
@@ -5409,6 +5420,37 @@ bool AMDGPUAsmParser::validateTFE(const MCInst &Inst,
5409
5420
return true ;
5410
5421
}
5411
5422
5423
+ bool AMDGPUAsmParser::validateWMMA (const MCInst &Inst,
5424
+ const OperandVector &Operands) {
5425
+ unsigned Opc = Inst.getOpcode ();
5426
+ const MCRegisterInfo *TRI = getContext ().getRegisterInfo ();
5427
+ const MCInstrDesc &Desc = MII.get (Opc);
5428
+
5429
+ auto validateFmt = [&](AMDGPU::OpName FmtOp, AMDGPU::OpName SrcOp) -> bool {
5430
+ int FmtIdx = AMDGPU::getNamedOperandIdx (Opc, FmtOp);
5431
+ if (FmtIdx == -1 )
5432
+ return true ;
5433
+ unsigned Fmt = Inst.getOperand (FmtIdx).getImm ();
5434
+ int SrcIdx = AMDGPU::getNamedOperandIdx (Opc, SrcOp);
5435
+ unsigned RegSize =
5436
+ TRI->getRegClass (Desc.operands ()[SrcIdx].RegClass ).getSizeInBits ();
5437
+
5438
+ if (RegSize == AMDGPU::wmmaScaleF8F6F4FormatToNumRegs (Fmt) * 32 )
5439
+ return true ;
5440
+
5441
+ static const char *FmtNames[] = {" MATRIX_FMT_FP8" , " MATRIX_FMT_BF8" ,
5442
+ " MATRIX_FMT_FP6" , " MATRIX_FMT_BF6" ,
5443
+ " MATRIX_FMT_FP4" };
5444
+
5445
+ Error (getRegLoc (mc2PseudoReg (Inst.getOperand (SrcIdx).getReg ()), Operands),
5446
+ " wrong register tuple size for " + Twine (FmtNames[Fmt]));
5447
+ return false ;
5448
+ };
5449
+
5450
+ return validateFmt (AMDGPU::OpName::matrix_a_fmt, AMDGPU::OpName::src0) &&
5451
+ validateFmt (AMDGPU::OpName::matrix_b_fmt, AMDGPU::OpName::src1);
5452
+ }
5453
+
5412
5454
bool AMDGPUAsmParser::validateInstruction (const MCInst &Inst,
5413
5455
const SMLoc &IDLoc,
5414
5456
const OperandVector &Operands) {
@@ -5542,6 +5584,9 @@ bool AMDGPUAsmParser::validateInstruction(const MCInst &Inst,
5542
5584
if (!validateTFE (Inst, Operands)) {
5543
5585
return false ;
5544
5586
}
5587
+ if (!validateWMMA (Inst, Operands)) {
5588
+ return false ;
5589
+ }
5545
5590
5546
5591
return true ;
5547
5592
}
@@ -7215,6 +7260,26 @@ ParseStatus AMDGPUAsmParser::parseIndexKey32bit(OperandVector &Operands) {
7215
7260
return tryParseIndexKey (Operands, AMDGPUOperand::ImmTyIndexKey32bit);
7216
7261
}
7217
7262
7263
+ ParseStatus AMDGPUAsmParser::tryParseMatrixFMT (OperandVector &Operands,
7264
+ StringRef Name,
7265
+ AMDGPUOperand::ImmTy Type) {
7266
+ return parseStringOrIntWithPrefix (Operands, Name,
7267
+ {" MATRIX_FMT_FP8" , " MATRIX_FMT_BF8" ,
7268
+ " MATRIX_FMT_FP6" , " MATRIX_FMT_BF6" ,
7269
+ " MATRIX_FMT_FP4" },
7270
+ Type);
7271
+ }
7272
+
7273
+ ParseStatus AMDGPUAsmParser::parseMatrixAFMT (OperandVector &Operands) {
7274
+ return tryParseMatrixFMT (Operands, " matrix_a_fmt" ,
7275
+ AMDGPUOperand::ImmTyMatrixAFMT);
7276
+ }
7277
+
7278
+ ParseStatus AMDGPUAsmParser::parseMatrixBFMT (OperandVector &Operands) {
7279
+ return tryParseMatrixFMT (Operands, " matrix_b_fmt" ,
7280
+ AMDGPUOperand::ImmTyMatrixBFMT);
7281
+ }
7282
+
7218
7283
// dfmt and nfmt (in a tbuffer instruction) are parsed as one to allow their
7219
7284
// values to live in a joint format operand in the MCInst encoding.
7220
7285
ParseStatus AMDGPUAsmParser::parseDfmtNfmt (int64_t &Format) {
@@ -9316,6 +9381,20 @@ void AMDGPUAsmParser::cvtVOP3P(MCInst &Inst, const OperandVector &Operands,
9316
9381
DefaultVal);
9317
9382
}
9318
9383
9384
+ int MatrixAFMTIdx =
9385
+ AMDGPU::getNamedOperandIdx (Opc, AMDGPU::OpName::matrix_a_fmt);
9386
+ if (MatrixAFMTIdx != -1 ) {
9387
+ addOptionalImmOperand (Inst, Operands, OptIdx,
9388
+ AMDGPUOperand::ImmTyMatrixAFMT, 0 );
9389
+ }
9390
+
9391
+ int MatrixBFMTIdx =
9392
+ AMDGPU::getNamedOperandIdx (Opc, AMDGPU::OpName::matrix_b_fmt);
9393
+ if (MatrixBFMTIdx != -1 ) {
9394
+ addOptionalImmOperand (Inst, Operands, OptIdx,
9395
+ AMDGPUOperand::ImmTyMatrixBFMT, 0 );
9396
+ }
9397
+
9319
9398
if (AMDGPU::hasNamedOperand (Opc, AMDGPU::OpName::matrix_a_reuse))
9320
9399
addOptionalImmOperand (Inst, Operands, OptIdx,
9321
9400
AMDGPUOperand::ImmTyMatrixAReuse, 0 );
0 commit comments