Skip to content

Commit 9f26edd

Browse files
Don't need label column for inference TextClassification. (#6259)
1 parent de9afb5 commit 9f26edd

File tree

2 files changed

+117
-32
lines changed

2 files changed

+117
-32
lines changed

src/Microsoft.ML.TorchSharp/NasBert/TextClassificationTrainer.cs

Lines changed: 71 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
using System.IO;
2828
using System.CodeDom;
2929
using System.Runtime.CompilerServices;
30+
using Microsoft.ML.Data.IO;
3031

3132
[assembly: LoadableClass(typeof(TextClassificationTransformer), null, typeof(SignatureLoadModel),
3233
TextClassificationTransformer.UserName, TextClassificationTransformer.LoaderSignature)]
@@ -70,7 +71,6 @@ public sealed class TextClassificationTrainer : IEstimator<TextClassificationTra
7071
{
7172
private readonly IHost _host;
7273
private readonly Options _options;
73-
private TextClassificationTransformer _transformer;
7474
private const string ModelUrl = "models/NasBert2000000.tsm";
7575

7676
internal sealed class Options : TransformInputBase
@@ -290,6 +290,10 @@ internal TextClassificationTrainer(IHostEnvironment env, Options options)
290290

291291
public TextClassificationTransformer Fit(IDataView input)
292292
{
293+
CheckInputSchema(SchemaShape.Create(input.Schema));
294+
295+
TextClassificationTransformer transformer = default;
296+
293297
using (var ch = _host.Start("TrainModel"))
294298
using (var pch = _host.StartProgressChannel("Training model"))
295299
{
@@ -304,11 +308,13 @@ public TextClassificationTransformer Fit(IDataView input)
304308
if (_options.ValidationSet != null)
305309
trainer.Validate(pch, ch, i);
306310
}
307-
_transformer = new TextClassificationTransformer(_host, _options, trainer.Model, trainer.Tokenizer.Vocabulary);
311+
var labelCol = input.Schema.GetColumnOrNull(_options.LabelColumnName);
312+
313+
transformer = new TextClassificationTransformer(_host, _options, trainer.Model, new DataViewSchema.DetachedColumn(labelCol.Value));
308314

309-
_transformer.GetOutputSchema(input.Schema);
315+
transformer.GetOutputSchema(input.Schema);
310316
}
311-
return _transformer;
317+
return transformer;
312318
}
313319

314320
private class Trainer
@@ -668,31 +674,32 @@ public sealed class TextClassificationTransformer : RowToRowTransformerBase
668674

669675
private readonly Device _device;
670676
private readonly TextClassificationModel _model;
671-
private readonly Vocabulary _vocabulary;
672677
private readonly TextClassificationTrainer.Options _options;
673678

674679
private readonly string _predictedLabelColumnName;
675680
private readonly string _scoreColumnName;
676681

677682
public readonly SchemaShape.Column SentenceColumn;
678683
public readonly SchemaShape.Column SentenceColumn2;
679-
public readonly SchemaShape.Column LabelColumn;
684+
public readonly DataViewSchema.DetachedColumn LabelColumn;
680685

681686
internal const string LoaderSignature = "NASBERT";
682687

