Skip to content

Commit 73beb28

Browse files
use seed from AutoMLExperiment.setting in eci_cfo tuner (#6406)
1 parent c69acbe commit 73beb28

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

src/Microsoft.ML.AutoML/Tuner/EciCfoTuner.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ public EciCostFrugalTuner(SweepablePipeline sweepablePipeline, AutoMLExperiment.
3232
_tuners = pipelineSchemas.ToDictionary(schema => schema, schema =>
3333
{
3434
var searchSpace = sweepablePipeline.BuildSweepableEstimatorPipeline(schema).SearchSpace;
35-
return new CostFrugalTuner(searchSpace, searchSpace.SampleFromFeatureSpace(searchSpace.Default)) as ITuner;
35+
return new CostFrugalTuner(searchSpace, searchSpace.SampleFromFeatureSpace(searchSpace.Default), seed: settings.Seed) as ITuner;
3636
});
3737

3838
if (trialResultManager != null)

test/Microsoft.ML.AutoML.Tests/AutoMLExperimentTests.cs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ public async Task AutoMLExperiment_Iris_CV_5_Test()
228228
var context = new MLContext(1);
229229
context.Log += (o, e) =>
230230
{
231-
if (e.Source.StartsWith("AutoMLExperiment"))
231+
if (e.RawMessage.Contains("Trial"))
232232
{
233233
this.Output.WriteLine(e.RawMessage);
234234
}
@@ -328,6 +328,20 @@ public async Task AutoMLExperiment_Taxi_Fare_CV_5_Test()
328328
var result = await experiment.RunAsync();
329329
result.Metric.Should().BeGreaterThan(0.5);
330330
}
331+
332+
[Fact]
333+
public void AutoMLExperiment_should_use_seed_from_context_if_provided()
334+
{
335+
var context = new MLContext();
336+
var experiment = context.Auto().CreateExperiment();
337+
var settings = experiment.ServiceCollection.BuildServiceProvider().GetRequiredService<AutoMLExperiment.AutoMLExperimentSettings>();
338+
settings.Seed.Should().BeNull();
339+
340+
context = new MLContext(1);
341+
experiment = context.Auto().CreateExperiment();
342+
settings = experiment.ServiceCollection.BuildServiceProvider().GetRequiredService<AutoMLExperiment.AutoMLExperimentSettings>();
343+
settings.Seed.Should().Be(1);
344+
}
331345
}
332346

333347
class DummyTrialRunner : ITrialRunner

0 commit comments

Comments
 (0)