Skip to content

Commit 7879849

Browse files
authored
StopWordsRemovingEstimator export to Onnx (#5279)
* StopWordsRemoving transformer export to onnx * format changes * adding types
1 parent 1b44be7 commit 7879849

File tree

2 files changed

+135
-6
lines changed

2 files changed

+135
-6
lines changed

src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs

Lines changed: 85 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
using Microsoft.ML.Data.IO;
1616
using Microsoft.ML.EntryPoints;
1717
using Microsoft.ML.Internal.Utilities;
18+
using Microsoft.ML.Model.OnnxConverter;
1819
using Microsoft.ML.Runtime;
1920
using Microsoft.ML.Transforms.Text;
2021

@@ -343,14 +344,16 @@ private static Stream GetResourceFileStreamOrNull(StopWordsRemovingEstimator.Lan
343344
return assembly.GetManifestResourceStream($"{assembly.GetName().Name}.Text.StopWords.{lang.ToString()}.txt");
344345
}
345346

346-
private sealed class Mapper : MapperBase
347+
private sealed class Mapper : MapperBase, ISaveAsOnnx
347348
{
348349
private readonly DataViewType[] _types;
349350
private readonly StopWordsRemovingTransformer _parent;
350351
private readonly int[] _languageColumns;
351352
private readonly bool?[] _resourcesExist;
352353
private readonly Dictionary<int, int> _colMapNewToOld;
353354

355+
public bool CanSaveOnnx(OnnxContext ctx) => true;
356+
354357
public Mapper(StopWordsRemovingTransformer parent, DataViewSchema inputSchema)
355358
: base(Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper)), inputSchema, parent)
356359
{
@@ -438,6 +441,45 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
438441
return del;
439442
}
440443

444+
public void SaveAsOnnx(OnnxContext ctx)
445+
{
446+
const int minimumOpSetVersion = 9;
447+
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
448+
449+
for (int i = 0; i < _parent.ColumnPairs.Length; i++)
450+
{
451+
var srcVariableName = ctx.GetVariableName(_parent.ColumnPairs[i].inputColumnName);
452+
if (!ctx.ContainsColumn(srcVariableName))
453+
continue;
454+
var dstVariableName = ctx.AddIntermediateVariable(_types[i], _parent.ColumnPairs[i].outputColumnName);
455+
SaveAsOnnxCore(ctx, i, srcVariableName, dstVariableName);
456+
}
457+
}
458+
459+
private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
460+
{
461+
var opType = "Squeeze";
462+
var squeezeOutput = ctx.AddIntermediateVariable(_types[iinfo], "SqueezeOutput", true);
463+
var node = ctx.CreateNode(opType, srcVariableName, squeezeOutput, ctx.GetNodeName(opType), "");
464+
node.AddAttribute("axes", new long[] { 0 });
465+
466+
opType = "StringNormalizer";
467+
var stringNormalizerOutput = ctx.AddIntermediateVariable(_types[iinfo], "StringNormalizerOutput", true);
468+
node = ctx.CreateNode(opType, squeezeOutput, stringNormalizerOutput, ctx.GetNodeName(opType), "");
469+
470+
var langToUse = _parent._columns[iinfo].Language;
471+
var lang = default(ReadOnlyMemory<char>);
472+
UpdateLanguage(ref langToUse, null, ref lang);
473+
474+
var words = StopWords[iinfo].Select(item => Convert.ToString(item.Value));
475+
node.AddAttribute("stopwords", StopWords[iinfo].Select(item => Convert.ToString(item.Value)));
476+
477+
opType = "Unsqueeze";
478+
squeezeOutput = ctx.AddIntermediateVariable(_types[iinfo], "SqueezeOutput");
479+
node = ctx.CreateNode(opType, stringNormalizerOutput, dstVariableName, ctx.GetNodeName(opType), "");
480+
node.AddAttribute("axes", new long[] { 0 });
481+
}
482+
441483
private void UpdateLanguage(ref StopWordsRemovingEstimator.Language langToUse, ValueGetter<ReadOnlyMemory<char>> getLang, ref ReadOnlyMemory<char> langTxt)
442484
{
443485
if (getLang != null)
@@ -490,7 +532,7 @@ private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> a
490532
/// | Does this estimator need to look at the data to train its parameters? | No |
491533
/// | Input column data type | Vector of [Text](xref:Microsoft.ML.Data.TextDataViewType) |
492534
/// | Output column data type | Variable-sized vector of [Text](xref:Microsoft.ML.Data.TextDataViewType) |
493-
/// | Exportable to ONNX | No |
535+
/// | Exportable to ONNX | Yes |
494536
///
495537
/// The resulting <xref:Microsoft.ML.Transforms.Text.StopWordsRemovingTransformer> creates a new column, named as specified in the output column name parameter,
496538
/// and fills it with a vector of words containing all of the words in the input column **except the predefined list of stopwords for the specified language.
@@ -1016,11 +1058,13 @@ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Dat
10161058

10171059
private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema);
10181060

