Skip to content

Commit 75fc055

Browse files
authored
Remove model saving/loading inconsistencies (#3044)
* Change the model load/save API to always have ITransformer as central object. * Keep the with loader save order the same as with schema overload, with ITransformer always first. * Change ModelLoadingTests to use the MLContext of its new base class.
1 parent b5a8d99 commit 75fc055

File tree

7 files changed

+268
-221
lines changed

7 files changed

+268
-221
lines changed

docs/code/MlNetCookBook.md

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -383,19 +383,18 @@ var metrics = mlContext.Regression.Evaluate(model.Transform(testData), labelColu
383383

384384
Assuming that the model metrics look good to you, it's time to 'operationalize' the model. This is where ML.NET really shines: the `model` object you just built is ready for immediate consumption, it will apply all the same steps that it has 'learned' during training, and it can be persisted and reused in different environments.
385385

386-
Here's what you do to save the model to a file, and reload it (potentially in a different context).
386+
Here's what you do to save the model as well as its input schema to a file, and reload it (potentially in a different context).
387387

388388
```csharp
389-
using (var stream = File.Create(modelPath))
390-
{
391-
mlContext.Model.Save(model, stream);
392-
}
389+
// Saving and loading happens to transformers. We save the input schema with this model.
390+
mlContext.Model.Save(model, trainData.Schema, modelPath);
393391

394392
// Potentially, the lines below can be in a different process altogether.
395-
ITransformer loadedModel;
396-
using (var stream = File.OpenRead(modelPath))
397-
loadedModel = mlContext.Model.Load(stream);
393+
// When you load the model, it's a non-specific ITransformer. We also recover
394+
// the original schema.
395+
ITransformer loadedModel = mlContext.Model.Load(modelPath, out var schema);
398396
```
397+
399398
## How do I use the model to make one prediction?
400399

401400
Since any ML.NET model is a transformer, you can of course use `model.Transform` to apply the model to the 'data view' and obtain predictions this way.
@@ -1018,7 +1017,5 @@ using (var fs = File.Create(modelPath))
10181017
newContext.ComponentCatalog.RegisterAssembly(typeof(CustomMappings).Assembly);
10191018

10201019
// Now we can load the model.
1021-
ITransformer loadedModel;
1022-
using (var fs = File.OpenRead(modelPath))
1023-
loadedModel = newContext.Model.Load(fs);
1020+
ITransformer loadedModel = newContext.Model.Load(modelPath, out var schema);
10241021
```

docs/code/experimental/MlNetCookBookStaticApi.md

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -396,18 +396,13 @@ This is where ML.NET really shines: the `model` object you just built is ready f
396396
Here's what you do to save the model to a file, and reload it (potentially in a different context).
397397

398398
```csharp
399-
using (var stream = File.Create(modelPath))
400-
{
401-
// Saving and loading happens to 'dynamic' models, so the static typing is lost in the process.
402-
mlContext.Model.Save(model.AsDynamic, stream);
403-
}
399+
// Saving and loading happens to 'dynamic' models, so the static typing is lost in the process.
400+
mlContext.Model.Save(model.AsDynamic, trainData.AsDynamic.Schema, modelPath);
404401

405402
// Potentially, the lines below can be in a different process altogether.
406403
407404
// When you load the model, it's a 'dynamic' transformer.
408-
ITransformer loadedModel;
409-
using (var stream = File.OpenRead(modelPath))
410-
loadedModel = mlContext.Model.Load(stream);
405+
ITransformer loadedModel = mlContext.Model.Load(modelPath, out var schema);
411406
```
412407

413408
## How do I use the model to make one prediction?

src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs

Lines changed: 109 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -29,63 +29,60 @@ internal ModelOperationsCatalog(IHostEnvironment env)
2929
}
3030

3131
/// <summary>
32-
/// Save the model to the stream.
32+
/// Save a transformer model and the loader used to create its input data to the stream.
3333
/// </summary>
34-
/// <param name="model">The trained model to be saved.</param>
34+
/// <param name="model">The trained model to be saved. Note that this can be <see langword="null"/>, as a shorthand
35+
/// for an empty transformer chain. Upon loading with <see cref="LoadWithDataLoader(Stream, out IDataLoader{IMultiStreamSource})"/>
36+
/// the returned value will be an empty <see cref="TransformerChain{TLastTransformer}"/>.</param>
37+
/// <param name="loader">The loader that was used to create data to train the model.</param>
3538
/// <param name="stream">A writeable, seekable stream to save to.</param>
36-
public void Save<TSource>(IDataLoader<TSource> model, Stream stream)
39+
public void Save<TSource>(ITransformer model, IDataLoader<TSource> loader, Stream stream)
3740
{
38-
_env.CheckValue(model, nameof(model));
41+
_env.CheckValue(loader, nameof(loader));
42+
_env.CheckValueOrNull(model);
3943
_env.CheckValue(stream, nameof(stream));
4044

45+
// For the sake of consistency of this API specifically, when called upon we save any transformer
46+
// in a single element transformer chain.
47+
var chainedModel = model == null ? null : new TransformerChain<ITransformer>(model);
48+
var compositeLoader = new CompositeDataLoader<TSource, ITransformer>(loader, chainedModel);
49+
4150
using (var rep = RepositoryWriter.CreateNew(stream))
4251
{
43-
ModelSaveContext.SaveModel(rep, model, null);
52+
ModelSaveContext.SaveModel(rep, compositeLoader, null);
4453
rep.Commit();
4554
}
4655
}
4756

48-
/// <summary>
49-
/// Save the model to the file.
50-
/// </summary>
51-
/// <param name="model">The trained model to be saved.</param>
52-
/// <param name="filePath">Path where model should be saved.</param>
53-
public void Save<TSource>(IDataLoader<TSource> model, string filePath)
54-
{
55-
using (var stream = File.Create(filePath))
56-
Save(model, stream);
57-
}
58-
59-
/// <summary>
60-
/// Save a transformer model and the loader used to create its input data to the stream.
61-
/// </summary>
62-
/// <param name="loader">The loader that was used to create data to train the model</param>
63-
/// <param name="model">The trained model to be saved</param>
64-
/// <param name="stream">A writeable, seekable stream to save to.</param>
65-
public void Save<TSource>(IDataLoader<TSource> loader, ITransformer model, Stream stream) =>
66-
Save(new CompositeDataLoader<TSource, ITransformer>(loader, new TransformerChain<ITransformer>(model)), stream);
67-
6857
/// <summary>
6958
/// Save a transformer model and the loader used to create its input data to the file.
7059
/// </summary>
71-
/// <param name="loader">The loader that was used to create data to train the model</param>
72-
/// <param name="model">The trained model to be saved</param>
60+
/// <param name="model">The trained model to be saved. Note that this can be <see langword="null"/>, as a shorthand
61+
/// for an empty transformer chain. Upon loading with <see cref="LoadWithDataLoader(Stream, out IDataLoader{IMultiStreamSource})"/>
62+
/// the returned value will be an empty <see cref="TransformerChain{TLastTransformer}"/>.</param>
63+
/// <param name="loader">The loader that was used to create data to train the model.</param>
7364
/// <param name="filePath">Path where model should be saved.</param>
74-
public void Save<TSource>(IDataLoader<TSource> loader, ITransformer model, string filePath)
65+
public void Save<TSource>(ITransformer model, IDataLoader<TSource> loader, string filePath)
7566
{
67+
_env.CheckValueOrNull(model);
68+
_env.CheckValue(loader, nameof(loader));
69+
_env.CheckNonEmpty(filePath, nameof(filePath));
70+
7671
using (var stream = File.Create(filePath))
77-
Save(loader, model, stream);
72+
Save(model, loader, stream);
7873
}
7974

8075
/// <summary>
8176
/// Save a transformer model and the schema of the data that was used to train it to the stream.
8277
/// </summary>
83-
/// <param name="model">The trained model to be saved.</param>
84-
/// <param name="inputSchema">The schema of the input to the transformer. This can be null.</param>
78+
/// <param name="model">The trained model to be saved. Note that this can be <see langword="null"/>, as a shorthand
79+
/// for an empty transformer chain. Upon loading with <see cref="Load(Stream, out DataViewSchema)"/> the returned value will
80+
/// be an empty <see cref="TransformerChain{TLastTransformer}"/>.</param>
81+
/// <param name="inputSchema">The schema of the input to the transformer. This can be <see langword="null"/>.</param>
8582
/// <param name="stream">A writeable, seekable stream to save to.</param>
8683
public void Save(ITransformer model, DataViewSchema inputSchema, Stream stream)
8784
{
88-
_env.CheckValue(model, nameof(model));
85+
_env.CheckValueOrNull(model);
8986
_env.CheckValueOrNull(inputSchema);
9087
_env.CheckValue(stream, nameof(stream));
9188

@@ -100,11 +97,17 @@ public void Save(ITransformer model, DataViewSchema inputSchema, Stream stream)
10097
/// <summary>
10198
/// Save a transformer model and the schema of the data that was used to train it to the file.
10299
/// </summary>
103-
/// <param name="model">The trained model to be saved.</param>
104-
/// <param name="inputSchema">The schema of the input to the transformer. This can be null.</param>
100+
/// <param name="model">The trained model to be saved. Note that this can be <see langword="null"/>, as a shorthand
101+
/// for an empty transformer chain. Upon loading with <see cref="Load(Stream, out DataViewSchema)"/> the returned value will
102+
/// be an empty <see cref="TransformerChain{TLastTransformer}"/>.</param>
103+
/// <param name="inputSchema">The schema of the input to the transformer. This can be <see langword="null"/>.</param>
105104
/// <param name="filePath">Path where model should be saved.</param>
106105
public void Save(ITransformer model, DataViewSchema inputSchema, string filePath)
107106
{
107+
_env.CheckValueOrNull(model);
108+
_env.CheckValueOrNull(inputSchema);
109+
_env.CheckNonEmpty(filePath, nameof(filePath));
110+
108111
using (var stream = File.Create(filePath))
109112
Save(model, inputSchema, stream);
110113
}
@@ -126,11 +129,11 @@ private void SaveInputSchema(DataViewSchema inputSchema, RepositoryWriter rep)
126129
}
127130

128131
/// <summary>
129-
/// Load the model and its input schema from the stream.
132+
/// Load the model and its input schema from a stream.
130133
/// </summary>
131134
/// <param name="stream">A readable, seekable stream to load from.</param>
132-
/// <param name="inputSchema">Will contain the input schema for the model. If the model was saved using older APIs
133-
/// it may not contain an input schema, in this case <paramref name="inputSchema"/> will be null.</param>
135+
/// <param name="inputSchema">Will contain the input schema for the model. If the model was saved without
136+
/// any description of the input, there will be no input schema. In this case this can be <see langword="null"/>.</param>
134137
/// <returns>The loaded model.</returns>
135138
public ITransformer Load(Stream stream, out DataViewSchema inputSchema)
136139
{
@@ -171,57 +174,100 @@ public ITransformer Load(Stream stream, out DataViewSchema inputSchema)
171174
throw _env.Except(ex, "Could not load legacy format model");
172175
}
173176
}
174-
if (dataLoader is CompositeDataLoader<IMultiStreamSource, ITransformer> composite)
175-
{
176-
inputSchema = composite.Loader.GetOutputSchema();
177-
return composite.Transformer;
178-
}
177+
var transformer = DecomposeLoader(ref dataLoader);
179178
inputSchema = dataLoader.GetOutputSchema();
180-
return new TransformerChain<ITransformer>();
179+
return transformer;
180+
}
181+
}
182+
183+
/// <summary>
184+
/// Load the model and its input schema from a file.
185+
/// </summary>
186+
/// <param name="filePath">Path to a file where the model should be read from.</param>
187+
/// <param name="inputSchema">Will contain the input schema for the model. If the model was saved without
188+
/// any description of the input, there will be no input schema. In this case this can be <see langword="null"/>.</param>
189+
/// <returns>The loaded model.</returns>
190+
public ITransformer Load(string filePath, out DataViewSchema inputSchema)
191+
{
192+
_env.CheckNonEmpty(filePath, nameof(filePath));
193+
194+
using (var stream = File.OpenRead(filePath))
195+
return Load(stream, out inputSchema);
196+
}
197+
198+
/// <summary>
199+
/// Given a loader, test try to "decompose" it into a source loader, and its transform if any.
200+
/// If necessary an empty chain will be created to stand in for the trivial transformation; it
201+
/// should never return <see langword="null"/>.
202+
/// </summary>
203+
private ITransformer DecomposeLoader(ref IDataLoader<IMultiStreamSource> loader)
204+
{
205+
_env.AssertValue(loader);
206+
207+
if (loader is CompositeDataLoader<IMultiStreamSource, ITransformer> composite)
208+
{
209+
loader = composite.Loader;
210+
var chain = composite.Transformer;
211+
// The save method corresponding to this load method encapsulates the input ITransformer
212+
// into a single-element transformer chain. If it is that sort, we guess that it is in fact
213+
// that sort, and so return it.
214+
var accessor = (ITransformerChainAccessor)chain;
215+
if (accessor.Transformers.Length == 1)
216+
return accessor.Transformers[0];
217+
// If it is some other length than 1 due to, say, some legacy model saving, just return that
218+
// chain. Using the above API this is not possible, since the chain saved will always be of length
219+
// one, but older APIs behaved differently so we should retain flexibility with those schemes.
220+
// (Those schemes are BTW by no means incorrect, they just aren't what the API in this particular
221+
// class will specifically do.)
222+
return chain;
181223
}
224+
// Maybe we have no transformer stored. Rather than return null, we prefer to return the
225+
// empty "trivial" transformer chain.
226+
return new TransformerChain<ITransformer>();
182227
}
183228

184229
/// <summary>
185-
/// Load the model and its input schema from the stream.
230+
/// Load a transformer model and a data loader model from a stream.
186231
/// </summary>
187232
/// <param name="stream">A readable, seekable stream to load from.</param>
188-
/// <returns>A model of type <see cref="CompositeDataLoader{IMultiStreamSource, ITransformer}"/> containing the loader
189-
/// and the transformer chain.</returns>
190-
public IDataLoader<IMultiStreamSource> Load(Stream stream)
233+
/// <param name="loader">The data loader from the model stream. Note that if there is no data loader,
234+
/// this method will throw an exception. The scenario where no loader is stored in the stream should
235+
/// be handled instead using the <see cref="Load(Stream, out DataViewSchema)"/> method.</param>
236+
/// <returns>The transformer model from the model stream.</returns>
237+
public ITransformer LoadWithDataLoader(Stream stream, out IDataLoader<IMultiStreamSource> loader)
191238
{
192239
_env.CheckValue(stream, nameof(stream));
193240

194241
using (var rep = RepositoryReader.Open(stream))
195242
{
196243
try
197244
{
198-
ModelLoadContext.LoadModel<IDataLoader<IMultiStreamSource>, SignatureLoadModel>(_env, out var model, rep, null);
199-
return model;
245+
ModelLoadContext.LoadModel<IDataLoader<IMultiStreamSource>, SignatureLoadModel>(_env, out loader, rep, null);
246+
return DecomposeLoader(ref loader);
200247
}
201248
catch (Exception ex)
202249
{
203-
throw _env.Except(ex, "Model does not contain an IDataLoader");
250+
throw _env.Except(ex, "Model does not contain an " + nameof(IDataLoader<IMultiStreamSource>) +
251+
". Perhaps this was saved with an " + nameof(DataViewSchema) + ", or even no information on its input at all. " +
252+
"Consider using the " + nameof(Load) + " method instead.");
204253
}
205254
}
206255
}
207256

208257
/// <summary>
209-
/// Load a transformer model and a data loader model from the stream.
258+
/// Load a transformer model and a data loader model from a file.
210259
/// </summary>
211-
/// <param name="stream">A readable, seekable stream to load from.</param>
212-
/// <param name="loader">The data loader from the model stream.</param>
213-
/// <returns>The transformer model from the model stream.</returns>
214-
public ITransformer LoadWithDataLoader(Stream stream, out IDataLoader<IMultiStreamSource> loader)
260+
/// <param name="filePath">Path to a file where the model should be read from.</param>
261+
/// <param name="loader">The data loader from the model stream. Note that if there is no data loader,
262+
/// this method will throw an exception. The scenario where no loader is stored in the stream should
263+
/// be handled instead using the <see cref="Load(Stream, out DataViewSchema)"/> method.</param>
264+
/// <returns>The transformer model from the model file.</returns>
265+
public ITransformer LoadWithDataLoader(string filePath, out IDataLoader<IMultiStreamSource> loader)
215266
{
216-
_env.CheckValue(stream, nameof(stream));
267+
_env.CheckNonEmpty(filePath, nameof(filePath));
217268

218-
loader = Load(stream);
219-
if (loader is CompositeDataLoader<IMultiStreamSource, ITransformer> composite)
220-
{
221-
loader = composite.Loader;
222-
return composite.Transformer;
223-
}
224-
return new TransformerChain<ITransformer>();
269+
using (var stream = File.OpenRead(filePath))
270+
return LoadWithDataLoader(stream, out loader);
225271
}
226272

227273
/// <summary>

0 commit comments

Comments
 (0)