@@ -739,5 +739,64 @@ public void TestCrossValidationMacroWithNonDefaultNames()
739739 }
740740 }
741741 }
742+
743+ [ Fact ]
744+ public void TestOvaMacro ( )
745+ {
746+ var dataPath = GetDataPath ( @"iris.txt" ) ;
747+ using ( var env = new TlcEnvironment ( 42 ) )
748+ {
749+ // Specify subgraph for OVA
750+ var subGraph = env . CreateExperiment ( ) ;
751+ var learnerInput = new Trainers . StochasticDualCoordinateAscentBinaryClassifier { NumThreads = 1 } ;
752+ var learnerOutput = subGraph . Add ( learnerInput ) ;
753+ // Create pipeline with OVA and multiclass scoring.
754+ var experiment = env . CreateExperiment ( ) ;
755+ var importInput = new ML . Data . TextLoader ( dataPath ) ;
756+ importInput . Arguments . Column = new TextLoaderColumn [ ]
757+ {
758+ new TextLoaderColumn { Name = "Label" , Source = new [ ] { new TextLoaderRange ( 0 ) } } ,
759+ new TextLoaderColumn { Name = "Features" , Source = new [ ] { new TextLoaderRange ( 1 , 4 ) } }
760+ } ;
761+ var importOutput = experiment . Add ( importInput ) ;
762+ var oneVersusAll = new Models . OneVersusAll
763+ {
764+ TrainingData = importOutput . Data ,
765+ Nodes = subGraph ,
766+ UseProbabilities = true ,
767+ } ;
768+ var ovaOutput = experiment . Add ( oneVersusAll ) ;
769+ var scoreInput = new ML . Transforms . DatasetScorer
770+ {
771+ Data = importOutput . Data ,
772+ PredictorModel = ovaOutput . PredictorModel
773+ } ;
774+ var scoreOutput = experiment . Add ( scoreInput ) ;
775+ var evalInput = new ML . Models . ClassificationEvaluator
776+ {
777+ Data = scoreOutput . ScoredData
778+ } ;
779+ var evalOutput = experiment . Add ( evalInput ) ;
780+ experiment . Compile ( ) ;
781+ experiment . SetInput ( importInput . InputFile , new SimpleFileHandle ( env , dataPath , false , false ) ) ;
782+ experiment . Run ( ) ;
783+
784+ var data = experiment . GetOutput ( evalOutput . OverallMetrics ) ;
785+ var schema = data . Schema ;
786+ var b = schema . TryGetColumnIndex ( MultiClassClassifierEvaluator . AccuracyMacro , out int accCol ) ;
787+ Assert . True ( b ) ;
788+ using ( var cursor = data . GetRowCursor ( col => col == accCol ) )
789+ {
790+ var getter = cursor . GetGetter < double > ( accCol ) ;
791+ b = cursor . MoveNext ( ) ;
792+ Assert . True ( b ) ;
793+ double acc = 0 ;
794+ getter ( ref acc ) ;
795+ Assert . Equal ( 0.96 , acc , 2 ) ;
796+ b = cursor . MoveNext ( ) ;
797+ Assert . False ( b ) ;
798+ }
799+ }
800+ }
742801 }
743802}
0 commit comments