@@ -786,6 +786,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src
786786 OnnxNode node ;
787787 long [ ] termIds ;
788788 string opType = "LabelEncoder" ;
789+ OnnxNode castNode ;
789790 var labelEncoderOutput = ctx . AddIntermediateVariable ( _types [ iinfo ] , "LabelEncoderOutput" , true ) ;
790791
791792 if ( info . TypeSrc . GetItemType ( ) . Equals ( TextDataViewType . Instance ) )
@@ -800,6 +801,26 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src
800801 var terms = GetTermsAndIds < float > ( iinfo , out termIds ) ;
801802 node . AddAttribute ( "keys_floats" , terms ) ;
802803 }
804+ else if ( info . TypeSrc . GetItemType ( ) . Equals ( NumberDataViewType . Double ) )
805+ {
806+ var castOutput = ctx . AddIntermediateVariable ( null , "castOutput" , true ) ;
807+ castNode = ctx . CreateNode ( "Cast" , srcVariableName , castOutput , ctx . GetNodeName ( opType ) , "" ) ;
808+ var t = InternalDataKindExtensions . ToInternalDataKind ( DataKind . Single ) . ToType ( ) ;
809+ castNode . AddAttribute ( "to" , t ) ;
810+ node = ctx . CreateNode ( opType , castOutput , labelEncoderOutput , ctx . GetNodeName ( opType ) ) ;
811+ var terms = GetTermsAndIds < double > ( iinfo , out termIds ) ;
812+ node . AddAttribute ( "keys_floats" , terms ) ;
813+ }
814+ else if ( info . TypeSrc . GetItemType ( ) . Equals ( NumberDataViewType . Int64 ) )
815+ {
816+ var castOutput = ctx . AddIntermediateVariable ( null , "castOutput" , true ) ;
817+ castNode = ctx . CreateNode ( "Cast" , srcVariableName , castOutput , ctx . GetNodeName ( opType ) , "" ) ;
818+ var t = InternalDataKindExtensions . ToInternalDataKind ( DataKind . String ) . ToType ( ) ;
819+ castNode . AddAttribute ( "to" , t ) ;
820+ node = ctx . CreateNode ( opType , castOutput , labelEncoderOutput , ctx . GetNodeName ( opType ) ) ;
821+ var terms = GetTermsAndIds < long > ( iinfo , out termIds ) ;
822+ node . AddAttribute ( "keys_strings" , terms . Select ( item => item . ToString ( ) ) ) ;
823+ }
803824 else
804825 {
805826 // LabelEncoder-2 in ORT v1 only supports the following mappings
@@ -822,7 +843,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src
822843 InternalDataKindExtensions . TryGetDataKind ( _parent . _unboundMaps [ iinfo ] . OutputType . RawType , out dataKind ) ;
823844
824845 opType = "Cast" ;
825- var castNode = ctx . CreateNode ( opType , labelEncoderOutput , dstVariableName , ctx . GetNodeName ( opType ) , "" ) ;
846+ castNode = ctx . CreateNode ( opType , labelEncoderOutput , dstVariableName , ctx . GetNodeName ( opType ) , "" ) ;
826847 castNode . AddAttribute ( "to" , dataKind . ToType ( ) ) ;
827848
828849 return true ;
0 commit comments