Skip to content

Commit c7b8e87

Browse files
authored
Use KeyTypeAttribute from Schema in CreateTextLoader (#5082)
* Read KeyTypeAttribute during CreateTextLoader * Changed to using Assert.Equal
1 parent 401928a commit c7b8e87

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1489,6 +1489,13 @@ internal static TextLoader CreateTextLoader<TInput>(IHostEnvironment host,
14891489
var column = new Column();
14901490
column.Name = mappingAttrName?.Name ?? memberInfo.Name;
14911491
column.Source = mappingAttr.Sources.ToArray();
1492+
1493+
var keyTypeAttr = memberInfo.GetCustomAttribute<KeyTypeAttribute>();
1494+
if (keyTypeAttr != null)
1495+
{
1496+
column.KeyCount = keyTypeAttr.KeyCount;
1497+
}
1498+
14921499
InternalDataKind dk;
14931500
switch (memberInfo)
14941501
{

test/Microsoft.ML.Tests/TextLoaderTests.cs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -877,5 +877,48 @@ public void TestTextLoaderNoFields()
877877
Assert.StartsWith("Should define at least one public, readable field or property in TInput.", ex.Message);
878878
}
879879
}
880+
881+
public class BreastCancerInputModelWithKeyType
882+
{
883+
[LoadColumn(0)]
884+
public bool IsMalignant { get; set; }
885+
886+
[LoadColumn(1), KeyType(10)]
887+
public uint Thickness { get; set; }
888+
}
889+
890+
public class BreastCancerInputModelWithoutKeyType
891+
{
892+
[LoadColumn(0)]
893+
public bool IsMalignant { get; set; }
894+
895+
[LoadColumn(1)]
896+
public uint Thickness { get; set; }
897+
}
898+
899+
[Fact]
900+
public void TestLoadTextWithKeyTypeAttribute()
901+
{
902+
ulong expectedCount = 10;
903+
904+
var mlContext = new MLContext(seed: 1);
905+
string breastCancerPath = GetDataPath(TestDatasets.breastCancer.trainFilename);
906+
907+
var data = mlContext.Data.CreateTextLoader<BreastCancerInputModelWithKeyType>(separatorChar: ',').Load(breastCancerPath);
908+
909+
Assert.Equal(expectedCount, data.Schema[1].Type.GetKeyCount());
910+
}
911+
912+
[Fact]
913+
public void TestLoadTextWithoutKeyTypeAttribute()
914+
{
915+
ulong expectedCount = 0;
916+
var mlContext = new MLContext(seed: 1);
917+
string breastCancerPath = GetDataPath(TestDatasets.breastCancer.trainFilename);
918+
919+
var data = mlContext.Data.CreateTextLoader<BreastCancerInputModelWithoutKeyType>(separatorChar: ',').Load(breastCancerPath);
920+
921+
Assert.Equal(expectedCount, data.Schema[1].Type.GetKeyCount());
922+
}
880923
}
881924
}

0 commit comments

Comments
 (0)