Skip to content

Commit 927a61a

Browse files
authored
Onnx Export for ValueMapping estimator (#5577)
1 parent c2ddff1 commit 927a61a

File tree

4 files changed

+396
-4
lines changed

4 files changed

+396
-4
lines changed

src/Microsoft.ML.Data/Transforms/ValueMapping.cs

Lines changed: 236 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
using Microsoft.ML.Data;
1313
using Microsoft.ML.Data.IO;
1414
using Microsoft.ML.Internal.Utilities;
15+
using Microsoft.ML.Model.OnnxConverter;
1516
using Microsoft.ML.Runtime;
1617
using Microsoft.ML.Transforms;
1718

@@ -818,6 +819,8 @@ private static ValueMap CreateValueMapInvoke<TKey, TValue>(DataViewSchema.Column
818819
public abstract Delegate GetGetter(DataViewRow input, int index);
819820

820821
public abstract IDataView GetDataView(IHostEnvironment env);
822+
public abstract TKey[] GetKeys<TKey>();
823+
public abstract TValue[] GetValues<TValue>();
821824
}
822825

823826
/// <summary>
@@ -962,6 +965,16 @@ private static TValue GetVector<T>(TValue value)
962965
}
963966

964967
private static TValue GetValue<T>(TValue value) => value;
968+
969+
public override T[] GetKeys<T>()
970+
{
971+
return _mapping.Keys.Cast<T>().ToArray();
972+
}
973+
public override T[] GetValues<T>()
974+
{
975+
return _mapping.Values.Cast<T>().ToArray();
976+
}
977+
965978
}
966979

967980
/// <summary>
@@ -1012,12 +1025,13 @@ private protected override IRowMapper MakeRowMapper(DataViewSchema schema)
10121025
return new Mapper(this, schema, _valueMap, ColumnPairs);
10131026
}
10141027

