@@ -25,6 +25,31 @@ public abstract class TrainCatalogBase
2525 [ BestFriend ]
2626 internal IHostEnvironment Environment => Host ;
2727
28+ /// <summary>
29+ /// A pair of datasets, for the train and test set.
30+ /// </summary>
31+ public struct TrainTestData
32+ {
33+ /// <summary>
34+ /// Training set.
35+ /// </summary>
36+ public readonly IDataView TrainSet ;
37+ /// <summary>
38+ /// Testing set.
39+ /// </summary>
40+ public readonly IDataView TestSet ;
41+ /// <summary>
42+ /// Create pair of datasets.
43+ /// </summary>
44+ /// <param name="trainSet">Training set.</param>
45+ /// <param name="testSet">Testing set.</param>
46+ internal TrainTestData ( IDataView trainSet , IDataView testSet )
47+ {
48+ TrainSet = trainSet ;
49+ TestSet = testSet ;
50+ }
51+ }
52+
2853 /// <summary>
2954 /// Split the dataset into the train set and test set according to the given fraction.
3055 /// Respects the <paramref name="stratificationColumn"/> if provided.
@@ -37,8 +62,7 @@ public abstract class TrainCatalogBase
3762 /// <param name="seed">Optional parameter used in combination with the <paramref name="stratificationColumn"/>.
3863 /// If the <paramref name="stratificationColumn"/> is not provided, the random numbers generated to create it, will use this seed as value.
3964 /// And if it is not provided, the default value will be used.</param>
40- /// <returns>A pair of datasets, for the train and test set.</returns>
41- public ( IDataView trainSet , IDataView testSet ) TrainTestSplit ( IDataView data , double testFraction = 0.1 , string stratificationColumn = null , uint ? seed = null )
65+ public TrainTestData TrainTestSplit ( IDataView data , double testFraction = 0.1 , string stratificationColumn = null , uint ? seed = null )
4266 {
4367 Host . CheckValue ( data , nameof ( data ) ) ;
4468 Host . CheckParam ( 0 < testFraction && testFraction < 1 , nameof ( testFraction ) , "Must be between 0 and 1 exclusive" ) ;
@@ -61,14 +85,71 @@ public abstract class TrainCatalogBase
6185 Complement = false
6286 } , data ) ;
6387
64- return ( trainFilter , testFilter ) ;
88+ return new TrainTestData ( trainFilter , testFilter ) ;
89+ }
90+
91+ /// <summary>
92+ /// Results for specific cross-validation fold.
93+ /// </summary>
94+ protected internal struct CrossValidationResult
95+ {
96+ /// <summary>
97+ /// Model trained during cross validation fold.
98+ /// </summary>
99+ public readonly ITransformer Model ;
100+ /// <summary>
101+ /// Scored test set with <see cref="Model"/> for this fold.
102+ /// </summary>
103+ public readonly IDataView Scores ;
104+ /// <summary>
105+ /// Fold number.
106+ /// </summary>
107+ public readonly int Fold ;
108+
109+ public CrossValidationResult ( ITransformer model , IDataView scores , int fold )
110+ {
111+ Model = model ;
112+ Scores = scores ;
113+ Fold = fold ;
114+ }
115+ }
116+ /// <summary>
117+ /// Results of running cross-validation.
118+ /// </summary>
119+ /// <typeparam name="T">Type of metric class.</typeparam>
120+ public sealed class CrossValidationResult < T > where T : class
121+ {
122+ /// <summary>
123+ /// Metrics for this cross-validation fold.
124+ /// </summary>
125+ public readonly T Metrics ;
126+ /// <summary>
127+ /// Model trained during cross-validation fold.
128+ /// </summary>
129+ public readonly ITransformer Model ;
130+ /// <summary>
131+ /// The scored hold-out set for this fold.
132+ /// </summary>
133+ public readonly IDataView ScoredHoldOutSet ;
134+ /// <summary>
135+ /// Fold number.
136+ /// </summary>
137+ public readonly int Fold ;
138+
139+ internal CrossValidationResult ( ITransformer model , T metrics , IDataView scores , int fold )
140+ {
141+ Model = model ;
142+ Metrics = metrics ;
143+ ScoredHoldOutSet = scores ;
144+ Fold = fold ;
145+ }
65146 }
66147
67148 /// <summary>
68149 /// Train the <paramref name="estimator"/> on <paramref name="numFolds"/> folds of the data sequentially.
69150 /// Return each model and each scored test dataset.
70151 /// </summary>
71- protected internal ( IDataView scoredTestSet , ITransformer model ) [ ] CrossValidateTrain ( IDataView data , IEstimator < ITransformer > estimator ,
152+ protected internal CrossValidationResult [ ] CrossValidateTrain ( IDataView data , IEstimator < ITransformer > estimator ,
72153 int numFolds , string stratificationColumn , uint ? seed = null )
73154 {
74155 Host . CheckValue ( data , nameof ( data ) ) ;
@@ -78,7 +159,7 @@ protected internal (IDataView scoredTestSet, ITransformer model)[] CrossValidate
78159
79160 EnsureStratificationColumn ( ref data , ref stratificationColumn , seed ) ;
80161
81- Func < int , ( IDataView scores , ITransformer model ) > foldFunction =
162+ Func < int , CrossValidationResult > foldFunction =
82163 fold =>
83164 {
84165 var trainFilter = new RangeFilter ( Host , new RangeFilter . Options
@@ -98,17 +179,17 @@ protected internal (IDataView scoredTestSet, ITransformer model)[] CrossValidate
98179
99180 var model = estimator . Fit ( trainFilter ) ;
100181 var scoredTest = model . Transform ( testFilter ) ;
101- return ( scoredTest , model ) ;
182+ return new CrossValidationResult ( model , scoredTest , fold ) ;
102183 } ;
103184
104185 // Sequential per-fold training.
105186 // REVIEW: we could have a parallel implementation here. We would need to
106187 // spawn off a separate host per fold in that case.
107- var result = new List < ( IDataView scores , ITransformer model ) > ( ) ;
188+ var result = new CrossValidationResult [ numFolds ] ;
108189 for ( int fold = 0 ; fold < numFolds ; fold ++ )
109- result . Add ( foldFunction ( fold ) ) ;
190+ result [ fold ] = foldFunction ( fold ) ;
110191
111- return result . ToArray ( ) ;
192+ return result ;
112193 }
113194
114195 protected internal TrainCatalogBase ( IHostEnvironment env , string registrationName )
@@ -263,13 +344,14 @@ public BinaryClassificationMetrics EvaluateNonCalibrated(IDataView data, string
263344 /// If the <paramref name="stratificationColumn"/> is not provided, the random numbers generated to create it, will use this seed as value.
264345 /// And if it is not provided, the default value will be used.</param>
265346 /// <returns>Per-fold results: metrics, models, scored datasets.</returns>
266- public ( BinaryClassificationMetrics metrics , ITransformer model , IDataView scoredTestData ) [ ] CrossValidateNonCalibrated (
347+ public CrossValidationResult < BinaryClassificationMetrics > [ ] CrossValidateNonCalibrated (
267348 IDataView data , IEstimator < ITransformer > estimator , int numFolds = 5 , string labelColumn = DefaultColumnNames . Label ,
268349 string stratificationColumn = null , uint ? seed = null )
269350 {
270351 Host . CheckNonEmpty ( labelColumn , nameof ( labelColumn ) ) ;
271352 var result = CrossValidateTrain ( data , estimator , numFolds , stratificationColumn , seed ) ;
272- return result . Select ( x => ( EvaluateNonCalibrated ( x . scoredTestSet , labelColumn ) , x . model , x . scoredTestSet ) ) . ToArray ( ) ;
353+ return result . Select ( x => new CrossValidationResult < BinaryClassificationMetrics > ( x . Model ,
354+ EvaluateNonCalibrated ( x . Scores , labelColumn ) , x . Scores , x . Fold ) ) . ToArray ( ) ;
273355 }
274356
275357 /// <summary>
@@ -287,13 +369,14 @@ public BinaryClassificationMetrics EvaluateNonCalibrated(IDataView data, string
287369 /// train to the test set.</remarks>
288370 /// <param name="seed">If <paramref name="stratificationColumn"/> not present in dataset we will generate random filled column based on provided <paramref name="seed"/>.</param>
289371 /// <returns>Per-fold results: metrics, models, scored datasets.</returns>
290- public ( CalibratedBinaryClassificationMetrics metrics , ITransformer model , IDataView scoredTestData ) [ ] CrossValidate (
372+ public CrossValidationResult < CalibratedBinaryClassificationMetrics > [ ] CrossValidate (
291373 IDataView data , IEstimator < ITransformer > estimator , int numFolds = 5 , string labelColumn = DefaultColumnNames . Label ,
292374 string stratificationColumn = null , uint ? seed = null )
293375 {
294376 Host . CheckNonEmpty ( labelColumn , nameof ( labelColumn ) ) ;
295377 var result = CrossValidateTrain ( data , estimator , numFolds , stratificationColumn , seed ) ;
296- return result . Select ( x => ( Evaluate ( x . scoredTestSet , labelColumn ) , x . model , x . scoredTestSet ) ) . ToArray ( ) ;
378+ return result . Select ( x => new CrossValidationResult < CalibratedBinaryClassificationMetrics > ( x . Model ,
379+ Evaluate ( x . Scores , labelColumn ) , x . Scores , x . Fold ) ) . ToArray ( ) ;
297380 }
298381 }
299382
@@ -369,12 +452,13 @@ public ClusteringMetrics Evaluate(IDataView data,
369452 /// If the <paramref name="stratificationColumn"/> is not provided, the random numbers generated to create it, will use this seed as value.
370453 /// And if it is not provided, the default value will be used.</param>
371454 /// <returns>Per-fold results: metrics, models, scored datasets.</returns>
372- public ( ClusteringMetrics metrics , ITransformer model , IDataView scoredTestData ) [ ] CrossValidate (
455+ public CrossValidationResult < ClusteringMetrics > [ ] CrossValidate (
373456 IDataView data , IEstimator < ITransformer > estimator , int numFolds = 5 , string labelColumn = null , string featuresColumn = null ,
374457 string stratificationColumn = null , uint ? seed = null )
375458 {
376459 var result = CrossValidateTrain ( data , estimator , numFolds , stratificationColumn , seed ) ;
377- return result . Select ( x => ( Evaluate ( x . scoredTestSet , label : labelColumn , features : featuresColumn ) , x . model , x . scoredTestSet ) ) . ToArray ( ) ;
460+ return result . Select ( x => new CrossValidationResult < ClusteringMetrics > ( x . Model ,
461+ Evaluate ( x . Scores , label : labelColumn , features : featuresColumn ) , x . Scores , x . Fold ) ) . ToArray ( ) ;
378462 }
379463 }
380464
@@ -444,13 +528,14 @@ public MultiClassClassifierMetrics Evaluate(IDataView data, string label = Defau
444528 /// If the <paramref name="stratificationColumn"/> is not provided, the random numbers generated to create it, will use this seed as value.
445529 /// And if it is not provided, the default value will be used.</param>
446530 /// <returns>Per-fold results: metrics, models, scored datasets.</returns>
447- public ( MultiClassClassifierMetrics metrics , ITransformer model , IDataView scoredTestData ) [ ] CrossValidate (
531+ public CrossValidationResult < MultiClassClassifierMetrics > [ ] CrossValidate (
448532 IDataView data , IEstimator < ITransformer > estimator , int numFolds = 5 , string labelColumn = DefaultColumnNames . Label ,
449533 string stratificationColumn = null , uint ? seed = null )
450534 {
451535 Host . CheckNonEmpty ( labelColumn , nameof ( labelColumn ) ) ;
452536 var result = CrossValidateTrain ( data , estimator , numFolds , stratificationColumn , seed ) ;
453- return result . Select ( x => ( Evaluate ( x . scoredTestSet , labelColumn ) , x . model , x . scoredTestSet ) ) . ToArray ( ) ;
537+ return result . Select ( x => new CrossValidationResult < MultiClassClassifierMetrics > ( x . Model ,
538+ Evaluate ( x . Scores , labelColumn ) , x . Scores , x . Fold ) ) . ToArray ( ) ;
454539 }
455540 }
456541
@@ -511,13 +596,14 @@ public RegressionMetrics Evaluate(IDataView data, string label = DefaultColumnNa
511596 /// If the <paramref name="stratificationColumn"/> is not provided, the random numbers generated to create it, will use this seed as value.
512597 /// And if it is not provided, the default value will be used.</param>
513598 /// <returns>Per-fold results: metrics, models, scored datasets.</returns>
514- public ( RegressionMetrics metrics , ITransformer model , IDataView scoredTestData ) [ ] CrossValidate (
599+ public CrossValidationResult < RegressionMetrics > [ ] CrossValidate (
515600 IDataView data , IEstimator < ITransformer > estimator , int numFolds = 5 , string labelColumn = DefaultColumnNames . Label ,
516601 string stratificationColumn = null , uint ? seed = null )
517602 {
518603 Host . CheckNonEmpty ( labelColumn , nameof ( labelColumn ) ) ;
519604 var result = CrossValidateTrain ( data , estimator , numFolds , stratificationColumn , seed ) ;
520- return result . Select ( x => ( Evaluate ( x . scoredTestSet , labelColumn ) , x . model , x . scoredTestSet ) ) . ToArray ( ) ;
605+ return result . Select ( x => new CrossValidationResult < RegressionMetrics > ( x . Model ,
606+ Evaluate ( x . Scores , labelColumn ) , x . Scores , x . Fold ) ) . ToArray ( ) ;
521607 }
522608 }
523609
0 commit comments