1212using Microsoft . ML . EntryPoints ;
1313using Microsoft . ML . Internal . Utilities ;
1414using Microsoft . ML . Model . Onnx ;
15+ using Microsoft . ML . UniversalModelFormat . Onnx ;
1516using Newtonsoft . Json ;
1617
1718[ assembly: LoadableClass ( SaveOnnxCommand . Summary , typeof ( SaveOnnxCommand ) , typeof ( SaveOnnxCommand . Arguments ) , typeof ( SignatureCommand ) ,
@@ -113,9 +114,10 @@ public override void Run()
113114 }
114115 }
115116
116- private void GetPipe ( OnnxContextImpl ctx , IChannel ch , IDataView end , out IDataView source , out IDataView trueEnd , out LinkedList < ITransformCanSaveOnnx > transforms )
117+ internal static void GetPipe ( OnnxContextImpl ctx , IChannel ch , IDataView end , out IDataView source , out IDataView trueEnd , out LinkedList < ITransformCanSaveOnnx > transforms )
117118 {
118- Host . AssertValue ( end ) ;
119+ ch . AssertValue ( end ) ;
120+
119121 source = trueEnd = ( end as CompositeDataLoader ) ? . View ?? end ;
120122 IDataTransform transform = source as IDataTransform ;
121123 transforms = new LinkedList < ITransformCanSaveOnnx > ( ) ;
@@ -134,7 +136,53 @@ private void GetPipe(OnnxContextImpl ctx, IChannel ch, IDataView end, out IDataV
134136 transform = ( source = transform . Source ) as IDataTransform ;
135137 }
136138
137- Host . AssertValue ( source ) ;
139+ ch . AssertValue ( source ) ;
140+ }
141+
142+ internal static ModelProto ConvertTransformListToOnnxModel ( OnnxContextImpl ctx , IChannel ch , IDataView inputData , IDataView outputData ,
143+ LinkedList < ITransformCanSaveOnnx > transforms , HashSet < string > inputColumnNamesToDrop = null , HashSet < string > outputColumnNamesToDrop = null )
144+ {
145+ inputColumnNamesToDrop = inputColumnNamesToDrop ?? new HashSet < string > ( ) ;
146+ outputColumnNamesToDrop = outputColumnNamesToDrop ?? new HashSet < string > ( ) ;
147+ HashSet < string > inputColumns = new HashSet < string > ( ) ;
148+ // Create graph inputs.
149+ for ( int i = 0 ; i < inputData . Schema . Count ; i ++ )
150+ {
151+ string colName = inputData . Schema [ i ] . Name ;
152+ if ( inputColumnNamesToDrop . Contains ( colName ) )
153+ continue ;
154+
155+ ctx . AddInputVariable ( inputData . Schema [ i ] . Type , colName ) ;
156+ inputColumns . Add ( colName ) ;
157+ }
158+
159+ // Create graph nodes, outputs and intermediate values.
160+ foreach ( var trans in transforms )
161+ {
162+ ch . Assert ( trans . CanSaveOnnx ( ctx ) ) ;
163+ trans . SaveAsOnnx ( ctx ) ;
164+ }
165+
166+ // Add graph outputs.
167+ for ( int i = 0 ; i < outputData . Schema . Count ; ++ i )
168+ {
169+ if ( outputData . Schema [ i ] . IsHidden )
170+ continue ;
171+
172+ var idataviewColumnName = outputData . Schema [ i ] . Name ;
173+
174+ // Since the last IDataView also contains columns of the initial IDataView, last IDataView's columns found in
175+ // _inputToDrop should be removed too.
176+ if ( inputColumnNamesToDrop . Contains ( idataviewColumnName ) || outputColumnNamesToDrop . Contains ( idataviewColumnName ) )
177+ continue ;
178+
179+ var variableName = ctx . TryGetVariableName ( idataviewColumnName ) ;
180+ var trueVariableName = ctx . AddIntermediateVariable ( null , idataviewColumnName , true ) ;
181+ ctx . CreateNode ( "Identity" , variableName , trueVariableName , ctx . GetNodeName ( "Identity" ) , "" ) ;
182+ ctx . AddOutputVariable ( outputData . Schema [ i ] . Type , trueVariableName ) ;
183+ }
184+
185+ return ctx . MakeModel ( ) ;
138186 }
139187
140188 private void Run ( IChannel ch )
@@ -210,45 +258,8 @@ private void Run(IChannel ch)
210258 nameof ( Arguments . LoadPredictor ) , "We were explicitly told to load the predictor but one was not present." ) ;
211259 }
212260
213- HashSet < string > inputColumns = new HashSet < string > ( ) ;
214- //Create graph inputs.
215- for ( int i = 0 ; i < source . Schema . Count ; i ++ )
216- {
217- string colName = source . Schema [ i ] . Name ;
218- if ( _inputsToDrop . Contains ( colName ) )
219- continue ;
220-
221- ctx . AddInputVariable ( source . Schema [ i ] . Type , colName ) ;
222- inputColumns . Add ( colName ) ;
223- }
224-
225- //Create graph nodes, outputs and intermediate values.
226- foreach ( var trans in transforms )
227- {
228- Host . Assert ( trans . CanSaveOnnx ( ctx ) ) ;
229- trans . SaveAsOnnx ( ctx ) ;
230- }
231-
232- //Add graph outputs.
233- for ( int i = 0 ; i < end . Schema . Count ; ++ i )
234- {
235- if ( end . Schema [ i ] . IsHidden )
236- continue ;
237-
238- var idataviewColumnName = end . Schema [ i ] . Name ;
239-
240- // Since the last IDataView also contains columns of the initial IDataView, last IDataView's columns found in
241- // _inputToDrop should be removed too.
242- if ( _inputsToDrop . Contains ( idataviewColumnName ) || _outputsToDrop . Contains ( idataviewColumnName ) )
243- continue ;
244-
245- var variableName = ctx . TryGetVariableName ( idataviewColumnName ) ;
246- var trueVariableName = ctx . AddIntermediateVariable ( null , idataviewColumnName , true ) ;
247- ctx . CreateNode ( "Identity" , variableName , trueVariableName , ctx . GetNodeName ( "Identity" ) , "" ) ;
248- ctx . AddOutputVariable ( end . Schema [ i ] . Type , trueVariableName ) ;
249- }
261+ var model = ConvertTransformListToOnnxModel ( ctx , ch , source , end , transforms , _inputsToDrop , _outputsToDrop ) ;
250262
251- var model = ctx . MakeModel ( ) ;
252263 using ( var file = Host . CreateOutputFile ( _outputModelPath ) )
253264 using ( var stream = file . CreateWriteStream ( ) )
254265 model . WriteTo ( stream ) ;
0 commit comments