Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions src/Microsoft.ML.AutoML/API/AutoMLExperimentExtension.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using System.Text.Json.Serialization;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.ML.Runtime;
using Microsoft.ML.SearchSpace.Option;
using Newtonsoft.Json;
using static Microsoft.ML.DataOperationsCatalog;

Expand All @@ -24,14 +25,18 @@ public static class AutoMLExperimentExtension
/// <param name="experiment"><see cref="AutoMLExperiment"/></param>
/// <param name="train">dataset for training a model.</param>
/// <param name="validation">dataset for validating a model during training.</param>
/// <param name="subSamplingTrainDataset">determine if subsampling <paramref name="train"/> to train. This will be useful if <paramref name="train"/> is too large to be held in memory.</param>
/// <returns><see cref="AutoMLExperiment"/></returns>
public static AutoMLExperiment SetDataset(this AutoMLExperiment experiment, IDataView train, IDataView validation)
public static AutoMLExperiment SetDataset(this AutoMLExperiment experiment, IDataView train, IDataView validation, bool subSamplingTrainDataset = false)
{
var datasetManager = new TrainValidateDatasetManager()
var datasetManager = new TrainValidateDatasetManager(train, validation);

if (subSamplingTrainDataset)
{
TrainDataset = train,
ValidateDataset = validation
};
var searchSpace = new SearchSpace.SearchSpace();
searchSpace.Add(datasetManager.SubSamplingKey, new UniformSingleOption(0, 1, false, 0.1f));
experiment.AddSearchSpace(nameof(TrainValidateDatasetManager), searchSpace);
}

