33// See the LICENSE file in the project root for more information.
44
55using Microsoft . Data . DataView ;
6+ using Microsoft . ML . Core . Data ;
67using Microsoft . ML . Data ;
8+ using Microsoft . ML . Internal . Internallearn ;
79using Microsoft . ML . RunTests ;
810using Microsoft . ML . TestFramework ;
911using Microsoft . ML . Trainers . HalLearners ;
1012using Xunit ;
13+ using static Microsoft . ML . RunTests . TestDataViewBase ;
1114
1215namespace Microsoft . ML . Functional . Tests
1316{
@@ -23,20 +26,20 @@ public class ValidationScenarios
2326 [ Fact ]
2427 void CrossValidation ( )
2528 {
26- var mlContext = new MLContext ( seed : 789 ) ;
29+ var mlContext = new MLContext ( seed : 1 , conc : 1 ) ;
2730
28- // Get the dataset
31+ // Get the dataset.
2932 var data = mlContext . Data . CreateTextLoader ( TestDatasets . housing . GetLoaderColumns ( ) , hasHeader : true )
3033 . Read ( BaseTestClass . GetDataPath ( TestDatasets . housing . trainFilename ) ) ;
3134
32- // Create a pipeline to train on the sentiment data
35+ // Create a pipeline to train on the sentiment data.
3336 var pipeline = mlContext . Transforms . Concatenate ( "Features" , new string [ ] {
3437 "CrimesPerCapita" , "PercentResidental" , "PercentNonRetail" , "CharlesRiver" , "NitricOxides" , "RoomsPerDwelling" ,
3538 "PercentPre40s" , "EmploymentDistance" , "HighwayDistance" , "TaxRate" , "TeacherRatio" } )
3639 . Append ( mlContext . Transforms . CopyColumns ( "Label" , "MedianHomeValue" ) )
3740 . Append ( mlContext . Regression . Trainers . OrdinaryLeastSquares ( ) ) ;
3841
39- // Compute the CV result
42+ // Compute the CV result.
4043 var cvResult = mlContext . Regression . CrossValidate ( data , pipeline , numFolds : 5 ) ;
4144
4245 // Check that the results are valid
@@ -45,9 +48,58 @@ void CrossValidation()
4548 Assert . True ( cvResult [ 0 ] . ScoredHoldOutSet is IDataView ) ;
4649 Assert . Equal ( 5 , cvResult . Length ) ;
4750
48- // And validate the metrics
51+ // And validate the metrics.
4952 foreach ( var result in cvResult )
5053 Common . CheckMetrics ( result . Metrics ) ;
5154 }
55+
56+ /// <summary>
57+ /// Train with validation set.
58+ /// </summary>
59+ [ Fact ]
60+ public void TrainWithValidationSet ( )
61+ {
62+ var mlContext = new MLContext ( seed : 1 , conc : 1 ) ;
63+
64+ // Get the dataset.
65+ var data = mlContext . Data . CreateTextLoader ( TestDatasets . housing . GetLoaderColumns ( ) , hasHeader : true )
66+ . Read ( BaseTestClass . GetDataPath ( TestDatasets . housing . trainFilename ) ) ;
67+ var dataSplit = mlContext . Regression . TrainTestSplit ( data , testFraction : 0.2 ) ;
68+ var trainData = dataSplit . TrainSet ;
69+ var validData = dataSplit . TestSet ;
70+
71+ // Create a pipeline to featurize the dataset.
72+ var pipeline = mlContext . Transforms . Concatenate ( "Features" , new string [ ] {
73+ "CrimesPerCapita" , "PercentResidental" , "PercentNonRetail" , "CharlesRiver" , "NitricOxides" , "RoomsPerDwelling" ,
74+ "PercentPre40s" , "EmploymentDistance" , "HighwayDistance" , "TaxRate" , "TeacherRatio" } )
75+ . Append ( mlContext . Transforms . CopyColumns ( "Label" , "MedianHomeValue" ) )
76+ . AppendCacheCheckpoint ( mlContext ) as IEstimator < ITransformer > ;
77+
78+ // Preprocess the datasets.
79+ var preprocessor = pipeline . Fit ( trainData ) ;
80+ var preprocessedTrainData = preprocessor . Transform ( trainData ) ;
81+ var preprocessedValidData = preprocessor . Transform ( validData ) ;
82+
83+ // Train the model with a validation set.
84+ var trainedModel = mlContext . Regression . Trainers . FastTree ( new Trainers . FastTree . FastTreeRegressionTrainer . Options {
85+ NumTrees = 2 ,
86+ EarlyStoppingMetrics = 2 ,
87+ EarlyStoppingRule = new GLEarlyStoppingCriterion . Arguments ( )
88+ } )
89+ . Train ( trainData : preprocessedTrainData , validationData : preprocessedValidData ) ;
90+
91+ // Combine the model.
92+ var model = preprocessor . Append ( trainedModel ) ;
93+
94+ // Score the data sets.
95+ var scoredTrainData = model . Transform ( trainData ) ;
96+ var scoredValidData = model . Transform ( validData ) ;
97+
98+ var trainMetrics = mlContext . Regression . Evaluate ( scoredTrainData ) ;
99+ var validMetrics = mlContext . Regression . Evaluate ( scoredValidData ) ;
100+
101+ Common . CheckMetrics ( trainMetrics ) ;
102+ Common . CheckMetrics ( validMetrics ) ;
103+ }
52104 }
53105}
0 commit comments