1015-
private sealed class Mapper : OneToOneMapperBase
1028+
private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx
10161029
{
10171030
private readonly DataViewSchema _inputSchema;
10181031
private readonly ValueMap _valueMap;
10191032
private readonly (string outputColumnName, string inputColumnName)[] _columns;
10201033
private readonly ValueMappingTransformer _parent;
1034+
public bool CanSaveOnnx(OnnxContext ctx) => true;
10211035

10221036
internal Mapper(ValueMappingTransformer transform,
10231037
DataViewSchema inputSchema,
@@ -1040,6 +1054,227 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
10401054
return _valueMap.GetGetter(input, ColMapNewToOld[iinfo]);
10411055
}
10421056

1057+
public void SaveAsOnnx(OnnxContext ctx)
1058+
{
1059+
const int minimumOpSetVersion = 9;
1060+
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
1061+
Host.CheckValue(ctx, nameof(ctx));
1062+
1063+
for (int iinfo = 0; iinfo < _parent.ColumnPairs.Length; ++iinfo)
1064+
{
1065+
string inputColumnName = _parent.ColumnPairs[iinfo].inputColumnName;
1066+
string outputColumnName = _parent.ColumnPairs[iinfo].outputColumnName;
1067+
1068+
if (!_inputSchema.TryGetColumnIndex(inputColumnName, out int colSrc))
1069+
throw Host.ExceptSchemaMismatch(nameof(_inputSchema), "input", inputColumnName);
1070+
var type = _inputSchema[colSrc].Type;
1071+
DataViewType colType;
1072+
if (type is VectorDataViewType vectorType)
1073+
colType = new VectorDataViewType((PrimitiveDataViewType)_parent.ValueColumnType, vectorType.Dimensions);
1074+
else
1075+
colType = _parent.ValueColumnType;
1076+
string dstVariableName = ctx.AddIntermediateVariable(colType, outputColumnName);
1077+
if (!ctx.ContainsColumn(inputColumnName))
1078+
continue;
1079+
1080+
if (!SaveAsOnnxCore(ctx, ctx.GetVariableName(inputColumnName), dstVariableName))
1081+
ctx.RemoveColumn(inputColumnName, true);
1082+
}
1083+
}
1084+
1085+
private void CastInputTo<T>(OnnxContext ctx, out OnnxNode node, string srcVariableName, string opType, string labelEncoderOutput, PrimitiveDataViewType itemType)
1086+
{
1087+
var srcShape = ctx.RetrieveShapeOrNull(srcVariableName);
1088+
var castOutput = ctx.AddIntermediateVariable(new VectorDataViewType(itemType, (int)srcShape[1]), "castOutput");
1089+
var castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName("Cast"), "");
1090+
castNode.AddAttribute("to", itemType.RawType);
1091+
node = ctx.CreateNode(opType, castOutput, labelEncoderOutput, ctx.GetNodeName(opType));
1092+
if (itemType == TextDataViewType.Instance)
1093+
node.AddAttribute("keys_strings", Array.ConvertAll(_valueMap.GetKeys<T>(), item => Convert.ToString(item)));
1094+
else if (itemType == NumberDataViewType.Single)
1095+
node.AddAttribute("keys_floats", Array.ConvertAll(_valueMap.GetKeys<T>(), item => Convert.ToSingle(item)));
1096+
else if (itemType == NumberDataViewType.Int64)
1097+
node.AddAttribute("keys_int64s", Array.ConvertAll(_valueMap.GetKeys<T>(), item => Convert.ToInt64(item)));
1098+
1099+
}
1100+
1101+
private bool SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstVariableName)
1102+
{
1103+
const int minimumOpSetVersion = 9;
1104+
ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);
1105+
OnnxNode node;
1106+
string opType = "LabelEncoder";
1107+
var labelEncoderInput = srcVariableName;
1108+
var srcShape = ctx.RetrieveShapeOrNull(srcVariableName);
1109+
var typeValue = _valueMap.ValueColumn.Type;
1110+
var typeKey = _valueMap.KeyColumn.Type;
1111+
var kind = _valueMap.ValueColumn.Type.GetRawKind();
1112+
1113+
var labelEncoderOutput = (typeValue == NumberDataViewType.Single || typeValue == TextDataViewType.Instance || typeValue == NumberDataViewType.Int64) ? dstVariableName :
1114+
(typeValue == NumberDataViewType.Double || typeValue == BooleanDataViewType.Instance) ? ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, (int)srcShape[1]), "LabelEncoderOutput") :
1115+
ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Int64, (int) srcShape[1]), "LabelEncoderOutput");
1116+
1117+
// The LabelEncoder operator doesn't support mappings between the same type and only supports mappings between int64s, floats, and strings.
1118+
// As a result, we need to cast most inputs and outputs. In order to avoid as many unsupported mappings, we cast keys that are of NumberDataTypeView
1119+
// to strings and values of NumberDataViewType to int64s.
1120+
// String -> String mappings can't be supported.
1121+
if (typeKey == NumberDataViewType.Int64)
1122+
{
1123+
// To avoid a int64 -> int64 mapping, we cast keys to strings
1124+
if (typeValue is NumberDataViewType)
1125+
{
1126+
CastInputTo<Int64>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance);
1127+
}
1128+
else
1129+
{
1130+
node = ctx.CreateNode(opType, srcVariableName, labelEncoderOutput, ctx.GetNodeName(opType));
1131+
node.AddAttribute("keys_int64s", _valueMap.GetKeys<Int64>());
1132+
}
1133+
}
1134+
else if (typeKey == NumberDataViewType.Int32)
1135+
{
1136+
// To avoid a string -> string mapping, we cast keys to int64s
1137+
if (typeValue is TextDataViewType)
1138+
CastInputTo<Int32>(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Int64);
1139+
else
1140+
CastInputTo<Int32>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance);
1141+
}
1142+
else if (typeKey == NumberDataViewType.Int16)
1143+
{
1144+
if (typeValue is TextDataViewType)
1145+
CastInputTo<Int16>(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Int64);
1146+
else
1147+
CastInputTo<Int16>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance);
1148+
}
1149+
else if (typeKey == NumberDataViewType.UInt64)
1150+
{
1151+
if (typeValue is TextDataViewType)
1152+
CastInputTo<UInt64>(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Int64);
1153+
else
1154+
CastInputTo<UInt64>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance);
1155+
}
1156+
else if (typeKey == NumberDataViewType.UInt32)
1157+
{
1158+
if (typeValue is TextDataViewType)
1159+
CastInputTo<UInt32>(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Int64);
1160+
else
1161+
CastInputTo<UInt32>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance);
1162+
}
1163+
else if (typeKey == NumberDataViewType.UInt16)
1164+
{
1165+
if (typeValue is TextDataViewType)
1166+
CastInputTo<UInt16>(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Int64);
1167+
else
1168+
CastInputTo<UInt16>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance);
1169+
}
1170+
else if (typeKey == NumberDataViewType.Single)
1171+
{
1172+
if (typeValue == NumberDataViewType.Single || typeValue == NumberDataViewType.Double || typeValue == BooleanDataViewType.Instance)
1173+
{
1174+
CastInputTo<float>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance);
1175+
}
1176+
else
1177+
{
1178+
node = ctx.CreateNode(opType, srcVariableName, labelEncoderOutput, ctx.GetNodeName(opType));
1179+
node.AddAttribute("keys_floats", _valueMap.GetKeys<float>());
1180+
}
1181+
}
1182+
else if (typeKey == NumberDataViewType.Double)
1183+
{
1184+
if (typeValue == NumberDataViewType.Single || typeValue == NumberDataViewType.Double || typeValue == BooleanDataViewType.Instance)
1185+
CastInputTo<double>(ctx, out node, srcVariableName, opType, labelEncoderOutput, TextDataViewType.Instance);
1186+
else
1187+
CastInputTo<double>(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Single);
1188+
}
1189+
else if (typeKey == TextDataViewType.Instance)
1190+
{
1191+
if (typeValue == TextDataViewType.Instance)
1192+
return false;
1193+
node = ctx.CreateNode(opType, srcVariableName, labelEncoderOutput, ctx.GetNodeName(opType));
1194+
node.AddAttribute("keys_strings", _valueMap.GetKeys<ReadOnlyMemory<char>>());
1195+
}
1196+
else if (typeKey == BooleanDataViewType.Instance)
1197+
{
1198+
if (typeValue == NumberDataViewType.Single || typeValue == NumberDataViewType.Double || typeValue == BooleanDataViewType.Instance)
1199+
{
1200+
var castOutput = ctx.AddIntermediateVariable(new VectorDataViewType(TextDataViewType.Instance, (int)srcShape[1]), "castOutput");
1201+
var castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName("Cast"), "");
1202+
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.String).ToType();
1203+
castNode.AddAttribute("to", t);
1204+
node = ctx.CreateNode(opType, castOutput, labelEncoderOutput, ctx.GetNodeName(opType));
1205+
var values = Array.ConvertAll(_valueMap.GetKeys<bool>(), item => Convert.ToString(Convert.ToByte(item)));
1206+
node.AddAttribute("keys_strings", values);
1207+
}
1208+
else
1209+
CastInputTo<bool>(ctx, out node, srcVariableName, opType, labelEncoderOutput, NumberDataViewType.Single);
1210+
}
1211+
else
1212+
return false;
1213+
1214+
if (typeValue == NumberDataViewType.Int64)
1215+
{
1216+
node.AddAttribute("values_int64s", _valueMap.GetValues<long>());
1217+
}
1218+
else if (typeValue == NumberDataViewType.Int32)
1219+
{
1220+
node.AddAttribute("values_int64s", _valueMap.GetValues<int>().Select(item => Convert.ToInt64(item)));
1221+
var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), "");
1222+
castNode.AddAttribute("to", typeValue.RawType);
1223+
}
1224+
else if (typeValue == NumberDataViewType.Int16)
1225+
{
1226+
node.AddAttribute("values_int64s", _valueMap.GetValues<short>().Select(item => Convert.ToInt64(item)));
1227+
var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), "");
1228+
castNode.AddAttribute("to", typeValue.RawType);
1229+
}
1230+
else if (typeValue == NumberDataViewType.UInt64 || kind == InternalDataKind.U8)
1231+
{
1232+
node.AddAttribute("values_int64s", _valueMap.GetValues<ulong>().Select(item => Convert.ToInt64(item)));
1233+
var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), "");
1234+
castNode.AddAttribute("to", typeValue.RawType);
1235+
}
1236+
else if (typeValue == NumberDataViewType.UInt32 || kind == InternalDataKind.U4)
1237+
{
1238+
node.AddAttribute("values_int64s", _valueMap.GetValues<uint>().Select(item => Convert.ToInt64(item)));
1239+
var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), "");
1240+
castNode.AddAttribute("to", typeValue.RawType);
1241+
}
1242+
else if (typeValue == NumberDataViewType.UInt16)
1243+
{
1244+
node.AddAttribute("values_int64s", _valueMap.GetValues<ushort>().Select(item => Convert.ToInt64(item)));
1245+
var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), "");
1246+
castNode.AddAttribute("to", typeValue.RawType);
1247+
}
1248+
else if (typeValue == NumberDataViewType.Single)
1249+
{
1250+
node.AddAttribute("values_floats", _valueMap.GetValues<float>());
1251+
}
1252+
else if (typeValue == NumberDataViewType.Double)
1253+
{
1254+
node.AddAttribute("values_floats", _valueMap.GetValues<double>().Select(item => Convert.ToSingle(item)));
1255+
var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), "");
1256+
castNode.AddAttribute("to", typeValue.RawType);
1257+
}
1258+
else if (typeValue == TextDataViewType.Instance)
1259+
{
1260+
node.AddAttribute("values_strings", _valueMap.GetValues<ReadOnlyMemory<char>>());
1261+
}
1262+
else if (typeValue == BooleanDataViewType.Instance)
1263+
{
1264+
node.AddAttribute("values_floats", _valueMap.GetValues<bool>().Select(item => Convert.ToSingle(item)));
1265+
var castNode = ctx.CreateNode("Cast", labelEncoderOutput, dstVariableName, ctx.GetNodeName("Cast"), "");
1266+
castNode.AddAttribute("to", typeValue.RawType);
1267+
}
1268+
else
1269+
return false;
1270+
1271+
//Unknown keys should map to 0
1272+
node.AddAttribute("default_int64", 0);
1273+
node.AddAttribute("default_string", "");
1274+
node.AddAttribute("default_float", 0f);
1275+
return true;
1276+
}
1277+
10431278
protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
10441279
{
10451280
var result = new DataViewSchema.DetachedColumn[_columns.Length];

test/Microsoft.ML.TestFramework/BaseTestBaseline.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -674,9 +674,7 @@ private static double Round(double value, int digitsOfPrecision)
674674
public void CompareResults(string leftColumnName, string rightColumnName, IDataView left, IDataView right, int precision = 6, bool isRightColumnOnnxScalar = false)
675675
{
676676
var leftColumn = left.Schema[leftColumnName];
677-
var rightColumn = right.Schema[rightColumnName];
678677
var leftType = leftColumn.Type.GetItemType();
679-
var rightType = rightColumn.Type.GetItemType();
680678

681679
if (leftType == NumberDataViewType.SByte)
682680
CompareSelectedColumns<sbyte>(leftColumnName, rightColumnName, left, right, isRightColumnOnnxScalar: isRightColumnOnnxScalar);

0 commit comments

Comments
 (0)