experiment.ServiceCollection.AddSingleton<IDatasetManager>(datasetManager);
experiment.ServiceCollection.AddSingleton(datasetManager);
Expand Down Expand Up @@ -62,13 +67,7 @@ public static AutoMLExperiment SetDataset(this AutoMLExperiment experiment, Trai
/// <returns><see cref="AutoMLExperiment"/></returns>
public static AutoMLExperiment SetDataset(this AutoMLExperiment experiment, IDataView dataset, int fold = 10, string samplingKeyColumnName = null)
{
var datasetManager = new CrossValidateDatasetManager()
{
Dataset = dataset,
Fold = fold,
SamplingKeyColumnName = samplingKeyColumnName,
};

var datasetManager = new CrossValidateDatasetManager(dataset, fold, samplingKeyColumnName);
experiment.ServiceCollection.AddSingleton<IDatasetManager>(datasetManager);
experiment.ServiceCollection.AddSingleton(datasetManager);

Expand Down
6 changes: 3 additions & 3 deletions src/Microsoft.ML.AutoML/API/BinaryClassificationExperiment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ public TrialResult Run(TrialSettings settings)
{
var stopWatch = new Stopwatch();
stopWatch.Start();
var fold = datasetManager.Fold ?? 5;
var fold = datasetManager.Fold;
var metrics = _context.BinaryClassification.CrossValidateNonCalibrated(datasetManager.Dataset, pipeline, fold, metricManager.LabelColumn);

// now we just randomly pick a model, but a better way is to provide option to pick a model which score is the cloest to average or the best.
Expand Down Expand Up @@ -420,8 +420,8 @@ public TrialResult Run(TrialSettings settings)
{
var stopWatch = new Stopwatch();
stopWatch.Start();
var model = pipeline.Fit(trainTestDatasetManager.TrainDataset);
var eval = model.Transform(trainTestDatasetManager.ValidateDataset);
var model = pipeline.Fit(trainTestDatasetManager.LoadTrainDataset(_context, settings));
var eval = model.Transform(trainTestDatasetManager.LoadValidateDataset(_context, settings));
var metrics = _context.BinaryClassification.EvaluateNonCalibrated(eval, metricManager.LabelColumn, predictedLabelColumnName: metricManager.PredictedColumn);
var metric = GetMetric(metricManager.Metric, metrics);
var loss = metricManager.IsMaximize ? -metric : metric;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ public TrialResult Run(TrialSettings settings)
{
var stopWatch = new Stopwatch();
stopWatch.Start();
var fold = datasetManager.Fold ?? 5;
var fold = datasetManager.Fold;
var metrics = _context.MulticlassClassification.CrossValidate(datasetManager.Dataset, pipeline, fold, metricManager.LabelColumn);

// now we just randomly pick a model, but a better way is to provide option to pick a model which score is the cloest to average or the best.
Expand Down Expand Up @@ -398,8 +398,8 @@ public TrialResult Run(TrialSettings settings)
{
var stopWatch = new Stopwatch();
stopWatch.Start();
var model = pipeline.Fit(trainTestDatasetManager.TrainDataset);
var eval = model.Transform(trainTestDatasetManager.ValidateDataset);
var model = pipeline.Fit(trainTestDatasetManager.LoadTrainDataset(_context, settings));
var eval = model.Transform(trainTestDatasetManager.LoadValidateDataset(_context, settings));
var metrics = _context.MulticlassClassification.Evaluate(eval, metricManager.LabelColumn, predictedLabelColumnName: metricManager.PredictedColumn);
var metric = GetMetric(metricManager.Metric, metrics);
var loss = metricManager.IsMaximize ? -metric : metric;
Expand Down
6 changes: 3 additions & 3 deletions src/Microsoft.ML.AutoML/API/RegressionExperiment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
{
var stopWatch = new Stopwatch();
stopWatch.Start();
var fold = datasetManager.Fold ?? 5;
var fold = datasetManager.Fold;
var metrics = _context.Regression.CrossValidate(datasetManager.Dataset, pipeline, fold, metricManager.LabelColumn);

// now we just randomly pick a model, but a better way is to provide option to pick a model which score is the cloest to average or the best.
Expand Down Expand Up @@ -425,8 +425,8 @@ public Task<TrialResult> RunAsync(TrialSettings settings, CancellationToken ct)
{
var stopWatch = new Stopwatch();
stopWatch.Start();
var model = pipeline.Fit(trainTestDatasetManager.TrainDataset);
var eval = model.Transform(trainTestDatasetManager.ValidateDataset);
var model = pipeline.Fit(trainTestDatasetManager.LoadTrainDataset(_context, settings));
var eval = model.Transform(trainTestDatasetManager.LoadValidateDataset(_context, settings));
var metrics = _context.Regression.Evaluate(eval, metricManager.LabelColumn, scoreColumnName: metricManager.ScoreColumn);
var metric = GetMetric(metricManager.Metric, metrics);
var loss = metricManager.IsMaximize ? -metric : metric;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ Abandoning Trial {trialSettings.TrialId} and continue training.
trialResultManager?.AddOrUpdateTrialResult(trialResult);
aggregateTrainingStopManager.Update(trialResult);

if (ex is not OperationCanceledException && _bestTrialResult == null)
if (ex is not OperationCanceledException && ex is not OutOfMemoryException && _bestTrialResult == null)
{
logger.Trace($"trial fatal error - {JsonSerializer.Serialize(trialSettings)}, stop training");

Expand Down
91 changes: 81 additions & 10 deletions src/Microsoft.ML.AutoML/AutoMLExperiment/IDatasetManager.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
#nullable enable

using Microsoft.ML.SearchSpace;

namespace Microsoft.ML.AutoML
{
Expand All @@ -12,34 +15,102 @@ public interface IDatasetManager
{
}

internal interface ICrossValidateDatasetManager
/// <summary>
/// Inferface for cross validate dataset manager.
/// </summary>
public interface ICrossValidateDatasetManager : IDatasetManager
{
int? Fold { get; set; }
/// <summary>
/// Cross validate fold.
/// </summary>
int Fold { get; set; }

/// <summary>
/// The dataset to cross validate.
/// </summary>
IDataView Dataset { get; set; }

string SamplingKeyColumnName { get; set; }
/// <summary>
/// The dataset column used for grouping rows.
/// </summary>
string? SamplingKeyColumnName { get; set; }
}

internal interface ITrainValidateDatasetManager
public interface ITrainValidateDatasetManager : IDatasetManager
{
IDataView TrainDataset { get; set; }
IDataView LoadTrainDataset(MLContext context, TrialSettings? settings);

IDataView ValidateDataset { get; set; }
IDataView LoadValidateDataset(MLContext context, TrialSettings? settings);
}

internal class TrainValidateDatasetManager : IDatasetManager, ITrainValidateDatasetManager
{
public IDataView TrainDataset { get; set; }
private ulong _rowCount;
private IDataView _trainDataset;
private readonly IDataView _validateDataset;
private readonly string _subSamplingKey = "TrainValidateDatasetSubsamplingKey";
private bool _isInitialized = false;
public TrainValidateDatasetManager(IDataView trainDataset, IDataView validateDataset, string? subSamplingKey = null)
{
_trainDataset = trainDataset;
_validateDataset = validateDataset;
_subSamplingKey = subSamplingKey ?? _subSamplingKey;
}

public string SubSamplingKey => _subSamplingKey;

/// <summary>
/// Load Train Dataset. If <see cref="TrialSettings.Parameter"/> contains <see cref="_subSamplingKey"/> then the train dataset will be subsampled.
/// </summary>
/// <param name="context">MLContext.</param>
/// <param name="settings">trial settings. If null, return entire train dataset.</param>
/// <returns>train dataset.</returns>
public IDataView LoadTrainDataset(MLContext context, TrialSettings? settings)
{
if (!_isInitialized)
{
InitializeTrainDataset(context);
_isInitialized = true;
}
var trainTestSplitParameter = settings?.Parameter.ContainsKey(nameof(TrainValidateDatasetManager)) is true ? settings.Parameter[nameof(TrainValidateDatasetManager)] : null;
if (trainTestSplitParameter is Parameter parameter)
{
var subSampleRatio = parameter.ContainsKey(_subSamplingKey) ? parameter[_subSamplingKey].AsType<double>() : 1;
if (subSampleRatio < 1.0)
{
var subSampledTrainDataset = context.Data.TakeRows(_trainDataset, (long)(subSampleRatio * _rowCount));
return subSampledTrainDataset;
}
}

public IDataView ValidateDataset { get; set; }
return _trainDataset;
}

public IDataView LoadValidateDataset(MLContext context, TrialSettings? settings)
{
return _validateDataset;
}

private void InitializeTrainDataset(MLContext context)
{
_rowCount = DatasetDimensionsUtil.CountRows(_trainDataset, ulong.MaxValue);
_trainDataset = context.Data.ShuffleRows(_trainDataset);
}
}

internal class CrossValidateDatasetManager : IDatasetManager, ICrossValidateDatasetManager
{
public CrossValidateDatasetManager(IDataView dataset, int fold, string? samplingKeyColumnName = null)
{
Dataset = dataset;
Fold = fold;
SamplingKeyColumnName = samplingKeyColumnName;
}

public IDataView Dataset { get; set; }

public int? Fold { get; set; }
public string SamplingKeyColumnName { get; set; }
public int Fold { get; set; }

public string? SamplingKeyColumnName { get; set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public TrialResult Run(TrialSettings settings)
var mlnetPipeline = _pipeline.BuildFromOption(_mLContext, parameter);
if (_datasetManager is ICrossValidateDatasetManager crossValidateDatasetManager)
{
var datasetSplit = _mLContext!.Data.CrossValidationSplit(crossValidateDatasetManager.Dataset, crossValidateDatasetManager.Fold ?? 5, crossValidateDatasetManager.SamplingKeyColumnName);
var datasetSplit = _mLContext!.Data.CrossValidationSplit(crossValidateDatasetManager.Dataset, crossValidateDatasetManager.Fold, crossValidateDatasetManager.SamplingKeyColumnName);
var metrics = new List<double>();
var models = new List<ITransformer>();
foreach (var split in datasetSplit)
Expand Down Expand Up @@ -68,8 +68,8 @@ public TrialResult Run(TrialSettings settings)

if (_datasetManager is ITrainValidateDatasetManager trainTestDatasetManager)
{
var model = mlnetPipeline.Fit(trainTestDatasetManager.TrainDataset);
var eval = model.Transform(trainTestDatasetManager.ValidateDataset);
var model = mlnetPipeline.Fit(trainTestDatasetManager.LoadTrainDataset(_mLContext!, settings));
var eval = model.Transform(trainTestDatasetManager.LoadValidateDataset(_mLContext!, settings));
var metric = _metricManager.Evaluate(_mLContext, eval);
stopWatch.Stop();
var loss = _metricManager.IsMaximize ? -metric : metric;
Expand Down
10 changes: 4 additions & 6 deletions src/Microsoft.ML.AutoML/Tuner/EciCfoTuner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ public EciCostFrugalTuner(SweepablePipeline sweepablePipeline, AutoMLExperiment.
_tuners = pipelineSchemas.ToDictionary(schema => schema, schema =>
{
var searchSpace = sweepablePipeline.BuildSweepableEstimatorPipeline(schema).SearchSpace;
return new CostFrugalTuner(searchSpace, searchSpace.SampleFromFeatureSpace(searchSpace.Default), seed: settings.Seed) as ITuner;
var aggregateSearchSpace = new SearchSpace.SearchSpace(settings.SearchSpace);
aggregateSearchSpace[AutoMLExperiment.PipelineSearchspaceName] = searchSpace;
return new CostFrugalTuner(aggregateSearchSpace, aggregateSearchSpace.SampleFromFeatureSpace(aggregateSearchSpace.Default), seed: settings.Seed) as ITuner;
});

if (trialResultManager != null)
Expand All @@ -57,22 +59,18 @@ public Parameter Propose(TrialSettings settings)
parameter[k.Key] = _defaultParameter[k.Key];
}
}
settings.Parameter[AutoMLExperiment.PipelineSearchspaceName] = parameter;
settings.Parameter = parameter;

return settings.Parameter;
}

public void Update(TrialResult result)
{
var originalParameter = result.TrialSettings.Parameter;
var schema = result.TrialSettings.Parameter[AutoMLExperiment.PipelineSearchspaceName]["_SCHEMA_"].AsType<string>();
_pipelineProposer.Update(result, schema);
if (_tuners.TryGetValue(schema, out var tuner))
{
var parameter = result.TrialSettings.Parameter[AutoMLExperiment.PipelineSearchspaceName];
result.TrialSettings.Parameter = parameter;
tuner.Update(result);
result.TrialSettings.Parameter = originalParameter;
}
}
}
Expand Down
15 changes: 11 additions & 4 deletions src/Microsoft.ML.Fairlearn/AutoML/AutoMLExperimentExtension.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using Microsoft.Extensions.DependencyInjection;
using Microsoft.ML.AutoML;
using Microsoft.ML.Data;
using Microsoft.ML.SearchSpace;

namespace Microsoft.ML.Fairlearn.AutoML
{
Expand Down Expand Up @@ -55,9 +56,14 @@ public static AutoMLExperiment SetBinaryClassificationMetricWithFairLearn(
{
var datasetManager = serviceProvider.GetRequiredService<TrainValidateDatasetManager>();
var moment = new UtilityParity();
var sensitiveFeature = DataFrameColumn.Create("group_id", datasetManager.TrainDataset.GetColumn<string>(sensitiveColumnName));
var label = DataFrameColumn.Create("label", datasetManager.TrainDataset.GetColumn<bool>(labelColumn));
moment.LoadData(datasetManager.TrainDataset, label, sensitiveFeature);
var context = serviceProvider.GetRequiredService<MLContext>();
var trainData = datasetManager.LoadTrainDataset(context, new TrialSettings
{
Parameter = Parameter.CreateNestedParameter(),
});
var sensitiveFeature = DataFrameColumn.Create("group_id", trainData.GetColumn<string>(sensitiveColumnName));
var label = DataFrameColumn.Create("label", trainData.GetColumn<bool>(labelColumn));
moment.LoadData(trainData, label, sensitiveFeature);
var lambdaSearchSpace = Utilities.GenerateBinaryClassificationLambdaSearchSpace(moment, gridLimit, negativeAllowed);
experiment.AddSearchSpace("_lambda_search_space", lambdaSearchSpace);

Expand All @@ -70,8 +76,9 @@ public static AutoMLExperiment SetBinaryClassificationMetricWithFairLearn(
var moment = serviceProvider.GetRequiredService<ClassificationMoment>();
var datasetManager = serviceProvider.GetRequiredService<TrainValidateDatasetManager>();
var pipeline = serviceProvider.GetRequiredService<SweepablePipeline>();
return new GridSearchTrailRunner(context, datasetManager.TrainDataset, datasetManager.ValidateDataset, labelColumn, sensitiveColumnName, pipeline, moment);
return new GridSearchTrailRunner(context, datasetManager, labelColumn, sensitiveColumnName, pipeline, moment);
});

experiment.SetRandomSearchTuner();

return experiment;
Expand Down
Loading