Skip to content

Commit b929023

Browse files
committed
added an extension method for saving statically typed model - adjust api changes & documentation example
1 parent c92f0e4 commit b929023

File tree

4 files changed

+44
-25
lines changed

4 files changed

+44
-25
lines changed

docs/code/experimental/MlNetCookBookStaticApi.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ Here's what you do to save the model to a file, and reload it (potentially in a
397397

398398
```csharp
399399
// Saving and loading happens to 'dynamic' models, so the static typing is lost in the process.
400-
mlContext.Model.Save(model.AsDynamic, trainData.AsDynamic.Schema, modelPath);
400+
mlContext.Model.Save(model, trainData, modelPath);
401401

402402
// Potentially, the lines below can be in a different process altogether.
403403

src/Microsoft.ML.StaticPipe/ModelOperationsCatalog.cs

Lines changed: 0 additions & 23 deletions
This file was deleted.
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.IO;
6+
using Microsoft.ML.Data;
7+
8+
namespace Microsoft.ML.StaticPipe
9+
{
10+
public static class ModelOperationsCatalogExtensions
11+
{
12+
/// <summary>
13+
/// Save statically typed model to the stream.
14+
/// </summary>
15+
/// <param name="catalog">The model explainability operations catalog.</param>
16+
/// <param name="model">The trained model to be saved. Note that this can be <see langword="null"/>, as a shorthand
17+
/// for an empty transformer chain. Upon loading with <see cref="ML.ModelOperationsCatalog.Load(Stream, out DataViewSchema)"/> the returned value will
18+
/// be an empty <see cref="TransformerChain{TLastTransformer}"/>.</param>
19+
/// <param name="dataView">The data view with the schema of the input to the transformer. This can be <see langword="null"/>.</param>
20+
/// <param name="stream">A writeable, seekable stream to save to.</param>
21+
public static void Save<TInShape, TOutShape, TTransformer>(this ML.ModelOperationsCatalog catalog, Transformer<TInShape, TOutShape, TTransformer> model, DataView<TInShape> dataView, Stream stream)
22+
where TTransformer : class, ITransformer
23+
{
24+
catalog.Save(model?.AsDynamic, dataView?.AsDynamic.Schema, stream);
25+
}
26+
27+
/// <summary>
28+
/// Save statically typed model to the stream.
29+
/// </summary>
30+
/// <param name="catalog">The model explainability operations catalog.</param>
31+
/// <param name="model">The trained model to be saved. Note that this can be <see langword="null"/>, as a shorthand
32+
/// for an empty transformer chain. Upon loading with <see cref="ML.ModelOperationsCatalog.Load(Stream, out DataViewSchema)"/> the returned value will
33+
/// be an empty <see cref="TransformerChain{TLastTransformer}"/>.</param>
34+
/// <param name="dataView">The data view with the schema of the input to the transformer. This can be <see langword="null"/>.</param>
35+
/// <param name="filePath">Path where model should be saved.</param>
36+
public static void Save<TInShape, TOutShape, TTransformer>(this ML.ModelOperationsCatalog catalog, Transformer<TInShape, TOutShape, TTransformer> model, DataView<TInShape> dataView, string filePath)
37+
where TTransformer : class, ITransformer
38+
{
39+
catalog.Save(model?.AsDynamic, dataView?.AsDynamic.Schema, filePath);
40+
}
41+
}
42+
}

test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamples.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ private void TrainRegression(string trainDataPath, string testDataPath, string m
147147
var metrics = mlContext.Regression.Evaluate(model.Transform(testData), label: r => r.Target, score: r => r.Prediction);
148148

149149
// Saving and loading happens to 'dynamic' models, so the static typing is lost in the process.
150-
mlContext.Model.Save(model.AsDynamic, trainData.AsDynamic.Schema, modelPath);
150+
mlContext.Model.Save(model, trainData, modelPath);
151151

152152
// Potentially, the lines below can be in a different process altogether.
153153

0 commit comments

Comments
 (0)