Skip to content

Commit d625b4c

Browse files
committed
Add unit tests, and address some code review comments
1 parent 9d21951 commit d625b4c

File tree

13 files changed

+297
-115
lines changed

13 files changed

+297
-115
lines changed

src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs

Lines changed: 34 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -250,54 +250,41 @@ private static TransformerChain<ITransformer> Create(IHostEnvironment env, Model
250250
public static void SaveTo(this ITransformer transformer, IHostEnvironment env, Stream outputStream)
251251
=> new TransformerChain<ITransformer>(transformer).SaveTo(env, outputStream);
252252

253-
public static ITransformer LoadFrom(IHostEnvironment env, Stream stream)
253+
public static ITransformer LoadFromLegacy(IHostEnvironment env, Stream stream)
254254
{
255-
using (var rep = RepositoryReader.Open(stream, env))
256-
{
257-
try
258-
{
259-
ModelLoadContext.LoadModel<ITransformer, SignatureLoadModel>(env, out var transformerChain, rep, LoaderSignature);
260-
return transformerChain;
261-
}
262-
catch (FormatException ex)
263-
{
264-
if (!ex.IsMarked())
265-
throw;
266-
var chain = ModelFileUtils.LoadPipeline(env, stream, new MultiFileSource(null), extractInnerPipe: false);
267-
TransformerChain<ITransformer> transformChain = (chain as LegacyCompositeDataLoader).GetTransformer();
268-
var predictor = ModelFileUtils.LoadPredictorOrNull(env, stream);
269-
if (predictor == null)
270-
return transformChain;
271-
var roles = ModelFileUtils.LoadRoleMappingsOrNull(env, stream);
272-
env.CheckDecode(roles != null, "Predictor model must contain role mappings");
273-
var roleMappings = roles.ToArray();
274-
275-
ITransformer pred = null;
276-
if (predictor.PredictionKind == PredictionKind.BinaryClassification)
277-
pred = new BinaryPredictionTransformer<IPredictorProducing<float>>(env, predictor as IPredictorProducing<float>, chain.Schema,
278-
roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value);
279-
else if (predictor.PredictionKind == PredictionKind.MulticlassClassification)
280-
pred = new MulticlassPredictionTransformer<IPredictorProducing<VBuffer<float>>>(env,
281-
predictor as IPredictorProducing<VBuffer<float>>, chain.Schema,
282-
roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value,
283-
roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Label.Value).First().Value);
284-
else if (predictor.PredictionKind == PredictionKind.Clustering)
285-
pred = new ClusteringPredictionTransformer<IPredictorProducing<VBuffer<float>>>(env, predictor as IPredictorProducing<VBuffer<float>>, chain.Schema,
286-
roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value);
287-
else if (predictor.PredictionKind == PredictionKind.Regression)
288-
pred = new RegressionPredictionTransformer<IPredictorProducing<float>>(env, predictor as IPredictorProducing<float>, chain.Schema,
289-
roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value);
290-
else if (predictor.PredictionKind == PredictionKind.AnomalyDetection)
291-
pred = new AnomalyPredictionTransformer<IPredictorProducing<float>>(env, predictor as IPredictorProducing<float>, chain.Schema,
292-
roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value);
293-
else if (predictor.PredictionKind == PredictionKind.Ranking)
294-
pred = new RankingPredictionTransformer<IPredictorProducing<float>>(env, predictor as IPredictorProducing<float>, chain.Schema,
295-
roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value);
296-
else
297-
throw env.Except("Don't know how to map prediction kind {0}", predictor.PredictionKind);
298-
return transformChain.Append(pred);
299-
}
300-
}
255+
var chain = ModelFileUtils.LoadPipeline(env, stream, new MultiFileSource(null), extractInnerPipe: false);
256+
TransformerChain<ITransformer> transformChain = (chain as LegacyCompositeDataLoader).GetTransformer();
257+
var predictor = ModelFileUtils.LoadPredictorOrNull(env, stream);
258+
if (predictor == null)
259+
return transformChain;
260+
var roles = ModelFileUtils.LoadRoleMappingsOrNull(env, stream);
261+
env.CheckDecode(roles != null, "Predictor model must contain role mappings");
262+
var roleMappings = roles.ToArray();
263+
264+
ITransformer pred = null;
265+
if (predictor.PredictionKind == PredictionKind.BinaryClassification)
266+
pred = new BinaryPredictionTransformer<IPredictorProducing<float>>(env, predictor as IPredictorProducing<float>, chain.Schema,
267+
roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value);
268+
else if (predictor.PredictionKind == PredictionKind.MulticlassClassification)
269+
pred = new MulticlassPredictionTransformer<IPredictorProducing<VBuffer<float>>>(env,
270+
predictor as IPredictorProducing<VBuffer<float>>, chain.Schema,
271+
roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value,
272+
roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Label.Value).First().Value);
273+
else if (predictor.PredictionKind == PredictionKind.Clustering)
274+
pred = new ClusteringPredictionTransformer<IPredictorProducing<VBuffer<float>>>(env, predictor as IPredictorProducing<VBuffer<float>>, chain.Schema,
275+
roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value);
276+
else if (predictor.PredictionKind == PredictionKind.Regression)
277+
pred = new RegressionPredictionTransformer<IPredictorProducing<float>>(env, predictor as IPredictorProducing<float>, chain.Schema,
278+
roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value);
279+
else if (predictor.PredictionKind == PredictionKind.AnomalyDetection)
280+
pred = new AnomalyPredictionTransformer<IPredictorProducing<float>>(env, predictor as IPredictorProducing<float>, chain.Schema,
281+
roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value);
282+
else if (predictor.PredictionKind == PredictionKind.Ranking)
283+
pred = new RankingPredictionTransformer<IPredictorProducing<float>>(env, predictor as IPredictorProducing<float>, chain.Schema,
284+
roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value);
285+
else
286+
throw env.Except("Don't know how to map prediction kind {0}", predictor.PredictionKind);
287+
return transformChain.Append(pred);
301288
}
302289
}
303290
}

