Skip to content

Commit a284b0f

Browse files
committed
Scrub n-gram hashing
1 parent 8bcc03c commit a284b0f

File tree

2 files changed

+46
-164
lines changed

2 files changed

+46
-164
lines changed

src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs

Lines changed: 29 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ internal sealed class Column : ManyToOneColumn
5252

5353
[Argument(ArgumentType.AtMostOnce,
5454
HelpText = "Number of bits to hash into. Must be between 1 and 30, inclusive.",
55-
ShortName = "bits")]
56-
public int? HashBits;
55+
Name = "HashBits", ShortName = "bits")]
56+
public int? NumberOfBits;
5757

5858
[Argument(ArgumentType.AtMostOnce, HelpText = "Hashing seed")]
5959
public uint? Seed;
@@ -91,7 +91,7 @@ private protected override bool TryParse(string str)
9191

9292
if (!int.TryParse(extra, out int bits))
9393
return false;
94-
HashBits = bits;
94+
NumberOfBits = bits;
9595
return true;
9696
}
9797

@@ -103,10 +103,10 @@ internal bool TryUnparse(StringBuilder sb)
103103
{
104104
return false;
105105
}
106-
if (HashBits == null)
106+
if (NumberOfBits == null)
107107
return TryUnparseCore(sb);
108108

109-
string extra = HashBits.Value.ToString();
109+
string extra = NumberOfBits.Value.ToString();
110110
return TryUnparseCore(sb, extra);
111111
}
112112
}
@@ -133,8 +133,8 @@ internal sealed class Options
133133

134134
[Argument(ArgumentType.AtMostOnce,
135135
HelpText = "Number of bits to hash into. Must be between 1 and 30, inclusive.",
136-
ShortName = "bits", SortOrder = 2)]
137-
public int HashBits = NgramHashingEstimator.Defaults.HashBits;
136+
Name = "HashBits", ShortName = "bits", SortOrder = 2)]
137+
public int NumberOfBits = NgramHashingEstimator.Defaults.NumberOfBits;
138138

139139
[Argument(ArgumentType.AtMostOnce, HelpText = "Hashing seed")]
140140
public uint Seed = NgramHashingEstimator.Defaults.Seed;
@@ -353,7 +353,7 @@ private static IDataTransform Create(IHostEnvironment env, Options options, IDat
353353
item.NgramLength ?? options.NgramLength,
354354
item.SkipLength ?? options.SkipLength,
355355
item.AllLengths ?? options.AllLengths,
356-
item.HashBits ?? options.HashBits,
356+
item.NumberOfBits ?? options.NumberOfBits,
357357
item.Seed ?? options.Seed,
358358
item.Ordered ?? options.Ordered,
359359
item.InvertHash ?? options.InvertHash,
@@ -408,13 +408,13 @@ public Mapper(NgramHashingTransformer parent, DataViewSchema inputSchema, Finder
408408
_srcTypes[i][j] = srcType;
409409
}
410410

411-
_types[i] = new VectorType(NumberDataViewType.Single, 1 << _parent._columns[i].HashBits);
411+
_types[i] = new VectorType(NumberDataViewType.Single, 1 << _parent._columns[i].NumberOfBits);
412412
}
413413
}
414414

415415
private NgramIdFinder GetNgramIdFinder(int iinfo)
416416
{
417-
uint mask = (1U << _parent._columns[iinfo].HashBits) - 1;
417+
uint mask = (1U << _parent._columns[iinfo].NumberOfBits) - 1;
418418
int ngramLength = _parent._columns[iinfo].NgramLength;
419419
bool rehash = _parent._columns[iinfo].RehashUnigrams;
420420
bool ordered = _parent._columns[iinfo].Ordered;
@@ -819,7 +819,7 @@ public NgramIdFinder Decorate(int iinfo, NgramIdFinder finder)
819819
}
820820

821821
var collector = _iinfoToCollector[iinfo] = new InvertHashCollector<NGram>(
822-
1 << _parent._columns[iinfo].HashBits, _invertHashMaxCounts[iinfo],
822+
1 << _parent._columns[iinfo].NumberOfBits, _invertHashMaxCounts[iinfo],
823823
stringMapper, EqualityComparer<NGram>.Default, (in NGram src, ref NGram dst) => dst = src.Clone());
824824

