1919using Tensorflow ;
2020using static Microsoft . ML . TensorFlow . TensorFlowUtils ;
2121using static Tensorflow . Binding ;
22+ using Utils = Microsoft . ML . Internal . Utilities . Utils ;
2223
2324[ assembly: LoadableClass ( DnnRetrainTransformer . Summary , typeof ( IDataTransform ) , typeof ( DnnRetrainTransformer ) ,
2425 typeof ( DnnRetrainEstimator . Options ) , typeof ( SignatureDataTransform ) , DnnRetrainTransformer . UserName , DnnRetrainTransformer . ShortName ) ]
@@ -607,15 +608,15 @@ internal static TensorShape GetTensorShape(TF_Output output, Graph graph, Status
607608 new ObjectDisposedException ( nameof ( graph ) ) ;
608609
609610 var cstatus = status == null ? new Status ( ) : status ;
610- var n = c_api . TF_GraphGetTensorNumDims ( graph , output , cstatus ) ;
611+ var n = c_api . TF_GraphGetTensorNumDims ( graph , output , cstatus . Handle ) ;
611612
612613 cstatus . Check ( ) ;
613614
614615 if ( n == - 1 )
615616 return new TensorShape ( new int [ 0 ] ) ;
616617
617618 var dims = new long [ n ] ;
618- c_api . TF_GraphGetTensorShape ( graph , output , dims , dims . Length , cstatus ) ;
619+ c_api . TF_GraphGetTensorShape ( graph , output , dims , dims . Length , cstatus . Handle ) ;
619620 cstatus . Check ( ) ;
620621 return new TensorShape ( dims . Select ( x => ( int ) x ) . ToArray ( ) ) ;
621622 }
@@ -1040,49 +1041,11 @@ public Tensor GetBufferedBatchTensor()
10401041 }
10411042 else
10421043 {
1043- var tensor = CastDataAndReturnAsTensor ( _bufferedData ) ;
1044+ var tensor = TensorFlowUtils . CastDataAndReturnAsTensor ( _bufferedData , _tfShape ) ;
10441045 _position = 0 ;
10451046 return tensor ;
10461047 }
10471048 }
1048-
1049- private Tensor CastDataAndReturnAsTensor ( T [ ] data )
1050- {
1051- if ( typeof ( T ) == typeof ( sbyte ) )
1052- return new Tensor ( ( sbyte [ ] ) ( object ) data , _dims , TF_DataType . TF_INT8 ) ;
1053- else if ( typeof ( T ) == typeof ( long ) )
1054- return new Tensor ( ( long [ ] ) ( object ) data , _dims , TF_DataType . TF_INT64 ) ;
1055- else if ( typeof ( T ) == typeof ( Int32 ) )
1056- return new Tensor ( ( Int32 [ ] ) ( object ) data , _dims , TF_DataType . TF_INT32 ) ;
1057- else if ( typeof ( T ) == typeof ( Int16 ) )
1058- return new Tensor ( ( Int16 [ ] ) ( object ) data , _dims , TF_DataType . TF_INT16 ) ;
1059- else if ( typeof ( T ) == typeof ( byte ) )
1060- return new Tensor ( ( byte [ ] ) ( object ) data , _dims , TF_DataType . TF_UINT8 ) ;
1061- else if ( typeof ( T ) == typeof ( ulong ) )
1062- return new Tensor ( ( ulong [ ] ) ( object ) data , _dims , TF_DataType . TF_UINT64 ) ;
1063- else if ( typeof ( T ) == typeof ( UInt32 ) )
1064- return new Tensor ( ( UInt32 [ ] ) ( object ) data , _dims , TF_DataType . TF_UINT32 ) ;
1065- else if ( typeof ( T ) == typeof ( UInt16 ) )
1066- return new Tensor ( ( UInt16 [ ] ) ( object ) data , _dims , TF_DataType . TF_UINT16 ) ;
1067- else if ( typeof ( T ) == typeof ( bool ) )
1068- return new Tensor ( ( bool [ ] ) ( object ) data , _dims , TF_DataType . TF_BOOL ) ;
1069- else if ( typeof ( T ) == typeof ( float ) )
1070- return new Tensor ( ( float [ ] ) ( object ) data , _dims , TF_DataType . TF_FLOAT ) ;
1071- else if ( typeof ( T ) == typeof ( float ) )
1072- return new Tensor ( ( double [ ] ) ( object ) data , _dims , TF_DataType . TF_DOUBLE ) ;
1073- else if ( typeof ( T ) == typeof ( ReadOnlyMemory < char > ) )
1074- {
1075- byte [ ] [ ] bytes = new byte [ _bufferedData . Length ] [ ] ;
1076- for ( int i = 0 ; i < bytes . Length ; i ++ )
1077- {
1078- bytes [ i ] = Encoding . UTF8 . GetBytes ( ( ( ReadOnlyMemory < char > ) ( object ) data [ i ] ) . ToArray ( ) ) ;
1079- }
1080-
1081- return new Tensor ( bytes , _tfShape . dims . Select ( x => ( long ) x ) . ToArray ( ) ) ;
1082- }
1083-
1084- return new Tensor ( new NDArray ( data , _tfShape ) ) ;
1085- }
10861049 }
10871050
10881051 private class TensorValueGetterVec < T > : ITensorValueGetter
@@ -1126,45 +1089,7 @@ public Tensor GetTensor()
11261089 // This is done to reduce memory allocation every time tensor is created.
11271090 _denseData = new T [ _vBuffer . Length ] ;
11281091 _vBuffer . CopyTo ( _denseData ) ;
1129- return CastDataAndReturnAsTensor ( _denseData ) ;
1130- }
1131-
1132- private Tensor CastDataAndReturnAsTensor ( T [ ] data )
1133- {
1134- if ( typeof ( T ) == typeof ( sbyte ) )
1135- return new Tensor ( ( sbyte [ ] ) ( object ) data , _dims , TF_DataType . TF_INT8 ) ;
1136- else if ( typeof ( T ) == typeof ( long ) )
1137- return new Tensor ( ( long [ ] ) ( object ) data , _dims , TF_DataType . TF_INT64 ) ;
1138- else if ( typeof ( T ) == typeof ( Int32 ) )
1139- return new Tensor ( ( Int32 [ ] ) ( object ) data , _dims , TF_DataType . TF_INT32 ) ;
1140- else if ( typeof ( T ) == typeof ( Int16 ) )
1141- return new Tensor ( ( Int16 [ ] ) ( object ) data , _dims , TF_DataType . TF_INT16 ) ;
1142- else if ( typeof ( T ) == typeof ( byte ) )
1143- return new Tensor ( ( byte [ ] ) ( object ) data , _dims , TF_DataType . TF_UINT8 ) ;
1144- else if ( typeof ( T ) == typeof ( ulong ) )
1145- return new Tensor ( ( ulong [ ] ) ( object ) data , _dims , TF_DataType . TF_UINT64 ) ;
1146- else if ( typeof ( T ) == typeof ( UInt32 ) )
1147- return new Tensor ( ( UInt32 [ ] ) ( object ) data , _dims , TF_DataType . TF_UINT32 ) ;
1148- else if ( typeof ( T ) == typeof ( UInt16 ) )
1149- return new Tensor ( ( UInt16 [ ] ) ( object ) data , _dims , TF_DataType . TF_UINT16 ) ;
1150- else if ( typeof ( T ) == typeof ( bool ) )
1151- return new Tensor ( ( bool [ ] ) ( object ) data , _dims , TF_DataType . TF_BOOL ) ;
1152- else if ( typeof ( T ) == typeof ( float ) )
1153- return new Tensor ( ( float [ ] ) ( object ) data , _dims , TF_DataType . TF_FLOAT ) ;
1154- else if ( typeof ( T ) == typeof ( double ) )
1155- return new Tensor ( ( double [ ] ) ( object ) data , _dims , TF_DataType . TF_DOUBLE ) ;
1156- else if ( typeof ( T ) == typeof ( ReadOnlyMemory < char > ) )
1157- {
1158- byte [ ] [ ] bytes = new byte [ _vBuffer . Length ] [ ] ;
1159- for ( int i = 0 ; i < bytes . Length ; i ++ )
1160- {
1161- bytes [ i ] = Encoding . UTF8 . GetBytes ( ( ( ReadOnlyMemory < char > ) ( object ) data [ i ] ) . ToArray ( ) ) ;
1162- }
1163-
1164- return new Tensor ( bytes , _tfShape . dims . Select ( x => ( long ) x ) . ToArray ( ) ) ;
1165- }
1166-
1167- return new Tensor ( new NDArray ( data , _tfShape ) ) ;
1092+ return TensorFlowUtils . CastDataAndReturnAsTensor ( _denseData , _tfShape ) ;
11681093 }
11691094
11701095 public void BufferTrainingData ( )
@@ -1177,7 +1102,7 @@ public void BufferTrainingData()
11771102 public Tensor GetBufferedBatchTensor ( )
11781103 {
11791104 _position = 0 ;
1180- var tensor = CastDataAndReturnAsTensor ( _bufferedData ) ;
1105+ var tensor = TensorFlowUtils . CastDataAndReturnAsTensor ( _bufferedData , _tfShape ) ;
11811106 _bufferedData = new T [ _bufferedDataSize ] ;
11821107 return tensor ;
11831108 }
0 commit comments