1515using Microsoft . ML . Data . IO ;
1616using Microsoft . ML . EntryPoints ;
1717using Microsoft . ML . Internal . Utilities ;
18+ using Microsoft . ML . Model . OnnxConverter ;
1819using Microsoft . ML . Runtime ;
1920using Microsoft . ML . Transforms . Text ;
2021
@@ -343,14 +344,16 @@ private static Stream GetResourceFileStreamOrNull(StopWordsRemovingEstimator.Lan
343344 return assembly . GetManifestResourceStream ( $ "{ assembly . GetName ( ) . Name } .Text.StopWords.{ lang . ToString ( ) } .txt") ;
344345 }
345346
346- private sealed class Mapper : MapperBase
347+ private sealed class Mapper : MapperBase , ISaveAsOnnx
347348 {
348349 private readonly DataViewType [ ] _types ;
349350 private readonly StopWordsRemovingTransformer _parent ;
350351 private readonly int [ ] _languageColumns ;
351352 private readonly bool ? [ ] _resourcesExist ;
352353 private readonly Dictionary < int , int > _colMapNewToOld ;
353354
355+ public bool CanSaveOnnx ( OnnxContext ctx ) => true ;
356+
354357 public Mapper ( StopWordsRemovingTransformer parent , DataViewSchema inputSchema )
355358 : base ( Contracts . CheckRef ( parent , nameof ( parent ) ) . Host . Register ( nameof ( Mapper ) ) , inputSchema , parent )
356359 {
@@ -438,6 +441,45 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
438441 return del ;
439442 }
440443
444+ public void SaveAsOnnx ( OnnxContext ctx )
445+ {
446+ const int minimumOpSetVersion = 9 ;
447+ ctx . CheckOpSetVersion ( minimumOpSetVersion , LoaderSignature ) ;
448+
449+ for ( int i = 0 ; i < _parent . ColumnPairs . Length ; i ++ )
450+ {
451+ var srcVariableName = ctx . GetVariableName ( _parent . ColumnPairs [ i ] . inputColumnName ) ;
452+ if ( ! ctx . ContainsColumn ( srcVariableName ) )
453+ continue ;
454+ var dstVariableName = ctx . AddIntermediateVariable ( _types [ i ] , _parent . ColumnPairs [ i ] . outputColumnName ) ;
455+ SaveAsOnnxCore ( ctx , i , srcVariableName , dstVariableName ) ;
456+ }
457+ }
458+
459+ private void SaveAsOnnxCore ( OnnxContext ctx , int iinfo , string srcVariableName , string dstVariableName )
460+ {
461+ var opType = "Squeeze" ;
462+ var squeezeOutput = ctx . AddIntermediateVariable ( _types [ iinfo ] , "SqueezeOutput" , true ) ;
463+ var node = ctx . CreateNode ( opType , srcVariableName , squeezeOutput , ctx . GetNodeName ( opType ) , "" ) ;
464+ node . AddAttribute ( "axes" , new long [ ] { 0 } ) ;
465+
466+ opType = "StringNormalizer" ;
467+ var stringNormalizerOutput = ctx . AddIntermediateVariable ( _types [ iinfo ] , "StringNormalizerOutput" , true ) ;
468+ node = ctx . CreateNode ( opType , squeezeOutput , stringNormalizerOutput , ctx . GetNodeName ( opType ) , "" ) ;
469+
470+ var langToUse = _parent . _columns [ iinfo ] . Language ;
471+ var lang = default ( ReadOnlyMemory < char > ) ;
472+ UpdateLanguage ( ref langToUse , null , ref lang ) ;
473+
474+ var words = StopWords [ iinfo ] . Select ( item => Convert . ToString ( item . Value ) ) ;
475+ node . AddAttribute ( "stopwords" , StopWords [ iinfo ] . Select ( item => Convert . ToString ( item . Value ) ) ) ;
476+
477+ opType = "Unsqueeze" ;
478+ squeezeOutput = ctx . AddIntermediateVariable ( _types [ iinfo ] , "SqueezeOutput" ) ;
479+ node = ctx . CreateNode ( opType , stringNormalizerOutput , dstVariableName , ctx . GetNodeName ( opType ) , "" ) ;
480+ node . AddAttribute ( "axes" , new long [ ] { 0 } ) ;
481+ }
482+
441483 private void UpdateLanguage ( ref StopWordsRemovingEstimator . Language langToUse , ValueGetter < ReadOnlyMemory < char > > getLang , ref ReadOnlyMemory < char > langTxt )
442484 {
443485 if ( getLang != null )
@@ -490,7 +532,7 @@ private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> a
490532 /// | Does this estimator need to look at the data to train its parameters? | No |
491533 /// | Input column data type | Vector of [Text](xref:Microsoft.ML.Data.TextDataViewType) |
492534 /// | Output column data type | Variable-sized vector of [Text](xref:Microsoft.ML.Data.TextDataViewType) |
493- /// | Exportable to ONNX | No |
535+ /// | Exportable to ONNX | Yes |
494536 ///
495537 /// The resulting <xref:Microsoft.ML.Transforms.Text.StopWordsRemovingTransformer> creates a new column, named as specified in the output column name parameter,
496538 /// and fills it with a vector of words containing all of the words in the input column **except the predefined list of stopwords for the specified language.
@@ -1016,11 +1058,13 @@ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Dat
10161058
10171059 private protected override IRowMapper MakeRowMapper ( DataViewSchema schema ) => new Mapper ( this , schema ) ;
10181060
1019- private sealed class Mapper : OneToOneMapperBase
1061+ private sealed class Mapper : OneToOneMapperBase , ISaveAsOnnx
10201062 {
10211063 private readonly DataViewType [ ] _types ;
10221064 private readonly CustomStopWordsRemovingTransformer _parent ;
10231065
1066+ public bool CanSaveOnnx ( OnnxContext ctx ) => true ;
1067+
10241068 public Mapper ( CustomStopWordsRemovingTransformer parent , DataViewSchema inputSchema )
10251069 : base ( Contracts . CheckRef ( parent , nameof ( parent ) ) . Host . Register ( nameof ( Mapper ) ) , parent , inputSchema )
10261070 {
@@ -1084,6 +1128,43 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
10841128
10851129 return del ;
10861130 }
1131+
1132+ public void SaveAsOnnx ( OnnxContext ctx )
1133+ {
1134+ const int minimumOpSetVersion = 9 ;
1135+ ctx . CheckOpSetVersion ( minimumOpSetVersion , LoaderSignature ) ;
1136+
1137+ for ( int i = 0 ; i < _parent . ColumnPairs . Length ; i ++ )
1138+ {
1139+ var srcVariableName = ctx . GetVariableName ( _parent . ColumnPairs [ i ] . inputColumnName ) ;
1140+ if ( ! ctx . ContainsColumn ( srcVariableName ) )
1141+ continue ;
1142+ var dstVariableName = ctx . AddIntermediateVariable ( _types [ i ] , _parent . ColumnPairs [ i ] . outputColumnName ) ;
1143+
1144+ SaveAsOnnxCore ( ctx , i , srcVariableName , dstVariableName ) ;
1145+ }
1146+ }
1147+
1148+ // Note: Since StringNormalizer only accepts inputs of shape [C] or [1,C], we temporarily squeeze the
1149+ // batch dimension which may exceed 1
1150+ private void SaveAsOnnxCore ( OnnxContext ctx , int iinfo , string srcVariableName , string dstVariableName )
1151+ {
1152+ var opType = "Squeeze" ;
1153+ var squeezeOutput = ctx . AddIntermediateVariable ( _types [ iinfo ] , "SqueezeOutput" , true ) ;
1154+ var node = ctx . CreateNode ( opType , srcVariableName , squeezeOutput , ctx . GetNodeName ( opType ) , "" ) ;
1155+ node . AddAttribute ( "axes" , new long [ ] { 0 } ) ;
1156+
1157+ opType = "StringNormalizer" ;
1158+ var stringNormalizerOutput = ctx . AddIntermediateVariable ( _types [ iinfo ] , "StringNormalizerOutput" , true ) ;
1159+ node = ctx . CreateNode ( opType , squeezeOutput , stringNormalizerOutput , ctx . GetNodeName ( opType ) , "" ) ;
1160+ var words = _parent . _stopWordsMap . ToList ( ) ;
1161+ node . AddAttribute ( "stopwords" , words . Select ( item => Convert . ToString ( item . Value ) ) ) ;
1162+
1163+ opType = "Unsqueeze" ;
1164+ squeezeOutput = ctx . AddIntermediateVariable ( _types [ iinfo ] , "SqueezeOutput" ) ;
1165+ node = ctx . CreateNode ( opType , stringNormalizerOutput , dstVariableName , ctx . GetNodeName ( opType ) , "" ) ;
1166+ node . AddAttribute ( "axes" , new long [ ] { 0 } ) ;
1167+ }
10871168 }
10881169 }
10891170
@@ -1098,7 +1179,7 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
10981179 /// | Does this estimator need to look at the data to train its parameters? | No |
10991180 /// | Input column data type | Vector of [Text](xref:Microsoft.ML.Data.TextDataViewType) |
11001181 /// | Output column data type | Vector of [Text](xref:Microsoft.ML.Data.TextDataViewType) |
1101- /// | Exportable to ONNX | No |
1182+ /// | Exportable to ONNX | Yes |
11021183 ///
11031184 /// The resulting <xref:Microsoft.ML.Transforms.Text.CustomStopWordsRemovingTransformer> creates a new column, named as specified by the output column name parameter, and
11041185 /// fills it with a vector of words containing all of the words in the input column except those given by the stopwords parameter.
0 commit comments