Skip to content
62 changes: 57 additions & 5 deletions test/Microsoft.ML.Functional.Tests/Validation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
// See the LICENSE file in the project root for more information.

using Microsoft.Data.DataView;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Internallearn;
using Microsoft.ML.RunTests;
using Microsoft.ML.TestFramework;
using Microsoft.ML.Trainers.HalLearners;
using Xunit;
using static Microsoft.ML.RunTests.TestDataViewBase;

namespace Microsoft.ML.Functional.Tests
{
Expand All @@ -23,20 +26,20 @@ public class ValidationScenarios
[Fact]
void CrossValidation()
{
var mlContext = new MLContext(seed: 789);
var mlContext = new MLContext(seed: 1, conc: 1);

// Get the dataset
// Get the dataset.
var data = mlContext.Data.CreateTextLoader(TestDatasets.housing.GetLoaderColumns(), hasHeader: true)
.Read(BaseTestClass.GetDataPath(TestDatasets.housing.trainFilename));

// Create a pipeline to train on the sentiment data
// Create a pipeline to train on the sentiment data.
var pipeline = mlContext.Transforms.Concatenate("Features", new string[] {
"CrimesPerCapita", "PercentResidental", "PercentNonRetail", "CharlesRiver", "NitricOxides", "RoomsPerDwelling",
"PercentPre40s", "EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio"})
.Append(mlContext.Transforms.CopyColumns("Label", "MedianHomeValue"))
.Append(mlContext.Regression.Trainers.OrdinaryLeastSquares());

// Compute the CV result
// Compute the CV result.
var cvResult = mlContext.Regression.CrossValidate(data, pipeline, numFolds: 5);

// Check that the results are valid
Expand All @@ -45,9 +48,58 @@ void CrossValidation()
Assert.True(cvResult[0].ScoredHoldOutSet is IDataView);
Assert.Equal(5, cvResult.Length);

// And validate the metrics
// And validate the metrics.
foreach (var result in cvResult)
Common.CheckMetrics(result.Metrics);
}

/// <summary>
/// Train with validation set.
/// </summary>
[Fact]
public void TrainWithValidationSet()
{
var mlContext = new MLContext(seed: 1, conc: 1);

// Get the dataset.
var data = mlContext.Data.CreateTextLoader(TestDatasets.housing.GetLoaderColumns(), hasHeader: true)
.Read(BaseTestClass.GetDataPath(TestDatasets.housing.trainFilename));
var dataSplit = mlContext.Regression.TrainTestSplit(data, testFraction: 0.2);
var trainData = dataSplit.TrainSet;
var validData = dataSplit.TestSet;

// Create a pipeline to featurize the dataset.
var pipeline = mlContext.Transforms.Concatenate("Features", new string[] {
"CrimesPerCapita", "PercentResidental", "PercentNonRetail", "CharlesRiver", "NitricOxides", "RoomsPerDwelling",
"PercentPre40s", "EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio"})
.Append(mlContext.Transforms.CopyColumns("Label", "MedianHomeValue"))
.AppendCacheCheckpoint(mlContext) as IEstimator<ITransformer>;

// Preprocess the datasets.
var preprocessor = pipeline.Fit(trainData);
var preprocessedTrainData = preprocessor.Transform(trainData);
var preprocessedValidData = preprocessor.Transform(validData);

// Train the model with a validation set.
var trainedModel = mlContext.Regression.Trainers.FastTree(new Trainers.FastTree.FastTreeRegressionTrainer.Options {
NumTrees = 2,
EarlyStoppingMetrics = 2,
EarlyStoppingRule = new GLEarlyStoppingCriterion.Arguments()
})
.Train(trainData: preprocessedTrainData, validationData: preprocessedValidData);

// Combine the model.
var model = preprocessor.Append(trainedModel);

// Score the data sets.
var scoredTrainData = model.Transform(trainData);
var scoredValidData = model.Transform(validData);

var trainMetrics = mlContext.Regression.Evaluate(scoredTrainData);
var validMetrics = mlContext.Regression.Evaluate(scoredValidData);

Common.CheckMetrics(trainMetrics);
Common.CheckMetrics(validMetrics);
}
}
}

This file was deleted.