@@ -16,10 +16,10 @@ public sealed class BaseTestClassAnalyzer : DiagnosticAnalyzer
1616 private const string Category = "Test" ;
1717 internal const string DiagnosticId = "MSML_ExtendBaseTestClass" ;
1818
19- private const string Title = "Test classes should be derived from BaseTestClass" ;
20- private const string Format = "Test class '{0}' should extend BaseTestClass." ;
19+ private const string Title = "Test classes should be derived from BaseTestClass or FunctionalTestBaseClass " ;
20+ private const string Format = "Test class '{0}' should extend BaseTestClass or FunctionalTestBaseClass ." ;
2121 private const string Description =
22- "Test classes should be derived from BaseTestClass." ;
22+ "Test classes should be derived from BaseTestClass or FunctionalTestBaseClass ." ;
2323
2424 private static DiagnosticDescriptor Rule =
2525 new DiagnosticDescriptor ( DiagnosticId , Title , Format , Category ,
@@ -51,13 +51,15 @@ private sealed class AnalyzerImpl
5151 private readonly Compilation _compilation ;
5252 private readonly INamedTypeSymbol _factAttribute ;
5353 private readonly INamedTypeSymbol _baseTestClass ;
54+ private readonly INamedTypeSymbol _FTbaseTestClass ;
5455 private readonly ConcurrentDictionary < INamedTypeSymbol , bool > _knownTestAttributes = new ConcurrentDictionary < INamedTypeSymbol , bool > ( ) ;
5556
5657 public AnalyzerImpl ( Compilation compilation , INamedTypeSymbol factAttribute )
5758 {
5859 _compilation = compilation ;
5960 _factAttribute = factAttribute ;
6061 _baseTestClass = _compilation . GetTypeByMetadataName ( "Microsoft.ML.TestFramework.BaseTestClass" ) ;
62+ _FTbaseTestClass = _compilation . GetTypeByMetadataName ( "Microsoft.ML.Functional.Tests.FunctionalTestBaseClass" ) ;
6163 }
6264
6365 public void AnalyzeNamedType ( SymbolAnalysisContext context )
@@ -87,12 +89,14 @@ public void AnalyzeNamedType(SymbolAnalysisContext context)
8789
8890 private bool ExtendsBaseTestClass ( INamedTypeSymbol namedType )
8991 {
90- if ( _baseTestClass is null )
92+ if ( _baseTestClass is null &&
93+ _FTbaseTestClass is null )
9194 return false ;
9295
9396 for ( var current = namedType ; current is object ; current = current . BaseType )
9497 {
95- if ( Equals ( current , _baseTestClass ) )
98+ if ( Equals ( current , _baseTestClass ) ||
99+ Equals ( current , _FTbaseTestClass ) )
96100 return true ;
97101 }
98102
0 commit comments