683-
internal TextClassificationTransformer(IHostEnvironment env, TextClassificationTrainer.Options options, TextClassificationModel model, Vocabulary vocabulary)
688+
private static readonly FuncStaticMethodInfo1<object, Delegate> _decodeInitMethodInfo
689+
= new FuncStaticMethodInfo1<object, Delegate>(DecodeInit<int>);
690+
691+
internal TextClassificationTransformer(IHostEnvironment env, TextClassificationTrainer.Options options, TextClassificationModel model, DataViewSchema.DetachedColumn labelColumn)
684692
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TextClassificationTransformer)))
685693
{
686694
_device = TorchUtils.InitializeDevice(env);
687695

688696
_options = options;
689-
LabelColumn = new SchemaShape.Column(_options.LabelColumnName, SchemaShape.Column.VectorKind.Scalar, NumberDataViewType.UInt32, true);
697+
LabelColumn = labelColumn;
690698
SentenceColumn = new SchemaShape.Column(_options.Sentence1ColumnName, SchemaShape.Column.VectorKind.Scalar, TextDataViewType.Instance, false);
691699
SentenceColumn2 = _options.Sentence2ColumnName == default ? default : new SchemaShape.Column(_options.Sentence2ColumnName, SchemaShape.Column.VectorKind.Scalar, TextDataViewType.Instance, false);
692700
_predictedLabelColumnName = _options.PredictionColumnName;
693701
_scoreColumnName = _options.ScoreColumnName;
694702

695-
_vocabulary = vocabulary;
696703
_model = model;
697704

698705
if (_device == CUDA)
@@ -736,24 +743,45 @@ private static TextClassificationTransformer Create(IHostEnvironment env, ModelL
736743
if (!ctx.TryLoadBinaryStream("TSModel", r => model.load(r)))
737744
throw env.ExceptDecode();
738745

739-
return new TextClassificationTransformer(env, options, model, vocabulary);
746+
BinarySaver saver = new BinarySaver(env, new BinarySaver.Arguments());
747+
DataViewType type;
748+
object value;
749+
env.CheckDecode(saver.TryLoadTypeAndValue(ctx.Reader.BaseStream, out type, out value));
750+
var vecType = type as VectorDataViewType;
751+
env.CheckDecode(vecType != null);
752+
env.CheckDecode(value != null);
753+
var labelGetter = Microsoft.ML.Internal.Utilities.Utils.MarshalInvoke(_decodeInitMethodInfo, vecType.ItemType.RawType, value);
754+
755+
var meta = new DataViewSchema.Annotations.Builder();
756+
meta.Add(AnnotationUtils.Kinds.KeyValues, type, labelGetter);
757+
758+
var labelCol = new DataViewSchema.DetachedColumn(options.LabelColumnName, type, meta.ToAnnotations());
759+
760+
return new TextClassificationTransformer(env, options, model, labelCol);
761+
}
762+
763+
private static Delegate DecodeInit<T>(object value)
764+
{
765+
VBuffer<T> buffValue = (VBuffer<T>)value;
766+
ValueGetter<VBuffer<T>> buffGetter = (ref VBuffer<T> dst) => buffValue.CopyTo(ref dst);
767+
return buffGetter;
740768
}
741769

742770
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
743771
{
744772
Host.CheckValue(inputSchema, nameof(inputSchema));
745773

746774
CheckInputSchema(inputSchema);
747-
inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol);
748-
var predLabelMetadata = new SchemaShape(labelCol.Annotations.Where(x => x.Name == AnnotationUtils.Kinds.KeyValues)
775+
var labelAnnotationsColumn = new SchemaShape.Column(AnnotationUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, LabelColumn.Annotations.Schema[AnnotationUtils.Kinds.SlotNames].Type, false);
776+
var predLabelMetadata = new SchemaShape(new SchemaShape.Column[] { labelAnnotationsColumn }
749777
.Concat(AnnotationUtils.GetTrainerOutputAnnotation()));
750778

751779
var outColumns = inputSchema.ToDictionary(x => x.Name);
752780
outColumns[_predictedLabelColumnName] = new SchemaShape.Column(_predictedLabelColumnName, SchemaShape.Column.VectorKind.Scalar,
753781
NumberDataViewType.UInt32, true, predLabelMetadata);
754782

755783
outColumns[_scoreColumnName] = new SchemaShape.Column(_scoreColumnName, SchemaShape.Column.VectorKind.Vector,
756-
NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.AnnotationsForMulticlassScoreColumn(labelCol)));
784+
NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.AnnotationsForMulticlassScoreColumn(labelAnnotationsColumn)));
757785

