44
55using System ;
66using System . IO ;
7+ using System . Linq ;
78using Microsoft . Data . DataView ;
89using Microsoft . ML . Data ;
910using Microsoft . ML . Data . IO ;
@@ -32,41 +33,44 @@ internal ModelOperationsCatalog(IHostEnvironment env)
3233 Explainability = new ExplainabilityTransforms ( this ) ;
3334 }
3435
35- private void Save < TSource > ( DataViewSchema schema , IDataLoader < TSource > model , Stream stream )
36+ /// <summary>
37+ /// Save the model to the stream.
38+ /// </summary>
39+ /// <param name="model">The trained model to be saved.</param>
40+ /// <param name="stream">A writeable, seekable stream to save to.</param>
41+ public void Save < TSource > ( IDataLoader < TSource > model , Stream stream )
3642 {
43+ _env . CheckValue ( model , nameof ( model ) ) ;
44+ _env . CheckValue ( stream , nameof ( stream ) ) ;
45+
3746 using ( var rep = RepositoryWriter . CreateNew ( stream ) )
3847 {
3948 ModelSaveContext . SaveModel ( rep , model , null ) ;
40- SaveInputSchema ( schema , rep ) ;
4149 rep . Commit ( ) ;
4250 }
4351 }
4452
45- /// <summary>
46- /// Save the model to the stream.
47- /// </summary>
48- /// <param name="model">The trained model to be saved.</param>
49- /// <param name="stream">A writeable, seekable stream to save to.</param>
50- public void Save < TSource > ( IDataLoader < TSource > model , Stream stream )
51- => Save ( model . GetOutputSchema ( ) , model , stream ) ;
52-
5353 /// <summary>
5454 /// Save a transformer model and the loader used to create its input data to the stream.
5555 /// </summary>
5656 /// <param name="loader">The loader that was used to create data to train the model</param>
5757 /// <param name="model">The trained model to be saved</param>
5858 /// <param name="stream">A writeable, seekable stream to save to.</param>
5959 public void Save < TSource > ( IDataLoader < TSource > loader , ITransformer model , Stream stream ) =>
60- Save ( loader . GetOutputSchema ( ) , new CompositeDataLoader < TSource , ITransformer > ( loader , new TransformerChain < ITransformer > ( model ) ) , stream ) ;
60+ Save ( new CompositeDataLoader < TSource , ITransformer > ( loader , new TransformerChain < ITransformer > ( model ) ) , stream ) ;
6161
6262 /// <summary>
6363 /// Save a transformer model and the schema of the data that was used to train it to the stream.
6464 /// </summary>
65- /// <param name="inputSchema">The schema of the input to the transformer.</param>
6665 /// <param name="model">The trained model to be saved.</param>
66+ /// <param name="inputSchema">The schema of the input to the transformer. This can be null.</param>
6767 /// <param name="stream">A writeable, seekable stream to save to.</param>
68- public void Save ( DataViewSchema inputSchema , ITransformer model , Stream stream )
68+ public void Save ( ITransformer model , DataViewSchema inputSchema , Stream stream )
6969 {
70+ _env . CheckValue ( model , nameof ( model ) ) ;
71+ _env . CheckValueOrNull ( inputSchema ) ;
72+ _env . CheckValue ( stream , nameof ( stream ) ) ;
73+
7074 using ( var rep = RepositoryWriter . CreateNew ( stream ) )
7175 {
7276 ModelSaveContext . SaveModel ( rep , model , CompositeDataLoader < object , ITransformer > . TransformerDirectory ) ;
@@ -77,6 +81,12 @@ public void Save(DataViewSchema inputSchema, ITransformer model, Stream stream)
7781
7882 private void SaveInputSchema ( DataViewSchema inputSchema , RepositoryWriter rep )
7983 {
84+ _env . AssertValueOrNull ( inputSchema ) ;
85+ _env . AssertValue ( rep ) ;
86+
87+ if ( inputSchema == null )
88+ return ;
89+
8090 using ( var ch = _env . Start ( "Saving Schema" ) )
8191 {
8292 var entry = rep . CreateEntry ( SchemaEntryName ) ;
@@ -94,30 +104,50 @@ private void SaveInputSchema(DataViewSchema inputSchema, RepositoryWriter rep)
94104 /// <returns>The loaded model.</returns>
95105 public ITransformer Load ( Stream stream , out DataViewSchema inputSchema )
96106 {
107+ _env . CheckValue ( stream , nameof ( stream ) ) ;
108+
97109 using ( var rep = RepositoryReader . Open ( stream , _env ) )
98110 {
99111 var entry = rep . OpenEntryOrNull ( SchemaEntryName ) ;
100112 if ( entry != null )
101113 {
102114 var loader = new BinaryLoader ( _env , new BinaryLoader . Arguments ( ) , entry . Stream ) ;
103115 inputSchema = loader . Schema ;
116+ ModelLoadContext . LoadModel < ITransformer , SignatureLoadModel > ( _env , out var transformerChain , rep ,
117+ CompositeDataLoader < object , ITransformer > . TransformerDirectory ) ;
118+ return transformerChain ;
104119 }
105- else
120+
121+ ModelLoadContext . LoadModelOrNull < IDataLoader < IMultiStreamSource > , SignatureLoadModel > ( _env , out var dataLoader , rep , null ) ;
122+ if ( dataLoader == null )
106123 {
124+ // Try to see if the model was saved without a loader or a schema.
125+ if ( ModelLoadContext . LoadModelOrNull < ITransformer , SignatureLoadModel > ( _env , out var transformerChain , rep ,
126+ CompositeDataLoader < object , ITransformer > . TransformerDirectory ) )
127+ {
128+ inputSchema = null ;
129+ return transformerChain ;
130+ }
131+
107132 // Try to load from legacy model format.
108133 try
109134 {
110135 var loader = ModelFileUtils . LoadLoader ( _env , rep , new MultiFileSource ( null ) , false ) ;
111136 inputSchema = loader . Schema ;
137+ return TransformerChain . LoadFromLegacy ( _env , stream ) ;
112138 }
113139 catch ( Exception ex )
114140 {
115- if ( ! ex . IsMarked ( ) )
116- throw ;
117- inputSchema = null ;
141+ throw _env . Except ( ex , "Could not load legacy format model" ) ;
118142 }
119143 }
120- return TransformerChain . LoadFrom ( _env , stream ) ;
144+ if ( dataLoader is CompositeDataLoader < IMultiStreamSource , ITransformer > composite )
145+ {
146+ inputSchema = composite . Loader . GetOutputSchema ( ) ;
147+ return composite . Transformer ;
148+ }
149+ inputSchema = dataLoader . GetOutputSchema ( ) ;
150+ return new TransformerChain < ITransformer > ( ) ;
121151 }
122152 }
123153
@@ -127,12 +157,21 @@ public ITransformer Load(Stream stream, out DataViewSchema inputSchema)
127157 /// <param name="stream">A readable, seekable stream to load from.</param>
128158 /// <returns>A model of type <see cref="CompositeDataLoader{IMultiStreamSource, ITransformer}"/> containing the loader
129159 /// and the transformer chain.</returns>
130- public CompositeDataLoader < IMultiStreamSource , ITransformer > Load ( Stream stream )
160+ public IDataLoader < IMultiStreamSource > Load ( Stream stream )
131161 {
162+ _env . CheckValue ( stream , nameof ( stream ) ) ;
163+
132164 using ( var rep = RepositoryReader . Open ( stream ) )
133165 {
134- ModelLoadContext . LoadModel < CompositeDataLoader < IMultiStreamSource , ITransformer > , SignatureLoadModel > ( _env , out var model , rep , null ) ;
135- return model ;
166+ try
167+ {
168+ ModelLoadContext . LoadModel < IDataLoader < IMultiStreamSource > , SignatureLoadModel > ( _env , out var model , rep , null ) ;
169+ return model ;
170+ }
171+ catch ( Exception ex )
172+ {
173+ throw _env . Except ( ex , "Model does not contain an IDataLoader" ) ;
174+ }
136175 }
137176 }
138177
@@ -144,6 +183,8 @@ public CompositeDataLoader<IMultiStreamSource, ITransformer> Load(Stream stream)
144183 /// <returns>The transformer model from the model stream.</returns>
145184 public ITransformer Load ( Stream stream , out IDataLoader < IMultiStreamSource > loader )
146185 {
186+ _env . CheckValue ( stream , nameof ( stream ) ) ;
187+
147188 loader = Load ( stream ) ;
148189 if ( loader is CompositeDataLoader < IMultiStreamSource , ITransformer > composite )
149190 {
0 commit comments