825825
return
@@ -852,7 +852,7 @@ public VBuffer<ReadOnlyMemory<char>>[] SlotNamesMetadata(out VectorType[] types)
852852
if (_iinfoToCollector[iinfo] != null)
853853
{
854854
var vec = values[iinfo] = _iinfoToCollector[iinfo].GetMetadata();
855-
Contracts.Assert(vec.Length == 1 << _parent._columns[iinfo].HashBits);
855+
Contracts.Assert(vec.Length == 1 << _parent._columns[iinfo].NumberOfBits);
856856
types[iinfo] = new VectorType(TextDataViewType.Instance, vec.Length);
857857
}
858858
}
@@ -887,7 +887,7 @@ public sealed class ColumnOptions
887887
/// <summary>Whether to store all ngram lengths up to <see cref="NgramLength"/>, or only <see cref="NgramLength"/>.</summary>
888888
public readonly bool AllLengths;
889889
/// <summary>Number of bits to hash into. Must be between 1 and 31, inclusive.</summary>
890-
public readonly int HashBits;
890+
public readonly int NumberOfBits;
891891
/// <summary>Hashing seed.</summary>
892892
public readonly uint Seed;
893893
/// <summary>Whether the position of each term should be included in the hash.</summary>
@@ -907,14 +907,14 @@ public sealed class ColumnOptions
907907
internal string[] FriendlyNames;
908908