src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs

Lines changed: 62 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using System;
66
using System.IO;
7+
using System.Linq;
78
using Microsoft.Data.DataView;
89
using Microsoft.ML.Data;
910
using Microsoft.ML.Data.IO;
@@ -32,41 +33,44 @@ internal ModelOperationsCatalog(IHostEnvironment env)
3233
Explainability = new ExplainabilityTransforms(this);
3334
}
3435

35-
private void Save<TSource>(DataViewSchema schema, IDataLoader<TSource> model, Stream stream)
36+
/// <summary>
37+
/// Save the model to the stream.
38+
/// </summary>
39+
/// <param name="model">The trained model to be saved.</param>
40+
/// <param name="stream">A writeable, seekable stream to save to.</param>
41+
public void Save<TSource>(IDataLoader<TSource> model, Stream stream)
3642
{
43+
_env.CheckValue(model, nameof(model));
44+
_env.CheckValue(stream, nameof(stream));
45+
3746
using (var rep = RepositoryWriter.CreateNew(stream))
3847
{
3948
ModelSaveContext.SaveModel(rep, model, null);
40-
SaveInputSchema(schema, rep);
4149
rep.Commit();
4250
}
4351
}
4452

45-
/// <summary>
46-
/// Save the model to the stream.
47-
/// </summary>
48-
/// <param name="model">The trained model to be saved.</param>
49-
/// <param name="stream">A writeable, seekable stream to save to.</param>
50-
public void Save<TSource>(IDataLoader<TSource> model, Stream stream)
51-
=> Save(model.GetOutputSchema(), model, stream);
52-
5353
/// <summary>
5454
/// Save a transformer model and the loader used to create its input data to the stream.
5555
/// </summary>
5656
/// <param name="loader">The loader that was used to create data to train the model</param>
5757
/// <param name="model">The trained model to be saved</param>
5858
/// <param name="stream">A writeable, seekable stream to save to.</param>
5959
public void Save<TSource>(IDataLoader<TSource> loader, ITransformer model, Stream stream) =>
60-
Save(loader.GetOutputSchema(), new CompositeDataLoader<TSource, ITransformer>(loader, new TransformerChain<ITransformer>(model)), stream);
60+
Save(new CompositeDataLoader<TSource, ITransformer>(loader, new TransformerChain<ITransformer>(model)), stream);
6161