758786
return new SchemaShape(outColumns.Values);
759787
}
@@ -775,12 +803,6 @@ private void CheckInputSchema(SchemaShape inputSchema)
775803
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "sentence2", SentenceColumn2.Name,
776804
SentenceColumn2.GetTypeString(), sentenceCol2.GetTypeString());
777805
}
778-
779-
if (!inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol))
780-
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "label", LabelColumn.Name);
781-
if (!LabelColumn.IsCompatibleWith(labelCol))
782-
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "label", LabelColumn.Name,
783-
LabelColumn.GetTypeString(), labelCol.GetTypeString());
784806
}
785807

786808
private static VersionInfo GetVersionInfo()
@@ -819,6 +841,22 @@ private protected override void SaveModel(ModelSaveContext ctx)
819841
{
820842
_model.save(w);
821843
});
844+
845+
var labelColType = LabelColumn.Annotations.Schema[AnnotationUtils.Kinds.KeyValues].Type as VectorDataViewType;
846+
Microsoft.ML.Internal.Utilities.Utils.MarshalActionInvoke(SaveLabelValues<int>, labelColType.ItemType.RawType, ctx);
847+
}
848+
849+
private void SaveLabelValues<T>(ModelSaveContext ctx)
850+
{
851+
ValueGetter<VBuffer<T>> getter = LabelColumn.Annotations.GetGetter<VBuffer<T>>(LabelColumn.Annotations.Schema[AnnotationUtils.Kinds.KeyValues]);
852+
var val = default(VBuffer<T>);
853+
getter(ref val);
854+
855+
BinarySaver saver = new BinarySaver(Host, new BinarySaver.Arguments());
856+
int bytesWritten;
857+
var labelColType = LabelColumn.Annotations.Schema[AnnotationUtils.Kinds.KeyValues].Type as VectorDataViewType;
858+
if (!saver.TryWriteTypeAndValue<VBuffer<T>>(ctx.Writer.BaseStream, labelColType, ref val, out bytesWritten))
859+
throw Host.Except("We do not know how to serialize label names of type '{0}'", labelColType.ItemType);
822860
}
823861

824862
private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema);
@@ -827,25 +865,24 @@ private sealed class Mapper : MapperBase
827865
{
828866
private readonly TextClassificationTransformer _parent;
829867
private readonly HashSet<int> _inputColIndices;
830-
private readonly DataViewSchema.Column _labelCol;
831868
private readonly DataViewSchema _inputSchema;
832-
private static readonly FuncInstanceMethodInfo1<Mapper, Delegate> _makeLabelAnnotationGetter
833-
= FuncInstanceMethodInfo1<Mapper, Delegate>.Create(target => target.GetLabelAnnotations<int>);
869+
870+
private static readonly FuncInstanceMethodInfo1<Mapper, DataViewSchema.DetachedColumn, Delegate> _makeLabelAnnotationGetter
871+
= FuncInstanceMethodInfo1<Mapper, DataViewSchema.DetachedColumn, Delegate>.Create(target => target.GetLabelAnnotations<int>);
872+
834873

835874
public Mapper(TextClassificationTransformer parent, DataViewSchema inputSchema) :
836875
base(Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper)), inputSchema, parent)
837876
{
838877
_parent = parent;
839878
_inputColIndices = new HashSet<int>();
840-
int col = 0;
841-
if (inputSchema.TryGetColumnIndex(parent._options.Sentence1ColumnName, out col))
879+
if (inputSchema.TryGetColumnIndex(parent._options.Sentence1ColumnName, out var col))
842880
_inputColIndices.Add(col);
843881

844882
if (parent._options.Sentence2ColumnName != default)
845883
if (inputSchema.TryGetColumnIndex(parent._options.Sentence2ColumnName, out col))
846884
_inputColIndices.Add(col);
847885

848-
_labelCol = inputSchema[_parent._options.LabelColumnName];
849886
_inputSchema = inputSchema;
850887

851888
torch.random.manual_seed(1);
@@ -855,8 +892,9 @@ public Mapper(TextClassificationTransformer parent, DataViewSchema inputSchema)
855892
protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
856893
{
857894
var info = new DataViewSchema.DetachedColumn[2];
858-
var keyType = _labelCol.Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.KeyValues)?.Type as VectorDataViewType;
859-
var getter = Microsoft.ML.Internal.Utilities.Utils.MarshalInvoke(_makeLabelAnnotationGetter, this, keyType.ItemType.RawType);
895+
var keyType = _parent.LabelColumn.Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.KeyValues)?.Type as VectorDataViewType;
896+
var getter = Microsoft.ML.Internal.Utilities.Utils.MarshalInvoke(_makeLabelAnnotationGetter, this, keyType.ItemType.RawType, _parent.LabelColumn);
897+
860898

