diff --git a/src/Microsoft.ML.AutoML/API/AutoMLExperimentExtension.cs b/src/Microsoft.ML.AutoML/API/AutoMLExperimentExtension.cs
index 35d0a679ad..727e443852 100644
--- a/src/Microsoft.ML.AutoML/API/AutoMLExperimentExtension.cs
+++ b/src/Microsoft.ML.AutoML/API/AutoMLExperimentExtension.cs
@@ -27,10 +27,10 @@ public static class AutoMLExperimentExtension
///
public static AutoMLExperiment SetDataset(this AutoMLExperiment experiment, IDataView train, IDataView validation)
{
- var datasetManager = new TrainTestDatasetManager()
+ var datasetManager = new TrainValidateDatasetManager()
{
TrainDataset = train,
- TestDataset = validation
+ ValidateDataset = validation
};
experiment.ServiceCollection.AddSingleton(datasetManager);
diff --git a/src/Microsoft.ML.AutoML/API/BinaryClassificationExperiment.cs b/src/Microsoft.ML.AutoML/API/BinaryClassificationExperiment.cs
index bbf4ef6531..6db1da1dec 100644
--- a/src/Microsoft.ML.AutoML/API/BinaryClassificationExperiment.cs
+++ b/src/Microsoft.ML.AutoML/API/BinaryClassificationExperiment.cs
@@ -400,12 +400,12 @@ public TrialResult Run(TrialSettings settings)
};
}
- if (_datasetManager is ITrainTestDatasetManager trainTestDatasetManager)
+ if (_datasetManager is ITrainValidateDatasetManager trainTestDatasetManager)
{
var stopWatch = new Stopwatch();
stopWatch.Start();
var model = pipeline.Fit(trainTestDatasetManager.TrainDataset);
- var eval = model.Transform(trainTestDatasetManager.TestDataset);
+ var eval = model.Transform(trainTestDatasetManager.ValidateDataset);
var metrics = _context.BinaryClassification.EvaluateNonCalibrated(eval, metricManager.LabelColumn, predictedLabelColumnName: metricManager.PredictedColumn);
var metric = GetMetric(metricManager.Metric, metrics);
var loss = metricManager.IsMaximize ? -metric : metric;
@@ -426,7 +426,7 @@ public TrialResult Run(TrialSettings settings)
}
}
- throw new ArgumentException($"The runner metric manager is of type {_metricManager.GetType()} which expected to be of type {typeof(ITrainTestDatasetManager)} or {typeof(ICrossValidateDatasetManager)}");
+ throw new ArgumentException($"The runner metric manager is of type {_metricManager.GetType()} which expected to be of type {typeof(ITrainValidateDatasetManager)} or {typeof(ICrossValidateDatasetManager)}");
}
public Task RunAsync(TrialSettings settings, CancellationToken ct)
diff --git a/src/Microsoft.ML.AutoML/API/MulticlassClassificationExperiment.cs b/src/Microsoft.ML.AutoML/API/MulticlassClassificationExperiment.cs
index 975381f2c6..b855c9e71a 100644
--- a/src/Microsoft.ML.AutoML/API/MulticlassClassificationExperiment.cs
+++ b/src/Microsoft.ML.AutoML/API/MulticlassClassificationExperiment.cs
@@ -394,12 +394,12 @@ public TrialResult Run(TrialSettings settings)
};
}
- if (_datasetManager is ITrainTestDatasetManager trainTestDatasetManager)
+ if (_datasetManager is ITrainValidateDatasetManager trainTestDatasetManager)
{
var stopWatch = new Stopwatch();
stopWatch.Start();
var model = pipeline.Fit(trainTestDatasetManager.TrainDataset);
- var eval = model.Transform(trainTestDatasetManager.TestDataset);
+ var eval = model.Transform(trainTestDatasetManager.ValidateDataset);
var metrics = _context.MulticlassClassification.Evaluate(eval, metricManager.LabelColumn, predictedLabelColumnName: metricManager.PredictedColumn);
var metric = GetMetric(metricManager.Metric, metrics);
var loss = metricManager.IsMaximize ? -metric : metric;
@@ -420,7 +420,7 @@ public TrialResult Run(TrialSettings settings)
}
}
- throw new ArgumentException($"The runner metric manager is of type {_metricManager.GetType()} which expected to be of type {typeof(ITrainTestDatasetManager)} or {typeof(ICrossValidateDatasetManager)}");
+ throw new ArgumentException($"The runner metric manager is of type {_metricManager.GetType()} which expected to be of type {typeof(ITrainValidateDatasetManager)} or {typeof(ICrossValidateDatasetManager)}");
}
public Task RunAsync(TrialSettings settings, CancellationToken ct)
diff --git a/src/Microsoft.ML.AutoML/API/RegressionExperiment.cs b/src/Microsoft.ML.AutoML/API/RegressionExperiment.cs
index 3adad3fddc..99f9e9800f 100644
--- a/src/Microsoft.ML.AutoML/API/RegressionExperiment.cs
+++ b/src/Microsoft.ML.AutoML/API/RegressionExperiment.cs
@@ -421,12 +421,12 @@ public Task RunAsync(TrialSettings settings, CancellationToken ct)
} as TrialResult);
}
- if (_datasetManager is ITrainTestDatasetManager trainTestDatasetManager)
+ if (_datasetManager is ITrainValidateDatasetManager trainTestDatasetManager)
{
var stopWatch = new Stopwatch();
stopWatch.Start();
var model = pipeline.Fit(trainTestDatasetManager.TrainDataset);
- var eval = model.Transform(trainTestDatasetManager.TestDataset);
+ var eval = model.Transform(trainTestDatasetManager.ValidateDataset);
var metrics = _context.Regression.Evaluate(eval, metricManager.LabelColumn, scoreColumnName: metricManager.ScoreColumn);
var metric = GetMetric(metricManager.Metric, metrics);
var loss = metricManager.IsMaximize ? -metric : metric;
@@ -447,7 +447,7 @@ public Task RunAsync(TrialSettings settings, CancellationToken ct)
}
}
- throw new ArgumentException($"The runner metric manager is of type {_metricManager.GetType()} which expected to be of type {typeof(ITrainTestDatasetManager)} or {typeof(ICrossValidateDatasetManager)}");
+ throw new ArgumentException($"The runner metric manager is of type {_metricManager.GetType()} which expected to be of type {typeof(ITrainValidateDatasetManager)} or {typeof(ICrossValidateDatasetManager)}");
}
}
catch (Exception ex) when (ct.IsCancellationRequested)
diff --git a/src/Microsoft.ML.AutoML/AutoMLExperiment/IDatasetManager.cs b/src/Microsoft.ML.AutoML/AutoMLExperiment/IDatasetManager.cs
index ee752f3e07..e1cac220bc 100644
--- a/src/Microsoft.ML.AutoML/AutoMLExperiment/IDatasetManager.cs
+++ b/src/Microsoft.ML.AutoML/AutoMLExperiment/IDatasetManager.cs
@@ -19,18 +19,18 @@ internal interface ICrossValidateDatasetManager
IDataView Dataset { get; set; }
}
- internal interface ITrainTestDatasetManager
+ internal interface ITrainValidateDatasetManager
{
IDataView TrainDataset { get; set; }
- IDataView TestDataset { get; set; }
+ IDataView ValidateDataset { get; set; }
}
- internal class TrainTestDatasetManager : IDatasetManager, ITrainTestDatasetManager
+ internal class TrainValidateDatasetManager : IDatasetManager, ITrainValidateDatasetManager
{
public IDataView TrainDataset { get; set; }
- public IDataView TestDataset { get; set; }
+ public IDataView ValidateDataset { get; set; }
}
internal class CrossValidateDatasetManager : IDatasetManager, ICrossValidateDatasetManager
diff --git a/src/Microsoft.ML.AutoML/AutoMLExperiment/Runner/SweepablePipelineRunner.cs b/src/Microsoft.ML.AutoML/AutoMLExperiment/Runner/SweepablePipelineRunner.cs
index ff3298c291..c0b2b7d6a8 100644
--- a/src/Microsoft.ML.AutoML/AutoMLExperiment/Runner/SweepablePipelineRunner.cs
+++ b/src/Microsoft.ML.AutoML/AutoMLExperiment/Runner/SweepablePipelineRunner.cs
@@ -66,10 +66,10 @@ public TrialResult Run(TrialSettings settings)
};
}
- if (_datasetManager is ITrainTestDatasetManager trainTestDatasetManager)
+ if (_datasetManager is ITrainValidateDatasetManager trainTestDatasetManager)
{
var model = mlnetPipeline.Fit(trainTestDatasetManager.TrainDataset);
- var eval = model.Transform(trainTestDatasetManager.TestDataset);
+ var eval = model.Transform(trainTestDatasetManager.ValidateDataset);
var metric = _metricManager.Evaluate(_mLContext, eval);
stopWatch.Stop();
var loss = _metricManager.IsMaximize ? -metric : metric;