Skip to content

Commit 4c29a09

Browse files
Ivanidzo4kaTomFinley
authored andcommitted
Get rid of value tuples n the public API (#2950)
1 parent 6a4df7c commit 4c29a09

File tree

6 files changed

+48
-81
lines changed

6 files changed

+48
-81
lines changed

src/Microsoft.ML.StaticPipe/LdaStaticExtensions.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ public sealed class LatentDirichletAllocationFitResult
2020
/// <param name="result"></param>
2121
public delegate void OnFit(LatentDirichletAllocationFitResult result);
2222

23-
public LatentDirichletAllocationTransformer.LdaSummary LdaTopicSummary;
24-
public LatentDirichletAllocationFitResult(LatentDirichletAllocationTransformer.LdaSummary ldaTopicSummary)
23+
public LatentDirichletAllocationTransformer.ModelParameters LdaTopicSummary;
24+
public LatentDirichletAllocationFitResult(LatentDirichletAllocationTransformer.ModelParameters ldaTopicSummary)
2525
{
2626
LdaTopicSummary = ldaTopicSummary;
2727
}
@@ -43,11 +43,11 @@ private struct Config
4343
public readonly int NumberOfBurninIterations;
4444
public readonly bool ResetRandomGenerator;
4545

46-
public readonly Action<LatentDirichletAllocationTransformer.LdaSummary> OnFit;
46+
public readonly Action<LatentDirichletAllocationTransformer.ModelParameters> OnFit;
4747

4848
public Config(int numberOfTopics, Single alphaSum, Single beta, int samplingStepCount, int maximumNumberOfIterations, int likelihoodInterval,
4949
int numberOfThreads, int maximumTokenCountPerDocument, int numberOfSummaryTermsPerTopic, int numberOfBurninIterations, bool resetRandomGenerator,
50-
Action<LatentDirichletAllocationTransformer.LdaSummary> onFit)
50+
Action<LatentDirichletAllocationTransformer.ModelParameters> onFit)
5151
{
5252
NumberOfTopics = numberOfTopics;
5353
AlphaSum = alphaSum;
@@ -65,7 +65,7 @@ public Config(int numberOfTopics, Single alphaSum, Single beta, int samplingStep
6565
}
6666
}
6767

68-
private static Action<LatentDirichletAllocationTransformer.LdaSummary> Wrap(LatentDirichletAllocationFitResult.OnFit onFit)
68+
private static Action<LatentDirichletAllocationTransformer.ModelParameters> Wrap(LatentDirichletAllocationFitResult.OnFit onFit)
6969
{
7070
if (onFit == null)
7171
return null;

src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs

Lines changed: 2 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
namespace Microsoft.ML.Transforms
3131
{
3232
/// <include file='doc.xml' path='doc/members/member[@name="NADrop"]'/>
33-
public sealed class MissingValueDroppingTransformer : OneToOneTransformerBase
33+
internal sealed class MissingValueDroppingTransformer : OneToOneTransformerBase
3434
{
3535
internal sealed class Options : TransformInputBase
3636
{
@@ -76,7 +76,7 @@ private static VersionInfo GetVersionInfo()
7676
/// <summary>
7777
/// The names of the input columns of the transformation and the corresponding names for the output columns.
7878
/// </summary>
79-
public IReadOnlyList<(string outputColumnName, string inputColumnName)> Columns => ColumnPairs.AsReadOnly();
79+
internal IReadOnlyList<(string outputColumnName, string inputColumnName)> Columns => ColumnPairs.AsReadOnly();
8080

8181
/// <summary>
8282
/// Initializes a new instance of <see cref="MissingValueDroppingTransformer"/>
@@ -339,55 +339,4 @@ private void DropNAs<TDst>(ref VBuffer<TDst> src, ref VBuffer<TDst> dst, InPredi
339339
}
340340
}
341341
}
342-
/// <summary>
343-
/// Drops missing values from columns.
344-
/// </summary>
345-
public sealed class MissingValueDroppingEstimator : TrivialEstimator<MissingValueDroppingTransformer>
346-
{
347-
/// <summary>
348-
/// Drops missing values from columns.
349-
/// </summary>
350-
/// <param name="env">The environment to use.</param>
351-
/// <param name="columns">The names of the input columns of the transformation and the corresponding names for the output columns.</param>
352-
internal MissingValueDroppingEstimator(IHostEnvironment env, params (string outputColumnName, string inputColumnName)[] columns)
353-
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MissingValueDroppingEstimator)), new MissingValueDroppingTransformer(env, columns))
354-
{
355-
Contracts.CheckValue(env, nameof(env));
356-
}
357-
358-
/// <summary>
359-
/// Drops missing values from columns.
360-
/// </summary>
361-
/// <param name="env">The environment to use.</param>
362-
/// <param name="outputColumnName">Name of the column resulting from the transformation of <paramref name="inputColumnName"/>.</param>
363-
/// <param name="inputColumnName">Name of the column to transform. If set to <see langword="null"/>, the value of the <paramref name="outputColumnName"/> will be used as source.</param>
364-
internal MissingValueDroppingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null)
365-
: this(env, (outputColumnName, inputColumnName ?? outputColumnName))
366-
{
367-
}
368-
369-
/// <summary>
370-
/// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer.
371-
/// Used for schema propagation and verification in a pipeline.
372-
/// </summary>
373-
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
374-
{
375-
Host.CheckValue(inputSchema, nameof(inputSchema));
376-
var result = inputSchema.ToDictionary(x => x.Name);
377-
foreach (var colPair in Transformer.Columns)
378-
{
379-
if (!inputSchema.TryFindColumn(colPair.inputColumnName, out var col) || !Data.Conversion.Conversions.Instance.TryGetIsNAPredicate(col.ItemType, out Delegate del))
380-
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.inputColumnName);
381-
if (!(col.Kind == SchemaShape.Column.VectorKind.Vector || col.Kind == SchemaShape.Column.VectorKind.VariableVector))
382-
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.inputColumnName, "known-size vector", col.GetTypeString());
383-
var metadata = new List<SchemaShape.Column>();
384-
if (col.Annotations.TryFindColumn(AnnotationUtils.Kinds.KeyValues, out var keyMeta))
385-
metadata.Add(keyMeta);
386-
if (col.Annotations.TryFindColumn(AnnotationUtils.Kinds.IsNormalized, out var normMeta))
387-
metadata.Add(normMeta);
388-
result[colPair.outputColumnName] = new SchemaShape.Column(colPair.outputColumnName, SchemaShape.Column.VectorKind.VariableVector, col.ItemType, false, new SchemaShape(metadata.ToArray()));
389-
}
390-
return new SchemaShape(result.Values);
391-
}
392-
}
393342
}

