@@ -1349,20 +1349,38 @@ private void AddMetaKeyValues(int i, DataViewSchema.Annotations.Builder builder)
13491349 private bool SaveAsOnnxCore ( OnnxContext ctx , int iinfo , string srcVariable , string dstVariable )
13501350 {
13511351 string castOutput ;
1352+ string isGreaterThanZeroOutput = "" ;
13521353 OnnxNode castNode ;
13531354 OnnxNode murmurNode ;
1355+ OnnxNode isZeroNode ;
13541356
13551357 var srcType = _srcTypes [ iinfo ] . GetItemType ( ) ;
1356- if ( srcType is KeyDataViewType )
1357- return false ;
13581358 if ( _parent . _columns [ iinfo ] . Combine )
13591359 return false ;
13601360
13611361 var opType = "MurmurHash3" ;
13621362 string murmurOutput = ctx . AddIntermediateVariable ( _dstTypes [ iinfo ] , "MurmurOutput" ) ;
13631363
1364- // Numeric input types are limited to those supported by the Onnxruntime MurmurHash operator, which currently only supports
1365- // uints and ints. Thus, ulongs, longs, doubles and floats are not supported.
1364+ // Get zero value indeces
1365+ if ( _srcTypes [ iinfo ] is KeyDataViewType )
1366+ {
1367+ var optType2 = "Cast" ;
1368+ castOutput = ctx . AddIntermediateVariable ( NumberDataViewType . Int64 , "CastOutput" , true ) ;
1369+ isZeroNode = ctx . CreateNode ( optType2 , srcVariable , castOutput , ctx . GetNodeName ( optType2 ) , "" ) ;
1370+ isZeroNode . AddAttribute ( "to" , NumberDataViewType . Int64 . RawType ) ;
1371+
1372+ var zero = ctx . AddInitializer ( 0 ) ;
1373+ var isGreaterThanZeroOutputBool = ctx . AddIntermediateVariable ( BooleanDataViewType . Instance , "isGreaterThanZeroOutputBool" ) ;
1374+ optType2 = "Greater" ;
1375+ ctx . CreateNode ( optType2 , new [ ] { castOutput , zero } , new [ ] { isGreaterThanZeroOutputBool } , ctx . GetNodeName ( optType2 ) , "" ) ;
1376+
1377+ isGreaterThanZeroOutput = ctx . AddIntermediateVariable ( NumberDataViewType . Int64 , "isGreaterThanZeroOutput" ) ;
1378+ optType2 = "Cast" ;
1379+ isZeroNode = ctx . CreateNode ( optType2 , isGreaterThanZeroOutputBool , isGreaterThanZeroOutput , ctx . GetNodeName ( optType2 ) , "" ) ;
1380+ isZeroNode . AddAttribute ( "to" , NumberDataViewType . Int64 . RawType ) ;
1381+ }
1382+
1383+ // Since these numeric types are not supported by Onnxruntime, we cast them to UInt32.
13661384 if ( srcType == NumberDataViewType . UInt16 || srcType == NumberDataViewType . Int16 ||
13671385 srcType == NumberDataViewType . SByte || srcType == NumberDataViewType . Byte ||
13681386 srcType == BooleanDataViewType . Instance )
@@ -1372,15 +1390,9 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariable, stri
13721390 castNode . AddAttribute ( "to" , NumberDataViewType . UInt32 . RawType ) ;
13731391 murmurNode = ctx . CreateNode ( opType , castOutput , murmurOutput , ctx . GetNodeName ( opType ) , "com.microsoft" ) ;
13741392 }
1375- else if ( srcType == NumberDataViewType . UInt32 || srcType == NumberDataViewType . Int32 || srcType == NumberDataViewType . UInt64 ||
1376- srcType == NumberDataViewType . Int64 || srcType == NumberDataViewType . Single || srcType == NumberDataViewType . Double || srcType == TextDataViewType . Instance )
1377-
1378- {
1379- murmurNode = ctx . CreateNode ( opType , srcVariable , murmurOutput , ctx . GetNodeName ( opType ) , "com.microsoft" ) ;
1380- }
13811393 else
13821394 {
1383- return false ;
1395+ murmurNode = ctx . CreateNode ( opType , srcVariable , murmurOutput , ctx . GetNodeName ( opType ) , "com.microsoft" ) ;
13841396 }
13851397
13861398 murmurNode . AddAttribute ( "positive" , 1 ) ;
@@ -1417,10 +1429,17 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariable, stri
14171429 string one = ctx . AddInitializer ( 1 ) ;
14181430 ctx . CreateNode ( opType , new [ ] { castOutput , one } , new [ ] { addOutput } , ctx . GetNodeName ( opType ) , "" ) ;
14191431
1432+ string mulOutput = ctx . AddIntermediateVariable ( vectorShape , "MulOutput" ) ;
1433+ if ( _srcTypes [ iinfo ] is KeyDataViewType )
1434+ {
1435+ opType = "Mul" ;
1436+ ctx . CreateNode ( opType , new [ ] { isGreaterThanZeroOutput , addOutput } , new [ ] { mulOutput } , ctx . GetNodeName ( opType ) , "" ) ;
1437+ }
1438+
14201439 opType = "Cast" ;
1421- var castNodeFinal = ctx . CreateNode ( opType , addOutput , dstVariable , ctx . GetNodeName ( opType ) , "" ) ;
1440+ var input = ( _srcTypes [ iinfo ] is KeyDataViewType ) ? mulOutput : addOutput ;
1441+ var castNodeFinal = ctx . CreateNode ( opType , input , dstVariable , ctx . GetNodeName ( opType ) , "" ) ;
14221442 castNodeFinal . AddAttribute ( "to" , _dstTypes [ iinfo ] . GetItemType ( ) . RawType ) ;
1423-
14241443 return true ;
14251444 }
14261445
0 commit comments