2727using System . IO ;
2828using System . CodeDom ;
2929using System . Runtime . CompilerServices ;
30+ using Microsoft . ML . Data . IO ;
3031
3132[ assembly: LoadableClass ( typeof ( TextClassificationTransformer ) , null , typeof ( SignatureLoadModel ) ,
3233 TextClassificationTransformer . UserName , TextClassificationTransformer . LoaderSignature ) ]
@@ -70,7 +71,6 @@ public sealed class TextClassificationTrainer : IEstimator<TextClassificationTra
7071 {
7172 private readonly IHost _host ;
7273 private readonly Options _options ;
73- private TextClassificationTransformer _transformer ;
7474 private const string ModelUrl = "models/NasBert2000000.tsm" ;
7575
7676 internal sealed class Options : TransformInputBase
@@ -290,6 +290,10 @@ internal TextClassificationTrainer(IHostEnvironment env, Options options)
290290
291291 public TextClassificationTransformer Fit ( IDataView input )
292292 {
293+ CheckInputSchema ( SchemaShape . Create ( input . Schema ) ) ;
294+
295+ TextClassificationTransformer transformer = default ;
296+
293297 using ( var ch = _host . Start ( "TrainModel" ) )
294298 using ( var pch = _host . StartProgressChannel ( "Training model" ) )
295299 {
@@ -304,11 +308,13 @@ public TextClassificationTransformer Fit(IDataView input)
304308 if ( _options . ValidationSet != null )
305309 trainer . Validate ( pch , ch , i ) ;
306310 }
307- _transformer = new TextClassificationTransformer ( _host , _options , trainer . Model , trainer . Tokenizer . Vocabulary ) ;
311+ var labelCol = input . Schema . GetColumnOrNull ( _options . LabelColumnName ) ;
312+
313+ transformer = new TextClassificationTransformer ( _host , _options , trainer . Model , new DataViewSchema . DetachedColumn ( labelCol . Value ) ) ;
308314
309- _transformer . GetOutputSchema ( input . Schema ) ;
315+ transformer . GetOutputSchema ( input . Schema ) ;
310316 }
311- return _transformer ;
317+ return transformer ;
312318 }
313319
314320 private class Trainer
@@ -668,31 +674,32 @@ public sealed class TextClassificationTransformer : RowToRowTransformerBase
668674
669675 private readonly Device _device ;
670676 private readonly TextClassificationModel _model ;
671- private readonly Vocabulary _vocabulary ;
672677 private readonly TextClassificationTrainer . Options _options ;
673678
674679 private readonly string _predictedLabelColumnName ;
675680 private readonly string _scoreColumnName ;
676681
677682 public readonly SchemaShape . Column SentenceColumn ;
678683 public readonly SchemaShape . Column SentenceColumn2 ;
679- public readonly SchemaShape . Column LabelColumn ;
684+ public readonly DataViewSchema . DetachedColumn LabelColumn ;
680685
681686 internal const string LoaderSignature = "NASBERT" ;
682687
683- internal TextClassificationTransformer ( IHostEnvironment env , TextClassificationTrainer . Options options , TextClassificationModel model , Vocabulary vocabulary )
688+ private static readonly FuncStaticMethodInfo1 < object , Delegate > _decodeInitMethodInfo
689+ = new FuncStaticMethodInfo1 < object , Delegate > ( DecodeInit < int > ) ;
690+
691+ internal TextClassificationTransformer ( IHostEnvironment env , TextClassificationTrainer . Options options , TextClassificationModel model , DataViewSchema . DetachedColumn labelColumn )
684692 : base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( nameof ( TextClassificationTransformer ) ) )
685693 {
686694 _device = TorchUtils . InitializeDevice ( env ) ;
687695
688696 _options = options ;
689- LabelColumn = new SchemaShape . Column ( _options . LabelColumnName , SchemaShape . Column . VectorKind . Scalar , NumberDataViewType . UInt32 , true ) ;
697+ LabelColumn = labelColumn ;
690698 SentenceColumn = new SchemaShape . Column ( _options . Sentence1ColumnName , SchemaShape . Column . VectorKind . Scalar , TextDataViewType . Instance , false ) ;
691699 SentenceColumn2 = _options . Sentence2ColumnName == default ? default : new SchemaShape . Column ( _options . Sentence2ColumnName , SchemaShape . Column . VectorKind . Scalar , TextDataViewType . Instance , false ) ;
692700 _predictedLabelColumnName = _options . PredictionColumnName ;
693701 _scoreColumnName = _options . ScoreColumnName ;
694702
695- _vocabulary = vocabulary ;
696703 _model = model ;
697704
698705 if ( _device == CUDA )
@@ -736,24 +743,45 @@ private static TextClassificationTransformer Create(IHostEnvironment env, ModelL
736743 if ( ! ctx . TryLoadBinaryStream ( "TSModel" , r => model . load ( r ) ) )
737744 throw env . ExceptDecode ( ) ;
738745
739- return new TextClassificationTransformer ( env , options , model , vocabulary ) ;
746+ BinarySaver saver = new BinarySaver ( env , new BinarySaver . Arguments ( ) ) ;
747+ DataViewType type ;
748+ object value ;
749+ env . CheckDecode ( saver . TryLoadTypeAndValue ( ctx . Reader . BaseStream , out type , out value ) ) ;
750+ var vecType = type as VectorDataViewType ;
751+ env . CheckDecode ( vecType != null ) ;
752+ env . CheckDecode ( value != null ) ;
753+ var labelGetter = Microsoft . ML . Internal . Utilities . Utils . MarshalInvoke ( _decodeInitMethodInfo , vecType . ItemType . RawType , value ) ;
754+
755+ var meta = new DataViewSchema . Annotations . Builder ( ) ;
756+ meta . Add ( AnnotationUtils . Kinds . KeyValues , type , labelGetter ) ;
757+
758+ var labelCol = new DataViewSchema . DetachedColumn ( options . LabelColumnName , type , meta . ToAnnotations ( ) ) ;
759+
760+ return new TextClassificationTransformer ( env , options , model , labelCol ) ;
761+ }
762+
763+ private static Delegate DecodeInit < T > ( object value )
764+ {
765+ VBuffer < T > buffValue = ( VBuffer < T > ) value ;
766+ ValueGetter < VBuffer < T > > buffGetter = ( ref VBuffer < T > dst ) => buffValue . CopyTo ( ref dst ) ;
767+ return buffGetter ;
740768 }
741769
742770 public SchemaShape GetOutputSchema ( SchemaShape inputSchema )
743771 {
744772 Host . CheckValue ( inputSchema , nameof ( inputSchema ) ) ;
745773
746774 CheckInputSchema ( inputSchema ) ;
747- inputSchema . TryFindColumn ( LabelColumn . Name , out var labelCol ) ;
748- var predLabelMetadata = new SchemaShape ( labelCol . Annotations . Where ( x => x . Name == AnnotationUtils . Kinds . KeyValues )
775+ var labelAnnotationsColumn = new SchemaShape . Column ( AnnotationUtils . Kinds . SlotNames , SchemaShape . Column . VectorKind . Vector , LabelColumn . Annotations . Schema [ AnnotationUtils . Kinds . SlotNames ] . Type , false ) ;
776+ var predLabelMetadata = new SchemaShape ( new SchemaShape . Column [ ] { labelAnnotationsColumn }
749777 . Concat ( AnnotationUtils . GetTrainerOutputAnnotation ( ) ) ) ;
750778
751779 var outColumns = inputSchema . ToDictionary ( x => x . Name ) ;
752780 outColumns [ _predictedLabelColumnName ] = new SchemaShape . Column ( _predictedLabelColumnName , SchemaShape . Column . VectorKind . Scalar ,
753781 NumberDataViewType . UInt32 , true , predLabelMetadata ) ;
754782
755783 outColumns [ _scoreColumnName ] = new SchemaShape . Column ( _scoreColumnName , SchemaShape . Column . VectorKind . Vector ,
756- NumberDataViewType . Single , false , new SchemaShape ( AnnotationUtils . AnnotationsForMulticlassScoreColumn ( labelCol ) ) ) ;
784+ NumberDataViewType . Single , false , new SchemaShape ( AnnotationUtils . AnnotationsForMulticlassScoreColumn ( labelAnnotationsColumn ) ) ) ;
757785
758786 return new SchemaShape ( outColumns . Values ) ;
759787 }
@@ -775,12 +803,6 @@ private void CheckInputSchema(SchemaShape inputSchema)
775803 throw Host . ExceptSchemaMismatch ( nameof ( inputSchema ) , "sentence2" , SentenceColumn2 . Name ,
776804 SentenceColumn2 . GetTypeString ( ) , sentenceCol2 . GetTypeString ( ) ) ;
777805 }
778-
779- if ( ! inputSchema . TryFindColumn ( LabelColumn . Name , out var labelCol ) )
780- throw Host . ExceptSchemaMismatch ( nameof ( inputSchema ) , "label" , LabelColumn . Name ) ;
781- if ( ! LabelColumn . IsCompatibleWith ( labelCol ) )
782- throw Host . ExceptSchemaMismatch ( nameof ( inputSchema ) , "label" , LabelColumn . Name ,
783- LabelColumn . GetTypeString ( ) , labelCol . GetTypeString ( ) ) ;
784806 }
785807
786808 private static VersionInfo GetVersionInfo ( )
@@ -819,6 +841,22 @@ private protected override void SaveModel(ModelSaveContext ctx)
819841 {
820842 _model . save ( w ) ;
821843 } ) ;
844+
845+ var labelColType = LabelColumn . Annotations . Schema [ AnnotationUtils . Kinds . KeyValues ] . Type as VectorDataViewType ;
846+ Microsoft . ML . Internal . Utilities . Utils . MarshalActionInvoke ( SaveLabelValues < int > , labelColType . ItemType . RawType , ctx ) ;
847+ }
848+
849+ private void SaveLabelValues < T > ( ModelSaveContext ctx )
850+ {
851+ ValueGetter < VBuffer < T > > getter = LabelColumn . Annotations . GetGetter < VBuffer < T > > ( LabelColumn . Annotations . Schema [ AnnotationUtils . Kinds . KeyValues ] ) ;
852+ var val = default ( VBuffer < T > ) ;
853+ getter ( ref val ) ;
854+
855+ BinarySaver saver = new BinarySaver ( Host , new BinarySaver . Arguments ( ) ) ;
856+ int bytesWritten ;
857+ var labelColType = LabelColumn . Annotations . Schema [ AnnotationUtils . Kinds . KeyValues ] . Type as VectorDataViewType ;
858+ if ( ! saver . TryWriteTypeAndValue < VBuffer < T > > ( ctx . Writer . BaseStream , labelColType , ref val , out bytesWritten ) )
859+ throw Host . Except ( "We do not know how to serialize label names of type '{0}'" , labelColType . ItemType ) ;
822860 }
823861
824862 private protected override IRowMapper MakeRowMapper ( DataViewSchema schema ) => new Mapper ( this , schema ) ;
@@ -827,25 +865,24 @@ private sealed class Mapper : MapperBase
827865 {
828866 private readonly TextClassificationTransformer _parent ;
829867 private readonly HashSet < int > _inputColIndices ;
830- private readonly DataViewSchema . Column _labelCol ;
831868 private readonly DataViewSchema _inputSchema ;
832- private static readonly FuncInstanceMethodInfo1 < Mapper , Delegate > _makeLabelAnnotationGetter
833- = FuncInstanceMethodInfo1 < Mapper , Delegate > . Create ( target => target . GetLabelAnnotations < int > ) ;
869+
870+ private static readonly FuncInstanceMethodInfo1 < Mapper , DataViewSchema . DetachedColumn , Delegate > _makeLabelAnnotationGetter
871+ = FuncInstanceMethodInfo1 < Mapper , DataViewSchema . DetachedColumn , Delegate > . Create ( target => target . GetLabelAnnotations < int > ) ;
872+
834873
835874 public Mapper ( TextClassificationTransformer parent , DataViewSchema inputSchema ) :
836875 base ( Contracts . CheckRef ( parent , nameof ( parent ) ) . Host . Register ( nameof ( Mapper ) ) , inputSchema , parent )
837876 {
838877 _parent = parent ;
839878 _inputColIndices = new HashSet < int > ( ) ;
840- int col = 0 ;
841- if ( inputSchema . TryGetColumnIndex ( parent . _options . Sentence1ColumnName , out col ) )
879+ if ( inputSchema . TryGetColumnIndex ( parent . _options . Sentence1ColumnName , out var col ) )
842880 _inputColIndices . Add ( col ) ;
843881
844882 if ( parent . _options . Sentence2ColumnName != default )
845883 if ( inputSchema . TryGetColumnIndex ( parent . _options . Sentence2ColumnName , out col ) )
846884 _inputColIndices . Add ( col ) ;
847885
848- _labelCol = inputSchema [ _parent . _options . LabelColumnName ] ;
849886 _inputSchema = inputSchema ;
850887
851888 torch . random . manual_seed ( 1 ) ;
@@ -855,8 +892,9 @@ public Mapper(TextClassificationTransformer parent, DataViewSchema inputSchema)
855892 protected override DataViewSchema . DetachedColumn [ ] GetOutputColumnsCore ( )
856893 {
857894 var info = new DataViewSchema . DetachedColumn [ 2 ] ;
858- var keyType = _labelCol . Annotations . Schema . GetColumnOrNull ( AnnotationUtils . Kinds . KeyValues ) ? . Type as VectorDataViewType ;
859- var getter = Microsoft . ML . Internal . Utilities . Utils . MarshalInvoke ( _makeLabelAnnotationGetter , this , keyType . ItemType . RawType ) ;
895+ var keyType = _parent . LabelColumn . Annotations . Schema . GetColumnOrNull ( AnnotationUtils . Kinds . KeyValues ) ? . Type as VectorDataViewType ;
896+ var getter = Microsoft . ML . Internal . Utilities . Utils . MarshalInvoke ( _makeLabelAnnotationGetter , this , keyType . ItemType . RawType , _parent . LabelColumn ) ;
897+
860898
861899 var meta = new DataViewSchema . Annotations . Builder ( ) ;
862900 meta . Add ( AnnotationUtils . Kinds . ScoreColumnKind , TextDataViewType . Instance , ( ref ReadOnlyMemory < char > value ) => { value = AnnotationUtils . Const . ScoreColumnKind . MulticlassClassification . AsMemory ( ) ; } ) ;
@@ -865,15 +903,18 @@ protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
865903 meta . Add ( AnnotationUtils . Kinds . TrainingLabelValues , keyType , getter ) ;
866904 meta . Add ( AnnotationUtils . Kinds . SlotNames , keyType , getter ) ;
867905
868- info [ 0 ] = new DataViewSchema . DetachedColumn ( _parent . _options . PredictionColumnName , new KeyDataViewType ( typeof ( uint ) , _parent . _options . NumberOfClasses ) , _labelCol . Annotations ) ;
906+ var labelBuilder = new DataViewSchema . Annotations . Builder ( ) ;
907+ labelBuilder . Add ( AnnotationUtils . Kinds . KeyValues , keyType , getter ) ;
908+
909+ info [ 0 ] = new DataViewSchema . DetachedColumn ( _parent . _options . PredictionColumnName , new KeyDataViewType ( typeof ( uint ) , _parent . _options . NumberOfClasses ) , labelBuilder . ToAnnotations ( ) ) ;
869910
870911 info [ 1 ] = new DataViewSchema . DetachedColumn ( _parent . _options . ScoreColumnName , new VectorDataViewType ( NumberDataViewType . Single , _parent . _options . NumberOfClasses ) , meta . ToAnnotations ( ) ) ;
871912 return info ;
872913 }
873914
874- private Delegate GetLabelAnnotations < T > ( )
915+ private Delegate GetLabelAnnotations < T > ( DataViewSchema . DetachedColumn labelCol )
875916 {
876- return _labelCol . Annotations . GetGetter < VBuffer < T > > ( _labelCol . Annotations . Schema [ AnnotationUtils . Kinds . KeyValues ] ) ;
917+ return labelCol . Annotations . GetGetter < VBuffer < T > > ( labelCol . Annotations . Schema [ AnnotationUtils . Kinds . KeyValues ] ) ;
877918 }
878919
879920 private ValueGetter < uint > GetScoreColumnSetId ( DataViewSchema schema )
0 commit comments