Skip to content

Commit fa68003

Browse files
committed
Address the feedback
1 parent 35e2cbc commit fa68003

16 files changed

+212
-79
lines changed

src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,19 @@
22
<Import Project="$(RepoRoot)eng/pkg/Pack.props" />
33

44
<PropertyGroup>
5-
<TargetFramework>netstandard2.0</TargetFramework>
5+
<TargetFrameworks>netstandard2.0;net8.0</TargetFrameworks>
66
<Nullable>enable</Nullable>
77
<PackageDescription>Microsoft.ML.Tokenizers contains the implmentation of the tokenization used in the NLP transforms.</PackageDescription>
88
</PropertyGroup>
99

10+
<ItemGroup Condition="'$(TargetFramework)' == 'netstandard2.0'">
11+
<Compile Remove="Utils/Helpers.netcoreapp.cs" />
12+
</ItemGroup>
13+
14+
<ItemGroup Condition="'$(TargetFramework)' != 'netstandard2.0'">
15+
<Compile Remove="Utils/Helpers.netfx.cs" />
16+
</ItemGroup>
17+
1018
<ItemGroup>
1119
<PackageReference Include="System.Text.Json" Version="$(SystemTextJsonVersion)" />
1220
</ItemGroup>

src/Microsoft.ML.Tokenizers/Model/BPE.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ public string? UnknownToken
3636