src/Microsoft.ML.Transforms/Text/LdaTransform.cs

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -163,27 +163,50 @@ internal bool TryUnparse(StringBuilder sb)
163163
/// <summary>
164164
/// Provide details about the topics discovered by <a href="https://arxiv.org/abs/1412.1576">LightLDA.</a>
165165
/// </summary>
166-
public sealed class LdaSummary
166+
public sealed class ModelParameters
167167
{
168+
public struct ItemScore
169+
{
170+
public readonly int Item;
171+
public readonly float Score;
172+
public ItemScore(int item, float score)
173+
{
174+
Item = item;
175+
Score = score;
176+
}
177+
}
178+
public struct WordItemScore
179+
{
180+
public readonly int Item;
181+
public readonly string Word;
182+
public readonly float Score;
183+
public WordItemScore(int item, string word, float score)
184+
{
185+
Item = item;
186+
Word = word;
187+
Score = score;
188+
}
189+
}
190+
168191
// For each topic, provide information about the (item, score) pairs.
169-
public readonly ImmutableArray<List<(int Item, float Score)>> ItemScoresPerTopic;
192+
public readonly IReadOnlyList<IReadOnlyList<ItemScore>> ItemScoresPerTopic;
170193

171194
// For each topic, provide information about the (item, word, score) tuple.
172-
public readonly ImmutableArray<List<(int Item, string Word, float Score)>> WordScoresPerTopic;
195+
public readonly IReadOnlyList<IReadOnlyList<WordItemScore>> WordScoresPerTopic;
173196

174-
internal LdaSummary(ImmutableArray<List<(int Item, float Score)>> itemScoresPerTopic)
197+
internal ModelParameters(IReadOnlyList<IReadOnlyList<ItemScore>> itemScoresPerTopic)
175198
{
176199
ItemScoresPerTopic = itemScoresPerTopic;
177200
}
178201

179-
internal LdaSummary(ImmutableArray<List<(int Item, string Word, float Score)>> wordScoresPerTopic)
202+
internal ModelParameters(IReadOnlyList<IReadOnlyList<WordItemScore>> wordScoresPerTopic)
180203
{
181204
WordScoresPerTopic = wordScoresPerTopic;
182205
}
183206
}
184207

185208
[BestFriend]
186-
internal LdaSummary GetLdaDetails(int iinfo)
209+
internal ModelParameters GetLdaDetails(int iinfo)
187210
{
188211
Contracts.Assert(0 <= iinfo && iinfo < _ldas.Length);
189212

@@ -302,40 +325,40 @@ internal LdaState(IExceptionContext ectx, ModelLoadContext ctx)
302325
}
303326
}
304327

