diff --git a/test/Microsoft.ML.Functional.Tests/Validation.cs b/test/Microsoft.ML.Functional.Tests/Validation.cs
index b04eff387a..c1d98881c2 100644
--- a/test/Microsoft.ML.Functional.Tests/Validation.cs
+++ b/test/Microsoft.ML.Functional.Tests/Validation.cs
@@ -3,11 +3,14 @@
// See the LICENSE file in the project root for more information.
using Microsoft.Data.DataView;
+using Microsoft.ML.Core.Data;
using Microsoft.ML.Data;
+using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.RunTests;
using Microsoft.ML.TestFramework;
using Microsoft.ML.Trainers.HalLearners;
using Xunit;
+using static Microsoft.ML.RunTests.TestDataViewBase;
namespace Microsoft.ML.Functional.Tests
{
@@ -23,20 +26,20 @@ public class ValidationScenarios
[Fact]
void CrossValidation()
{
- var mlContext = new MLContext(seed: 789);
+ var mlContext = new MLContext(seed: 1, conc: 1);
- // Get the dataset
+ // Get the dataset.
var data = mlContext.Data.CreateTextLoader(TestDatasets.housing.GetLoaderColumns(), hasHeader: true)
.Read(BaseTestClass.GetDataPath(TestDatasets.housing.trainFilename));
- // Create a pipeline to train on the sentiment data
+ // Create a pipeline to train on the sentiment data.
var pipeline = mlContext.Transforms.Concatenate("Features", new string[] {
"CrimesPerCapita", "PercentResidental", "PercentNonRetail", "CharlesRiver", "NitricOxides", "RoomsPerDwelling",
"PercentPre40s", "EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio"})
.Append(mlContext.Transforms.CopyColumns("Label", "MedianHomeValue"))
.Append(mlContext.Regression.Trainers.OrdinaryLeastSquares());
- // Compute the CV result
+ // Compute the CV result.
var cvResult = mlContext.Regression.CrossValidate(data, pipeline, numFolds: 5);
// Check that the results are valid
@@ -45,9 +48,58 @@ void CrossValidation()
Assert.True(cvResult[0].ScoredHoldOutSet is IDataView);
Assert.Equal(5, cvResult.Length);
- // And validate the metrics
+ // And validate the metrics.
foreach (var result in cvResult)
Common.CheckMetrics(result.Metrics);
}
+
+ ///
+ /// Train with validation set.
+ ///
+ [Fact]
+ public void TrainWithValidationSet()
+ {
+ var mlContext = new MLContext(seed: 1, conc: 1);
+
+ // Get the dataset.
+ var data = mlContext.Data.CreateTextLoader(TestDatasets.housing.GetLoaderColumns(), hasHeader: true)
+ .Read(BaseTestClass.GetDataPath(TestDatasets.housing.trainFilename));
+ var dataSplit = mlContext.Regression.TrainTestSplit(data, testFraction: 0.2);
+ var trainData = dataSplit.TrainSet;
+ var validData = dataSplit.TestSet;
+
+ // Create a pipeline to featurize the dataset.
+ var pipeline = mlContext.Transforms.Concatenate("Features", new string[] {
+ "CrimesPerCapita", "PercentResidental", "PercentNonRetail", "CharlesRiver", "NitricOxides", "RoomsPerDwelling",
+ "PercentPre40s", "EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio"})
+ .Append(mlContext.Transforms.CopyColumns("Label", "MedianHomeValue"))
+ .AppendCacheCheckpoint(mlContext) as IEstimator;
+
+ // Preprocess the datasets.
+ var preprocessor = pipeline.Fit(trainData);
+ var preprocessedTrainData = preprocessor.Transform(trainData);
+ var preprocessedValidData = preprocessor.Transform(validData);
+
+ // Train the model with a validation set.
+ var trainedModel = mlContext.Regression.Trainers.FastTree(new Trainers.FastTree.FastTreeRegressionTrainer.Options {
+ NumTrees = 2,
+ EarlyStoppingMetrics = 2,
+ EarlyStoppingRule = new GLEarlyStoppingCriterion.Arguments()
+ })
+ .Train(trainData: preprocessedTrainData, validationData: preprocessedValidData);
+
+ // Combine the model.
+ var model = preprocessor.Append(trainedModel);
+
+ // Score the data sets.
+ var scoredTrainData = model.Transform(trainData);
+ var scoredValidData = model.Transform(validData);
+
+ var trainMetrics = mlContext.Regression.Evaluate(scoredTrainData);
+ var validMetrics = mlContext.Regression.Evaluate(scoredValidData);
+
+ Common.CheckMetrics(trainMetrics);
+ Common.CheckMetrics(validMetrics);
+ }
}
}
diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithValidationSet.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithValidationSet.cs
deleted file mode 100644
index 9fde458562..0000000000
--- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithValidationSet.cs
+++ /dev/null
@@ -1,35 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-using Microsoft.ML.RunTests;
-using Xunit;
-
-namespace Microsoft.ML.Tests.Scenarios.Api
-{
- public partial class ApiScenariosTests
- {
- ///
- /// Train with validation set: Similar to the simple train scenario, but also support a validation set.
- /// The learner might be trees with early stopping.
- ///
- [Fact]
- public void TrainWithValidationSet()
- {
- var ml = new MLContext(seed: 1, conc: 1);
- // Pipeline.
- var data = ml.Data.ReadFromTextFile(GetDataPath(TestDatasets.Sentiment.trainFilename), hasHeader: true);
- var pipeline = ml.Transforms.Text.FeaturizeText("Features", "SentimentText");
-
- // Train the pipeline, prepare train and validation set.
- var preprocess = pipeline.Fit(data);
- var trainData = preprocess.Transform(data);
- var validDataSource = ml.Data.ReadFromTextFile(GetDataPath(TestDatasets.Sentiment.testFilename), hasHeader: true);
- var validData = preprocess.Transform(validDataSource);
-
- // Train model with validation set.
- var trainer = ml.BinaryClassification.Trainers.FastTree("Label","Features");
- var model = trainer.Train(ml.Data.Cache(trainData), ml.Data.Cache(validData));
- }
- }
-}