Skip to content

Commit 0f6972b

Browse files
author
Rogan Carr
committed
work in progress
1 parent a100505 commit 0f6972b

File tree

2 files changed

+118
-0
lines changed

2 files changed

+118
-0
lines changed
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using Microsoft.ML.Data;
8+
using Microsoft.ML.Functional.Tests.Datasets;
9+
using Microsoft.ML.RunTests;
10+
using Microsoft.ML.TestFramework;
11+
using Microsoft.ML.Transforms;
12+
using Xunit;
13+
using Xunit.Abstractions;
14+
15+
namespace Microsoft.ML.Functional.Tests
16+
{
17+
public class IntrospectiveTraining : BaseTestClass
18+
{
19+
public IntrospectiveTraining(ITestOutputHelper output): base(output)
20+
{
21+
}
22+
23+
/// <summary>
24+
/// Introspective Training: Map hashed values back to the original value.
25+
/// </summary>
26+
[Fact]
27+
public void InspectSlotNamesForReversibleHash()
28+
{
29+
var mlContext = new MLContext(seed: 1, conc: 1);
30+
31+
// Load the Adult dataset.
32+
var data = mlContext.Data.LoadFromTextFile<Adult>(GetDataPath(TestDatasets.adult.trainFilename),
33+
hasHeader: TestDatasets.adult.fileHasHeader,
34+
separatorChar: TestDatasets.adult.fileSeparator);
35+
36+
// Create the learning pipeline.
37+
var pipeline = mlContext.Transforms.Concatenate("NumericalFeatures", Adult.NumericalFeatures)
38+
.Append(mlContext.Transforms.Concatenate("CategoricalFeatures", Adult.CategoricalFeatures))
39+
.Append(mlContext.Transforms.Categorical.OneHotHashEncoding("CategoricalFeatures", hashBits: 8, // get collisions!
40+
invertHash: -1, outputKind: OneHotEncodingTransformer.OutputKind.Bag));
41+
42+
// Train the model.
43+
var model = pipeline.Fit(data);
44+
45+
// Transform the data.
46+
var transformedData = model.Transform(data);
47+
48+
// Verify that the slotnames cane be used to backtrack by confirming that
49+
// all unique values in the input data are in the output data slot names.
50+
// First get a list of the unique values.
51+
VBuffer<ReadOnlyMemory<char>> categoricalSlotNames = new VBuffer<ReadOnlyMemory<char>>();
52+
transformedData.Schema["CategoricalFeatures"].GetSlotNames(ref categoricalSlotNames);
53+
var uniqueValues = new HashSet<string>();
54+
foreach (var slotName in categoricalSlotNames.GetValues())
55+
{
56+
var slotNameString = slotName.ToString();
57+
if (slotNameString.StartsWith('{'))
58+
{
59+
// Values look like this: {3:Exec-managerial,2:Widowed}.
60+
slotNameString = slotNameString.Substring(1, slotNameString.Length - 2);
61+
foreach (var name in slotNameString.Split(','))
62+
uniqueValues.Add(name);
63+
}
64+
else
65+
uniqueValues.Add(slotNameString);
66+
}
67+
68+
// Now validate that all values in the dataset are there
69+
var transformedRows = mlContext.Data.CreateEnumerable<Adult>(data, false);
70+
foreach (var row in transformedRows)
71+
{
72+
for (int i = 0; i < Adult.CategoricalFeatures.Length; i++)
73+
{
74+
// Fetch the categorical value.
75+
string value = (string) row.GetType().GetProperty(Adult.CategoricalFeatures[i]).GetValue(row, null);
76+
Assert.Contains($"{i}:{value}", uniqueValues);
77+
}
78+
}
79+
80+
float x = (float)double.MinValue;
81+
Output.WriteLine($"{x}");
82+
}
83+
84+
//private void BooYa()
85+
//{
86+
// // Create the learning pipeline
87+
// var nestedPipeline = mlContext.Transforms.Concatenate("NumericalFeatures", Adult.NumericalFeatures)
88+
// .Append(mlContext.Transforms.Concatenate("CategoricalFeatures", Adult.CategoricalFeatures))
89+
// .Append(mlContext.Transforms.Categorical.OneHotHashEncoding("CategoricalFeatures",
90+
// invertHash: 2, outputKind: OneHotEncodingTransformer.OutputKind.Bag)
91+
// .Append(mlContext.Transforms.Concatenate("Features", "NumericalFeatures", "CategoricalFeatures"))
92+
// .Append(mlContext.BinaryClassification.Trainers.LogisticRegression()));
93+
94+
// // Train the model.
95+
// var nestedModel = nestedPipeline.Fit(data);
96+
// var nestedPredictor = nestedModel.LastTransformer.LastTransformer;
97+
// var nestedTransformedData = nestedModel.Transform(data);
98+
99+
// Assert.Equal(predictor.Model.SubModel.Bias, nestedPredictor.Model.SubModel.Bias);
100+
// int nFeatures = predictor.Model.SubModel.Weights.Count;
101+
// for (int i = 0; i<nFeatures; i++ )
102+
// Assert.Equal(predictor.Model.SubModel.Weights[i], nestedPredictor.Model.SubModel.Weights[i]);
103+
104+
// var transformedRows = mlContext.Data.CreateEnumerable<BinaryPrediction>(transformedData, false).ToArray();
105+
// var nestedTransformedRows = mlContext.Data.CreateEnumerable<BinaryPrediction>(nestedTransformedData, false).ToArray();
106+
// for (int i = 0; i<transformedRows.Length; i++)
107+
// Assert.Equal(transformedRows[i].Score, nestedTransformedRows[i].Score);
108+
//}
109+
110+
//private class BinaryPrediction
111+
//{
112+
// public float Score { get; set; }
113+
// public float Probability { get; set; }
114+
//}
115+
}
116+
}

test/Microsoft.ML.TestFramework/Datasets.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,8 @@ public static class TestDatasets
276276
name = "Census",
277277
trainFilename = "adult.tiny.with-schema.txt",
278278
testFilename = "adult.tiny.with-schema.txt",
279+
fileHasHeader = true,
280+
fileSeparator = '\t',
279281
loaderSettings = "loader=Text{header+ col=Label:0 col=Num:9-14 col=Cat:TX:1-8}",
280282
mamlExtraSettings = new[] { "xf=Cat{col=Cat}", "xf=Concat{col=Features:Num,Cat}" },
281283
extraSettings = @"/inst Text{header+ sep=, label=14 handler=Categorical{cols=5-9,1,13,3}}",

0 commit comments

Comments
 (0)