305-
internal LdaSummary GetLdaSummary(VBuffer<ReadOnlyMemory<char>> mapping)
328+
internal ModelParameters GetLdaSummary(VBuffer<ReadOnlyMemory<char>> mapping)
306329
{
307330
if (mapping.Length == 0)
308331
{
309-
var itemScoresPerTopicBuilder = ImmutableArray.CreateBuilder<List<(int Item, float Score)>>();
332+
var itemScoresPerTopicBuilder = ImmutableArray.CreateBuilder<List<ModelParameters.ItemScore>>();
310333
for (int i = 0; i < _ldaTrainer.NumTopic; i++)
311334
{
312335
var scores = _ldaTrainer.GetTopicSummary(i);
313-
var itemScores = new List<(int, float)>();
336+
var itemScores = new List<ModelParameters.ItemScore>();
314337
foreach (KeyValuePair<int, float> p in scores)
315338
{
316-
itemScores.Add((p.Key, p.Value));
339+
itemScores.Add(new ModelParameters.ItemScore(p.Key, p.Value));
317340
}
318341

319342
itemScoresPerTopicBuilder.Add(itemScores);
320343
}
321-
return new LdaSummary(itemScoresPerTopicBuilder.ToImmutable());
344+
return new ModelParameters(itemScoresPerTopicBuilder.ToImmutable());
322345
}
323346
else
324347
{
325348
ReadOnlyMemory<char> slotName = default;
326-
var wordScoresPerTopicBuilder = ImmutableArray.CreateBuilder<List<(int Item, string Word, float Score)>>();
349+
var wordScoresPerTopicBuilder = ImmutableArray.CreateBuilder<List<ModelParameters.WordItemScore>>();
327350
for (int i = 0; i < _ldaTrainer.NumTopic; i++)
328351
{
329352
var scores = _ldaTrainer.GetTopicSummary(i);
330-
var wordScores = new List<(int, string, float)>();
353+
var wordScores = new List<ModelParameters.WordItemScore>();
331354
foreach (KeyValuePair<int, float> p in scores)
332355
{
333356
mapping.GetItemOrDefault(p.Key, ref slotName);
334-
wordScores.Add((p.Key, slotName.ToString(), p.Value));
357+
wordScores.Add(new ModelParameters.WordItemScore(p.Key, slotName.ToString(), p.Value));
335358
}
336359
wordScoresPerTopicBuilder.Add(wordScores);
337360
}
338-
return new LdaSummary(wordScoresPerTopicBuilder.ToImmutable());
361+
return new ModelParameters(wordScoresPerTopicBuilder.ToImmutable());
339362
}
340363
}
341364

src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,6 @@ internal static VersionInfo GetVersionInfo()
9393
private readonly Model _currentVocab;
9494
private static Dictionary<string, WeakReference<Model>> _vocab = new Dictionary<string, WeakReference<Model>>();
9595

96-
/// <summary>
97-
/// The names of the output and input column pairs on which the transformation is applied.
98-
/// </summary>
99-
private IReadOnlyCollection<(string outputColumnName, string inputColumnName)> Columns => ColumnPairs.AsReadOnly();
100-
10196
private sealed class Model
10297
{
10398
public readonly BigArray<float> WordVectors;
@@ -335,7 +330,7 @@ protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
335330

336331
public void SaveAsOnnx(OnnxContext ctx)
337332
{
338-
foreach (var (outputColumnName, inputColumnName) in _parent.Columns)
333+
foreach (var (outputColumnName, inputColumnName) in _parent.ColumnPairs)
339334
{
340335
var srcVariableName = ctx.GetVariableName(inputColumnName);
341336
var schema = _parent.GetOutputSchema(InputSchema);

test/Microsoft.ML.Benchmarks/Helpers/EnvironmentFactory.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ internal static MLContext CreateRankingEnvironment<TEvaluator, TLoader, TTransfo
4141
environment.ComponentCatalog.RegisterAssembly(typeof(TTransformer).Assembly);
4242
environment.ComponentCatalog.RegisterAssembly(typeof(TTrainer).Assembly);
4343

44-
environment.ComponentCatalog.RegisterAssembly(typeof(MissingValueDroppingTransformer).Assembly);
44+
environment.ComponentCatalog.RegisterAssembly(typeof(OneHotEncodingTransformer).Assembly);
4545

4646
return ctx;
4747
}

test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,7 @@ public void LdaTopicModel()
670670
var data = reader.Load(dataSource);
671671

672672
// This will be populated once we call fit.
673-
LdaSummary ldaSummary;
673+
ModelParameters ldaSummary;
674674

675675
var est = data.MakeNewEstimator()
676676
.Append(r => (

0 commit comments

Comments
 (0)