Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/Microsoft.ML.AutoML/Tuner/SmacTuner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ public Parameter Propose(TrialSettings settings)
}
}

// test purpose
internal Queue<Parameter> Candidates => _candidates;

// test purpose
internal List<TrialResult> Histories => _histories;

private FastForestRegressionModelParameters FitModel(IEnumerable<TrialResult> history)
{
Single[] losses = new Single[history.Count()];
Expand Down Expand Up @@ -357,7 +363,10 @@ private double ComputeEI(double bestLoss, double[] forestStatistics)

public void Update(TrialResult result)
{
_histories.Add(result);
if (!double.IsNaN(result.Loss) && !double.IsInfinity(result.Loss))
{
_histories.Add(result);
}
}
}
}
33 changes: 33 additions & 0 deletions test/Microsoft.ML.AutoML.Tests/TunerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,39 @@ public void tuner_e2e_test()
}
}

[Fact]
public void Smac_should_ignore_fail_trials_during_initialize()
{
// fix for https://github.com/dotnet/machinelearning-modelbuilder/issues/2721
var context = new MLContext(1);
var searchSpace = new SearchSpace<LbfgsOption>();
var tuner = new SmacTuner(context, searchSpace, seed: 1);
for (int i = 0; i != 1000; ++i)
{
var trialSettings = new TrialSettings()
{
TrialId = i,
};

var param = tuner.Propose(trialSettings);
trialSettings.Parameter = param;
var option = param.AsType<LbfgsOption>();

option.L1Regularization.Should().BeInRange(0.03125f, 32768.0f);
option.L2Regularization.Should().BeInRange(0.03125f, 32768.0f);

tuner.Update(new TrialResult()
{
DurationInMilliseconds = i * 1000,
Loss = double.NaN,
TrialSettings = trialSettings,
});
}

tuner.Candidates.Count.Should().Be(0);
tuner.Histories.Count.Should().Be(0);
}

[Fact]
public void CFO_should_be_recoverd_if_history_provided()
{
Expand Down