@@ -42,19 +42,48 @@ public void AutoFitBinaryTest()
4242 Assert . NotNull ( result . BestRun . TrainerName ) ;
4343 }
4444
45- [ Fact ]
46- public void AutoFitMultiTest ( )
45+ [ Theory ]
46+ [ InlineData ( true ) ]
47+ [ InlineData ( false ) ]
48+ public void AutoFitMultiTest ( bool useNumberOfCVFolds )
4749 {
4850 var context = new MLContext ( 0 ) ;
4951 var columnInference = context . Auto ( ) . InferColumns ( DatasetUtil . TrivialMulticlassDatasetPath , DatasetUtil . TrivialMulticlassDatasetLabel ) ;
5052 var textLoader = context . Data . CreateTextLoader ( columnInference . TextLoaderOptions ) ;
5153 var trainData = textLoader . Load ( DatasetUtil . TrivialMulticlassDatasetPath ) ;
52- var result = context . Auto ( )
53- . CreateMulticlassClassificationExperiment ( 0 )
54- . Execute ( trainData , 5 , DatasetUtil . TrivialMulticlassDatasetLabel ) ;
55- Assert . True ( result . BestRun . Results . First ( ) . ValidationMetrics . MicroAccuracy >= 0.7 ) ;
56- var scoredData = result . BestRun . Results . First ( ) . Model . Transform ( trainData ) ;
57- Assert . Equal ( NumberDataViewType . Single , scoredData . Schema [ DefaultColumnNames . PredictedLabel ] . Type ) ;
54+
55+ if ( useNumberOfCVFolds )
56+ {
57+ // When setting numberOfCVFolds
58+ // The results object is a CrossValidationExperimentResults<> object
59+ uint numberOfCVFolds = 5 ;
60+ var result = context . Auto ( )
61+ . CreateMulticlassClassificationExperiment ( 0 )
62+ . Execute ( trainData , numberOfCVFolds , DatasetUtil . TrivialMulticlassDatasetLabel ) ;
63+
64+ Assert . True ( result . BestRun . Results . First ( ) . ValidationMetrics . MicroAccuracy >= 0.7 ) ;
65+ var scoredData = result . BestRun . Results . First ( ) . Model . Transform ( trainData ) ;
66+ Assert . Equal ( NumberDataViewType . Single , scoredData . Schema [ DefaultColumnNames . PredictedLabel ] . Type ) ;
67+ }
68+ else
69+ {
70+ // When using this other API, if the trainset is under the
71+ // crossValRowCounThreshold, AutoML will also perform CrossValidation
72+ // but through a very different path that the one above,
73+ // throw a CrossValSummaryRunner and will return
74+ // a different type of object as "result" which would now be
75+ // simply a ExperimentResult<> object
76+
77+ int crossValRowCountThreshold = 15000 ;
78+ trainData = context . Data . TakeRows ( trainData , crossValRowCountThreshold - 1 ) ;
79+ var result = context . Auto ( )
80+ . CreateMulticlassClassificationExperiment ( 0 )
81+ . Execute ( trainData , DatasetUtil . TrivialMulticlassDatasetLabel ) ;
82+
83+ Assert . True ( result . BestRun . ValidationMetrics . MicroAccuracy >= 0.7 ) ;
84+ var scoredData = result . BestRun . Model . Transform ( trainData ) ;
85+ Assert . Equal ( NumberDataViewType . Single , scoredData . Schema [ DefaultColumnNames . PredictedLabel ] . Type ) ;
86+ }
5887 }
5988
6089 [ TensorFlowFact ]
0 commit comments