1019-
private sealed class Mapper : OneToOneMapperBase
1061+
private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx
10201062
{
10211063
private readonly DataViewType[] _types;
10221064
private readonly CustomStopWordsRemovingTransformer _parent;
10231065

1066+
public bool CanSaveOnnx(OnnxContext ctx) => true;
1067+
10241068
public Mapper(CustomStopWordsRemovingTransformer parent, DataViewSchema inputSchema)
10251069
: base(Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper)), parent, inputSchema)
10261070
{
@@ -1084,6 +1128,43 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
10841128

10851129
return del;
10861130
}
1131+
1132+
public void SaveAsOnnx(OnnxContext ctx)
1133+
{
1134+
const int minimumOpSetVersion = 9;
1135+
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
1136+
1137+
for (int i = 0; i < _parent.ColumnPairs.Length; i++)
1138+
{
1139+
var srcVariableName = ctx.GetVariableName(_parent.ColumnPairs[i].inputColumnName);
1140+
if (!ctx.ContainsColumn(srcVariableName))
1141+
continue;
1142+
var dstVariableName = ctx.AddIntermediateVariable(_types[i], _parent.ColumnPairs[i].outputColumnName);
1143+
1144+
SaveAsOnnxCore(ctx, i, srcVariableName, dstVariableName);
1145+
}
1146+
}
1147+
1148+
// Note: Since StringNormalizer only accepts inputs of shape [C] or [1,C], we temporarily squeeze the
1149+
// batch dimension which may exceed 1
1150+
private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
1151+
{
1152+
var opType = "Squeeze";
1153+
var squeezeOutput = ctx.AddIntermediateVariable(_types[iinfo], "SqueezeOutput", true);
1154+
var node = ctx.CreateNode(opType, srcVariableName, squeezeOutput, ctx.GetNodeName(opType), "");
1155+
node.AddAttribute("axes", new long[] { 0 });
1156+
1157+
opType = "StringNormalizer";
1158+
var stringNormalizerOutput = ctx.AddIntermediateVariable(_types[iinfo], "StringNormalizerOutput", true);
1159+
node = ctx.CreateNode(opType, squeezeOutput, stringNormalizerOutput, ctx.GetNodeName(opType), "");
1160+
var words = _parent._stopWordsMap.ToList();
1161+
node.AddAttribute("stopwords", words.Select(item => Convert.ToString(item.Value)));
1162+
1163+
opType = "Unsqueeze";
1164+
squeezeOutput = ctx.AddIntermediateVariable(_types[iinfo], "SqueezeOutput");
1165+
node = ctx.CreateNode(opType, stringNormalizerOutput, dstVariableName, ctx.GetNodeName(opType), "");
1166+
node.AddAttribute("axes", new long[] { 0 });
1167+
}
10871168
}
10881169
}
10891170