861899
var meta = new DataViewSchema.Annotations.Builder();
862900
meta.Add(AnnotationUtils.Kinds.ScoreColumnKind, TextDataViewType.Instance, (ref ReadOnlyMemory<char> value) => { value = AnnotationUtils.Const.ScoreColumnKind.MulticlassClassification.AsMemory(); });
@@ -865,15 +903,18 @@ protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
865903
meta.Add(AnnotationUtils.Kinds.TrainingLabelValues, keyType, getter);
866904
meta.Add(AnnotationUtils.Kinds.SlotNames, keyType, getter);
867905

868-
info[0] = new DataViewSchema.DetachedColumn(_parent._options.PredictionColumnName, new KeyDataViewType(typeof(uint), _parent._options.NumberOfClasses), _labelCol.Annotations);
906+
var labelBuilder = new DataViewSchema.Annotations.Builder();
907+
labelBuilder.Add(AnnotationUtils.Kinds.KeyValues, keyType, getter);
908+
909+
info[0] = new DataViewSchema.DetachedColumn(_parent._options.PredictionColumnName, new KeyDataViewType(typeof(uint), _parent._options.NumberOfClasses), labelBuilder.ToAnnotations());
869910

870911
info[1] = new DataViewSchema.DetachedColumn(_parent._options.ScoreColumnName, new VectorDataViewType(NumberDataViewType.Single, _parent._options.NumberOfClasses), meta.ToAnnotations());
871912
return info;
872913
}
873914

874-
private Delegate GetLabelAnnotations<T>()
915+
private Delegate GetLabelAnnotations<T>(DataViewSchema.DetachedColumn labelCol)
875916
{
876-
return _labelCol.Annotations.GetGetter<VBuffer<T>>(_labelCol.Annotations.Schema[AnnotationUtils.Kinds.KeyValues]);
917+
return labelCol.Annotations.GetGetter<VBuffer<T>>(labelCol.Annotations.Schema[AnnotationUtils.Kinds.KeyValues]);
877918
}
878919

879920
private ValueGetter<uint> GetScoreColumnSetId(DataViewSchema schema)

test/Microsoft.ML.Tests/TextClassificationTests.cs

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ private class TestSingleSentenceData
2929
public string Sentiment;
3030
}
3131

32+
private class TestSingleSentenceDataNoLabel
33+
{
34+
public string Sentence1;
35+
}
36+
3237
private class TestDoubleSentenceData
3338
{
3439
public string Sentence;
@@ -82,7 +87,8 @@ public void TestSingleSentence2Classes()
8287
Sentiment = "Negative"
8388
}
8489
}));
85-
var estimator = ML.Transforms.Conversion.MapValueToKey("Label", "Sentiment")
90+
var chain = new EstimatorChain<ITransformer>();
91+
var estimator = chain.Append(ML.Transforms.Conversion.MapValueToKey("Label", "Sentiment"), TransformerScope.TrainTest)
8692
.Append(ML.MulticlassClassification.Trainers.TextClassification(outputColumnName: "outputColumn"))
8793
.Append(ML.Transforms.Conversion.MapKeyToValue("outputColumn"));
8894

