1212using Microsoft . ML . Data ;
1313using Microsoft . ML . Data . IO ;
1414using Microsoft . ML . Internal . Utilities ;
15+ using Microsoft . ML . Model . OnnxConverter ;
1516using Microsoft . ML . Runtime ;
1617using Microsoft . ML . Transforms ;
1718
@@ -818,6 +819,8 @@ private static ValueMap CreateValueMapInvoke<TKey, TValue>(DataViewSchema.Column
818819 public abstract Delegate GetGetter ( DataViewRow input , int index ) ;
819820
820821 public abstract IDataView GetDataView ( IHostEnvironment env ) ;
822+ public abstract TKey [ ] GetKeys < TKey > ( ) ;
823+ public abstract TValue [ ] GetValues < TValue > ( ) ;
821824 }
822825
823826 /// <summary>
@@ -962,6 +965,16 @@ private static TValue GetVector<T>(TValue value)
962965 }
963966
964967 private static TValue GetValue < T > ( TValue value ) => value ;
968+
969+ public override T [ ] GetKeys < T > ( )
970+ {
971+ return _mapping . Keys . Cast < T > ( ) . ToArray ( ) ;
972+ }
973+ public override T [ ] GetValues < T > ( )
974+ {
975+ return _mapping . Values . Cast < T > ( ) . ToArray ( ) ;
976+ }
977+
965978 }
966979
967980 /// <summary>
@@ -1012,12 +1025,13 @@ private protected override IRowMapper MakeRowMapper(DataViewSchema schema)
10121025 return new Mapper ( this , schema , _valueMap , ColumnPairs ) ;
10131026 }
10141027
1015- private sealed class Mapper : OneToOneMapperBase
1028+ private sealed class Mapper : OneToOneMapperBase , ISaveAsOnnx
10161029 {
10171030 private readonly DataViewSchema _inputSchema ;
10181031 private readonly ValueMap _valueMap ;
10191032 private readonly ( string outputColumnName , string inputColumnName ) [ ] _columns ;
10201033 private readonly ValueMappingTransformer _parent ;
1034+ public bool CanSaveOnnx ( OnnxContext ctx ) => true ;
10211035
10221036 internal Mapper ( ValueMappingTransformer transform ,
10231037 DataViewSchema inputSchema ,
@@ -1040,6 +1054,227 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
10401054 return _valueMap . GetGetter ( input , ColMapNewToOld [ iinfo ] ) ;
10411055 }
10421056
1057+ public void SaveAsOnnx ( OnnxContext ctx )
1058+ {
1059+ const int minimumOpSetVersion = 9 ;
1060+ ctx . CheckOpSetVersion ( minimumOpSetVersion , LoaderSignature ) ;
1061+ Host . CheckValue ( ctx , nameof ( ctx ) ) ;
1062+
1063+ for ( int iinfo = 0 ; iinfo < _parent . ColumnPairs . Length ; ++ iinfo )
1064+ {
1065+ string inputColumnName = _parent . ColumnPairs [ iinfo ] . inputColumnName ;
1066+ string outputColumnName = _parent . ColumnPairs [ iinfo ] . outputColumnName ;
1067+
1068+ if ( ! _inputSchema . TryGetColumnIndex ( inputColumnName , out int colSrc ) )
1069+ throw Host . ExceptSchemaMismatch ( nameof ( _inputSchema ) , "input" , inputColumnName ) ;
1070+ var type = _inputSchema [ colSrc ] . Type ;
1071+ DataViewType colType ;
1072+ if ( type is VectorDataViewType vectorType )
1073+ colType = new VectorDataViewType ( ( PrimitiveDataViewType ) _parent . ValueColumnType , vectorType . Dimensions ) ;
1074+ else
1075+ colType = _parent . ValueColumnType ;
1076+ string dstVariableName = ctx . AddIntermediateVariable ( colType , outputColumnName ) ;
1077+ if ( ! ctx . ContainsColumn ( inputColumnName ) )
1078+ continue ;
1079+
1080+ if ( ! SaveAsOnnxCore ( ctx , ctx . GetVariableName ( inputColumnName ) , dstVariableName ) )
1081+ ctx . RemoveColumn ( inputColumnName , true ) ;
1082+ }
1083+ }
1084+
1085+ private void CastInputTo < T > ( OnnxContext ctx , out OnnxNode node , string srcVariableName , string opType , string labelEncoderOutput , PrimitiveDataViewType itemType )
1086+ {
1087+ var srcShape = ctx . RetrieveShapeOrNull ( srcVariableName ) ;
1088+ var castOutput = ctx . AddIntermediateVariable ( new VectorDataViewType ( itemType , ( int ) srcShape [ 1 ] ) , "castOutput" ) ;
1089+ var castNode = ctx . CreateNode ( "Cast" , srcVariableName , castOutput , ctx . GetNodeName ( "Cast" ) , "" ) ;
1090+ castNode . AddAttribute ( "to" , itemType . RawType ) ;
1091+ node = ctx . CreateNode ( opType , castOutput , labelEncoderOutput , ctx . GetNodeName ( opType ) ) ;
1092+ if ( itemType == TextDataViewType . Instance )
1093+ node . AddAttribute ( "keys_strings" , Array . ConvertAll ( _valueMap . GetKeys < T > ( ) , item => Convert . ToString ( item ) ) ) ;
1094+ else if ( itemType == NumberDataViewType . Single )
1095+ node . AddAttribute ( "keys_floats" , Array . ConvertAll ( _valueMap . GetKeys < T > ( ) , item => Convert . ToSingle ( item ) ) ) ;
1096+ else if ( itemType == NumberDataViewType . Int64 )
1097+ node . AddAttribute ( "keys_int64s" , Array . ConvertAll ( _valueMap . GetKeys < T > ( ) , item => Convert . ToInt64 ( item ) ) ) ;
1098+
1099+ }
1100+
1101+ private bool SaveAsOnnxCore ( OnnxContext ctx , string srcVariableName , string dstVariableName )
1102+ {
1103+ const int minimumOpSetVersion = 9 ;
1104+ ctx . CheckOpSetVersion ( minimumOpSetVersion , LoaderSignature ) ;
1105+ OnnxNode node ;
1106+ string opType = "LabelEncoder" ;
1107+ var labelEncoderInput = srcVariableName ;
1108+ var srcShape = ctx . RetrieveShapeOrNull ( srcVariableName ) ;
1109+ var typeValue = _valueMap . ValueColumn . Type ;
1110+ var typeKey = _valueMap . KeyColumn . Type ;
1111+ var kind = _valueMap . ValueColumn . Type . GetRawKind ( ) ;
1112+
1113+ var labelEncoderOutput = ( typeValue == NumberDataViewType . Single || typeValue == TextDataViewType . Instance || typeValue == NumberDataViewType . Int64 ) ? dstVariableName :
1114+ ( typeValue == NumberDataViewType . Double || typeValue == BooleanDataViewType . Instance ) ? ctx . AddIntermediateVariable ( new VectorDataViewType ( NumberDataViewType . Single , ( int ) srcShape [ 1 ] ) , "LabelEncoderOutput" ) :
1115+ ctx . AddIntermediateVariable ( new VectorDataViewType ( NumberDataViewType . Int64 , ( int ) srcShape [ 1 ] ) , "LabelEncoderOutput" ) ;
1116+
1117+ // The LabelEncoder operator doesn't support mappings between the same type and only supports mappings between int64s, floats, and strings.
1118+ // As a result, we need to cast most inputs and outputs. In order to avoid as many unsupported mappings, we cast keys that are of NumberDataTypeView
1119+ // to strings and values of NumberDataViewType to int64s.
1120+ // String -> String mappings can't be supported.
1121+ if ( typeKey == NumberDataViewType . Int64 )
1122+ {
1123+ // To avoid a int64 -> int64 mapping, we cast keys to strings
1124+ if ( typeValue is NumberDataViewType )
1125+ {
1126+ CastInputTo < Int64 > ( ctx , out node , srcVariableName , opType , labelEncoderOutput , TextDataViewType . Instance ) ;
1127+ }
1128+ else
1129+ {
1130+ node = ctx . CreateNode ( opType , srcVariableName , labelEncoderOutput , ctx . GetNodeName ( opType ) ) ;
1131+ node . AddAttribute ( "keys_int64s" , _valueMap . GetKeys < Int64 > ( ) ) ;
1132+ }
1133+ }
1134+ else if ( typeKey == NumberDataViewType . Int32 )
1135+ {
1136+ // To avoid a string -> string mapping, we cast keys to int64s
1137+ if ( typeValue is TextDataViewType )
1138+ CastInputTo < Int32 > ( ctx , out node , srcVariableName , opType , labelEncoderOutput , NumberDataViewType . Int64 ) ;
1139+ else
1140+ CastInputTo < Int32 > ( ctx , out node , srcVariableName , opType , labelEncoderOutput , TextDataViewType . Instance ) ;
1141+ }
1142+ else if ( typeKey == NumberDataViewType . Int16 )
1143+ {
1144+ if ( typeValue is TextDataViewType )
1145+ CastInputTo < Int16 > ( ctx , out node , srcVariableName , opType , labelEncoderOutput , NumberDataViewType . Int64 ) ;
1146+ else
1147+ CastInputTo < Int16 > ( ctx , out node , srcVariableName , opType , labelEncoderOutput , TextDataViewType . Instance ) ;
1148+ }
1149+ else if ( typeKey == NumberDataViewType . UInt64 )
1150+ {
1151+ if ( typeValue is TextDataViewType )
1152+ CastInputTo < UInt64 > ( ctx , out node , srcVariableName , opType , labelEncoderOutput , NumberDataViewType . Int64 ) ;
1153+ else
1154+ CastInputTo < UInt64 > ( ctx , out node , srcVariableName , opType , labelEncoderOutput , TextDataViewType . Instance ) ;
1155+ }
1156+ else if ( typeKey == NumberDataViewType . UInt32 )
1157+ {
1158+ if ( typeValue is TextDataViewType )
1159+ CastInputTo < UInt32 > ( ctx , out node , srcVariableName , opType , labelEncoderOutput , NumberDataViewType . Int64 ) ;
1160+ else
1161+ CastInputTo < UInt32 > ( ctx , out node , srcVariableName , opType , labelEncoderOutput , TextDataViewType . Instance ) ;
1162+ }
1163+ else if ( typeKey == NumberDataViewType . UInt16 )
1164+ {
1165+ if ( typeValue is TextDataViewType )
1166+ CastInputTo < UInt16 > ( ctx , out node , srcVariableName , opType , labelEncoderOutput , NumberDataViewType . Int64 ) ;
1167+ else
1168+ CastInputTo < UInt16 > ( ctx , out node , srcVariableName , opType , labelEncoderOutput , TextDataViewType . Instance ) ;
1169+ }
1170+ else if ( typeKey == NumberDataViewType . Single )
1171+ {
1172+ if ( typeValue == NumberDataViewType . Single || typeValue == NumberDataViewType . Double || typeValue == BooleanDataViewType . Instance )
1173+ {
1174+ CastInputTo < float > ( ctx , out node , srcVariableName , opType , labelEncoderOutput , TextDataViewType . Instance ) ;
1175+ }
1176+ else
1177+ {
1178+ node = ctx . CreateNode ( opType , srcVariableName , labelEncoderOutput , ctx . GetNodeName ( opType ) ) ;
1179+ node . AddAttribute ( "keys_floats" , _valueMap . GetKeys < float > ( ) ) ;
1180+ }
1181+ }
1182+ else if ( typeKey == NumberDataViewType . Double )
1183+ {
1184+ if ( typeValue == NumberDataViewType . Single || typeValue == NumberDataViewType . Double || typeValue == BooleanDataViewType . Instance )
1185+ CastInputTo < double > ( ctx , out node , srcVariableName , opType , labelEncoderOutput , TextDataViewType . Instance ) ;
1186+ else
1187+ CastInputTo < double > ( ctx , out node , srcVariableName , opType , labelEncoderOutput , NumberDataViewType . Single ) ;
1188+ }
1189+ else if ( typeKey == TextDataViewType . Instance )
1190+ {
1191+ if ( typeValue == TextDataViewType . Instance )
1192+ return false ;
1193+ node = ctx . CreateNode ( opType , srcVariableName , labelEncoderOutput , ctx . GetNodeName ( opType ) ) ;
1194+ node . AddAttribute ( "keys_strings" , _valueMap . GetKeys < ReadOnlyMemory < char > > ( ) ) ;
1195+ }
1196+ else if ( typeKey == BooleanDataViewType . Instance )
1197+ {
1198+ if ( typeValue == NumberDataViewType . Single || typeValue == NumberDataViewType . Double || typeValue == BooleanDataViewType . Instance )
1199+ {
1200+ var castOutput = ctx . AddIntermediateVariable ( new VectorDataViewType ( TextDataViewType . Instance , ( int ) srcShape [ 1 ] ) , "castOutput" ) ;
1201+ var castNode = ctx . CreateNode ( "Cast" , srcVariableName , castOutput , ctx . GetNodeName ( "Cast" ) , "" ) ;
1202+ var t = InternalDataKindExtensions . ToInternalDataKind ( DataKind . String ) . ToType ( ) ;
1203+ castNode . AddAttribute ( "to" , t ) ;
1204+ node = ctx . CreateNode ( opType , castOutput , labelEncoderOutput , ctx . GetNodeName ( opType ) ) ;
1205+ var values = Array . ConvertAll ( _valueMap . GetKeys < bool > ( ) , item => Convert . ToString ( Convert . ToByte ( item ) ) ) ;
1206+ node . AddAttribute ( "keys_strings" , values ) ;
1207+ }
1208+ else
1209+ CastInputTo < bool > ( ctx , out node , srcVariableName , opType , labelEncoderOutput , NumberDataViewType . Single ) ;
1210+ }
1211+ else
1212+ return false ;
1213+
1214+ if ( typeValue == NumberDataViewType . Int64 )
1215+ {
1216+ node . AddAttribute ( "values_int64s" , _valueMap . GetValues < long > ( ) ) ;
1217+ }
1218+ else if ( typeValue == NumberDataViewType . Int32 )
1219+ {
1220+ node . AddAttribute ( "values_int64s" , _valueMap . GetValues < int > ( ) . Select ( item => Convert . ToInt64 ( item ) ) ) ;
1221+ var castNode = ctx . CreateNode ( "Cast" , labelEncoderOutput , dstVariableName , ctx . GetNodeName ( "Cast" ) , "" ) ;
1222+ castNode . AddAttribute ( "to" , typeValue . RawType ) ;
1223+ }
1224+ else if ( typeValue == NumberDataViewType . Int16 )
1225+ {
1226+ node . AddAttribute ( "values_int64s" , _valueMap . GetValues < short > ( ) . Select ( item => Convert . ToInt64 ( item ) ) ) ;
1227+ var castNode = ctx . CreateNode ( "Cast" , labelEncoderOutput , dstVariableName , ctx . GetNodeName ( "Cast" ) , "" ) ;
1228+ castNode . AddAttribute ( "to" , typeValue . RawType ) ;
1229+ }
1230+ else if ( typeValue == NumberDataViewType . UInt64 || kind == InternalDataKind . U8 )
1231+ {
1232+ node . AddAttribute ( "values_int64s" , _valueMap . GetValues < ulong > ( ) . Select ( item => Convert . ToInt64 ( item ) ) ) ;
1233+ var castNode = ctx . CreateNode ( "Cast" , labelEncoderOutput , dstVariableName , ctx . GetNodeName ( "Cast" ) , "" ) ;
1234+ castNode . AddAttribute ( "to" , typeValue . RawType ) ;
1235+ }
1236+ else if ( typeValue == NumberDataViewType . UInt32 || kind == InternalDataKind . U4 )
1237+ {
1238+ node . AddAttribute ( "values_int64s" , _valueMap . GetValues < uint > ( ) . Select ( item => Convert . ToInt64 ( item ) ) ) ;
1239+ var castNode = ctx . CreateNode ( "Cast" , labelEncoderOutput , dstVariableName , ctx . GetNodeName ( "Cast" ) , "" ) ;
1240+ castNode . AddAttribute ( "to" , typeValue . RawType ) ;
1241+ }
1242+ else if ( typeValue == NumberDataViewType . UInt16 )
1243+ {
1244+ node . AddAttribute ( "values_int64s" , _valueMap . GetValues < ushort > ( ) . Select ( item => Convert . ToInt64 ( item ) ) ) ;
1245+ var castNode = ctx . CreateNode ( "Cast" , labelEncoderOutput , dstVariableName , ctx . GetNodeName ( "Cast" ) , "" ) ;
1246+ castNode . AddAttribute ( "to" , typeValue . RawType ) ;
1247+ }
1248+ else if ( typeValue == NumberDataViewType . Single )
1249+ {
1250+ node . AddAttribute ( "values_floats" , _valueMap . GetValues < float > ( ) ) ;
1251+ }
1252+ else if ( typeValue == NumberDataViewType . Double )
1253+ {
1254+ node . AddAttribute ( "values_floats" , _valueMap . GetValues < double > ( ) . Select ( item => Convert . ToSingle ( item ) ) ) ;
1255+ var castNode = ctx . CreateNode ( "Cast" , labelEncoderOutput , dstVariableName , ctx . GetNodeName ( "Cast" ) , "" ) ;
1256+ castNode . AddAttribute ( "to" , typeValue . RawType ) ;
1257+ }
1258+ else if ( typeValue == TextDataViewType . Instance )
1259+ {
1260+ node . AddAttribute ( "values_strings" , _valueMap . GetValues < ReadOnlyMemory < char > > ( ) ) ;
1261+ }
1262+ else if ( typeValue == BooleanDataViewType . Instance )
1263+ {
1264+ node . AddAttribute ( "values_floats" , _valueMap . GetValues < bool > ( ) . Select ( item => Convert . ToSingle ( item ) ) ) ;
1265+ var castNode = ctx . CreateNode ( "Cast" , labelEncoderOutput , dstVariableName , ctx . GetNodeName ( "Cast" ) , "" ) ;
1266+ castNode . AddAttribute ( "to" , typeValue . RawType ) ;
1267+ }
1268+ else
1269+ return false ;
1270+
1271+ //Unknown keys should map to 0
1272+ node . AddAttribute ( "default_int64" , 0 ) ;
1273+ node . AddAttribute ( "default_string" , "" ) ;
1274+ node . AddAttribute ( "default_float" , 0f ) ;
1275+ return true ;
1276+ }
1277+
10431278 protected override DataViewSchema . DetachedColumn [ ] GetOutputColumnsCore ( )
10441279 {
10451280 var result = new DataViewSchema . DetachedColumn [ _columns . Length ] ;
0 commit comments