Skip to content

Commit e818e35

Browse files
authored
Add validation scenario tests (#2503)
Add a validation scenario test.
1 parent 5c442a9 commit e818e35

File tree

2 files changed

+57
-40
lines changed

2 files changed

+57
-40
lines changed

test/Microsoft.ML.Functional.Tests/Validation.cs

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@
33
// See the LICENSE file in the project root for more information.
44

55
using Microsoft.Data.DataView;
6+
using Microsoft.ML.Core.Data;
67
using Microsoft.ML.Data;
8+
using Microsoft.ML.Internal.Internallearn;
79
using Microsoft.ML.RunTests;
810
using Microsoft.ML.TestFramework;
911
using Microsoft.ML.Trainers.HalLearners;
1012
using Xunit;
13+
using static Microsoft.ML.RunTests.TestDataViewBase;
1114

1215
namespace Microsoft.ML.Functional.Tests
1316
{
@@ -23,20 +26,20 @@ public class ValidationScenarios
2326
[Fact]
2427
void CrossValidation()
2528
{
26-
var mlContext = new MLContext(seed: 789);
29+
var mlContext = new MLContext(seed: 1, conc: 1);
2730

28-
// Get the dataset
31+
// Get the dataset.
2932
var data = mlContext.Data.CreateTextLoader(TestDatasets.housing.GetLoaderColumns(), hasHeader: true)
3033
.Read(BaseTestClass.GetDataPath(TestDatasets.housing.trainFilename));
3134

32-
// Create a pipeline to train on the sentiment data
35+
// Create a pipeline to train on the sentiment data.
3336
var pipeline = mlContext.Transforms.Concatenate("Features", new string[] {
3437
"CrimesPerCapita", "PercentResidental", "PercentNonRetail", "CharlesRiver", "NitricOxides", "RoomsPerDwelling",
3538
"PercentPre40s", "EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio"})
3639
.Append(mlContext.Transforms.CopyColumns("Label", "MedianHomeValue"))
3740
.Append(mlContext.Regression.Trainers.OrdinaryLeastSquares());
3841

39-
// Compute the CV result
42+
// Compute the CV result.
4043
var cvResult = mlContext.Regression.CrossValidate(data, pipeline, numFolds: 5);
4144

4245
// Check that the results are valid
@@ -45,9 +48,58 @@ void CrossValidation()
4548
Assert.True(cvResult[0].ScoredHoldOutSet is IDataView);
4649
Assert.Equal(5, cvResult.Length);
4750

48-
// And validate the metrics
51+
// And validate the metrics.
4952
foreach (var result in cvResult)
5053
Common.CheckMetrics(result.Metrics);
5154
}
55+
56+
/// <summary>
57+
/// Train with validation set.
58+
/// </summary>
59+
[Fact]
60+
public void TrainWithValidationSet()
61+
{
62+
var mlContext = new MLContext(seed: 1, conc: 1);
63+
64+
// Get the dataset.
65+
var data = mlContext.Data.CreateTextLoader(TestDatasets.housing.GetLoaderColumns(), hasHeader: true)
66+
.Read(BaseTestClass.GetDataPath(TestDatasets.housing.trainFilename));
67+
var dataSplit = mlContext.Regression.TrainTestSplit(data, testFraction: 0.2);
68+
var trainData = dataSplit.TrainSet;
69+
var validData = dataSplit.TestSet;
70+
71+
// Create a pipeline to featurize the dataset.
72+
var pipeline = mlContext.Transforms.Concatenate("Features", new string[] {
73+
"CrimesPerCapita", "PercentResidental", "PercentNonRetail", "CharlesRiver", "NitricOxides", "RoomsPerDwelling",
74+
"PercentPre40s", "EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio"})
75+
.Append(mlContext.Transforms.CopyColumns("Label", "MedianHomeValue"))
76+
.AppendCacheCheckpoint(mlContext) as IEstimator<ITransformer>;
77+
78+
// Preprocess the datasets.
79+
var preprocessor = pipeline.Fit(trainData);
80+
var preprocessedTrainData = preprocessor.Transform(trainData);
81+
var preprocessedValidData = preprocessor.Transform(validData);
82+
83+
// Train the model with a validation set.
84+
var trainedModel = mlContext.Regression.Trainers.FastTree(new Trainers.FastTree.FastTreeRegressionTrainer.Options {
85+
NumTrees = 2,
86+
EarlyStoppingMetrics = 2,
87+
EarlyStoppingRule = new GLEarlyStoppingCriterion.Arguments()
88+
})
89+
.Train(trainData: preprocessedTrainData, validationData: preprocessedValidData);
90+
91+
// Combine the model.
92+
var model = preprocessor.Append(trainedModel);
93+
94+
// Score the data sets.
95+
var scoredTrainData = model.Transform(trainData);
96+
var scoredValidData = model.Transform(validData);
97+
98+
var trainMetrics = mlContext.Regression.Evaluate(scoredTrainData);
99+
var validMetrics = mlContext.Regression.Evaluate(scoredValidData);
100+
101+
Common.CheckMetrics(trainMetrics);
102+
Common.CheckMetrics(validMetrics);
103+
}
52104
}
53105
}

test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithValidationSet.cs

Lines changed: 0 additions & 35 deletions
This file was deleted.

0 commit comments

Comments
 (0)