@@ -1098,7 +1179,7 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
10981179
/// | Does this estimator need to look at the data to train its parameters? | No |
10991180
/// | Input column data type | Vector of [Text](xref:Microsoft.ML.Data.TextDataViewType) |
11001181
/// | Output column data type | Vector of [Text](xref:Microsoft.ML.Data.TextDataViewType) |
1101-
/// | Exportable to ONNX | No |
1182+
/// | Exportable to ONNX | Yes |
11021183
///
11031184
/// The resulting <xref:Microsoft.ML.Transforms.Text.CustomStopWordsRemovingTransformer> creates a new column, named as specified by the output column name parameter, and
11041185
/// fills it with a vector of words containing all of the words in the input column except those given by the stopwords parameter.

test/Microsoft.ML.Tests/OnnxConversionTest.cs

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -974,8 +974,8 @@ public void OneHotHashEncodingOnnxConversionTest()
974974
var mlContext = new MLContext();
975975
string dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename);
976976

977-
var dataView = ML.Data.LoadFromTextFile<BreastCancerCatFeatureExample>(dataPath);
978-
var pipeline = ML.Transforms.Categorical.OneHotHashEncoding(new[]{
977+
var dataView = mlContext.Data.LoadFromTextFile<BreastCancerCatFeatureExample>(dataPath);
978+
var pipeline = mlContext.Transforms.Categorical.OneHotHashEncoding(new[]{
979979
new OneHotHashEncodingEstimator.ColumnOptions("Output", "F3", useOrderedHashing:false),
980980
});
981981
var onnxFileName = "OneHotHashEncoding.onnx";
@@ -1343,6 +1343,54 @@ public void NgramOnnxConversionTest(
13431343
Done();
13441344
}
13451345

1346+
[Fact]
1347+
public void CustomStopWordsRemovingEstimatorOnnxTest()
1348+
{
1349+
var mlContext = new MLContext();
1350+
1351+
var pipeline = mlContext.Transforms.Text.TokenizeIntoWords("Words", "Text")
1352+
.Append(mlContext.Transforms.Text.RemoveStopWords(
1353+
"WordsWithoutStopWords", "Words", stopwords:
1354+
new[] { "cat", "sat", "on" }));
1355+
1356+
var samples = new List<TextData>()
1357+
{
1358+
new TextData(){ Text = "cat sat on mat" },
1359+
new TextData(){ Text = "mat not fit cat" },
1360+
new TextData(){ Text = "a cat think mat bad" },
1361+
};
1362+
var dataView = mlContext.Data.LoadFromEnumerable(samples);
1363+
var onnxFileName = $"CustomStopWordsRemovingEstimator.onnx";
1364+
1365+
TestPipeline(pipeline, dataView, onnxFileName, new ColumnComparison[] { new ColumnComparison("WordsWithoutStopWords")});
1366+
1367+
Done();
1368+
}
1369+
1370+
[Fact]
1371+
public void StopWordsRemovingEstimatorOnnxTest()
1372+
{
1373+
var mlContext = new MLContext();
1374+
1375+
var pipeline = mlContext.Transforms.Text.TokenizeIntoWords("Words", "Text")
1376+
.Append(mlContext.Transforms.Text.RemoveDefaultStopWords(
1377+
"WordsWithoutStopWords", "Words", language:
1378+
StopWordsRemovingEstimator.Language.English));
1379+
1380+
var samples = new List<TextData>()
1381+
{
1382+
new TextData(){ Text = "a go cat sat on mat" },
1383+
new TextData(){ Text = "a mat not fit go cat" },
1384+
new TextData(){ Text = "cat think mat bad a" },
1385+
};
1386+
var dataView = mlContext.Data.LoadFromEnumerable(samples);
1387+
var onnxFileName = $"StopWordsRemovingEstimator.onnx";
1388+
1389+
TestPipeline(pipeline, dataView, onnxFileName, new ColumnComparison[] { new ColumnComparison("WordsWithoutStopWords") });
1390+
1391+
Done();
1392+
}
1393+
13461394
[Theory]
13471395
[InlineData(DataKind.Boolean)]
13481396
[InlineData(DataKind.SByte)]

0 commit comments

Comments
 (0)