909909
/// <summary>
910-
/// Describes how the transformer handles one column pair.
910+
/// Describes how the transformer maps several input columns, <paramref name="inputColumnNames"/>, to a output column, <paramref name="name"/>.
911911
/// </summary>
912912
/// <param name="name">Name of the column resulting from the transformation of <paramref name="inputColumnNames"/>.</param>
913913
/// <param name="inputColumnNames">Names of the columns to transform. </param>
914914
/// <param name="ngramLength">Maximum ngram length.</param>
915915
/// <param name="skipLength">Maximum number of tokens to skip when constructing an ngram.</param>
916916
/// <param name="allLengths">Whether to store all ngram lengths up to <paramref name="ngramLength"/>, or only <paramref name="ngramLength"/>.</param>
917-
/// <param name="hashBits">Number of bits to hash into. Must be between 1 and 31, inclusive.</param>
917+
/// <param name="numberOfBits">Number of bits to hash into. Must be between 1 and 31, inclusive.</param>
918918
/// <param name="seed">Hashing seed.</param>
919919
/// <param name="ordered">Whether the position of each term should be included in the hash.</param>
920920
/// <param name="invertHash">During hashing we constuct mappings between original values and the produced hash values.
@@ -928,7 +928,7 @@ public ColumnOptions(string name,
928928
int ngramLength = NgramHashingEstimator.Defaults.NgramLength,
929929
int skipLength = NgramHashingEstimator.Defaults.SkipLength,
930930
bool allLengths = NgramHashingEstimator.Defaults.AllLengths,
931-
int hashBits = NgramHashingEstimator.Defaults.HashBits,
931+
int numberOfBits = NgramHashingEstimator.Defaults.NumberOfBits,
932932
uint seed = NgramHashingEstimator.Defaults.Seed,
933933
bool ordered = NgramHashingEstimator.Defaults.Ordered,
934934
int invertHash = NgramHashingEstimator.Defaults.InvertHash,
@@ -942,8 +942,8 @@ public ColumnOptions(string name,
942942
throw Contracts.ExceptParam(nameof(invertHash), "Value too small, must be -1 or larger");
943943
// If the bits is 31 or higher, we can't declare a KeyValues of the appropriate length,
944944
// this requiring a VBuffer of length 1u << 31 which exceeds int.MaxValue.
945-
if (invertHash != 0 && hashBits >= 31)
946-
throw Contracts.ExceptParam(nameof(hashBits), $"Cannot support invertHash for a {0} bit hash. 30 is the maximum possible.", hashBits);
945+
if (invertHash != 0 && numberOfBits >= 31)
946+
throw Contracts.ExceptParam(nameof(numberOfBits), $"Cannot support invertHash for a {0} bit hash. 30 is the maximum possible.", numberOfBits);
947947

948948
if (NgramLength + SkipLength > NgramBufferBuilder.MaxSkipNgramLength)
949949
{
@@ -956,7 +956,7 @@ public ColumnOptions(string name,
956956
NgramLength = ngramLength;
957957
SkipLength = skipLength;
958958
AllLengths = allLengths;
959-
HashBits = hashBits;
959+
NumberOfBits = numberOfBits;
960960
Seed = seed;
961961
Ordered = ordered;
962962
InvertHash = invertHash;
@@ -988,8 +988,8 @@ internal ColumnOptions(ModelLoadContext ctx)
988988
SkipLength = ctx.Reader.ReadInt32();
989989
Contracts.CheckDecode(0 <= SkipLength && SkipLength <= NgramBufferBuilder.MaxSkipNgramLength);
990990
Contracts.CheckDecode(SkipLength <= NgramBufferBuilder.MaxSkipNgramLength - NgramLength);
991-
HashBits = ctx.Reader.ReadInt32();
992-
Contracts.CheckDecode(1 <= HashBits && HashBits <= 30);
991+
NumberOfBits = ctx.Reader.ReadInt32();
992+
Contracts.CheckDecode(1 <= NumberOfBits && NumberOfBits <= 30);
993993
Seed = ctx.Reader.ReadUInt32();
994994
RehashUnigrams = ctx.Reader.ReadBoolByte();
995995
Ordered = ctx.Reader.ReadBoolByte();
@@ -1018,8 +1018,8 @@ internal ColumnOptions(ModelLoadContext ctx, string name, string[] inputColumnNa
10181018
SkipLength = ctx.Reader.ReadInt32();
10191019
Contracts.CheckDecode(0 <= SkipLength && SkipLength <= NgramBufferBuilder.MaxSkipNgramLength);
10201020
Contracts.CheckDecode(SkipLength <= NgramBufferBuilder.MaxSkipNgramLength - NgramLength);
1021-
HashBits = ctx.Reader.ReadInt32();
1022-
Contracts.CheckDecode(1 <= HashBits && HashBits <= 30);
1021+
NumberOfBits = ctx.Reader.ReadInt32();
1022+
Contracts.CheckDecode(1 <= NumberOfBits && NumberOfBits <= 30);
10231023
Seed = ctx.Reader.ReadUInt32();
10241024
RehashUnigrams = ctx.Reader.ReadBoolByte();
10251025
Ordered = ctx.Reader.ReadBoolByte();
@@ -1052,8 +1052,8 @@ internal void Save(ModelSaveContext ctx)
10521052
Contracts.Assert(0 <= SkipLength && SkipLength <= NgramBufferBuilder.MaxSkipNgramLength);
10531053
Contracts.Assert(NgramLength + SkipLength <= NgramBufferBuilder.MaxSkipNgramLength);
10541054
ctx.Writer.Write(SkipLength);
1055-
Contracts.Assert(1 <= HashBits && HashBits <= 30);
1056-
ctx.Writer.Write(HashBits);
1055+
Contracts.Assert(1 <= NumberOfBits && NumberOfBits <= 30);
1056+
ctx.Writer.Write(NumberOfBits);
10571057
ctx.Writer.Write(Seed);
10581058
ctx.Writer.WriteBoolByte(RehashUnigrams);
10591059
ctx.Writer.WriteBoolByte(Ordered);
@@ -1066,7 +1066,7 @@ internal static class Defaults
10661066
internal const int NgramLength = 2;
10671067
internal const bool AllLengths = true;
10681068
internal const int SkipLength = 0;
1069-
internal const int HashBits = 16;
1069+
internal const int NumberOfBits = 16;
10701070
internal const uint Seed = 314489979;
10711071
internal const bool RehashUnigrams = false;
10721072
internal const bool Ordered = true;
@@ -1086,7 +1086,7 @@ internal static class Defaults
10861086
/// <param name="env">The environment.</param>
10871087
/// <param name="outputColumnName">Name of output column, will contain the ngram vector. Null means <paramref name="inputColumnName"/> is replaced.</param>
10881088
/// <param name="inputColumnName">Name of input column containing tokenized text.</param>
1089-
/// <param name="hashBits">Number of bits to hash into. Must be between 1 and 30, inclusive.</param>
1089+
/// <param name="numberOfBits">Number of bits to hash into. Must be between 1 and 30, inclusive.</param>
10901090
/// <param name="ngramLength">Ngram length.</param>
10911091
/// <param name="skipLength">Maximum number of tokens to skip when constructing an ngram.</param>
10921092
/// <param name="allLengths">Whether to include all ngram lengths up to <paramref name="ngramLength"/> or only <paramref name="ngramLength"/>.</param>
@@ -1099,84 +1099,17 @@ internal static class Defaults
10991099
internal NgramHashingEstimator(IHostEnvironment env,
11001100
string outputColumnName,
11011101
string inputColumnName = null,
1102-
int hashBits = 16,
1102+
int numberOfBits = 16,
11031103
int ngramLength = 2,
11041104
int skipLength = 0,
11051105
bool allLengths = true,
11061106
uint seed = 314489979,
11071107
bool ordered = true,
11081108
int invertHash = 0)
1109-
: this(env, new[] { (outputColumnName, new[] { inputColumnName ?? outputColumnName }) }, hashBits, ngramLength, skipLength, allLengths, seed, ordered, invertHash)
1109+
: this(env, new ColumnOptions(outputColumnName, new[] { inputColumnName ?? outputColumnName }, ngramLength, skipLength, allLengths, numberOfBits, seed, ordered, invertHash))
11101110
{
11111111
}
11121112

1113-
/// <summary>
1114-
/// Produces a bag of counts of hashed ngrams in <paramref name="inputColumnNames"/>
1115-
/// and outputs ngram vector as <paramref name="outputColumnName"/>
1116-
///
1117-
/// <see cref="NgramHashingEstimator"/> is different from <see cref="WordHashBagEstimator"/> in a way that <see cref="NgramHashingEstimator"/>
1118-
/// takes tokenized text as input while <see cref="WordHashBagEstimator"/> tokenizes text internally.
1119-
/// </summary>
1120-
/// <param name="env">The environment.</param>
1121-
/// <param name="outputColumnName">Name of output column, will contain the ngram vector.</param>
1122-
/// <param name="inputColumnNames">Name of input columns containing tokenized text.</param>
1123-
/// <param name="hashBits">Number of bits to hash into. Must be between 1 and 30, inclusive.</param>
1124-
/// <param name="ngramLength">Ngram length.</param>
1125-
/// <param name="skipLength">Maximum number of tokens to skip when constructing an ngram.</param>
1126-
/// <param name="allLengths">Whether to include all ngram lengths up to <paramref name="ngramLength"/> or only <paramref name="ngramLength"/>.</param>
1127-
/// <param name="seed">Hashing seed.</param>
1128-
/// <param name="ordered">Whether the position of each source column should be included in the hash (when there are multiple source columns).</param>
1129-
/// <param name="invertHash">During hashing we constuct mappings between original values and the produced hash values.
1130-
/// Text representation of original values are stored in the slot names of the metadata for the new column.Hashing, as such, can map many initial values to one.
1131-
/// <paramref name="invertHash"/> specifies the upper bound of the number of distinct input values mapping to a hash that should be retained.
1132-
/// <value>0</value> does not retain any input values. <value>-1</value> retains all input values mapping to each hash.</param>
1133-
internal NgramHashingEstimator(IHostEnvironment env,
1134-
string outputColumnName,
1135-
string[] inputColumnNames,
1136-
int hashBits = 16,
1137-
int ngramLength = 2,
1138-
int skipLength = 0,
1139-
bool allLengths = true,
1140-
uint seed = 314489979,
1141-
bool ordered = true,
1142-
int invertHash = 0)
1143-
: this(env, new[] { (outputColumnName, inputColumnNames) }, hashBits, ngramLength, skipLength, allLengths, seed, ordered, invertHash)
1144-
{
1145-
}
1146-
1147-
/// <summary>
1148-
/// Produces a bag of counts of hashed ngrams in <paramref name="columns.inputs"/>
1149-
/// and outputs ngram vector for each output in <paramref name="columns.output"/>
1150-
///
1151-
/// <see cref="NgramHashingEstimator"/> is different from <see cref="WordHashBagEstimator"/> in a way that <see cref="NgramHashingEstimator"/>
1152-
/// takes tokenized text as input while <see cref="WordHashBagEstimator"/> tokenizes text internally.
1153-
/// </summary>
1154-
/// <param name="env">The environment.</param>
1155-
/// <param name="columns">Pairs of input columns to output column mappings on which to compute ngram vector.</param>
1156-
/// <param name="hashBits">Number of bits to hash into. Must be between 1 and 30, inclusive.</param>
1157-
/// <param name="ngramLength">Ngram length.</param>
1158-
/// <param name="skipLength">Maximum number of tokens to skip when constructing an ngram.</param>
1159-
/// <param name="allLengths">Whether to include all ngram lengths up to <paramref name="ngramLength"/> or only <paramref name="ngramLength"/>.</param>
1160-
/// <param name="seed">Hashing seed.</param>
1161-
/// <param name="ordered">Whether the position of each source column should be included in the hash (when there are multiple source columns).</param>
1162-
/// <param name="invertHash">During hashing we constuct mappings between original values and the produced hash values.
1163-
/// Text representation of original values are stored in the slot names of the metadata for the new column.Hashing, as such, can map many initial values to one.
1164-
/// <paramref name="invertHash"/> specifies the upper bound of the number of distinct input values mapping to a hash that should be retained.
1165-
/// <value>0</value> does not retain any input values. <value>-1</value> retains all input values mapping to each hash.</param>
1166-
internal NgramHashingEstimator(IHostEnvironment env,
1167-
(string outputColumnName, string[] inputColumnName)[] columns,
1168-
int hashBits = 16,
1169-
int ngramLength = 2,
1170-
int skipLength = 0,
1171-
bool allLengths = true,
1172-
uint seed = 314489979,
1173-
bool ordered = true,
1174-
int invertHash = 0)
1175-
: this(env, columns.Select(x => new ColumnOptions(x.outputColumnName, x.inputColumnName, ngramLength, skipLength, allLengths, hashBits, seed, ordered, invertHash)).ToArray())
1176-
{
1177-
1178-
}
1179-
11801113
/// <summary>
11811114
/// Produces a bag of counts of hashed ngrams in <paramref name="columns.inputs"/>
11821115
/// and outputs ngram vector for each output in <paramref name="columns.output"/>

0 commit comments

Comments
 (0)