3737
if (value is null)
3838
{
39-
if (VocabReverse.TryGetValue(0, out string v))
39+
if (VocabReverse.TryGetValue(0, out string? v))
4040
{
4141
VocabReverse.Remove(0);
4242
if (Vocab.TryGetValue(v, out int id))
@@ -103,7 +103,7 @@ public Bpe(string vocabFile, string? mergesFile, string? unknownToken = null, st
103103
VocabReverse.Add(kvp.Value, kvp.Key);
104104
}
105105

106-
if (unknownToken is null && VocabReverse.TryGetValue(0, out string unkToken))
106+
if (unknownToken is null && VocabReverse.TryGetValue(0, out string? unkToken))
107107
{
108108
unknownToken = unkToken;
109109
}
@@ -187,7 +187,7 @@ public override IReadOnlyList<Token> Tokenize(string sequence)
187187
/// <returns>The mapped token of the Id.</returns>
188188
public override string? IdToToken(int id, bool skipSpecialTokens = false)
189189
{
190-
if (VocabReverse.TryGetValue(id, out string value))
190+
if (VocabReverse.TryGetValue(id, out string? value))
191191
{
192192
return value;
193193
}
@@ -253,7 +253,7 @@ public override string[] Save(string path, string? prefix = null)
253253
}
254254

255255
/// Read the given files to extract the vocab and merges
256-
internal static (Dictionary<string, int>?, Vec<(string, string)>) ReadFile(string? vocab, string? merges)
256+
internal static (Dictionary<string, int>?, Vec<(string, string)>) ReadFile(string vocab, string? merges)
257257
{
258258
Dictionary<string, int>? dic;
259259
using (Stream stream = File.OpenRead(vocab))
@@ -320,7 +320,7 @@ internal static (Dictionary<string, int>?, Vec<(string, string)>) ReadFile(strin
320320
[MethodImpl(MethodImplOptions.AggressiveInlining)]
321321
internal string CharToString(char c)
322322
{
323-
if (_charToString.TryGetValue(c, out string v))
323+
if (_charToString.TryGetValue(c, out string? v))
324324
{
325325
return v;
326326
}

src/Microsoft.ML.Tokenizers/Model/BpeTrainer.cs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,12 @@ public BpeTrainer(
8383
MinFrequency = minFrequency;
8484
VocabSize = vocabSize;
8585
Progress = progress;
86-
SpecialTokens = new List<AddedToken>(specialTokens);
86+
87+
if (specialTokens is not null)
88+
{
89+
SpecialTokens = new List<AddedToken>(specialTokens);
90+
}
91+
8792
LimitAlphabet = limitAlphabet;
8893
InitialAlphabet = initialAlphabet;
8994
ContinuingSubwordPrefix = continuingSubwordPrefix;
@@ -172,7 +177,7 @@ private void ComputeAlphabet(Dictionary<string, int> wc, Dictionary<string, int>
172177
[MethodImpl(MethodImplOptions.AggressiveInlining)]
173178
internal string CharToString(char c)
174179
{
175-
if (_charToString.TryGetValue(c, out string v))
180+
if (_charToString.TryGetValue(c, out string? v))
176181
{
177182
return v;
178183
}
@@ -259,7 +264,7 @@ internal string CharToString(char c)
259264
// Then update counts
260265
int count = counts[i];
261266

262-
if (!whereToUpdate.TryGetValue(curPair, out HashSet<int> h))
267+
if (!whereToUpdate.TryGetValue(curPair, out HashSet<int>? h))
263268
{
264269
h = new HashSet<int>();
265270
whereToUpdate[curPair] = h;
@@ -398,7 +403,7 @@ internal string CharToString(char c)
398403

399404
if (change > 0)
400405
{
401-
if (!whereToUpdate.TryGetValue(p, out HashSet<int> h))
406+
if (!whereToUpdate.TryGetValue(p, out HashSet<int>? h))
402407
{
403408
h = new();
404409
whereToUpdate[p] = h;

src/Microsoft.ML.Tokenizers/Model/Cache.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
namespace Microsoft.ML.Tokenizers
1111
{
12-
internal sealed class Cache<TKey, TValue>
12+
internal sealed class Cache<TKey, TValue> where TKey : notnull
1313
{
1414
internal Cache() : this(Bpe.DefaultCacheCapacity) { }
1515

@@ -39,13 +39,13 @@ internal void Clear()
3939

4040
internal List<TValue> GetValues(IEnumerable<TKey> keys)
4141
{
42-
List<TValue>? values = new();
42+
List<TValue> values = new();
4343
_cacheLock.EnterReadLock();
4444
try
4545
{
4646
foreach (TKey key in keys)
4747
{
48-
if (Map.TryGetValue(key, out TValue value))
48+
if (Map.TryGetValue(key, out TValue? value))
4949
{
5050
values.Add(value);
5151
}
@@ -61,7 +61,7 @@ internal List<TValue> GetValues(IEnumerable<TKey> keys)
6161
_cacheLock.EnterReadLock();
6262
try
6363
{
64-
if (Map.TryGetValue(key, out TValue value))
64+
if (Map.TryGetValue(key, out TValue? value))
6565
{
6666
return value;
6767
}

src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ private Dictionary<string, int> GetVocabulary(Stream vocabularyStream)
429429
using StreamReader reader = new StreamReader(mergeStream);
430430
while (reader.Peek() >= 0)
431431
{
432-
splitContents.Add(reader.ReadLine());
432+
splitContents.Add(reader.ReadLine()!);
433433
}
434434
}
435435
catch (Exception e)
@@ -761,7 +761,11 @@ public void AddFromStream(Stream stream)
761761

762762
while (reader.Peek() >= 0)
763763
{
764-
string line = reader.ReadLine();
764+
string? line = reader.ReadLine();
765+
if (line is null)
766+
{
767+
continue;
768+
}
765769

766770
var splitLine = line.Trim().Split(' ');
767771
if (splitLine.Length != 2)

src/Microsoft.ML.Tokenizers/Model/Model.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public abstract class Model
3535
/// <param name="isSpecialToken">Indicate if the token is a special token.</param>
3636
/// <param name="accumulatedIds">The list of accumulated tokenized Ids.</param>
3737
/// <returns>True if the operation succeeded, false otherwise.</returns>
38-
public virtual bool TokenizeToIds(string sequence, bool isSpecialToken, List<int> accumulatedIds)
38+
public virtual bool TokenizeToIds(string sequence, bool isSpecialToken, IList<int> accumulatedIds)
3939
{
4040
if (accumulatedIds is null)
4141
{

src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -120,22 +120,21 @@ internal static (Dictionary<byte[], int>, Dictionary<string, int>, IReadOnlyDict
120120
{
121121
while (!reader.EndOfStream)
122122
{
123-
string line = reader.ReadLine();
123+
string? line = reader.ReadLine();
124124
if (string.IsNullOrWhiteSpace(line))
125125
{
126126
continue;
127127
}
128128

129-
var tokens = line.Split(' ');
130-
if (tokens.Length != 2)
129+
int spaceIndex = line.IndexOf(' ');
130+
if (spaceIndex <= 0 || spaceIndex >= line.Length - 1 || line.IndexOf(' ', spaceIndex + 1) >= 0)
131131
{
132132
throw new FormatException($"Invalid format in the BPE encoder file stream");
133133
}
134134

135-
byte[] tokenBytes = Convert.FromBase64String(tokens[0]);
136-
int rank = 0;
135+
byte[] tokenBytes = Helpers.FromBase64String(line, 0, spaceIndex);
137136

138-
if (int.TryParse(tokens[1], out rank))
137+
if (Helpers.TryParseInt32(line, spaceIndex + 1, out int rank))
139138
{
140139
encoder[tokenBytes] = rank;
141140
decoder[rank] = tokenBytes;
@@ -146,7 +145,7 @@ internal static (Dictionary<byte[], int>, Dictionary<string, int>, IReadOnlyDict
146145
}
147146
else
148147
{
149-
throw new FormatException($"Can't parse {tokens[1]} to integer");
148+
throw new FormatException($"Can't parse {line.Substring(spaceIndex)} to integer");
150149
}
151150
}
152151
}
@@ -242,7 +241,7 @@ public override IReadOnlyList<Token> Tokenize(string sequence, bool isSpecialTok
242241
/// <param name="isSpecialToken">Indicate if the token is a special token.</param>
243242
/// <param name="accumulatedIds">The list of accumulated Ids.</param>
244243
/// <returns>True if the operation succeeded, false otherwise.</returns>
245-
public override bool TokenizeToIds(string sequence, bool isSpecialToken, List<int> accumulatedIds)
244+
public override bool TokenizeToIds(string sequence, bool isSpecialToken, IList<int> accumulatedIds)
246245
{
247246
if (string.IsNullOrEmpty(sequence))
248247
{
@@ -320,7 +319,7 @@ public override bool TokenizeToIds(string sequence, bool isSpecialToken, List<in
320319
}
321320

322321
int[] idsToCache = BytePairEncoder.BytePairEncode(Encoding.UTF8.GetBytes(token), _encoder);
323-
_cache.Add(token, idsToCache.ToArray());
322+
_cache.Add(token, idsToCache);
324323

325324
if (idsToCache.Length == 1)
326325
{
@@ -338,12 +337,12 @@ public override bool TokenizeToIds(string sequence, bool isSpecialToken, List<in
338337
/// <returns>The mapped token of the Id.</returns>
339338
public override string? IdToToken(int id, bool skipSpecialTokens = false)
340339
{
341-
if (!skipSpecialTokens && _specialTokensDecoder is not null && _specialTokensDecoder.TryGetValue(id, out string token))
340+
if (!skipSpecialTokens && _specialTokensDecoder is not null && _specialTokensDecoder.TryGetValue(id, out string? token))
342341
{
343342
return token;
344343
}
345344

346-
if (_decoder.TryGetValue(id, out byte[] tokenBytes))
345+
if (_decoder.TryGetValue(id, out byte[]? tokenBytes))
347346
{
348347
return Encoding.UTF8.GetString(tokenBytes);
349348
}
@@ -363,11 +362,11 @@ public override bool TokenizeToIds(string sequence, bool isSpecialToken, List<in
363362

364363
foreach (int id in ids)
365364
{
366-
if (_decoder.TryGetValue(id, out byte[] tokenBytes))
365+
if (_decoder.TryGetValue(id, out byte[]? tokenBytes))
367366
{
368367
utf8Bytes.AddRange(tokenBytes);
369368
}
370-
else if (useSpecialTokens && _specialTokensDecoder!.TryGetValue(id, out string token))
369+
else if (useSpecialTokens && _specialTokensDecoder!.TryGetValue(id, out string? token))
371370
{
372371
utf8Bytes.AddRange(Encoding.UTF8.GetBytes(token));
373372
}

src/Microsoft.ML.Tokenizers/PreTokenizer/PreTokenizer.cs

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System;
66
using System.Collections;
77
using System.Collections.Generic;
8+
using System.Diagnostics;
89
using System.Text.RegularExpressions;
910

1011
namespace Microsoft.ML.Tokenizers
@@ -75,53 +76,87 @@ public abstract class PreTokenizer
7576
public abstract IEnumerable<Split> PreTokenize(string sentence, bool skipSpecialTokens = false);
7677
}
7778

78-
internal readonly struct RegexSplitEnumerable : IEnumerable<Split>
79+
internal sealed class RegexSplitEnumerable : IEnumerable<Split>
7980
{
80-
private readonly MatchCollection _matches;
81+
private readonly static Dictionary<string, Regex> _regexCache = new(StringComparer.Ordinal);
82+
private readonly Regex _regex;
83+
private readonly string _sentence;
8184

8285
public RegexSplitEnumerable(string sentence, string pattern)
8386
{
84-
_matches = Regex.Matches(sentence, pattern);
87+
Debug.Assert(sentence is not null);
88+
Debug.Assert(pattern is not null);
89+
90+
Regex? regex;
91+
lock (_regexCache)
92+
{
93+
if (!_regexCache.TryGetValue(pattern!, out regex))
94+
{
95+
regex = new Regex(pattern, RegexOptions.Compiled);
96+
_regexCache[pattern!] = regex;
97+
}
98+
}
99+
100+
_regex = regex;
101+
_sentence = sentence!;
85102
}
86103

87-
public IEnumerator<Split> GetEnumerator() => new RegexSplitEnumerator(_matches);
104+
public IEnumerator<Split> GetEnumerator() => new RegexSplitEnumerator(_regex, _sentence);
88105

89-
IEnumerator IEnumerable.GetEnumerator() => new RegexSplitEnumerator(_matches);
106+
IEnumerator IEnumerable.GetEnumerator() => new RegexSplitEnumerator(_regex, _sentence);
90107

91-
private struct RegexSplitEnumerator : IEnumerator<Split>
108+
private sealed class RegexSplitEnumerator : IEnumerator<Split>
92109
{
93110
private Split _current = default;
94-
private int _matchIndex = 0;
95-
private readonly MatchCollection _matches;
111+
private readonly Regex _regex;
112+
private Match? _tokenMatch;
113+
private readonly string _sentence;
96114

97-
public RegexSplitEnumerator(MatchCollection matches) => _matches = matches;
115+
public RegexSplitEnumerator(Regex regex, string sentence)
116+
{
117+
Debug.Assert(sentence is not null);
118+
Debug.Assert(regex is not null);
119+
120+
_regex = regex!;
121+
_sentence = sentence!;
122+
}
98123

99124
public Split Current => _current;
100125

101126
object IEnumerator.Current => _current;
102127

103128
public bool MoveNext()
104129
{
105-
if (_matchIndex >= _matches.Count)
130+
if (_tokenMatch is null)
131+
{
132+
_tokenMatch = _regex.Match(_sentence);
133+
}
134+
else if (!_tokenMatch.Success)
106135
{
107136
return false;
108137
}
138+
else
139+
{
140+
_tokenMatch = _tokenMatch.NextMatch();
141+
}
109142

110-
var match = _matches[_matchIndex++];
111-
_current = new Split(match.Value, (match.Index, match.Index + match.Length));
143+
if (!_tokenMatch.Success)
144+
{
145+
return false;
146+
}
147+
148+
_current = new Split(_tokenMatch.Value, (_tokenMatch.Index, _tokenMatch.Index + _tokenMatch.Length));
112149
return true;
113150
}
114151

115152
public void Reset()
116153
{
117-
_matchIndex = 0;
154+
_tokenMatch = null;
118155
}
119156

120157
public void Dispose()
121158
{
122159
}
123160
}
124161
}
125-
126-
127162
}

0 commit comments

Comments
 (0)