1212using Microsoft . ML . EntryPoints ;
1313using Microsoft . ML . Internal . Utilities ;
1414using Microsoft . ML . Model . Onnx ;
15+ using Microsoft . ML . UniversalModelFormat . Onnx ;
1516using Newtonsoft . Json ;
1617
1718[ assembly: LoadableClass ( SaveOnnxCommand . Summary , typeof ( SaveOnnxCommand ) , typeof ( SaveOnnxCommand . Arguments ) , typeof ( SignatureCommand ) ,
@@ -113,9 +114,11 @@ public override void Run()
113114 }
114115 }
115116
116- private void GetPipe ( OnnxContextImpl ctx , IChannel ch , IDataView end , out IDataView source , out IDataView trueEnd , out LinkedList < ITransformCanSaveOnnx > transforms )
117+ [ BestFriend ]
118+ internal static void GetPipe ( OnnxContextImpl ctx , IChannel ch , IDataView end , out IDataView source , out IDataView trueEnd , out LinkedList < ITransformCanSaveOnnx > transforms )
117119 {
118- Host . AssertValue ( end ) ;
120+ Contracts . AssertValue ( end ) ;
121+
119122 source = trueEnd = ( end as CompositeDataLoader ) ? . View ?? end ;
120123 IDataTransform transform = source as IDataTransform ;
121124 transforms = new LinkedList < ITransformCanSaveOnnx > ( ) ;
@@ -134,7 +137,51 @@ private void GetPipe(OnnxContextImpl ctx, IChannel ch, IDataView end, out IDataV
134137 transform = ( source = transform . Source ) as IDataTransform ;
135138 }
136139
137- Host . AssertValue ( source ) ;
140+ Contracts . AssertValue ( source ) ;
141+ }
142+
143+ [ BestFriend ]
144+ internal static ModelProto ConvertTransformListToOnnxModel ( OnnxContextImpl ctx , IDataView inputData , IDataView outputData ,
145+ LinkedList < ITransformCanSaveOnnx > transforms , HashSet < string > inputColumnNamesToDrop = null , HashSet < string > outputColumnNamesToDrop = null )
146+ {
147+ inputColumnNamesToDrop = inputColumnNamesToDrop ?? new HashSet < string > ( ) ;
148+ outputColumnNamesToDrop = outputColumnNamesToDrop ?? new HashSet < string > ( ) ;
149+ HashSet < string > inputColumns = new HashSet < string > ( ) ;
150+ // Create graph inputs.
151+ for ( int i = 0 ; i < inputData . Schema . Count ; i ++ )
152+ {
153+ string colName = inputData . Schema [ i ] . Name ;
154+ if ( inputColumnNamesToDrop . Contains ( colName ) )
155+ continue ;
156+
157+ ctx . AddInputVariable ( inputData . Schema [ i ] . Type , colName ) ;
158+ inputColumns . Add ( colName ) ;
159+ }
160+
161+ // Create graph nodes, outputs and intermediate values.
162+ foreach ( var trans in transforms )
163+ trans . SaveAsOnnx ( ctx ) ;
164+
165+ // Add graph outputs.
166+ for ( int i = 0 ; i < outputData . Schema . Count ; ++ i )
167+ {
168+ if ( outputData . Schema [ i ] . IsHidden )
169+ continue ;
170+
171+ var idataviewColumnName = outputData . Schema [ i ] . Name ;
172+
173+ // Since the last IDataView also contains columns of the initial IDataView, last IDataView's columns found in
174+ // _inputToDrop should be removed too.
175+ if ( inputColumnNamesToDrop . Contains ( idataviewColumnName ) || outputColumnNamesToDrop . Contains ( idataviewColumnName ) )
176+ continue ;
177+
178+ var variableName = ctx . TryGetVariableName ( idataviewColumnName ) ;
179+ var trueVariableName = ctx . AddIntermediateVariable ( null , idataviewColumnName , true ) ;
180+ ctx . CreateNode ( "Identity" , variableName , trueVariableName , ctx . GetNodeName ( "Identity" ) , "" ) ;
181+ ctx . AddOutputVariable ( outputData . Schema [ i ] . Type , trueVariableName ) ;
182+ }
183+
184+ return ctx . MakeModel ( ) ;
138185 }
139186
140187 private void Run ( IChannel ch )
@@ -210,45 +257,8 @@ private void Run(IChannel ch)
210257 nameof ( Arguments . LoadPredictor ) , "We were explicitly told to load the predictor but one was not present." ) ;
211258 }
212259
213- HashSet < string > inputColumns = new HashSet < string > ( ) ;
214- //Create graph inputs.
215- for ( int i = 0 ; i < source . Schema . Count ; i ++ )
216- {
217- string colName = source . Schema [ i ] . Name ;
218- if ( _inputsToDrop . Contains ( colName ) )
219- continue ;
220-
221- ctx . AddInputVariable ( source . Schema [ i ] . Type , colName ) ;
222- inputColumns . Add ( colName ) ;
223- }
224-
225- //Create graph nodes, outputs and intermediate values.
226- foreach ( var trans in transforms )
227- {
228- Host . Assert ( trans . CanSaveOnnx ( ctx ) ) ;
229- trans . SaveAsOnnx ( ctx ) ;
230- }
231-
232- //Add graph outputs.
233- for ( int i = 0 ; i < end . Schema . Count ; ++ i )
234- {
235- if ( end . Schema [ i ] . IsHidden )
236- continue ;
237-
238- var idataviewColumnName = end . Schema [ i ] . Name ;
239-
240- // Since the last IDataView also contains columns of the initial IDataView, last IDataView's columns found in
241- // _inputToDrop should be removed too.
242- if ( _inputsToDrop . Contains ( idataviewColumnName ) || _outputsToDrop . Contains ( idataviewColumnName ) )
243- continue ;
244-
245- var variableName = ctx . TryGetVariableName ( idataviewColumnName ) ;
246- var trueVariableName = ctx . AddIntermediateVariable ( null , idataviewColumnName , true ) ;
247- ctx . CreateNode ( "Identity" , variableName , trueVariableName , ctx . GetNodeName ( "Identity" ) , "" ) ;
248- ctx . AddOutputVariable ( end . Schema [ i ] . Type , trueVariableName ) ;
249- }
260+ var model = ConvertTransformListToOnnxModel ( ctx , source , end , transforms , _inputsToDrop , _outputsToDrop ) ;
250261
251- var model = ctx . MakeModel ( ) ;
252262 using ( var file = Host . CreateOutputFile ( _outputModelPath ) )
253263 using ( var stream = file . CreateWriteStream ( ) )
254264 model . WriteTo ( stream ) ;
0 commit comments