6262
/// <summary>
6363
/// Save a transformer model and the schema of the data that was used to train it to the stream.
6464
/// </summary>
65-
/// <param name="inputSchema">The schema of the input to the transformer.</param>
6665
/// <param name="model">The trained model to be saved.</param>
66+
/// <param name="inputSchema">The schema of the input to the transformer. This can be null.</param>
6767
/// <param name="stream">A writeable, seekable stream to save to.</param>
68-
public void Save(DataViewSchema inputSchema, ITransformer model, Stream stream)
68+
public void Save(ITransformer model, DataViewSchema inputSchema, Stream stream)
6969
{
70+
_env.CheckValue(model, nameof(model));
71+
_env.CheckValueOrNull(inputSchema);
72+
_env.CheckValue(stream, nameof(stream));
73+
7074
using (var rep = RepositoryWriter.CreateNew(stream))
7175
{
7276
ModelSaveContext.SaveModel(rep, model, CompositeDataLoader<object, ITransformer>.TransformerDirectory);
@@ -77,6 +81,12 @@ public void Save(DataViewSchema inputSchema, ITransformer model, Stream stream)
7781

7882
private void SaveInputSchema(DataViewSchema inputSchema, RepositoryWriter rep)
7983
{
84+
_env.AssertValueOrNull(inputSchema);
85+
_env.AssertValue(rep);
86+
87+
if (inputSchema == null)
88+
return;
89+
8090
using (var ch = _env.Start("Saving Schema"))
8191
{
8292
var entry = rep.CreateEntry(SchemaEntryName);
@@ -94,30 +104,50 @@ private void SaveInputSchema(DataViewSchema inputSchema, RepositoryWriter rep)
94104
/// <returns>The loaded model.</returns>
95105
public ITransformer Load(Stream stream, out DataViewSchema inputSchema)
96106
{
107+
_env.CheckValue(stream, nameof(stream));
108+
97109
using (var rep = RepositoryReader.Open(stream, _env))
98110
{
99111
var entry = rep.OpenEntryOrNull(SchemaEntryName);
100112
if (entry != null)
101113
{
102114
var loader = new BinaryLoader(_env, new BinaryLoader.Arguments(), entry.Stream);
103115
inputSchema = loader.Schema;
116+
ModelLoadContext.LoadModel<ITransformer, SignatureLoadModel>(_env, out var transformerChain, rep,
117+
CompositeDataLoader<object, ITransformer>.TransformerDirectory);
118+
return transformerChain;
104119
}
105-
else
120+
121+
ModelLoadContext.LoadModelOrNull<IDataLoader<IMultiStreamSource>, SignatureLoadModel>(_env, out var dataLoader, rep, null);
122+
if (dataLoader == null)
106123
{
124+
// Try to see if the model was saved without a loader or a schema.
125+
if (ModelLoadContext.LoadModelOrNull<ITransformer, SignatureLoadModel>(_env, out var transformerChain, rep,
126+
CompositeDataLoader<object, ITransformer>.TransformerDirectory))
127+
{
128+
inputSchema = null;
129+
return transformerChain;
130+
}
131+
107132
// Try to load from legacy model format.
108133
try
109134
{
110135
var loader = ModelFileUtils.LoadLoader(_env, rep, new MultiFileSource(null), false);
111136
inputSchema = loader.Schema;
137+
return TransformerChain.LoadFromLegacy(_env, stream);
112138
}
113139
catch (Exception ex)
114140
{
115-
if (!ex.IsMarked())
116-
throw;
117-
inputSchema = null;
141+
throw _env.Except(ex, "Could not load legacy format model");
118142
}
119143
}
120-
return TransformerChain.LoadFrom(_env, stream);
144+
if (dataLoader is CompositeDataLoader<IMultiStreamSource, ITransformer> composite)
145+
{
146+
inputSchema = composite.Loader.GetOutputSchema();
147+
return composite.Transformer;
148+
}
149+
inputSchema = dataLoader.GetOutputSchema();
150+
return new TransformerChain<ITransformer>();
121151
}
122152
}
123153

@@ -127,12 +157,21 @@ public ITransformer Load(Stream stream, out DataViewSchema inputSchema)
127157
/// <param name="stream">A readable, seekable stream to load from.</param>
128158
/// <returns>A model of type <see cref="CompositeDataLoader{IMultiStreamSource, ITransformer}"/> containing the loader
129159
/// and the transformer chain.</returns>
130-
public CompositeDataLoader<IMultiStreamSource, ITransformer> Load(Stream stream)
160+
public IDataLoader<IMultiStreamSource> Load(Stream stream)
131161
{
162+
_env.CheckValue(stream, nameof(stream));
163+
132164
using (var rep = RepositoryReader.Open(stream))
133165
{
134-
ModelLoadContext.LoadModel<CompositeDataLoader<IMultiStreamSource, ITransformer>, SignatureLoadModel>(_env, out var model, rep, null);
135-
return model;
166+
try
167+
{
168+
ModelLoadContext.LoadModel<IDataLoader<IMultiStreamSource>, SignatureLoadModel>(_env, out var model, rep, null);
169+
return model;
170+
}
171+
catch (Exception ex)
172+
{
173+
throw _env.Except(ex, "Model does not contain an IDataLoader");
174+
}
136175
}
137176
}
138177

@@ -144,6 +183,8 @@ public CompositeDataLoader<IMultiStreamSource, ITransformer> Load(Stream stream)
144183
/// <returns>The transformer model from the model stream.</returns>
145184
public ITransformer Load(Stream stream, out IDataLoader<IMultiStreamSource> loader)
146185
{
186+
_env.CheckValue(stream, nameof(stream));
187+
147188
loader = Load(stream);
148189
if (loader is CompositeDataLoader<IMultiStreamSource, ITransformer> composite)
149190
{

test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5647,7 +5647,6 @@ public void LoadEntryPointModel()
56475647
{
56485648
loadedModel = ml.Model.Load(stream, out DataViewSchema inputSchema);
56495649
}
5650-
56515650
}
56525651
}
56535652
}

0 commit comments

Comments
 (0)