@@ -96,11 +102,49 @@ public void TestSingleSentence2Classes()
96102
var transformer = estimator.Fit(dataView);
97103
var transformerSchema = transformer.GetOutputSchema(dataView.Schema);
98104

105+
var filteredModel = transformer.GetModelFor(TransformerScope.Scoring);
106+
99107
Assert.Equal(6, transformerSchema.Count);
100108
Assert.Equal("outputColumn", transformerSchema[4].Name);
101109
Assert.Equal(TextDataViewType.Instance, transformerSchema[4].Type);
102110

103-
var predictedLabel = transformer.Transform(dataView).GetColumn<ReadOnlyMemory<char>>(transformerSchema[4].Name);
111+
var dataNoLabel = ML.Data.LoadFromEnumerable(
112+
new List<TestSingleSentenceDataNoLabel>(new TestSingleSentenceDataNoLabel[] {
113+
new ()
114+
{ // Testing longer than 512 words.
115+
Sentence1 = "ultimately feels as flat as the scruffy sands of its titular community .ultimately feels as flat as the scruffy sands of its titular community .ultimately feels as flat as the scruffy sands of its titular community .ultimately feels as flat as the scruffy sands of its titular community .ultimately feels as flat as the scruffy sands of its titular community .ultimately feels as flat as the scruffy sands of its titular community .ultimately feels as flat as the scruffy sands of its titular community .ultimately feels as flat as the scruffy sands of its titular community . ultimately feels as flat as the scruffy sands of its titular community .ultimately feels as flat as the scruffy sands of its titular community .ultimately feels as flat as the scruffy sands of its titular community .ultimately feels as flat as the scruffy sands of its titular community .ultimately feels as flat as the scruffy sands of its titular community .ultimately feels as flat as the scruffy sands of its titular community .ultimately feels as flat as the scruffy sands of its titular community .ultimately feels as flat as the scruffy sands of its titular community . ultimately feels as flat as the scruffy sands of its titular community .ultimately feels as flat as the scruffy sands of its titular community .ultimately feels as flat as the scruffy sands of its titular community .ultimately feels as flat as the scruffy sands of its titular community .ultimately feels as flat as the scruffy sands of its titular community .ultimately feels as flat as the scruffy sands of its titular community .ultimately feels as flat as the scruffy sands of its titular community .ultimately feels as flat as the scruffy sands of its titular community . ultimately feels as flat as the scruffy sands of its titular community .ultimately feels as flat as the scruffy sands of its titular community .ultimately feels as flat as the scruffy sands of its titular community .ultimately feels as flat as the scruffy sands of its titular community .ultimately feels as flat as the scruffy sands of its titular community .ultimately feels as flat as the scruffy sands of its titular community .ultimately feels as flat as the scruffy sands of its titular community .ultimately feels as flat as the scruffy sands of its titular community .",
116+
},
117+
new ()
118+
{
119+
Sentence1 = "with a sharp script and strong performances",
120+
},
121+
new ()
122+
{
123+
Sentence1 = "that director m. night shyamalan can weave an eerie spell and",
124+
},
125+
new ()
126+
{
127+
Sentence1 = "comfortable",
128+
},
129+
new ()
130+
{
131+
Sentence1 = "does have its charms .",
132+
},
133+
new ()
134+
{
135+
Sentence1 = "banal as the telling",
136+
},
137+
new ()
138+
{
139+
Sentence1 = "faithful without being forceful , sad without being shrill , `` a walk to remember '' succeeds through sincerity .",
140+
},
141+
new ()
142+
{
143+
Sentence1 = "leguizamo 's best movie work so far",
144+
}
145+
}));
146+
147+
var predictedLabel = filteredModel.Transform(dataNoLabel).GetColumn<ReadOnlyMemory<char>>(transformerSchema[4].Name);
104148

105149
// Make sure that we can use the multiclass evaluate method
106150
var metrics = ML.MulticlassClassification.Evaluate(transformer.Transform(dataView), predictedLabelColumnName: "outputColumn");

0 commit comments

Comments
 (0)