Skip to content

Commit 970ad76

Browse files
committed
Add API for saving/loading input schema
1 parent 297a677 commit 970ad76

File tree

17 files changed

+174
-41
lines changed

17 files changed

+174
-41
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/IidChangePointDetectorTransform.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ public static void IidChangePointDetectorPrediction()
138138

139139
// Load the model.
140140
using (var file = File.OpenRead(modelPath))
141-
model = ml.Model.Load(file);
141+
model = ml.Model.Load(file, out var schema);
142142

143143
// Create a time series prediction engine from the checkpointed model.
144144
engine = model.CreateTimeSeriesPredictionFunction<IidChangePointData, ChangePointPrediction>(ml);

docs/samples/Microsoft.ML.Samples/Dynamic/IidSpikeDetectorTransform.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ public static void IidSpikeDetectorPrediction()
122122

123123
// Load the model.
124124
using (var file = File.OpenRead(modelPath))
125-
model = ml.Model.Load(file);
125+
model = ml.Model.Load(file, out var schema);
126126

127127
for (int index = 0; index < 5; index++)
128128
{

docs/samples/Microsoft.ML.Samples/Dynamic/SsaChangePointDetectorTransform.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ public static void SsaChangePointDetectorPrediction()
142142

143143
// Load the model.
144144
using (var file = File.OpenRead(modelPath))
145-
model = ml.Model.Load(file);
145+
model = ml.Model.Load(file, out var schema);
146146

147147
// We must create a new prediction engine from the persisted model.
148148
engine = model.CreateTimeSeriesPredictionFunction<SsaChangePointData, ChangePointPrediction>(ml);

docs/samples/Microsoft.ML.Samples/Dynamic/SsaSpikeDetectorTransform.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ public static void SsaSpikeDetectorPrediction()
150150

151151
// Load the model.
152152
using (var file = File.OpenRead(modelPath))
153-
model = ml.Model.Load(file);
153+
model = ml.Model.Load(file, out var schema);
154154

155155
// We must create a new prediction engine from the persisted model.
156156
engine = model.CreateTimeSeriesPredictionFunction<SsaSpikeData, SsaSpikePrediction>(ml);

src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,6 @@ namespace Microsoft.ML.Data
1919
public sealed class CompositeDataLoader<TSource, TLastTransformer> : IDataLoader<TSource>
2020
where TLastTransformer : class, ITransformer
2121
{
22-
private const string LoaderDirectory = "Loader";
23-
private const string LegacyLoaderDirectory = "Reader";
24-
private const string TransformerDirectory = TransformerChain.LoaderSignature;
25-
2622
/// <summary>
2723
/// The underlying data loader.
2824
/// </summary>
@@ -43,9 +39,9 @@ public CompositeDataLoader(IDataLoader<TSource> loader, TransformerChain<TLastTr
4339

4440
private CompositeDataLoader(IHost host, ModelLoadContext ctx)
4541
{
46-
if (!ctx.LoadModelOrNull<IDataLoader<TSource>, SignatureLoadModel>(host, out Loader, LegacyLoaderDirectory))
47-
ctx.LoadModel<IDataLoader<TSource>, SignatureLoadModel>(host, out Loader, LoaderDirectory);
48-
ctx.LoadModel<TransformerChain<TLastTransformer>, SignatureLoadModel>(host, out Transformer, TransformerDirectory);
42+
if (!ctx.LoadModelOrNull<IDataLoader<TSource>, SignatureLoadModel>(host, out Loader, ModelOperationsCatalog.LegacyLoaderDirectory))
43+
ctx.LoadModel<IDataLoader<TSource>, SignatureLoadModel>(host, out Loader, ModelOperationsCatalog.LoaderDirectory);
44+
ctx.LoadModel<TransformerChain<TLastTransformer>, SignatureLoadModel>(host, out Transformer, ModelOperationsCatalog.TransformerDirectory);
4945
}
5046

5147
private static CompositeDataLoader<TSource, TLastTransformer> Create(IHostEnvironment env, ModelLoadContext ctx)
@@ -94,8 +90,8 @@ void ICanSaveModel.Save(ModelSaveContext ctx)
9490
ctx.CheckAtModel();
9591
ctx.SetVersionInfo(GetVersionInfo());
9692

97-
ctx.SaveModel(Loader, LoaderDirectory);
98-
ctx.SaveModel(Transformer, TransformerDirectory);
93+
ctx.SaveModel(Loader, ModelOperationsCatalog.LoaderDirectory);
94+
ctx.SaveModel(Transformer, ModelOperationsCatalog.TransformerDirectory);
9995
}
10096

10197
internal const string Summary = "A loader that encapsulates a loader and a transformer chain.";

src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,14 @@ public IDataView LoadFromEnumerable<TRow>(IEnumerable<TRow> data, SchemaDefiniti
5757
return DataViewConstructionUtils.CreateFromEnumerable(_env, data, schemaDefinition);
5858
}
5959

60+
public IDataView LoadFromEnumerable<TRow>(IEnumerable<TRow> data, DataViewSchema schema)
61+
where TRow : class
62+
{
63+
_env.CheckValue(data, nameof(data));
64+
_env.CheckValue(schema, nameof(schema));
65+
return DataViewConstructionUtils.CreateFromEnumerable(_env, data, schema);
66+
}
67+
6068
/// <summary>
6169
/// Convert an <see cref="IDataView"/> into a strongly-typed <see cref="IEnumerable{TRow}"/>.
6270
/// </summary>

src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -249,15 +249,15 @@ private static TransformerChain<ITransformer> Create(IHostEnvironment env, Model
249249
public static void SaveTo(this ITransformer transformer, IHostEnvironment env, Stream outputStream)
250250
=> new TransformerChain<ITransformer>(transformer).SaveTo(env, outputStream);
251251

252-
public static TransformerChain<ITransformer> LoadFrom(IHostEnvironment env, Stream stream)
252+
public static ITransformer LoadFrom(IHostEnvironment env, Stream stream)
253253
{
254254
using (var rep = RepositoryReader.Open(stream, env))
255255
{
256256
try
257257
{
258-
ModelLoadContext.LoadModelOrNull<TransformerChain<ITransformer>, SignatureLoadModel>(env, out var transformerChain, rep, LoaderSignature);
258+
ModelLoadContext.LoadModelOrNull<ITransformer, SignatureLoadModel>(env, out var transformerChain, rep, LoaderSignature);
259259
if (transformerChain == null)
260-
ModelLoadContext.LoadModel<TransformerChain<ITransformer>, SignatureLoadModel>(env, out transformerChain, rep, $@"Model\{LoaderSignature}");
260+
ModelLoadContext.LoadModel<ITransformer, SignatureLoadModel>(env, out transformerChain, rep, $@"Model\{LoaderSignature}");
261261
return transformerChain;
262262
}
263263
catch (FormatException ex)

src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,50 @@ public static StreamingDataView<TRow> CreateFromEnumerable<TRow>(IHostEnvironmen
4545
return new StreamingDataView<TRow>(env, data, internalSchemaDefn);
4646
}
4747

48+
public static StreamingDataView<TRow> CreateFromEnumerable<TRow>(IHostEnvironment env, IEnumerable<TRow> data,
49+
DataViewSchema schema)
50+
where TRow : class
51+
{
52+
Contracts.AssertValue(env);
53+
env.AssertValue(data);
54+
env.AssertValueOrNull(schema);
55+
schema = schema ?? new DataViewSchema.Builder().ToSchema();
56+
return new StreamingDataView<TRow>(env, data, GetInternalSchemaDefinition<TRow>(env, schema));
57+
}
58+
59+
private static InternalSchemaDefinition GetInternalSchemaDefinition<TRow>(IHostEnvironment env, DataViewSchema schema)
60+
where TRow : class
61+
{
62+
Contracts.AssertValue(env);
63+
env.AssertValue(schema);
64+
65+
var isd = InternalSchemaDefinition.Create(typeof(TRow), SchemaDefinition.Direction.Read);
66+
foreach (var col in schema)
67+
{
68+
var name = col.Name;
69+
var isdCol = isd.Columns.FirstOrDefault(c => c.ColumnName == name);
70+
if (isdCol == null)
71+
throw env.Except($"Type should contain a member named {isdCol.ColumnName}");
72+
var annotations = col.Annotations;
73+
if (annotations != null)
74+
{
75+
foreach (var annotation in annotations.Schema)
76+
{
77+
var info = Utils.MarshalInvoke(GetAnnotationInfo<int>, annotation.Type.RawType, annotation.Name, annotations);
78+
isdCol.Annotations.Add(annotation.Name, info);
79+
}
80+
}
81+
}
82+
return isd;
83+
}
84+
85+
private static AnnotationInfo GetAnnotationInfo<T>(string kind, DataViewSchema.Annotations annotations)
86+
{
87+
T value = default;
88+
annotations.GetValue(kind, ref value);
89+
return new AnnotationInfo<T>(kind, value);
90+
}
91+
4892
public static InputRow<TRow> CreateInputRow<TRow>(IHostEnvironment env, SchemaDefinition schemaDefinition = null)
4993
where TRow : class
5094
{
@@ -604,7 +648,7 @@ public StreamingDataView(IHostEnvironment env, IEnumerable<TRow> data, InternalS
604648
public override DataViewRowCursor GetRowCursor(IEnumerable<DataViewSchema.Column> columnsNeeded, Random rand = null)
605649
{
606650
var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, Schema);
607-
return new WrappedCursor (new Cursor(Host, this, predicate));
651+
return new WrappedCursor(new Cursor(Host, this, predicate));
608652
}
609653

610654
private sealed class Cursor : DataViewCursorBase
@@ -674,7 +718,7 @@ public override DataViewRowCursor GetRowCursor(IEnumerable<DataViewSchema.Column
674718
{
675719
Contracts.Assert(_current != null, "The current object must be set prior to cursoring");
676720
var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, Schema);
677-
return new WrappedCursor (new Cursor(Host, this, predicate));
721+
return new WrappedCursor(new Cursor(Host, this, predicate));
678722
}
679723

680724
private sealed class Cursor : DataViewCursorBase

src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs

Lines changed: 79 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System;
56
using System.IO;
7+
using Microsoft.Data.DataView;
68
using Microsoft.ML.Data;
9+
using Microsoft.ML.Data.IO;
10+
using Microsoft.ML.Model;
711

812
namespace Microsoft.ML
913
{
@@ -12,6 +16,11 @@ namespace Microsoft.ML
1216
/// </summary>
1317
public sealed class ModelOperationsCatalog : IInternalCatalog
1418
{
19+
internal const string LoaderDirectory = "Loader";
20+
internal const string LegacyLoaderDirectory = "Reader";
21+
internal const string TransformerDirectory = TransformerChain.LoaderSignature;
22+
internal const string SchemaEntryName = "Schema";
23+
1524
IHostEnvironment IInternalCatalog.Environment => _env;
1625
private readonly IHostEnvironment _env;
1726

@@ -30,28 +39,93 @@ internal ModelOperationsCatalog(IHostEnvironment env)
3039
/// </summary>
3140
/// <param name="model">The trained model to be saved.</param>
3241
/// <param name="stream">A writeable, seekable stream to save to.</param>
33-
public void Save(ITransformer model, Stream stream) => model.SaveTo(_env, stream);
34-
3542
public void Save<TSource>(IDataLoader<TSource> model, Stream stream)
3643
{
3744
using (var rep = RepositoryWriter.CreateNew(stream))
3845
{
3946
ModelSaveContext.SaveModel(rep, model, "Model");
47+
SaveInputSchema(model.GetOutputSchema(), rep);
4048
rep.Commit();
4149
}
4250
}
4351

52+
/// <summary>
53+
/// Save a transformer model and the loader used to create its input data to the stream.
54+
/// </summary>
55+
/// <param name="loader">The loader that was used to create data to train the model</param>
56+
/// <param name="model">The trained model to be saved</param>
57+
/// <param name="stream">A writeable, seekable stream to save to.</param>
4458
public void Save<TSource>(IDataLoader<TSource> loader, ITransformer model, Stream stream) =>
4559
Save(new CompositeDataLoader<TSource, ITransformer>(loader, new TransformerChain<ITransformer>(model)), stream);
4660

4761
/// <summary>
48-
/// Load the model from the stream.
62+
/// Save a transformer model and the schema of the data that was used to train it to the stream.
63+
/// </summary>
64+
/// <param name="inputSchema">The schema of the input to the transformer.</param>
65+
/// <param name="model">The trained model to be saved.</param>
66+
/// <param name="stream">A writeable, seekable stream to save to.</param>
67+
public void Save(DataViewSchema inputSchema, ITransformer model, Stream stream)
68+
{
69+
using (var rep = RepositoryWriter.CreateNew(stream))
70+
{
71+
ModelSaveContext.SaveModel(rep, model, TransformerDirectory);
72+
SaveInputSchema(inputSchema, rep);
73+
rep.Commit();
74+
}
75+
}
76+
77+
private void SaveInputSchema(DataViewSchema inputSchema, RepositoryWriter rep)
78+
{
79+
using (var ch = _env.Start("Saving Schema"))
80+
{
81+
var entry = rep.CreateEntry(SchemaEntryName);
82+
var saver = new BinarySaver(_env, new BinarySaver.Arguments { Silent = true });
83+
DataSaverUtils.SaveDataView(ch, saver, new EmptyDataView(_env, inputSchema), entry.Stream, keepHidden: true);
84+
}
85+
}
86+
87+
/// <summary>
88+
/// Load the model and its input schema from the stream.
4989
/// </summary>
5090
/// <param name="stream">A readable, seekable stream to load from.</param>
91+
/// <param name="inputSchema">Will contain the input schema for the model.</param>
5192
/// <returns>The loaded model.</returns>
52-
public ITransformer Load(Stream stream) => TransformerChain.LoadFrom(_env, stream);
93+
public ITransformer Load(Stream stream, out DataViewSchema inputSchema)
94+
{
95+
using (var rep = RepositoryReader.Open(stream, _env))
96+
{
97+
var entry = rep.OpenEntryOrNull(SchemaEntryName);
98+
if (entry != null)
99+
{
100+
var loader = new BinaryLoader(_env, new BinaryLoader.Arguments(), entry.Stream);
101+
inputSchema = loader.Schema;
102+
}
103+
else
104+
{
105+
// Try to load from legacy model format.
106+
try
107+
{
108+
var loader = ModelFileUtils.LoadLoader(_env, rep, new MultiFileSource(null), false);
109+
inputSchema = loader.Schema;
110+
}
111+
catch (Exception ex)
112+
{
113+
if (!ex.IsMarked())
114+
throw;
115+
inputSchema = null;
116+
}
117+
}
118+
return TransformerChain.LoadFrom(_env, stream);
119+
}
120+
}
53121

54-
public CompositeDataLoader<IMultiStreamSource, ITransformer> LoadAsCompositeDataLoader(Stream stream)
122+
/// <summary>
123+
/// Load the model and its input schema from the stream.
124+
/// </summary>
125+
/// <param name="stream">A readable, seekable stream to load from.</param>
126+
/// <returns>A model of type <see cref="CompositeDataLoader{IMultiStreamSource, ITransformer}"/> containing the loader
127+
/// and the transformer chain.</returns>
128+
public CompositeDataLoader<IMultiStreamSource, ITransformer> Load(Stream stream)
55129
{
56130
using (var rep = RepositoryReader.Open(stream))
57131
{

test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5615,7 +5615,7 @@ public void LoadEntryPointModel()
56155615
ITransformer loadedModel;
56165616
using (var stream = File.OpenRead(modelPath))
56175617
{
5618-
loadedModel = ml.Model.Load(stream);
5618+
loadedModel = ml.Model.Load(stream, out var inputSchema);
56195619
}
56205620

56215621
}

0 commit comments

Comments
 (0)