Skip to content

Commit 26913ac

Browse files
author
Pete Luferenko
committed
ML Context and a couple extensions
1 parent c016807 commit 26913ac

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+1426
-506
lines changed

docs/samples/Microsoft.ML.Samples/Trainers.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
// the alignment of the usings with the methods is intentional so they can display on the same level in the docs site.
66
using Microsoft.ML.Runtime.Data;
77
using Microsoft.ML.Runtime.Learners;
8-
using Microsoft.ML.Trainers;
8+
using Microsoft.ML.StaticPipe;
99
using System;
1010

1111
// NOTE: WHEN ADDING TO THE FILE, ALWAYS APPEND TO THE END OF IT.
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Microsoft.ML.Runtime
6+
{
7+
/// <summary>
8+
/// A catalog of operations to load and save data.
9+
/// </summary>
10+
public sealed class DataLoadSaveOperations
11+
{
12+
internal IHostEnvironment Environment { get; }
13+
14+
internal DataLoadSaveOperations(IHostEnvironment env)
15+
{
16+
Contracts.AssertValue(env);
17+
Environment = env;
18+
}
19+
}
20+
}

src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ public Column() { }
4646
public Column(string name, DataKind? type, int index)
4747
: this(name, type, new[] { new Range(index) }) { }
4848

