Skip to content

Commit 8fd2aa8

Browse files
authored
Add estimator to public API iteration result (dotnet#248)
1 parent 3f77e59 commit 8fd2aa8

File tree

6 files changed

+26
-21
lines changed

6 files changed

+26
-21
lines changed

src/Microsoft.ML.Auto/API/RunResult.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@ public sealed class RunResult<T>
1414
public Exception Exception { get; private set; }
1515
public string TrainerName { get; private set; }
1616
public int RuntimeInSeconds { get; private set; }
17+
public IEstimator<ITransformer> Estimator { get; private set; }
1718

1819
internal Pipeline Pipeline { get; private set; }
1920
internal int PipelineInferenceTimeInSeconds { get; private set; }
2021

2122
internal RunResult(
2223
ITransformer model,
2324
T metrics,
25+
IEstimator<ITransformer> estimator,
2426
Pipeline pipeline,
2527
Exception exception,
2628
int runtimeInSeconds,
@@ -29,6 +31,7 @@ internal RunResult(
2931
Model = model;
3032
ValidationMetrics = metrics;
3133
Pipeline = pipeline;
34+
Estimator = estimator;
3235
Exception = exception;
3336
RuntimeInSeconds = runtimeInSeconds;
3437
PipelineInferenceTimeInSeconds = pipelineInferenceTimeInSeconds;

src/Microsoft.ML.Auto/Experiment/Experiment.cs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,9 @@ public List<RunResult<T>> Execute()
9797
// evaluate pipeline
9898
runResult = ProcessPipeline(pipeline);
9999

100-
if (preprocessorTransform != null)
100+
if (_preFeaturizers != null)
101101
{
102+
runResult.Estimator = _preFeaturizers.Append(runResult.Estimator);
102103
runResult.Model = preprocessorTransform.Append(runResult.Model);
103104
}
104105

@@ -108,7 +109,7 @@ public List<RunResult<T>> Execute()
108109
catch (Exception ex)
109110
{
110111
WriteDebugLog(DebugStream.Exception, $"{pipeline?.Trainer} Crashed {ex}");
111-
runResult = new SuggestedPipelineResult<T>(null, null, pipeline, -1, ex);
112+
runResult = new SuggestedPipelineResult<T>(null, null, null, pipeline, -1, ex);
112113
}
113114

114115
var iterationResult = runResult.ToIterationResult();
@@ -149,19 +150,22 @@ private SuggestedPipelineResult<T> ProcessPipeline(SuggestedPipeline pipeline)
149150

150151
WriteDebugLog(DebugStream.RunResult, $"Processing pipeline {commandLineStr}.");
151152

153+
var pipelineEstimator = pipeline.ToEstimator();
154+
152155
SuggestedPipelineResult<T> runResult;
156+
153157
try
154158
{
155-
var pipelineModel = pipeline.Fit(_trainData);
159+
var pipelineModel = pipelineEstimator.Fit(_trainData);
156160
var scoredValidationData = pipelineModel.Transform(_validationData);
157161
var metrics = GetEvaluatedMetrics(scoredValidationData);
158162
var score = _metricsAgent.GetScore(metrics);
159-
runResult = new SuggestedPipelineResult<T>(metrics, pipelineModel, pipeline, score, null);
163+
runResult = new SuggestedPipelineResult<T>(metrics, pipelineEstimator, pipelineModel, pipeline, score, null);
160164
}
161165
catch(Exception ex)
162166
{
163167
WriteDebugLog(DebugStream.Exception, $"{pipeline.Trainer} Crashed {ex}");
164-
runResult = new SuggestedPipelineResult<T>(null, null, pipeline, 0, ex);
168+
runResult = new SuggestedPipelineResult<T>(null, pipelineEstimator, null, pipeline, 0, ex);
165169
}
166170

167171
// save pipeline run

src/Microsoft.ML.Auto/Experiment/SuggestedPipeline.cs

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,6 @@ public IEstimator<ITransformer> ToEstimator()
113113
return pipeline;
114114
}
115115

116-
public ITransformer Fit(IDataView trainData)
117-
{
118-
var estimator = ToEstimator();
119-
return estimator.Fit(trainData);
120-
}
121-
122116
private void AddNormalizationTransforms()
123117
{
124118
// get learner

src/Microsoft.ML.Auto/Experiment/SuggestedPipelineResult.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,23 +33,26 @@ public IRunResult ToRunResult(bool isMetricMaximizing)
3333
internal class SuggestedPipelineResult<T> : SuggestedPipelineResult
3434
{
3535
public readonly T EvaluatedMetrics;
36+
public IEstimator<ITransformer> Estimator { get; set; }
3637
public ITransformer Model { get; set; }
3738
public Exception Exception { get; set; }
3839

3940
public int RuntimeInSeconds { get; set; }
4041
public int PipelineInferenceTimeInSeconds { get; set; }
4142

42-
public SuggestedPipelineResult(T evaluatedMetrics, ITransformer model, SuggestedPipeline pipeline, double score, Exception exception)
43+
public SuggestedPipelineResult(T evaluatedMetrics, IEstimator<ITransformer> estimator,
44+
ITransformer model, SuggestedPipeline pipeline, double score, Exception exception)
4345
: base(pipeline, score, exception == null)
4446
{
4547
EvaluatedMetrics = evaluatedMetrics;
48+
Estimator = estimator;
4649
Model = model;
4750
Exception = exception;
4851
}
4952

5053
public RunResult<T> ToIterationResult()
5154
{
52-
return new RunResult<T>(Model, EvaluatedMetrics, Pipeline.ToPipeline(), Exception, RuntimeInSeconds, PipelineInferenceTimeInSeconds);
55+
return new RunResult<T>(Model, EvaluatedMetrics, Estimator, Pipeline.ToPipeline(), Exception, RuntimeInSeconds, PipelineInferenceTimeInSeconds);
5356
}
5457
}
5558
}

src/Test/AutoFitTests.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@ public void AutoFitBinaryTest()
2020
var trainData = textLoader.Load(dataPath);
2121
var validationData = context.Data.TakeRows(trainData, 100);
2222
trainData = context.Data.SkipRows(trainData, 100);
23-
var result = context.Auto()
23+
var results = context.Auto()
2424
.CreateBinaryClassificationExperiment(0)
2525
.Execute(trainData, validationData, new ColumnInformation() { LabelColumn = DatasetUtil.UciAdultLabel });
26-
27-
Assert.IsTrue(result.Max(i => i.ValidationMetrics.Accuracy) > 0.80);
26+
var best = results.Best();
27+
Assert.IsTrue(best.ValidationMetrics.Accuracy > 0.80);
28+
Assert.IsNotNull(best.Estimator);
2829
}
2930

3031
[TestMethod]

src/Test/RunResultTests.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ public void FindBestResultWithSomeNullMetrics()
2020

2121
var runResults = new List<RunResult<RegressionMetrics>>()
2222
{
23-
new RunResult<RegressionMetrics>(null, null, null, null, 0, 0),
24-
new RunResult<RegressionMetrics>(null, metrics1, null, null, 0, 0),
25-
new RunResult<RegressionMetrics>(null, metrics2, null, null, 0, 0),
26-
new RunResult<RegressionMetrics>(null, metrics3, null, null, 0, 0),
23+
new RunResult<RegressionMetrics>(null, null, null, null, null, 0, 0),
24+
new RunResult<RegressionMetrics>(null, metrics1, null, null, null, 0, 0),
25+
new RunResult<RegressionMetrics>(null, metrics2, null, null, null, 0, 0),
26+
new RunResult<RegressionMetrics>(null, metrics3, null, null, null, 0, 0),
2727
};
2828

2929
var metricsAgent = new RegressionMetricsAgent(RegressionMetric.RSquared);
@@ -36,7 +36,7 @@ public void FindBestResultWithAllNullMetrics()
3636
{
3737
var runResults = new List<RunResult<RegressionMetrics>>()
3838
{
39-
new RunResult<RegressionMetrics>(null, null, null, null, 0, 0),
39+
new RunResult<RegressionMetrics>(null, null, null, null, null, 0, 0),
4040
};
4141

4242
var metricsAgent = new RegressionMetricsAgent(RegressionMetric.RSquared);

0 commit comments

Comments
 (0)