diff --git a/src/Microsoft.ML.Api/SchemaDefinition.cs b/src/Microsoft.ML.Api/SchemaDefinition.cs index 5f84712625..ef181244ab 100644 --- a/src/Microsoft.ML.Api/SchemaDefinition.cs +++ b/src/Microsoft.ML.Api/SchemaDefinition.cs @@ -67,8 +67,9 @@ public VectorTypeAttribute(params int[] dims) /// column encapsulates. /// [AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)] - public sealed class ColumnAttribute : Attribute + public class ColumnAttribute : Attribute { + public ColumnAttribute(string ordinal, string name = null) { Name = name; @@ -93,6 +94,61 @@ public ColumnAttribute(string ordinal, string name = null) public string Ordinal { get; } } + /// + /// Describes 'Label' column with indicies. + /// + public sealed class LabelColumnAttribute : ColumnAttribute + { + public LabelColumnAttribute(string ordinal): + base(ordinal, "Label") + { + } + } + + /// + /// Describes 'Features' column with indicies. + /// + public sealed class FeaturesColumnAttribute : ColumnAttribute + { + public FeaturesColumnAttribute(string ordinal) : + base(ordinal, "Features") + { + } + } + + /// + /// Describes 'Weight' column with indicies. + /// + public sealed class WeightColumnAttribute : ColumnAttribute + { + public WeightColumnAttribute(string ordinal) : + base(ordinal, "Weight") + { + } + } + + /// + /// Describes 'GroupId' column with indicies. + /// + public sealed class GroupColumnAttribute : ColumnAttribute + { + public GroupColumnAttribute(string ordinal) : + base(ordinal, "GroupId") + { + } + } + + /// + /// Describes 'Name' column with indicies. + /// + public sealed class NameColumnAttribute : ColumnAttribute + { + public NameColumnAttribute(string ordinal) : + base(ordinal, "Name") + { + } + } + /// /// Allows a member to specify its column name directly, as opposed to the default /// behavior of using the member name as the column name. diff --git a/test/Microsoft.ML.Tests/Scenarios/HousePricePredictionTests.cs b/test/Microsoft.ML.Tests/Scenarios/HousePricePredictionTests.cs index 392462a0eb..ba8bdb28df 100644 --- a/test/Microsoft.ML.Tests/Scenarios/HousePricePredictionTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/HousePricePredictionTests.cs @@ -60,7 +60,7 @@ public class HousePriceData [Column(ordinal: "1")] public string Date; - [Column(ordinal: "2", name: "Label")] + [LabelColumn(ordinal: "2")] public float Price; [Column(ordinal: "3")] diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithFeatureVector.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithFeatureVector.cs new file mode 100644 index 0000000000..fd22c3dc22 --- /dev/null +++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithFeatureVector.cs @@ -0,0 +1,117 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Data; +using Microsoft.ML.Models; +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Trainers; +using Microsoft.ML.Transforms; +using Xunit; + +namespace Microsoft.ML.Scenarios +{ + public partial class ScenariosTests + { + [Fact] + public void TrainAndPredictIrisModelWithFeatureVectorTest() + { + string dataPath = GetDataPath("iris.data"); + + var pipeline = new LearningPipeline(); + + pipeline.Add(new TextLoader(dataPath).CreateFrom(useHeader: false, separator: ',')); + + pipeline.Add(new Dictionarizer("Label")); // "IrisPlantType" is used as "Label" because of column attribute name on the field. + + pipeline.Add(new StochasticDualCoordinateAscentClassifier()); + + PredictionModel model = pipeline.Train(); + + IrisPrediction prediction = model.Predict(new IrisDataWithFeatureVector() + { + Feat = new float[] { 5.1f, 3.3f, 1.6f, 0.2f } + }); + + Assert.Equal(1, prediction.PredictedLabels[0], 2); + Assert.Equal(0, prediction.PredictedLabels[1], 2); + Assert.Equal(0, prediction.PredictedLabels[2], 2); + + prediction = model.Predict(new IrisDataWithFeatureVector() + { + Feat = new float[] { 6.4f, 3.1f, 5.5f, 2.2f } + }); + + Assert.Equal(0, prediction.PredictedLabels[0], 2); + Assert.Equal(0, prediction.PredictedLabels[1], 2); + Assert.Equal(1, prediction.PredictedLabels[2], 2); + + prediction = model.Predict(new IrisDataWithFeatureVector() + { + Feat = new float[] { 4.4f, 3.1f, 2.5f, 1.2f } + }); + + Assert.Equal(.2, prediction.PredictedLabels[0], 1); + Assert.Equal(.8, prediction.PredictedLabels[1], 1); + Assert.Equal(0, prediction.PredictedLabels[2], 2); + + // Note: Testing against the same data set as a simple way to test evaluation. + // This isn't appropriate in real-world scenarios. + string testDataPath = GetDataPath("iris.data"); + var testData = new TextLoader(testDataPath).CreateFrom(useHeader: false, separator: ','); + + var evaluator = new ClassificationEvaluator(); + evaluator.OutputTopKAcc = 3; + ClassificationMetrics metrics = evaluator.Evaluate(model, testData); + + Assert.Equal(.98, metrics.AccuracyMacro); + Assert.Equal(.98, metrics.AccuracyMicro, 2); + Assert.Equal(.06, metrics.LogLoss, 2); + Assert.InRange(metrics.LogLossReduction, 94, 96); + Assert.Equal(1, metrics.TopKAccuracy); + + Assert.Equal(3, metrics.PerClassLogLoss.Length); + Assert.Equal(0, metrics.PerClassLogLoss[0], 1); + Assert.Equal(.1, metrics.PerClassLogLoss[1], 1); + Assert.Equal(.1, metrics.PerClassLogLoss[2], 1); + + ConfusionMatrix matrix = metrics.ConfusionMatrix; + Assert.Equal(3, matrix.Order); + Assert.Equal(3, matrix.ClassNames.Count); + Assert.Equal("Iris-setosa", matrix.ClassNames[0]); + Assert.Equal("Iris-versicolor", matrix.ClassNames[1]); + Assert.Equal("Iris-virginica", matrix.ClassNames[2]); + + Assert.Equal(50, matrix[0, 0]); + Assert.Equal(50, matrix["Iris-setosa", "Iris-setosa"]); + Assert.Equal(0, matrix[0, 1]); + Assert.Equal(0, matrix["Iris-setosa", "Iris-versicolor"]); + Assert.Equal(0, matrix[0, 2]); + Assert.Equal(0, matrix["Iris-setosa", "Iris-virginica"]); + + Assert.Equal(0, matrix[1, 0]); + Assert.Equal(0, matrix["Iris-versicolor", "Iris-setosa"]); + Assert.Equal(48, matrix[1, 1]); + Assert.Equal(48, matrix["Iris-versicolor", "Iris-versicolor"]); + Assert.Equal(2, matrix[1, 2]); + Assert.Equal(2, matrix["Iris-versicolor", "Iris-virginica"]); + + Assert.Equal(0, matrix[2, 0]); + Assert.Equal(0, matrix["Iris-virginica", "Iris-setosa"]); + Assert.Equal(1, matrix[2, 1]); + Assert.Equal(1, matrix["Iris-virginica", "Iris-versicolor"]); + Assert.Equal(49, matrix[2, 2]); + Assert.Equal(49, matrix["Iris-virginica", "Iris-virginica"]); + } + + public class IrisDataWithFeatureVector + { + [FeaturesColumn("0-3")] + [VectorType(4)] + public float[] Feat; + + [LabelColumn("4")] + public string IrisPlantType; + } + } +} diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs index ebddc33b03..286771d719 100644 --- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs @@ -130,7 +130,7 @@ public class IrisDataWithStringLabel [Column("3")] public float PetalLength; - [Column("4", name: "Label")] + [LabelColumn("4")] public string IrisPlantType; } } diff --git a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs index 80947644e9..235fc7432c 100644 --- a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs @@ -145,7 +145,7 @@ public void TrainAndPredictSentimentModelTest() public class SentimentData { - [Column(ordinal: "0", name: "Label")] + [LabelColumn(ordinal: "0")] public float Sentiment; [Column(ordinal: "1")] public string SentimentText; diff --git a/test/Microsoft.ML.Tests/TextLoaderTests.cs b/test/Microsoft.ML.Tests/TextLoaderTests.cs index 40c0b6525f..2df3636bdd 100644 --- a/test/Microsoft.ML.Tests/TextLoaderTests.cs +++ b/test/Microsoft.ML.Tests/TextLoaderTests.cs @@ -228,6 +228,79 @@ public void ThrowsExceptionWithPropertyName() Assert.StartsWith("String1 is missing ColumnAttribute", ex.Message); } + + [Fact] + public void CanSuccessfullyNamedColumns() + { + string dataPath = GetDataPath("SparseData.txt"); + var loader = new Data.TextLoader(dataPath).CreateFrom(useHeader: true, allowQuotedStrings: false, supportSparse: true); + + using (var environment = new TlcEnvironment()) + { + Experiment experiment = environment.CreateExperiment(); + ILearningPipelineDataStep output = loader.ApplyStep(null, experiment) as ILearningPipelineDataStep; + + experiment.Compile(); + loader.SetInput(environment, experiment); + experiment.Run(); + + IDataLoader data = experiment.GetOutput(output.Data) as IDataLoader; + Assert.NotNull(data); + + Assert.Equal(5, data.Schema.ColumnCount); + Assert.Equal("Name", data.Schema.GetColumnName(0)); + Assert.Equal("GroupId", data.Schema.GetColumnName(1)); + Assert.Equal("Weight", data.Schema.GetColumnName(2)); + Assert.Equal("Features", data.Schema.GetColumnName(3)); + Assert.Equal("Label", data.Schema.GetColumnName(4)); + + using (var cursor = data.GetRowCursor((a => true))) + { + var getters = new ValueGetter[]{ + cursor.GetGetter(0), + cursor.GetGetter(1), + cursor.GetGetter(2), + cursor.GetGetter(3), + cursor.GetGetter(4) + }; + + + Assert.True(cursor.MoveNext()); + + float[] targets = new float[] { 1, 2, 3, 4, 5 }; + for (int i = 0; i < getters.Length; i++) + { + float value = 0; + getters[i](ref value); + Assert.Equal(targets[i], value); + } + + Assert.True(cursor.MoveNext()); + + targets = new float[] { 0, 0, 0, 4, 5 }; + for (int i = 0; i < getters.Length; i++) + { + float value = 0; + getters[i](ref value); + Assert.Equal(targets[i], value); + } + + Assert.True(cursor.MoveNext()); + + targets = new float[] { 0, 2, 0, 0, 0 }; + for (int i = 0; i < getters.Length; i++) + { + float value = 0; + getters[i](ref value); + Assert.Equal(targets[i], value); + } + + Assert.False(cursor.MoveNext()); + } + } + + } + public class QuoteInput { [Column("0")] @@ -268,5 +341,23 @@ public class ModelWithoutColumnAttribute { public string String1; } + + public class SparseInputWithNamedColumns + { + [NameColumn("0")] + public float C1; + + [GroupColumn("1")] + public float C2; + + [WeightColumn("2")] + public float C3; + + [FeaturesColumn("3")] + public float C4; + + [LabelColumn("4")] + public float C5; + } } }