diff --git a/test/Microsoft.ML.Functional.Tests/Validation.cs b/test/Microsoft.ML.Functional.Tests/Validation.cs index b04eff387a..c1d98881c2 100644 --- a/test/Microsoft.ML.Functional.Tests/Validation.cs +++ b/test/Microsoft.ML.Functional.Tests/Validation.cs @@ -3,11 +3,14 @@ // See the LICENSE file in the project root for more information. using Microsoft.Data.DataView; +using Microsoft.ML.Core.Data; using Microsoft.ML.Data; +using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.RunTests; using Microsoft.ML.TestFramework; using Microsoft.ML.Trainers.HalLearners; using Xunit; +using static Microsoft.ML.RunTests.TestDataViewBase; namespace Microsoft.ML.Functional.Tests { @@ -23,20 +26,20 @@ public class ValidationScenarios [Fact] void CrossValidation() { - var mlContext = new MLContext(seed: 789); + var mlContext = new MLContext(seed: 1, conc: 1); - // Get the dataset + // Get the dataset. var data = mlContext.Data.CreateTextLoader(TestDatasets.housing.GetLoaderColumns(), hasHeader: true) .Read(BaseTestClass.GetDataPath(TestDatasets.housing.trainFilename)); - // Create a pipeline to train on the sentiment data + // Create a pipeline to train on the sentiment data. var pipeline = mlContext.Transforms.Concatenate("Features", new string[] { "CrimesPerCapita", "PercentResidental", "PercentNonRetail", "CharlesRiver", "NitricOxides", "RoomsPerDwelling", "PercentPre40s", "EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio"}) .Append(mlContext.Transforms.CopyColumns("Label", "MedianHomeValue")) .Append(mlContext.Regression.Trainers.OrdinaryLeastSquares()); - // Compute the CV result + // Compute the CV result. var cvResult = mlContext.Regression.CrossValidate(data, pipeline, numFolds: 5); // Check that the results are valid @@ -45,9 +48,58 @@ void CrossValidation() Assert.True(cvResult[0].ScoredHoldOutSet is IDataView); Assert.Equal(5, cvResult.Length); - // And validate the metrics + // And validate the metrics. foreach (var result in cvResult) Common.CheckMetrics(result.Metrics); } + + /// + /// Train with validation set. + /// + [Fact] + public void TrainWithValidationSet() + { + var mlContext = new MLContext(seed: 1, conc: 1); + + // Get the dataset. + var data = mlContext.Data.CreateTextLoader(TestDatasets.housing.GetLoaderColumns(), hasHeader: true) + .Read(BaseTestClass.GetDataPath(TestDatasets.housing.trainFilename)); + var dataSplit = mlContext.Regression.TrainTestSplit(data, testFraction: 0.2); + var trainData = dataSplit.TrainSet; + var validData = dataSplit.TestSet; + + // Create a pipeline to featurize the dataset. + var pipeline = mlContext.Transforms.Concatenate("Features", new string[] { + "CrimesPerCapita", "PercentResidental", "PercentNonRetail", "CharlesRiver", "NitricOxides", "RoomsPerDwelling", + "PercentPre40s", "EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio"}) + .Append(mlContext.Transforms.CopyColumns("Label", "MedianHomeValue")) + .AppendCacheCheckpoint(mlContext) as IEstimator; + + // Preprocess the datasets. + var preprocessor = pipeline.Fit(trainData); + var preprocessedTrainData = preprocessor.Transform(trainData); + var preprocessedValidData = preprocessor.Transform(validData); + + // Train the model with a validation set. + var trainedModel = mlContext.Regression.Trainers.FastTree(new Trainers.FastTree.FastTreeRegressionTrainer.Options { + NumTrees = 2, + EarlyStoppingMetrics = 2, + EarlyStoppingRule = new GLEarlyStoppingCriterion.Arguments() + }) + .Train(trainData: preprocessedTrainData, validationData: preprocessedValidData); + + // Combine the model. + var model = preprocessor.Append(trainedModel); + + // Score the data sets. + var scoredTrainData = model.Transform(trainData); + var scoredValidData = model.Transform(validData); + + var trainMetrics = mlContext.Regression.Evaluate(scoredTrainData); + var validMetrics = mlContext.Regression.Evaluate(scoredValidData); + + Common.CheckMetrics(trainMetrics); + Common.CheckMetrics(validMetrics); + } } } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithValidationSet.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithValidationSet.cs deleted file mode 100644 index 9fde458562..0000000000 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithValidationSet.cs +++ /dev/null @@ -1,35 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using Microsoft.ML.RunTests; -using Xunit; - -namespace Microsoft.ML.Tests.Scenarios.Api -{ - public partial class ApiScenariosTests - { - /// - /// Train with validation set: Similar to the simple train scenario, but also support a validation set. - /// The learner might be trees with early stopping. - /// - [Fact] - public void TrainWithValidationSet() - { - var ml = new MLContext(seed: 1, conc: 1); - // Pipeline. - var data = ml.Data.ReadFromTextFile(GetDataPath(TestDatasets.Sentiment.trainFilename), hasHeader: true); - var pipeline = ml.Transforms.Text.FeaturizeText("Features", "SentimentText"); - - // Train the pipeline, prepare train and validation set. - var preprocess = pipeline.Fit(data); - var trainData = preprocess.Transform(data); - var validDataSource = ml.Data.ReadFromTextFile(GetDataPath(TestDatasets.Sentiment.testFilename), hasHeader: true); - var validData = preprocess.Transform(validDataSource); - - // Train model with validation set. - var trainer = ml.BinaryClassification.Trainers.FastTree("Label","Features"); - var model = trainer.Train(ml.Data.Cache(trainData), ml.Data.Cache(validData)); - } - } -}