49+
public Column(string name, DataKind? type, int minIndex, int maxIndex)
50+
: this(name, type, new[] { new Range(minIndex, maxIndex) })
51+
{
52+
}
53+
4954
public Column(string name, DataKind? type, Range[] source, KeyRange keyRange = null)
5055
{
5156
Contracts.CheckValue(name, nameof(name));
@@ -998,6 +1003,18 @@ private bool HasHeader
9981003
private readonly IHost _host;
9991004
private const string RegistrationName = "TextLoader";
10001005

1006+
public TextLoader(IHostEnvironment env, Column[] columns, Action<Arguments> advancedSettings, IMultiStreamSource dataSample = null)
1007+
: this(env, MakeArgs(columns, advancedSettings), dataSample)
1008+
{
1009+
}
1010+
1011+
private static Arguments MakeArgs(Column[] columns, Action<Arguments> advancedSettings)
1012+
{
1013+
var result = new Arguments { Column = columns };
1014+
advancedSettings?.Invoke(result);
1015+
return result;
1016+
}
1017+
10011018
public TextLoader(IHostEnvironment env, Arguments args, IMultiStreamSource dataSample = null)
10021019
{
10031020
Contracts.CheckValue(env, nameof(env));
@@ -1315,6 +1332,8 @@ public void Save(ModelSaveContext ctx)
13151332

13161333
public IDataView Read(IMultiStreamSource source) => new BoundLoader(this, source);
13171334

1335+
public IDataView Read(string path) => Read(new MultiFileSource(path));
1336+
13181337
private sealed class BoundLoader : IDataLoader
13191338
{
13201339
private readonly TextLoader _reader;
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Runtime;
6+
using Microsoft.ML.Runtime.Data;
7+
using Microsoft.ML.Runtime.Data.IO;
8+
using Microsoft.ML.Runtime.Internal.Utilities;
9+
using System;
10+
using System.Collections.Generic;
11+
using System.IO;
12+
using System.Linq;
13+
using System.Text;
14+
15+
namespace Microsoft.ML
16+
{
17+
public static class TextLoaderSaverCatalog
18+
{
19+
/// <summary>
20+
/// Create a text reader.
21+
/// </summary>
22+
/// <param name="catalog">The catalog.</param>
23+
/// <param name="args">The arguments to text reader, describing the data schema.</param>
24+
/// <param name="dataSample">The optional data sample</param>
25+
/// <returns></returns>
26+
public static TextLoader TextReader(this DataLoadSaveOperations catalog,
27+
TextLoader.Arguments args, IMultiStreamSource dataSample = null)
28+
=> new TextLoader(CatalogUtils.GetEnvironment(catalog), args, dataSample);
29+
30+
/// <summary>
31+
/// Create a text reader.
32+
/// </summary>
33+
/// <param name="catalog">The catalog.</param>
34+
/// <param name="columns">The columns of the schema.</param>
35+
/// <param name="advancedSettings">The delegate to set additional settings</param>
36+
/// <param name="dataSample">The optional data sample</param>
37+
/// <returns></returns>
38+
public static TextLoader TextReader(this DataLoadSaveOperations catalog,
39+
TextLoader.Column[] columns, Action<TextLoader.Arguments> advancedSettings = null, IMultiStreamSource dataSample = null)
40+
=> new TextLoader(CatalogUtils.GetEnvironment(catalog), columns, advancedSettings, dataSample);
41+
42+
/// <summary>
43+
/// Read a data view from a text file using <see cref="TextLoader"/>.
44+
/// </summary>
45+
/// <param name="catalog">The catalog.</param>
46+
/// <param name="columns">The columns of the schema.</param>
47+
/// <param name="advancedSettings">The delegate to set additional settings</param>
48+
/// <param name="path">The path to the file</param>
49+
/// <returns>The data view.</returns>
50+
public static IDataView ReadFromTextFile(this DataLoadSaveOperations catalog,
51+
TextLoader.Column[] columns, string path, Action<TextLoader.Arguments> advancedSettings = null)
52+
{
53+
Contracts.CheckNonEmpty(path, nameof(path));
54+
55+
var env = catalog.GetEnvironment();
56+
57+
// REVIEW: it is almost always a mistake to have a 'trainable' text loader here.
58+
// Therefore, we are going to disallow data sample.
59+
var reader = new TextLoader(env, columns, advancedSettings, dataSample: null);
60+
return reader.Read(new MultiFileSource(path));
61+
}
62+
63+
/// <summary>
64+
/// Save the data view as text.
65+
/// </summary>
66+
/// <param name="catalog">The catalog.</param>
67+
/// <param name="data">The data view to save.</param>
68+
/// <param name="stream">The stream to write to.</param>
69+
/// <param name="separator">The column separator.</param>
70+
/// <param name="headerRow">Whether to write the header row.</param>
71+
/// <param name="schema">Whether to write the header comment with the schema.</param>
72+
/// <param name="keepHidden">Whether to keep hidden columns in the dataset.</param>
73+
public static void SaveAsText(this DataLoadSaveOperations catalog, IDataView data, Stream stream,
74+
char separator = '\t', bool headerRow = true, bool schema = true, bool keepHidden = false)
75+
{
76+
Contracts.CheckValue(catalog, nameof(catalog));
77+
Contracts.CheckValue(data, nameof(data));
78+
Contracts.CheckValue(stream, nameof(stream));
79+
80+
var env = catalog.GetEnvironment();
81+
var saver = new TextSaver(env, new TextSaver.Arguments { Separator = separator.ToString(), OutputHeader = headerRow, OutputSchema = schema });
82+
83+
using (var ch = env.Start("Saving data"))
84+
DataSaverUtils.SaveDataView(ch, saver, data, stream, keepHidden);
85+
}
86+
}
87+
}

src/Microsoft.ML.Data/Evaluators/EvaluatorStaticExtensions.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ public static RegressionEvaluator.Result Evaluate<T>(
213213
/// <param name="score">The index delegate for predicted score column.</param>
214214
/// <returns>The evaluation metrics.</returns>
215215
public static RankerEvaluator.Result Evaluate<T, TVal>(
216-
this RankerContext ctx,
216+
this RankingContext ctx,
217217
DataView<T> data,
218218
Func<T, Scalar<float>> label,
219219
Func<T, Key<uint, TVal>> groupId,
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Core.Data;
6+
using Microsoft.ML.Runtime.Data;
7+
using System.IO;
8+
9+
namespace Microsoft.ML.Runtime
10+
{
11+
/// <summary>
12+
/// An object serving as a 'catalog' of available model operations.
13+
/// </summary>
14+
public sealed class ModelOperationsCatalog
15+
{
16+
internal IHostEnvironment Environment { get; }
17+
18+
internal ModelOperationsCatalog(IHostEnvironment env)
19+
{
20+
Contracts.AssertValue(env);
21+
Environment = env;
22+
}
23+
24+
/// <summary>
25+
/// Save the model to the stream.
26+
/// </summary>
27+
/// <param name="model">The trained model to be saved.</param>
28+
/// <param name="stream">A writeable, seekable stream to save to.</param>
29+
public void Save(ITransformer model, Stream stream) => model.SaveTo(Environment, stream);
30+
31+
/// <summary>
32+
/// Load the model from the stream.
33+
/// </summary>
34+
/// <param name="stream">A readable, seekable stream to load from.</param>
35+
/// <returns>The loaded model.</returns>
36+
public ITransformer Load(Stream stream) => TransformerChain.LoadFrom(Environment, stream);
37+
}
38+
}
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Runtime;
6+
using Microsoft.ML.Runtime.Data;
7+
using System;
8+
9+
namespace Microsoft.ML
10+
{
11+
/// <summary>
12+
/// The <see cref="MLContext"/> is a starting point for all ML.NET operations. It is instantiated by user,
13+
/// provides mechanisms for logging and entry points for training, prediction, model operations etc.
14+
/// </summary>
15+
public sealed class MLContext : IHostEnvironment
16+
{
17+
private readonly LocalEnvironment _env;
18+
19+
/// <summary>
20+
/// Trainers and tasks specific to binary classification problems.
21+
/// </summary>
22+
public BinaryClassificationContext BinaryClassification { get; }
23+
/// <summary>
24+
/// Trainers and tasks specific to multiclass classification problems.
25+
/// </summary>
26+
public MulticlassClassificationContext MulticlassClassification { get; }
27+
/// <summary>
28+
/// Trainers and tasks specific to regression problems.
29+
/// </summary>
30+
public RegressionContext Regression { get; }
31+
/// <summary>
32+
/// Trainers and tasks specific to clustering problems.
33+
/// </summary>
34+
public ClusteringContext Clustering { get; }
35+
/// <summary>
36+
/// Trainers and tasks specific to ranking problems.
37+
/// </summary>
38+
public RankingContext Ranking { get; }
39+
40+
/// <summary>
41+
/// Data processing operations.
42+
/// </summary>
43+
public TransformsCatalog Transform { get; }
44+
45+
/// <summary>
46+
/// Operations with trained models.
47+
/// </summary>
48+
public ModelOperationsCatalog Model { get; }
49+
50+
/// <summary>
51+
/// Data loading and saving.
52+
/// </summary>
53+
public DataLoadSaveOperations Data { get; }
54+
55+
// REVIEW: I think it's valuable to have the simplest possible interface for logging interception here,
56+
// and expand if and when necessary. Exposing classes like ChannelMessage, MessageSensitivity and so on
57+
// looks premature at this point.
58+
/// <summary>
59+
/// The handler for the log messages.
60+
/// </summary>
61+
public Action<string> Log { get; set; }
62+
63+
/// <summary>
64+
/// Create the ML context.
65+
/// </summary>
66+
/// <param name="seed">Random seed. Set to <c>null</c> for a non-deterministic environment.</param>
67+
/// <param name="conc">Concurrency level. Set to 1 to run single-threaded. Set to 0 to pick automatically.</param>
68+
public MLContext(int? seed = null, int conc = 0)
69+
{
70+
_env = new LocalEnvironment(seed, conc);
71+
_env.AddListener(ProcessMessage);
72+
73+
BinaryClassification = new BinaryClassificationContext(_env);
74+
MulticlassClassification = new MulticlassClassificationContext(_env);
75+
Regression = new RegressionContext(_env);
76+
Clustering = new ClusteringContext(_env);
77+
Ranking = new RankingContext(_env);
78+
Transform = new TransformsCatalog(_env);
79+
Model = new ModelOperationsCatalog(_env);
80+
Data = new DataLoadSaveOperations(_env);
81+
}
82+
83+
private void ProcessMessage(IMessageSource source, ChannelMessage message)
84+
{
85+
if (Log == null)
86+
return;
87+
88+
var msg = $"[Source={source.FullName}, Kind={message.Kind}] {message.Message}";
89+
// Log may have been reset from another thread.
90+
// We don't care which logger we send the message to, just making sure we don't crash.
91+
Log?.Invoke(msg);
92+
}
93+
94+
int IHostEnvironment.ConcurrencyFactor => _env.ConcurrencyFactor;
95+
bool IHostEnvironment.IsCancelled => _env.IsCancelled;
96+
ComponentCatalog IHostEnvironment.ComponentCatalog => _env.ComponentCatalog;
97+
string IExceptionContext.ContextDescription => _env.ContextDescription;
98+
IFileHandle IHostEnvironment.CreateOutputFile(string path) => _env.CreateOutputFile(path);
99+
IFileHandle IHostEnvironment.CreateTempFile(string suffix, string prefix) => _env.CreateTempFile(suffix, prefix);
100+
IFileHandle IHostEnvironment.OpenInputFile(string path) => _env.OpenInputFile(path);
101+
TException IExceptionContext.Process<TException>(TException ex) => _env.Process(ex);
102+
IHost IHostEnvironment.Register(string name, int? seed, bool? verbose, int? conc) => _env.Register(name, seed, verbose, conc);
103+
IChannel IChannelProvider.Start(string name) => _env.Start(name);
104+
IPipe<TMessage> IChannelProvider.StartPipe<TMessage>(string name) => _env.StartPipe<TMessage>(name);
105+
IProgressChannel IProgressChannelProvider.StartProgressChannel(string name) => _env.StartProgressChannel(name);
106+
}
107+
}

0 commit comments

Comments
 (0)