@@ -29,63 +29,60 @@ internal ModelOperationsCatalog(IHostEnvironment env)
2929 }
3030
3131 /// <summary>
32- /// Save the model to the stream.
32+ /// Save a transformer model and the loader used to create its input data to the stream.
3333 /// </summary>
34- /// <param name="model">The trained model to be saved.</param>
34+ /// <param name="model">The trained model to be saved. Note that this can be <see langword="null"/>, as a shorthand
35+ /// for an empty transformer chain. Upon loading with <see cref="LoadWithDataLoader(Stream, out IDataLoader{IMultiStreamSource})"/>
36+ /// the returned value will be an empty <see cref="TransformerChain{TLastTransformer}"/>.</param>
37+ /// <param name="loader">The loader that was used to create data to train the model.</param>
3538 /// <param name="stream">A writeable, seekable stream to save to.</param>
36- public void Save < TSource > ( IDataLoader < TSource > model , Stream stream )
39+ public void Save < TSource > ( ITransformer model , IDataLoader < TSource > loader , Stream stream )
3740 {
38- _env . CheckValue ( model , nameof ( model ) ) ;
41+ _env . CheckValue ( loader , nameof ( loader ) ) ;
42+ _env . CheckValueOrNull ( model ) ;
3943 _env . CheckValue ( stream , nameof ( stream ) ) ;
4044
45+ // For the sake of consistency of this API specifically, when called upon we save any transformer
46+ // in a single element transformer chain.
47+ var chainedModel = model == null ? null : new TransformerChain < ITransformer > ( model ) ;
48+ var compositeLoader = new CompositeDataLoader < TSource , ITransformer > ( loader , chainedModel ) ;
49+
4150 using ( var rep = RepositoryWriter . CreateNew ( stream ) )
4251 {
43- ModelSaveContext . SaveModel ( rep , model , null ) ;
52+ ModelSaveContext . SaveModel ( rep , compositeLoader , null ) ;
4453 rep . Commit ( ) ;
4554 }
4655 }
4756
48- /// <summary>
49- /// Save the model to the file.
50- /// </summary>
51- /// <param name="model">The trained model to be saved.</param>
52- /// <param name="filePath">Path where model should be saved.</param>
53- public void Save < TSource > ( IDataLoader < TSource > model , string filePath )
54- {
55- using ( var stream = File . Create ( filePath ) )
56- Save ( model , stream ) ;
57- }
58-
59- /// <summary>
60- /// Save a transformer model and the loader used to create its input data to the stream.
61- /// </summary>
62- /// <param name="loader">The loader that was used to create data to train the model</param>
63- /// <param name="model">The trained model to be saved</param>
64- /// <param name="stream">A writeable, seekable stream to save to.</param>
65- public void Save < TSource > ( IDataLoader < TSource > loader , ITransformer model , Stream stream ) =>
66- Save ( new CompositeDataLoader < TSource , ITransformer > ( loader , new TransformerChain < ITransformer > ( model ) ) , stream ) ;
67-
6857 /// <summary>
6958 /// Save a transformer model and the loader used to create its input data to the file.
7059 /// </summary>
71- /// <param name="loader">The loader that was used to create data to train the model</param>
72- /// <param name="model">The trained model to be saved</param>
60+ /// <param name="model">The trained model to be saved. Note that this can be <see langword="null"/>, as a shorthand
61+ /// for an empty transformer chain. Upon loading with <see cref="LoadWithDataLoader(Stream, out IDataLoader{IMultiStreamSource})"/>
62+ /// the returned value will be an empty <see cref="TransformerChain{TLastTransformer}"/>.</param>
63+ /// <param name="loader">The loader that was used to create data to train the model.</param>
7364 /// <param name="filePath">Path where model should be saved.</param>
74- public void Save < TSource > ( IDataLoader < TSource > loader , ITransformer model , string filePath )
65+ public void Save < TSource > ( ITransformer model , IDataLoader < TSource > loader , string filePath )
7566 {
67+ _env . CheckValueOrNull ( model ) ;
68+ _env . CheckValue ( loader , nameof ( loader ) ) ;
69+ _env . CheckNonEmpty ( filePath , nameof ( filePath ) ) ;
70+
7671 using ( var stream = File . Create ( filePath ) )
77- Save ( loader , model , stream ) ;
72+ Save ( model , loader , stream ) ;
7873 }
7974
8075 /// <summary>
8176 /// Save a transformer model and the schema of the data that was used to train it to the stream.
8277 /// </summary>
83- /// <param name="model">The trained model to be saved.</param>
84- /// <param name="inputSchema">The schema of the input to the transformer. This can be null.</param>
78+ /// <param name="model">The trained model to be saved. Note that this can be <see langword="null"/>, as a shorthand
79+ /// for an empty transformer chain. Upon loading with <see cref="Load(Stream, out DataViewSchema)"/> the returned value will
80+ /// be an empty <see cref="TransformerChain{TLastTransformer}"/>.</param>
81+ /// <param name="inputSchema">The schema of the input to the transformer. This can be <see langword="null"/>.</param>
8582 /// <param name="stream">A writeable, seekable stream to save to.</param>
8683 public void Save ( ITransformer model , DataViewSchema inputSchema , Stream stream )
8784 {
88- _env . CheckValue ( model , nameof ( model ) ) ;
85+ _env . CheckValueOrNull ( model ) ;
8986 _env . CheckValueOrNull ( inputSchema ) ;
9087 _env . CheckValue ( stream , nameof ( stream ) ) ;
9188
@@ -100,11 +97,17 @@ public void Save(ITransformer model, DataViewSchema inputSchema, Stream stream)
10097 /// <summary>
10198 /// Save a transformer model and the schema of the data that was used to train it to the file.
10299 /// </summary>
103- /// <param name="model">The trained model to be saved.</param>
104- /// <param name="inputSchema">The schema of the input to the transformer. This can be null.</param>
100+ /// <param name="model">The trained model to be saved. Note that this can be <see langword="null"/>, as a shorthand
101+ /// for an empty transformer chain. Upon loading with <see cref="Load(Stream, out DataViewSchema)"/> the returned value will
102+ /// be an empty <see cref="TransformerChain{TLastTransformer}"/>.</param>
103+ /// <param name="inputSchema">The schema of the input to the transformer. This can be <see langword="null"/>.</param>
105104 /// <param name="filePath">Path where model should be saved.</param>
106105 public void Save ( ITransformer model , DataViewSchema inputSchema , string filePath )
107106 {
107+ _env . CheckValueOrNull ( model ) ;
108+ _env . CheckValueOrNull ( inputSchema ) ;
109+ _env . CheckNonEmpty ( filePath , nameof ( filePath ) ) ;
110+
108111 using ( var stream = File . Create ( filePath ) )
109112 Save ( model , inputSchema , stream ) ;
110113 }
@@ -126,11 +129,11 @@ private void SaveInputSchema(DataViewSchema inputSchema, RepositoryWriter rep)
126129 }
127130
128131 /// <summary>
129- /// Load the model and its input schema from the stream.
132+ /// Load the model and its input schema from a stream.
130133 /// </summary>
131134 /// <param name="stream">A readable, seekable stream to load from.</param>
132- /// <param name="inputSchema">Will contain the input schema for the model. If the model was saved using older APIs
133- /// it may not contain an input schema, in this case <paramref name="inputSchema "/> will be null .</param>
135+ /// <param name="inputSchema">Will contain the input schema for the model. If the model was saved without
136+ /// any description of the input, there will be no input schema. In this case this can be <see langword="null "/>.</param>
134137 /// <returns>The loaded model.</returns>
135138 public ITransformer Load ( Stream stream , out DataViewSchema inputSchema )
136139 {
@@ -171,57 +174,100 @@ public ITransformer Load(Stream stream, out DataViewSchema inputSchema)
171174 throw _env . Except ( ex , "Could not load legacy format model" ) ;
172175 }
173176 }
174- if ( dataLoader is CompositeDataLoader < IMultiStreamSource , ITransformer > composite )
175- {
176- inputSchema = composite . Loader . GetOutputSchema ( ) ;
177- return composite . Transformer ;
178- }
177+ var transformer = DecomposeLoader ( ref dataLoader ) ;
179178 inputSchema = dataLoader . GetOutputSchema ( ) ;
180- return new TransformerChain < ITransformer > ( ) ;
179+ return transformer ;
180+ }
181+ }
182+
183+ /// <summary>
184+ /// Load the model and its input schema from a file.
185+ /// </summary>
186+ /// <param name="filePath">Path to a file where the model should be read from.</param>
187+ /// <param name="inputSchema">Will contain the input schema for the model. If the model was saved without
188+ /// any description of the input, there will be no input schema. In this case this can be <see langword="null"/>.</param>
189+ /// <returns>The loaded model.</returns>
190+ public ITransformer Load ( string filePath , out DataViewSchema inputSchema )
191+ {
192+ _env . CheckNonEmpty ( filePath , nameof ( filePath ) ) ;
193+
194+ using ( var stream = File . OpenRead ( filePath ) )
195+ return Load ( stream , out inputSchema ) ;
196+ }
197+
198+ /// <summary>
199+ /// Given a loader, test try to "decompose" it into a source loader, and its transform if any.
200+ /// If necessary an empty chain will be created to stand in for the trivial transformation; it
201+ /// should never return <see langword="null"/>.
202+ /// </summary>
203+ private ITransformer DecomposeLoader ( ref IDataLoader < IMultiStreamSource > loader )
204+ {
205+ _env . AssertValue ( loader ) ;
206+
207+ if ( loader is CompositeDataLoader < IMultiStreamSource , ITransformer > composite )
208+ {
209+ loader = composite . Loader ;
210+ var chain = composite . Transformer ;
211+ // The save method corresponding to this load method encapsulates the input ITransformer
212+ // into a single-element transformer chain. If it is that sort, we guess that it is in fact
213+ // that sort, and so return it.
214+ var accessor = ( ITransformerChainAccessor ) chain ;
215+ if ( accessor . Transformers . Length == 1 )
216+ return accessor . Transformers [ 0 ] ;
217+ // If it is some other length than 1 due to, say, some legacy model saving, just return that
218+ // chain. Using the above API this is not possible, since the chain saved will always be of length
219+ // one, but older APIs behaved differently so we should retain flexibility with those schemes.
220+ // (Those schemes are BTW by no means incorrect, they just aren't what the API in this particular
221+ // class will specifically do.)
222+ return chain ;
181223 }
224+ // Maybe we have no transformer stored. Rather than return null, we prefer to return the
225+ // empty "trivial" transformer chain.
226+ return new TransformerChain < ITransformer > ( ) ;
182227 }
183228
184229 /// <summary>
185- /// Load the model and its input schema from the stream.
230+ /// Load a transformer model and a data loader model from a stream.
186231 /// </summary>
187232 /// <param name="stream">A readable, seekable stream to load from.</param>
188- /// <returns>A model of type <see cref="CompositeDataLoader{IMultiStreamSource, ITransformer}"/> containing the loader
189- /// and the transformer chain.</returns>
190- public IDataLoader < IMultiStreamSource > Load ( Stream stream )
233+ /// <param name="loader">The data loader from the model stream. Note that if there is no data loader,
234+ /// this method will throw an exception. The scenario where no loader is stored in the stream should
235+ /// be handled instead using the <see cref="Load(Stream, out DataViewSchema)"/> method.</param>
236+ /// <returns>The transformer model from the model stream.</returns>
237+ public ITransformer LoadWithDataLoader ( Stream stream , out IDataLoader < IMultiStreamSource > loader )
191238 {
192239 _env . CheckValue ( stream , nameof ( stream ) ) ;
193240
194241 using ( var rep = RepositoryReader . Open ( stream ) )
195242 {
196243 try
197244 {
198- ModelLoadContext . LoadModel < IDataLoader < IMultiStreamSource > , SignatureLoadModel > ( _env , out var model , rep , null ) ;
199- return model ;
245+ ModelLoadContext . LoadModel < IDataLoader < IMultiStreamSource > , SignatureLoadModel > ( _env , out loader , rep , null ) ;
246+ return DecomposeLoader ( ref loader ) ;
200247 }
201248 catch ( Exception ex )
202249 {
203- throw _env . Except ( ex , "Model does not contain an IDataLoader" ) ;
250+ throw _env . Except ( ex , "Model does not contain an " + nameof ( IDataLoader < IMultiStreamSource > ) +
251+ ". Perhaps this was saved with an " + nameof ( DataViewSchema ) + ", or even no information on its input at all. " +
252+ "Consider using the " + nameof ( Load ) + " method instead." ) ;
204253 }
205254 }
206255 }
207256
208257 /// <summary>
209- /// Load a transformer model and a data loader model from the stream .
258+ /// Load a transformer model and a data loader model from a file .
210259 /// </summary>
211- /// <param name="stream">A readable, seekable stream to load from.</param>
212- /// <param name="loader">The data loader from the model stream.</param>
213- /// <returns>The transformer model from the model stream.</returns>
214- public ITransformer LoadWithDataLoader ( Stream stream , out IDataLoader < IMultiStreamSource > loader )
260+ /// <param name="filePath">Path to a file where the model should be read from.</param>
261+ /// <param name="loader">The data loader from the model stream. Note that if there is no data loader,
262+ /// this method will throw an exception. The scenario where no loader is stored in the stream should
263+ /// be handled instead using the <see cref="Load(Stream, out DataViewSchema)"/> method.</param>
264+ /// <returns>The transformer model from the model file.</returns>
265+ public ITransformer LoadWithDataLoader ( string filePath , out IDataLoader < IMultiStreamSource > loader )
215266 {
216- _env . CheckValue ( stream , nameof ( stream ) ) ;
267+ _env . CheckNonEmpty ( filePath , nameof ( filePath ) ) ;
217268
218- loader = Load ( stream ) ;
219- if ( loader is CompositeDataLoader < IMultiStreamSource , ITransformer > composite )
220- {
221- loader = composite . Loader ;
222- return composite . Transformer ;
223- }
224- return new TransformerChain < ITransformer > ( ) ;
269+ using ( var stream = File . OpenRead ( filePath ) )
270+ return LoadWithDataLoader ( stream , out loader ) ;
225271 }
226272
227273 /// <summary>
0 commit comments