diff --git a/src/Microsoft.ML.Api/SchemaDefinition.cs b/src/Microsoft.ML.Api/SchemaDefinition.cs
index 5f84712625..ef181244ab 100644
--- a/src/Microsoft.ML.Api/SchemaDefinition.cs
+++ b/src/Microsoft.ML.Api/SchemaDefinition.cs
@@ -67,8 +67,9 @@ public VectorTypeAttribute(params int[] dims)
/// column encapsulates.
///
[AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = true)]
- public sealed class ColumnAttribute : Attribute
+ public class ColumnAttribute : Attribute
{
+
public ColumnAttribute(string ordinal, string name = null)
{
Name = name;
@@ -93,6 +94,61 @@ public ColumnAttribute(string ordinal, string name = null)
public string Ordinal { get; }
}
+ ///
+ /// Describes 'Label' column with indicies.
+ ///
+ public sealed class LabelColumnAttribute : ColumnAttribute
+ {
+ public LabelColumnAttribute(string ordinal):
+ base(ordinal, "Label")
+ {
+ }
+ }
+
+ ///
+ /// Describes 'Features' column with indicies.
+ ///
+ public sealed class FeaturesColumnAttribute : ColumnAttribute
+ {
+ public FeaturesColumnAttribute(string ordinal) :
+ base(ordinal, "Features")
+ {
+ }
+ }
+
+ ///
+ /// Describes 'Weight' column with indicies.
+ ///
+ public sealed class WeightColumnAttribute : ColumnAttribute
+ {
+ public WeightColumnAttribute(string ordinal) :
+ base(ordinal, "Weight")
+ {
+ }
+ }
+
+ ///
+ /// Describes 'GroupId' column with indicies.
+ ///
+ public sealed class GroupColumnAttribute : ColumnAttribute
+ {
+ public GroupColumnAttribute(string ordinal) :
+ base(ordinal, "GroupId")
+ {
+ }
+ }
+
+ ///
+ /// Describes 'Name' column with indicies.
+ ///
+ public sealed class NameColumnAttribute : ColumnAttribute
+ {
+ public NameColumnAttribute(string ordinal) :
+ base(ordinal, "Name")
+ {
+ }
+ }
+
///
/// Allows a member to specify its column name directly, as opposed to the default
/// behavior of using the member name as the column name.
diff --git a/test/Microsoft.ML.Tests/Scenarios/HousePricePredictionTests.cs b/test/Microsoft.ML.Tests/Scenarios/HousePricePredictionTests.cs
index 392462a0eb..ba8bdb28df 100644
--- a/test/Microsoft.ML.Tests/Scenarios/HousePricePredictionTests.cs
+++ b/test/Microsoft.ML.Tests/Scenarios/HousePricePredictionTests.cs
@@ -60,7 +60,7 @@ public class HousePriceData
[Column(ordinal: "1")]
public string Date;
- [Column(ordinal: "2", name: "Label")]
+ [LabelColumn(ordinal: "2")]
public float Price;
[Column(ordinal: "3")]
diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithFeatureVector.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithFeatureVector.cs
new file mode 100644
index 0000000000..fd22c3dc22
--- /dev/null
+++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithFeatureVector.cs
@@ -0,0 +1,117 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Data;
+using Microsoft.ML.Models;
+using Microsoft.ML.Runtime.Api;
+using Microsoft.ML.Trainers;
+using Microsoft.ML.Transforms;
+using Xunit;
+
+namespace Microsoft.ML.Scenarios
+{
+ public partial class ScenariosTests
+ {
+ [Fact]
+ public void TrainAndPredictIrisModelWithFeatureVectorTest()
+ {
+ string dataPath = GetDataPath("iris.data");
+
+ var pipeline = new LearningPipeline();
+
+ pipeline.Add(new TextLoader(dataPath).CreateFrom(useHeader: false, separator: ','));
+
+ pipeline.Add(new Dictionarizer("Label")); // "IrisPlantType" is used as "Label" because of column attribute name on the field.
+
+ pipeline.Add(new StochasticDualCoordinateAscentClassifier());
+
+ PredictionModel model = pipeline.Train();
+
+ IrisPrediction prediction = model.Predict(new IrisDataWithFeatureVector()
+ {
+ Feat = new float[] { 5.1f, 3.3f, 1.6f, 0.2f }
+ });
+
+ Assert.Equal(1, prediction.PredictedLabels[0], 2);
+ Assert.Equal(0, prediction.PredictedLabels[1], 2);
+ Assert.Equal(0, prediction.PredictedLabels[2], 2);
+
+ prediction = model.Predict(new IrisDataWithFeatureVector()
+ {
+ Feat = new float[] { 6.4f, 3.1f, 5.5f, 2.2f }
+ });
+
+ Assert.Equal(0, prediction.PredictedLabels[0], 2);
+ Assert.Equal(0, prediction.PredictedLabels[1], 2);
+ Assert.Equal(1, prediction.PredictedLabels[2], 2);
+
+ prediction = model.Predict(new IrisDataWithFeatureVector()
+ {
+ Feat = new float[] { 4.4f, 3.1f, 2.5f, 1.2f }
+ });
+
+ Assert.Equal(.2, prediction.PredictedLabels[0], 1);
+ Assert.Equal(.8, prediction.PredictedLabels[1], 1);
+ Assert.Equal(0, prediction.PredictedLabels[2], 2);
+
+ // Note: Testing against the same data set as a simple way to test evaluation.
+ // This isn't appropriate in real-world scenarios.
+ string testDataPath = GetDataPath("iris.data");
+ var testData = new TextLoader(testDataPath).CreateFrom(useHeader: false, separator: ',');
+
+ var evaluator = new ClassificationEvaluator();
+ evaluator.OutputTopKAcc = 3;
+ ClassificationMetrics metrics = evaluator.Evaluate(model, testData);
+
+ Assert.Equal(.98, metrics.AccuracyMacro);
+ Assert.Equal(.98, metrics.AccuracyMicro, 2);
+ Assert.Equal(.06, metrics.LogLoss, 2);
+ Assert.InRange(metrics.LogLossReduction, 94, 96);
+ Assert.Equal(1, metrics.TopKAccuracy);
+
+ Assert.Equal(3, metrics.PerClassLogLoss.Length);
+ Assert.Equal(0, metrics.PerClassLogLoss[0], 1);
+ Assert.Equal(.1, metrics.PerClassLogLoss[1], 1);
+ Assert.Equal(.1, metrics.PerClassLogLoss[2], 1);
+
+ ConfusionMatrix matrix = metrics.ConfusionMatrix;
+ Assert.Equal(3, matrix.Order);
+ Assert.Equal(3, matrix.ClassNames.Count);
+ Assert.Equal("Iris-setosa", matrix.ClassNames[0]);
+ Assert.Equal("Iris-versicolor", matrix.ClassNames[1]);
+ Assert.Equal("Iris-virginica", matrix.ClassNames[2]);
+
+ Assert.Equal(50, matrix[0, 0]);
+ Assert.Equal(50, matrix["Iris-setosa", "Iris-setosa"]);
+ Assert.Equal(0, matrix[0, 1]);
+ Assert.Equal(0, matrix["Iris-setosa", "Iris-versicolor"]);
+ Assert.Equal(0, matrix[0, 2]);
+ Assert.Equal(0, matrix["Iris-setosa", "Iris-virginica"]);
+
+ Assert.Equal(0, matrix[1, 0]);
+ Assert.Equal(0, matrix["Iris-versicolor", "Iris-setosa"]);
+ Assert.Equal(48, matrix[1, 1]);
+ Assert.Equal(48, matrix["Iris-versicolor", "Iris-versicolor"]);
+ Assert.Equal(2, matrix[1, 2]);
+ Assert.Equal(2, matrix["Iris-versicolor", "Iris-virginica"]);
+
+ Assert.Equal(0, matrix[2, 0]);
+ Assert.Equal(0, matrix["Iris-virginica", "Iris-setosa"]);
+ Assert.Equal(1, matrix[2, 1]);
+ Assert.Equal(1, matrix["Iris-virginica", "Iris-versicolor"]);
+ Assert.Equal(49, matrix[2, 2]);
+ Assert.Equal(49, matrix["Iris-virginica", "Iris-virginica"]);
+ }
+
+ public class IrisDataWithFeatureVector
+ {
+ [FeaturesColumn("0-3")]
+ [VectorType(4)]
+ public float[] Feat;
+
+ [LabelColumn("4")]
+ public string IrisPlantType;
+ }
+ }
+}
diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs
index ebddc33b03..286771d719 100644
--- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs
+++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs
@@ -130,7 +130,7 @@ public class IrisDataWithStringLabel
[Column("3")]
public float PetalLength;
- [Column("4", name: "Label")]
+ [LabelColumn("4")]
public string IrisPlantType;
}
}
diff --git a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs
index 80947644e9..235fc7432c 100644
--- a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs
+++ b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs
@@ -145,7 +145,7 @@ public void TrainAndPredictSentimentModelTest()
public class SentimentData
{
- [Column(ordinal: "0", name: "Label")]
+ [LabelColumn(ordinal: "0")]
public float Sentiment;
[Column(ordinal: "1")]
public string SentimentText;
diff --git a/test/Microsoft.ML.Tests/TextLoaderTests.cs b/test/Microsoft.ML.Tests/TextLoaderTests.cs
index 40c0b6525f..2df3636bdd 100644
--- a/test/Microsoft.ML.Tests/TextLoaderTests.cs
+++ b/test/Microsoft.ML.Tests/TextLoaderTests.cs
@@ -228,6 +228,79 @@ public void ThrowsExceptionWithPropertyName()
Assert.StartsWith("String1 is missing ColumnAttribute", ex.Message);
}
+
+ [Fact]
+ public void CanSuccessfullyNamedColumns()
+ {
+ string dataPath = GetDataPath("SparseData.txt");
+ var loader = new Data.TextLoader(dataPath).CreateFrom(useHeader: true, allowQuotedStrings: false, supportSparse: true);
+
+ using (var environment = new TlcEnvironment())
+ {
+ Experiment experiment = environment.CreateExperiment();
+ ILearningPipelineDataStep output = loader.ApplyStep(null, experiment) as ILearningPipelineDataStep;
+
+ experiment.Compile();
+ loader.SetInput(environment, experiment);
+ experiment.Run();
+
+ IDataLoader data = experiment.GetOutput(output.Data) as IDataLoader;
+ Assert.NotNull(data);
+
+ Assert.Equal(5, data.Schema.ColumnCount);
+ Assert.Equal("Name", data.Schema.GetColumnName(0));
+ Assert.Equal("GroupId", data.Schema.GetColumnName(1));
+ Assert.Equal("Weight", data.Schema.GetColumnName(2));
+ Assert.Equal("Features", data.Schema.GetColumnName(3));
+ Assert.Equal("Label", data.Schema.GetColumnName(4));
+
+ using (var cursor = data.GetRowCursor((a => true)))
+ {
+ var getters = new ValueGetter[]{
+ cursor.GetGetter(0),
+ cursor.GetGetter(1),
+ cursor.GetGetter(2),
+ cursor.GetGetter(3),
+ cursor.GetGetter(4)
+ };
+
+
+ Assert.True(cursor.MoveNext());
+
+ float[] targets = new float[] { 1, 2, 3, 4, 5 };
+ for (int i = 0; i < getters.Length; i++)
+ {
+ float value = 0;
+ getters[i](ref value);
+ Assert.Equal(targets[i], value);
+ }
+
+ Assert.True(cursor.MoveNext());
+
+ targets = new float[] { 0, 0, 0, 4, 5 };
+ for (int i = 0; i < getters.Length; i++)
+ {
+ float value = 0;
+ getters[i](ref value);
+ Assert.Equal(targets[i], value);
+ }
+
+ Assert.True(cursor.MoveNext());
+
+ targets = new float[] { 0, 2, 0, 0, 0 };
+ for (int i = 0; i < getters.Length; i++)
+ {
+ float value = 0;
+ getters[i](ref value);
+ Assert.Equal(targets[i], value);
+ }
+
+ Assert.False(cursor.MoveNext());
+ }
+ }
+
+ }
+
public class QuoteInput
{
[Column("0")]
@@ -268,5 +341,23 @@ public class ModelWithoutColumnAttribute
{
public string String1;
}
+
+ public class SparseInputWithNamedColumns
+ {
+ [NameColumn("0")]
+ public float C1;
+
+ [GroupColumn("1")]
+ public float C2;
+
+ [WeightColumn("2")]
+ public float C3;
+
+ [FeaturesColumn("3")]
+ public float C4;
+
+ [LabelColumn("4")]
+ public float C5;
+ }
}
}