diff --git a/src/Microsoft.ML.Tokenizers/AssemblyInfo.cs b/src/Microsoft.ML.Tokenizers/AssemblyInfo.cs
new file mode 100644
index 0000000000..4b89b383d5
--- /dev/null
+++ b/src/Microsoft.ML.Tokenizers/AssemblyInfo.cs
@@ -0,0 +1,7 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+#if NET5_0_OR_GREATER
+[module: System.Runtime.CompilerServices.SkipLocalsInit]
+#endif
diff --git a/src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj b/src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj
index e50c62889b..d370145bad 100644
--- a/src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj
+++ b/src/Microsoft.ML.Tokenizers/Microsoft.ML.Tokenizers.csproj
@@ -5,6 +5,7 @@
netstandard2.0;net8.0
enable
Microsoft.ML.Tokenizers contains the implmentation of the tokenization used in the NLP transforms.
+ true
diff --git a/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs b/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs
index e750a5df6c..ad98ed917c 100644
--- a/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs
@@ -8,10 +8,7 @@
using System.Diagnostics;
using System.IO;
using System.Linq;
-using System.Runtime.CompilerServices;
-using System.Text;
using System.Text.Json;
-using System.Text.Json.Serialization;
namespace Microsoft.ML.Tokenizers
{
@@ -27,7 +24,7 @@ public sealed class EnglishRoberta : Model
private readonly IReadOnlyDictionary _byteToUnicode;
private readonly IReadOnlyDictionary _unicodeToByte;
private readonly string[] _charToString;
- private readonly Cache> _cache;
+ private readonly Cache> _cache;
///
/// Construct tokenizer object to use with the English Robert model.
@@ -72,7 +69,7 @@ public EnglishRoberta(string vocabularyPath, string mergePath, string highestOcc
}
_unicodeToByte = _byteToUnicode.Reverse();
- _cache = new Cache>();
+ _cache = new Cache>();
}
///
@@ -110,7 +107,7 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes
}
_unicodeToByte = _byteToUnicode.Reverse();
- _cache = new Cache>();
+ _cache = new Cache>();
}
//
@@ -226,17 +223,17 @@ public override IReadOnlyList Tokenize(string sequence, bool isSpecialTok
{
ArrayPool.Shared.Return(token);
ArrayPool.Shared.Return(indexMapping);
- return Bpe.EmptyTokensList;
+ return Array.Empty();
}
- if (_cache.TryGet(sequence, out IReadOnlyList? hit))
+ if (_cache.TryGet(sequence, out List? hit))
{
ArrayPool.Shared.Return(token);
ArrayPool.Shared.Return(indexMapping);
return ModifyTokenListOffsets(hit, indexMapping);
}
- IReadOnlyList result = EncodeToTokens(token.AsSpan().Slice(0, newTokenIndex), indexMapping);
+ List result = EncodeToTokens(token.AsSpan().Slice(0, newTokenIndex), indexMapping);
_cache.Set(sequence, result);
ArrayPool.Shared.Return(token);
ArrayPool.Shared.Return(indexMapping);
@@ -261,7 +258,7 @@ public override IReadOnlyList Tokenize(string sequence, bool isSpecialTok
private int TokenizeToIds(string sequence, IList? accumulatedIds)
{
- if (_cache.TryGet(sequence, out IReadOnlyList? hit))
+ if (_cache.TryGet(sequence, out List? hit))
{
if (accumulatedIds is not null)
{
@@ -299,7 +296,7 @@ private int TokenizeToIds(string sequence, IList? accumulatedIds)
return 0;
}
- IReadOnlyList result = EncodeToTokens(token.Slice(0, newTokenIndex), indexMapping);
+ List result = EncodeToTokens(token.Slice(0, newTokenIndex), indexMapping);
_cache.Set(sequence, result);
return result.Count;
}
diff --git a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs
index bd9a376e20..9935dd6428 100644
--- a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs
@@ -1,13 +1,15 @@
-// Licensed to the .NET Foundation under one or more agreements.
+// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
+using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Text;
+using System.Threading.Tasks;
namespace Microsoft.ML.Tokenizers
{
@@ -16,14 +18,12 @@ namespace Microsoft.ML.Tokenizers
///
public sealed class Tiktoken : Model
{
- private Dictionary _encoder = null!;
- private IReadOnlyDictionary _decoder = null!;
+ private readonly Dictionary, int> _encoder = null!;
+ private readonly IReadOnlyDictionary _decoder = null!;
private readonly LruCache _cache;
- private IReadOnlyDictionary? _specialTokensEncoder;
- private Dictionary? _specialTokensDecoder;
-
- private Dictionary _vocab = null!;
- private static readonly List _emptyTokenList = new();
+ private readonly IReadOnlyDictionary? _specialTokensEncoder;
+ private readonly Dictionary? _specialTokensDecoder;
+ private readonly Dictionary _vocab = null!;
///
/// Create a new Tiktoken tokenizer object.
@@ -33,17 +33,9 @@ public sealed class Tiktoken : Model
/// The size of the cache to use.
/// Thrown when is null or empty.
/// Thrown when failed to load the BPE rank file.
- public Tiktoken(string tikTokenBpeFile, IReadOnlyDictionary? specialTokensEncoder = null, int cacheSize = LruCache.DefaultCacheSize) : this(cacheSize)
+ public Tiktoken(string tikTokenBpeFile, IReadOnlyDictionary? specialTokensEncoder = null, int cacheSize = LruCache.DefaultCacheSize) :
+ this(string.IsNullOrEmpty(tikTokenBpeFile) ? throw new ArgumentNullException(nameof(tikTokenBpeFile)) : File.OpenRead(tikTokenBpeFile), specialTokensEncoder, cacheSize, disposeStream: true)
{
- if (string.IsNullOrEmpty(tikTokenBpeFile))
- {
- throw new ArgumentNullException(nameof(tikTokenBpeFile));
- }
-
- using (Stream stream = File.OpenRead(tikTokenBpeFile))
- {
- Initialize(stream, specialTokensEncoder);
- }
}
///
@@ -54,17 +46,17 @@ public Tiktoken(string tikTokenBpeFile, IReadOnlyDictionary? specia
/// The size of the cache to use.
/// Thrown when is null or empty.
/// Thrown when failed to load the BPE rank file.
- public Tiktoken(Stream tikTokenBpeFileStream, IReadOnlyDictionary? specialTokensEncoder = null, int cacheSize = LruCache.DefaultCacheSize) : this(cacheSize)
+ public Tiktoken(Stream tikTokenBpeFileStream, IReadOnlyDictionary? specialTokensEncoder = null, int cacheSize = LruCache.DefaultCacheSize) :
+ this(tikTokenBpeFileStream ?? throw new ArgumentNullException(nameof(tikTokenBpeFileStream)), specialTokensEncoder, cacheSize, disposeStream: false)
{
- Initialize(tikTokenBpeFileStream, specialTokensEncoder);
}
internal Tiktoken(
- Dictionary encoder,
- IReadOnlyDictionary decoder,
- Dictionary vocab,
- IReadOnlyDictionary? specialTokensEncoder = null,
- int cacheSize = LruCache.DefaultCacheSize) : this(cacheSize)
+ Dictionary, int> encoder,
+ IReadOnlyDictionary decoder,
+ Dictionary vocab,
+ IReadOnlyDictionary? specialTokensEncoder = null,
+ int cacheSize = LruCache.DefaultCacheSize) : this(cacheSize)
{
Debug.Assert(encoder is not null);
Debug.Assert(decoder is not null);
@@ -81,36 +73,42 @@ internal Tiktoken(
}
}
- private Tiktoken(int cacheSize)
- {
- _cache = new LruCache(cacheSize);
- }
-
- private void Initialize(Stream tikTokenBpeFileStream, IReadOnlyDictionary? specialTokensEncoder = null)
+ private Tiktoken(Stream tikTokenBpeFileStream, IReadOnlyDictionary? specialTokensEncoder, int cacheSize, bool disposeStream) : this(cacheSize)
{
- if (tikTokenBpeFileStream is null)
+ try
{
- throw new ArgumentNullException(nameof(tikTokenBpeFileStream));
- }
-
- (_encoder, _vocab, _decoder) = LoadTikTokenBpe(tikTokenBpeFileStream);
+ (_encoder, _vocab, _decoder) = LoadTikTokenBpeAsync(tikTokenBpeFileStream, useAsync: false).GetAwaiter().GetResult();
- _specialTokensEncoder = specialTokensEncoder;
- if (_specialTokensEncoder is not null)
+ _specialTokensEncoder = specialTokensEncoder;
+ if (_specialTokensEncoder is not null)
+ {
+ _specialTokensDecoder = _specialTokensEncoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key);
+ }
+ }
+ finally
{
- _specialTokensDecoder = _specialTokensEncoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key);
+ if (disposeStream)
+ {
+ tikTokenBpeFileStream.Dispose();
+ }
}
}
+ private Tiktoken(int cacheSize)
+ {
+ _cache = new LruCache(cacheSize);
+ }
+
///
/// Load BPE rank dictionary from a stream.
///
/// Stream to the BPE rank file
+ /// Whether to perform I/O synchronously or asynchronously.
/// Map of byte[] to integer token id
///
- internal static (Dictionary, Dictionary, IReadOnlyDictionary) LoadTikTokenBpe(Stream tikTokenBpeFileStream)
+ internal static async ValueTask<(Dictionary, int>, Dictionary, IReadOnlyDictionary)> LoadTikTokenBpeAsync(Stream tikTokenBpeFileStream, bool useAsync)
{
- var encoder = new Dictionary(new ByteArrayComparer());
+ var encoder = new Dictionary, int>(ReadOnlyMemoryByteComparer.Instance);
var vocab = new Dictionary();
var decoder = new Dictionary();
@@ -118,11 +116,17 @@ internal static (Dictionary, Dictionary, IReadOnlyDict
{
using (StreamReader reader = new StreamReader(tikTokenBpeFileStream))
{
- while (!reader.EndOfStream)
+ while (true)
{
- string? line = reader.ReadLine();
+ string? line = useAsync ?
+ await reader.ReadLineAsync().ConfigureAwait(false) :
+ reader.ReadLine();
if (string.IsNullOrWhiteSpace(line))
{
+ if (line is null)
+ {
+ break;
+ }
continue;
}
@@ -172,11 +176,11 @@ internal static (Dictionary, Dictionary, IReadOnlyDict
/// The list of tokens generated from the sequence tokenization.
public override IReadOnlyList Tokenize(string sequence, bool isSpecialToken)
{
- List tokens;
+ Token[] tokens;
if (string.IsNullOrEmpty(sequence))
{
- return _emptyTokenList;
+ return Array.Empty();
}
if (isSpecialToken)
@@ -196,12 +200,12 @@ public override IReadOnlyList Tokenize(string sequence, bool isSpecialTok
if (_cache.Lookup(sequence, out int[] ids))
{
- tokens = new(ids.Length);
- tokens.Add(new Token(ids[0], sequence, (0, sequence.Length)));
+ tokens = new Token[ids.Length];
+ tokens[0] = new Token(ids[0], sequence, (0, sequence.Length));
for (int i = 1; i < ids.Length; i++)
{
// One word split mapped to multiple Ids. Make the offset of the remaining token point at the end with zero width.
- tokens.Add(new Token(ids[i], "", (sequence.Length, sequence.Length)));
+ tokens[i] = new Token(ids[i], "", (sequence.Length, sequence.Length));
}
return tokens;
@@ -213,17 +217,22 @@ public override IReadOnlyList Tokenize(string sequence, bool isSpecialTok
return new List { new(mappedId, sequence, (0, sequence.Length)) };
}
- int[] encodedIds = BytePairEncoder.BytePairEncode(Encoding.UTF8.GetBytes(sequence), _encoder);
+ byte[] arrayPoolArray = ArrayPool.Shared.Rent(Encoding.UTF8.GetMaxByteCount(sequence.Length));
+ int encodedLength = GetUtf8Bytes(sequence.AsSpan(), arrayPoolArray);
+
+ int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
+ Debug.Assert(encodedIds.Length > 0);
_cache.Add(sequence, encodedIds);
- tokens = new List(encodedIds.Length);
- tokens.Add(new Token(encodedIds[0], sequence, (0, sequence.Length)));
+ tokens = new Token[encodedIds.Length];
+ tokens[0] = new Token(encodedIds[0], sequence, (0, sequence.Length));
for (int i = 1; i < encodedIds.Length; i++)
{
// One word split mapped to multiple Ids. Make the offset of the remaining token point at the end with zero width.
- tokens.Add(new Token(encodedIds[i], "", (sequence.Length, sequence.Length)));
+ tokens[i] = new Token(encodedIds[i], "", (sequence.Length, sequence.Length));
}
+ ArrayPool.Shared.Return(arrayPoolArray);
return tokens;
}
@@ -262,10 +271,15 @@ public override void TokenizeToIds(string sequence, bool isSpecialToken, IList.Shared.Rent(Encoding.UTF8.GetMaxByteCount(sequence.Length));
+ int encodedLength = GetUtf8Bytes(sequence.AsSpan(), arrayPoolArray);
+
+ int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
_cache.Add(sequence, encodedIds);
accumulatedIds.AddRange(encodedIds);
+
+ ArrayPool.Shared.Return(arrayPoolArray);
return;
}
@@ -284,7 +298,7 @@ public override int CountTokens(string sequence, bool isSpecialToken)
if (isSpecialToken && _specialTokensEncoder is not null)
{
- return _specialTokensEncoder.TryGetValue(sequence, out int id) ? 1 : 0;
+ return _specialTokensEncoder.TryGetValue(sequence, out _) ? 1 : 0;
}
if (_cache.Lookup(sequence, out int[] ids))
@@ -292,14 +306,18 @@ public override int CountTokens(string sequence, bool isSpecialToken)
return ids.Length;
}
- if (_vocab.TryGetValue(sequence, out int mappedId))
+ if (_vocab.TryGetValue(sequence, out _))
{
return 1;
}
- int[] encodedIds = BytePairEncoder.BytePairEncode(Encoding.UTF8.GetBytes(sequence), _encoder);
+ byte[] arrayPoolArray = ArrayPool.Shared.Rent(Encoding.UTF8.GetMaxByteCount(sequence.Length));
+ int encodedLength = GetUtf8Bytes(sequence.AsSpan(), arrayPoolArray);
+
+ int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
_cache.Add(sequence, encodedIds);
+ ArrayPool.Shared.Return(arrayPoolArray);
return encodedIds.Length;
}
@@ -343,15 +361,25 @@ public override int CountTokens(string sequence, bool isSpecialToken)
return id;
}
- int[] idsToCache = BytePairEncoder.BytePairEncode(Encoding.UTF8.GetBytes(token), _encoder);
- _cache.Add(token, idsToCache);
+ byte[] arrayPoolArray = ArrayPool.Shared.Rent(Encoding.UTF8.GetMaxByteCount(token.Length));
+ try
+ {
+ int encodedLength = GetUtf8Bytes(token.AsSpan(), arrayPoolArray);
+
+ int[] idsToCache = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder);
+ _cache.Add(token, idsToCache);
+
+ if (idsToCache.Length == 1)
+ {
+ return idsToCache[0];
+ }
- if (idsToCache.Length == 1)
+ return null;
+ }
+ finally
{
- return idsToCache[0];
+ ArrayPool.Shared.Return(arrayPoolArray);
}
-
- return null;
}
///
@@ -382,26 +410,66 @@ public override int CountTokens(string sequence, bool isSpecialToken)
return null;
}
- List utf8Bytes = new();
- bool useSpecialTokens = !skipSpecialTokens && _specialTokensDecoder is not null;
-
- foreach (int id in ids)
+ byte[]? arrayPoolArray = null;
+ try
{
- if (_decoder.TryGetValue(id, out byte[]? tokenBytes))
+ Span utf8Bytes = stackalloc byte[256];
+ int utf8ByteCount = 0;
+
+ bool useSpecialTokens = !skipSpecialTokens && _specialTokensDecoder is not null;
+
+ foreach (int id in ids)
{
- utf8Bytes.AddRange(tokenBytes);
+ if (_decoder.TryGetValue(id, out byte[]? tokenBytes))
+ {
+ if ((uint)utf8ByteCount + (uint)tokenBytes.Length > (uint)utf8Bytes.Length)
+ {
+ ArrayPoolGrow(ref utf8Bytes, ref arrayPoolArray, utf8ByteCount + tokenBytes.Length);
+ }
+
+ tokenBytes.AsSpan().CopyTo(utf8Bytes.Slice(utf8ByteCount));
+ utf8ByteCount += tokenBytes.Length;
+ }
+ else if (useSpecialTokens && _specialTokensDecoder!.TryGetValue(id, out string? token))
+ {
+ while (true)
+ {
+ if (TryGetUtf8Bytes(token.AsSpan(), utf8Bytes.Slice(utf8ByteCount), out int bytesWritten))
+ {
+ utf8ByteCount += bytesWritten;
+ break;
+ }
+
+ ArrayPoolGrow(ref utf8Bytes, ref arrayPoolArray, utf8ByteCount + Encoding.UTF8.GetByteCount(token));
+ }
+ }
+ else
+ {
+ return null;
+ }
}
- else if (useSpecialTokens && _specialTokensDecoder!.TryGetValue(id, out string? token))
+
+ return GetString(utf8Bytes.Slice(0, utf8ByteCount));
+ }
+ finally
+ {
+ if (arrayPoolArray is not null)
{
- utf8Bytes.AddRange(Encoding.UTF8.GetBytes(token));
+ ArrayPool.Shared.Return(arrayPoolArray);
}
- else
+ }
+
+ static void ArrayPoolGrow(ref Span utf8Bytes, ref byte[]? arrayPoolArray, int requiredCapacity)
+ {
+ byte[] tmp = ArrayPool.Shared.Rent(Math.Max(utf8Bytes.Length * 2, requiredCapacity));
+ utf8Bytes.CopyTo(tmp.AsSpan());
+ byte[]? toReturn = arrayPoolArray;
+ utf8Bytes = arrayPoolArray = tmp;
+ if (toReturn is not null)
{
- return null;
+ ArrayPool.Shared.Return(toReturn);
}
}
-
- return utf8Bytes.Count > 0 ? Encoding.UTF8.GetString(utf8Bytes.ToArray()) : string.Empty;
}
///
@@ -426,5 +494,50 @@ public override int CountTokens(string sequence, bool isSpecialToken)
/// Gets a trainer object to use in training the model.
///
public override Trainer? GetTrainer() => throw new NotImplementedException();
+
+ private static unsafe int GetUtf8Bytes(ReadOnlySpan source, Span destination)
+ {
+#if NETCOREAPP
+ return Encoding.UTF8.GetBytes(source, destination);
+#else
+ fixed (char* sourcePtr = source)
+ fixed (byte* destPtr = destination)
+ {
+ return Encoding.UTF8.GetBytes(sourcePtr, source.Length, destPtr, destination.Length);
+ }
+#endif
+ }
+
+ private static unsafe bool TryGetUtf8Bytes(ReadOnlySpan source, Span destination, out int bytesWritten)
+ {
+#if NET8_0_OR_GREATER
+ return Encoding.UTF8.TryGetBytes(source, destination, out bytesWritten);
+#else
+ fixed (char* sourcePtr = source)
+ fixed (byte* destPtr = destination)
+ {
+ if (Encoding.UTF8.GetByteCount(sourcePtr, source.Length) <= destination.Length)
+ {
+ bytesWritten = Encoding.UTF8.GetBytes(sourcePtr, source.Length, destPtr, destination.Length);
+ return true;
+ }
+
+ bytesWritten = 0;
+ return false;
+ }
+#endif
+ }
+
+ private static unsafe string GetString(ReadOnlySpan utf8Bytes)
+ {
+#if NETCOREAPP
+ return Encoding.UTF8.GetString(utf8Bytes);
+#else
+ fixed (byte* sourcePtr = utf8Bytes)
+ {
+ return Encoding.UTF8.GetString(sourcePtr, utf8Bytes.Length);
+ }
+#endif
+ }
}
-}
\ No newline at end of file
+}
diff --git a/src/Microsoft.ML.Tokenizers/PreTokenizer/PreTokenizer.cs b/src/Microsoft.ML.Tokenizers/PreTokenizer/PreTokenizer.cs
index 94acfcb96f..aef8b13c42 100644
--- a/src/Microsoft.ML.Tokenizers/PreTokenizer/PreTokenizer.cs
+++ b/src/Microsoft.ML.Tokenizers/PreTokenizer/PreTokenizer.cs
@@ -3,9 +3,7 @@
// See the LICENSE file in the project root for more information.
using System;
-using System.Collections;
using System.Collections.Generic;
-using System.Diagnostics;
using System.Text.RegularExpressions;
namespace Microsoft.ML.Tokenizers
@@ -15,14 +13,22 @@ namespace Microsoft.ML.Tokenizers
/// in the original string. These offsets are in the `original` referential.
/// It also contains any `Token` associated to the current split.
///
- public readonly struct Split : IEquatable
+ public struct Split : IEquatable
{
+ private readonly string? _originalString;
+ private string? _tokenString;
+
///
/// Gets the underlying split token. Each SubString is represented by a token
/// and in the end we might be carrying a lot of SubString representing various parts of the
/// original input string.
///
- public string TokenString { get; }
+ public string TokenString => _tokenString ??= _originalString!.Substring(Offset.Index, Offset.End - Offset.Index);
+
+ ///
+ /// Gets the underlying split token as a span.
+ ///
+ public ReadOnlySpan TokenSpan => _tokenString is string s ? s.AsSpan() : _originalString.AsSpan(Offset.Index, Offset.End - Offset.Index);
///
/// Returns the offset mapping to the original string
@@ -37,7 +43,15 @@ namespace Microsoft.ML.Tokenizers
/// Indicates whether the token is a special token
public Split(string token, (int Index, int End) offset, bool isSpecialToken = false)
{
- TokenString = token;
+ _tokenString = token;
+ Offset = offset;
+ IsSpecialToken = isSpecialToken;
+ }
+
+ internal Split(string originalString, string? token, (int Index, int End) offset, bool isSpecialToken = false)
+ {
+ _originalString = originalString;
+ _tokenString = token;
Offset = offset;
IsSpecialToken = isSpecialToken;
}
@@ -52,21 +66,18 @@ public Split(string token, (int Index, int End) offset, bool isSpecialToken = fa
///
/// The Split object to compare with the current object.
public bool Equals(Split other) =>
- TokenString == other.TokenString &&
+ (_originalString == other._originalString || TokenString == other.TokenString) &&
IsSpecialToken == other.IsSpecialToken &&
Offset.Index == other.Offset.Index &&
Offset.End == other.Offset.End;
}
-
///
/// Base class for all pre-tokenizers classes.
/// The PreTokenizer is in charge of doing the pre-segmentation step.
///
public abstract class PreTokenizer
{
- internal static readonly IReadOnlyList EmptyList = new List();
-
///
/// Splits the given string in multiple substrings at the word boundary, keeping track of the offsets of said substrings from the original string.
///
@@ -74,89 +85,36 @@ public abstract class PreTokenizer
/// Indicates whether to skip the special tokens.
/// The list of the splits containing the tokens and the token's offsets to the original string.
public abstract IEnumerable PreTokenize(string sentence, bool skipSpecialTokens = false);
- }
- internal sealed class RegexSplitEnumerable : IEnumerable
- {
- private readonly static Dictionary _regexCache = new(StringComparer.Ordinal);
- private readonly Regex _regex;
- private readonly string _sentence;
-
- public RegexSplitEnumerable(string sentence, string pattern)
+ internal static IEnumerable SplitSentence(string sentence, Regex regex)
{
- Debug.Assert(sentence is not null);
- Debug.Assert(pattern is not null);
-
- Regex? regex;
- lock (_regexCache)
+ (int Offset, int Length) match;
+ int beginning = 0;
+ while (TryGetMatch(regex, sentence, beginning, sentence.Length - beginning, out match))
{
- if (!_regexCache.TryGetValue(pattern!, out regex))
- {
- regex = new Regex(pattern, RegexOptions.Compiled);
- _regexCache[pattern!] = regex;
- }
+ yield return new Split(sentence, null, (match.Offset, match.Offset + match.Length));
+ beginning = match.Offset + match.Length;
}
-
- _regex = regex;
- _sentence = sentence!;
}
- public IEnumerator GetEnumerator() => new RegexSplitEnumerator(_regex, _sentence);
-
- IEnumerator IEnumerable.GetEnumerator() => new RegexSplitEnumerator(_regex, _sentence);
-
- private sealed class RegexSplitEnumerator : IEnumerator
+ internal static bool TryGetMatch(Regex regex, string sentence, int beginning, int length, out (int offset, int length) match)
{
- private Split _current = default;
- private readonly Regex _regex;
- private Match? _tokenMatch;
- private readonly string _sentence;
-
- public RegexSplitEnumerator(Regex regex, string sentence)
- {
- Debug.Assert(sentence is not null);
- Debug.Assert(regex is not null);
-
- _regex = regex!;
- _sentence = sentence!;
- }
-
- public Split Current => _current;
-
- object IEnumerator.Current => _current;
-
- public bool MoveNext()
+#if NET7_0_OR_GREATER
+ foreach (ValueMatch m in regex.EnumerateMatches(sentence.AsSpan(beginning, length)))
{
- if (_tokenMatch is null)
- {
- _tokenMatch = _regex.Match(_sentence);
- }
- else if (!_tokenMatch.Success)
- {
- return false;
- }
- else
- {
- _tokenMatch = _tokenMatch.NextMatch();
- }
-
- if (!_tokenMatch.Success)
- {
- return false;
- }
-
- _current = new Split(_tokenMatch.Value, (_tokenMatch.Index, _tokenMatch.Index + _tokenMatch.Length));
+ match = (beginning + m.Index, m.Length);
return true;
}
-
- public void Reset()
- {
- _tokenMatch = null;
- }
-
- public void Dispose()
+#else
+ Match m = regex.Match(sentence, beginning, length);
+ if (m.Success)
{
+ match = (m.Index, m.Length);
+ return true;
}
+#endif
+ match = default;
+ return false;
}
}
}
diff --git a/src/Microsoft.ML.Tokenizers/PreTokenizer/Roberta.cs b/src/Microsoft.ML.Tokenizers/PreTokenizer/Roberta.cs
index e07e755c29..8fd748d838 100644
--- a/src/Microsoft.ML.Tokenizers/PreTokenizer/Roberta.cs
+++ b/src/Microsoft.ML.Tokenizers/PreTokenizer/Roberta.cs
@@ -4,20 +4,28 @@
using System;
using System.Collections.Generic;
+using System.Text.RegularExpressions;
namespace Microsoft.ML.Tokenizers
{
///
/// The pre-tokenizer for Roberta English tokenizer.
///
- public sealed class RobertaPreTokenizer : PreTokenizer
+ public sealed partial class RobertaPreTokenizer : PreTokenizer
{
///
/// Gets a singleton instance of the Roberta pre-tokenizer..
///
- public static readonly RobertaPreTokenizer Instance = new RobertaPreTokenizer();
+ public static RobertaPreTokenizer Instance { get; } = new RobertaPreTokenizer();
- private const string Pattern = @"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+";
+ private const string PretokenizePattern = @"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+";
+#if NET7_0_OR_GREATER
+ [GeneratedRegex(PretokenizePattern)]
+ private static partial Regex PretokenizeRegex();
+#else
+ private static readonly Regex _regex = new Regex(PretokenizePattern, RegexOptions.Compiled);
+ private static Regex PretokenizeRegex() => _regex;
+#endif
///
/// Splits the given string in multiple substrings at the word boundary, keeping track of the offsets of said substrings from the original string.
@@ -29,10 +37,10 @@ public override IEnumerable PreTokenize(string sentence, bool skipSpecial
{
if (string.IsNullOrEmpty(sentence))
{
- return EmptyList;
+ return Array.Empty();
}
- return new RegexSplitEnumerable(sentence, Pattern);
+ return SplitSentence(sentence, PretokenizeRegex());
}
}
}
diff --git a/src/Microsoft.ML.Tokenizers/PreTokenizer/TikTokenPreTokenizer.cs b/src/Microsoft.ML.Tokenizers/PreTokenizer/TikTokenPreTokenizer.cs
index b64096de71..7651de599d 100644
--- a/src/Microsoft.ML.Tokenizers/PreTokenizer/TikTokenPreTokenizer.cs
+++ b/src/Microsoft.ML.Tokenizers/PreTokenizer/TikTokenPreTokenizer.cs
@@ -1,11 +1,9 @@
-// Licensed to the .NET Foundation under one or more agreements.
+// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
-using System.Collections;
using System.Collections.Generic;
-using System.Diagnostics;
using System.Linq;
using System.Text.RegularExpressions;
@@ -34,7 +32,7 @@ public TikTokenPreTokenizer(Regex regex, IReadOnlyDictionary? speci
_regex = regex;
- if (specialTokensEncoder is not null && specialTokensEncoder.Count > 0)
+ if (specialTokensEncoder is { Count: > 0 })
{
_specialTokensRegex = new Regex(string.Join("|", specialTokensEncoder.Keys.Select(s => Regex.Escape(s))), RegexOptions.Compiled);
}
@@ -50,131 +48,41 @@ public override IEnumerable PreTokenize(string sentence, bool skipSpecial
{
if (string.IsNullOrEmpty(sentence))
{
- return EmptyList;
+ return Array.Empty();
}
- return new TokenizationEnumerable(sentence, _regex, skipSpecialTokens ? null : _specialTokensRegex);
- }
-
- private sealed class TokenizationEnumerable : IEnumerable
- {
- private readonly string _sentence;
- private readonly Regex _regex;
- private readonly Regex? _specialTokensRegex;
-
- public TokenizationEnumerable(string sentence, Regex regex, Regex? specialTokensRegex)
- {
- if (sentence is null)
- {
- throw new ArgumentNullException(nameof(sentence));
- }
-
- if (regex is null)
- {
- throw new ArgumentNullException(nameof(regex));
- }
+ return SplitSentences(sentence, _regex, skipSpecialTokens ? null : _specialTokensRegex);
- _sentence = sentence;
- _regex = regex;
- _specialTokensRegex = specialTokensRegex;
- }
-
- public IEnumerator GetEnumerator() => new TokenizationEnumerator(_sentence, _regex, _specialTokensRegex);
- IEnumerator IEnumerable.GetEnumerator() => new TokenizationEnumerator(_sentence, _regex, _specialTokensRegex);
-
- private sealed class TokenizationEnumerator : IEnumerator
+ static IEnumerable SplitSentences(string sentence, Regex regex, Regex? specialTokensRegex)
{
- private Split _current = default;
- private int _startIndex;
- private int _offset;
- private MatchCollection? _matches;
- private int _matchIndex;
- private Match? _specialTokenMatch;
- private readonly Regex _regex;
- private readonly string _sentence;
- private readonly Regex? _specialTokensRegex;
-
- public TokenizationEnumerator(string sentence, Regex regex, Regex? specialTokensRegex)
- {
- Debug.Assert(sentence is not null);
- Debug.Assert(regex is not null);
+ (int Offset, int Length) match;
+ int beginning = 0;
- _sentence = sentence!;
- _regex = regex!;
- _specialTokensRegex = specialTokensRegex;
- _startIndex = 0;
- _offset = 0;
- }
-
- object IEnumerator.Current => _current;
-
- Split IEnumerator.Current => _current;
-
- public bool MoveNext()
+ if (specialTokensRegex is not null)
{
- if (_matches is not null && _matchIndex < _matches.Count)
+ while (true)
{
- Match match = _matches[_matchIndex];
- _current = new Split(match.Value, (match.Index + _offset, match.Index + _offset + match.Length), false);
- _startIndex += match.Length;
- _matchIndex++;
- return true;
+ (int Offset, int Length) specialMatch;
+ if (!TryGetMatch(specialTokensRegex, sentence, beginning, sentence.Length - beginning, out specialMatch))
+ {
+ break;
+ }
+
+ while (TryGetMatch(regex, sentence, beginning, specialMatch.Offset - beginning, out match))
+ {
+ yield return new Split(sentence, null, (match.Offset, match.Offset + match.Length));
+ beginning = match.Offset + match.Length;
+ }
+
+ yield return new Split(sentence, null, (specialMatch.Offset, specialMatch.Offset + specialMatch.Length), isSpecialToken: true);
+ beginning = specialMatch.Offset + specialMatch.Length;
}
-
- if (_specialTokenMatch is not null && _specialTokenMatch.Success)
- {
- _current = new Split(_specialTokenMatch.Value, (_specialTokenMatch.Index, _specialTokenMatch.Index + _specialTokenMatch.Length), true);
- _startIndex += _specialTokenMatch.Length;
- _specialTokenMatch = null;
- return true;
- }
-
- if (_startIndex >= _sentence.Length)
- {
- return false;
- }
-
- if (_specialTokensRegex is not null)
- {
- _specialTokenMatch = _specialTokensRegex.Match(_sentence, _startIndex);
- _offset = _startIndex;
- _matches = _regex.Matches(_sentence.Substring(_startIndex, _specialTokenMatch.Success ? _specialTokenMatch.Index - _startIndex : _sentence.Length - _startIndex));
- }
- else
- {
- _matches = _regex.Matches(_sentence);
- }
-
- if (_matches.Count > 0)
- {
- Match match = _matches[0];
- _current = new Split(match.Value, (match.Index + _startIndex, match.Index + _startIndex + match.Length), false);
- _startIndex += match.Length;
- _matchIndex = 1;
- return true;
- }
- else if (_specialTokenMatch is not null && _specialTokenMatch.Success)
- {
- _current = new Split(_specialTokenMatch.Value, (_specialTokenMatch.Index, _specialTokenMatch.Index + _specialTokenMatch.Length), true);
- _startIndex += _specialTokenMatch.Length;
- _specialTokenMatch = null;
- return true;
- }
-
- return false;
- }
-
- public void Reset()
- {
- _current = default;
- _startIndex = 0;
- _matches = null;
- _matchIndex = -1;
- _specialTokenMatch = null;
}
- public void Dispose()
+ while (TryGetMatch(regex, sentence, beginning, sentence.Length - beginning, out match))
{
+ yield return new Split(sentence, null, (match.Offset, match.Offset + match.Length));
+ beginning = match.Length + match.Offset;
}
}
}
diff --git a/src/Microsoft.ML.Tokenizers/PreTokenizer/Whitespace.cs b/src/Microsoft.ML.Tokenizers/PreTokenizer/Whitespace.cs
index d2d0158885..2a53bec814 100644
--- a/src/Microsoft.ML.Tokenizers/PreTokenizer/Whitespace.cs
+++ b/src/Microsoft.ML.Tokenizers/PreTokenizer/Whitespace.cs
@@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
+using System.Text.RegularExpressions;
namespace Microsoft.ML.Tokenizers
{
@@ -11,14 +12,21 @@ namespace Microsoft.ML.Tokenizers
/// The pre-tokenizer which split the text at the word boundary.
/// The word is a set of alphabet, numeric, and underscore characters.
///
- public sealed class WhiteSpace : PreTokenizer
+ public sealed partial class WhiteSpace : PreTokenizer
{
///
/// Gets a singleton instance of the WhiteSpace pre-tokenizer..
///
- public static readonly WhiteSpace Instance = new WhiteSpace();
+ public static WhiteSpace Instance { get; } = new WhiteSpace();
- private const string Pattern = @"\w+|[^\w\s]+";
+ private const string PretokenizePattern = @"\w+|[^\w\s]+";
+#if NET7_0_OR_GREATER
+ [GeneratedRegex(PretokenizePattern)]
+ private static partial Regex PretokenizeRegex();
+#else
+ private static readonly Regex _regex = new Regex(PretokenizePattern, RegexOptions.Compiled);
+ private static Regex PretokenizeRegex() => _regex;
+#endif
///
/// Splits the given string in multiple substrings at the word boundary, keeping track of the offsets of said substrings from the original string.
@@ -30,10 +38,10 @@ public override IEnumerable PreTokenize(string sentence, bool skipSpecial
{
if (string.IsNullOrEmpty(sentence))
{
- return EmptyList;
+ return Array.Empty();
}
- return new RegexSplitEnumerable(sentence, Pattern);
+ return SplitSentence(sentence, PretokenizeRegex());
}
}
}
diff --git a/src/Microsoft.ML.Tokenizers/Tokenizer.cs b/src/Microsoft.ML.Tokenizers/Tokenizer.cs
index aee4c84bcb..d002f55833 100644
--- a/src/Microsoft.ML.Tokenizers/Tokenizer.cs
+++ b/src/Microsoft.ML.Tokenizers/Tokenizer.cs
@@ -3,10 +3,10 @@
// See the LICENSE file in the project root for more information.
using System;
+using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
-using System.Linq;
using System.Net.Http;
using System.Text.RegularExpressions;
using System.Threading.Tasks;
@@ -16,7 +16,7 @@ namespace Microsoft.ML.Tokenizers
///
/// A Tokenizer works as a pipeline. It processes some raw text as input and outputs a TokenizerResult object.
///
- public class Tokenizer
+ public partial class Tokenizer
{
///
/// Create a new Tokenizer object.
@@ -282,15 +282,14 @@ private enum ModelEncoding
GPT2
}
- private static readonly IReadOnlyDictionary _modelPrefixToEncoding =
- new Dictionary()
- {
+ private static readonly (string Prefix, ModelEncoding Encoding)[] _modelPrefixToEncoding =
+ [
// chat
- { "gpt-4-", ModelEncoding.Cl100kBase }, // e.g., gpt-4-0314, etc., plus gpt-4-32k
- { "gpt-3.5-turbo-", ModelEncoding.Cl100kBase } // e.g, gpt-3.5-turbo-0301, -0401, etc.
- };
+ ( "gpt-4-", ModelEncoding.Cl100kBase ), // e.g., gpt-4-0314, etc., plus gpt-4-32k
+ ( "gpt-3.5-turbo-", ModelEncoding.Cl100kBase ) // e.g, gpt-3.5-turbo-0301, -0401, etc.
+ ];
- private static readonly IReadOnlyDictionary _modelToEncoding =
+ private static readonly Dictionary _modelToEncoding =
new Dictionary(StringComparer.OrdinalIgnoreCase)
{
// chat
@@ -353,15 +352,15 @@ public static async Task CreateByModelNameAsync(
IReadOnlyDictionary? extraSpecialTokens = null,
Normalizer? normalizer = null)
{
- var encoder = ModelEncoding.None;
+ ModelEncoding encoder;
if (!_modelToEncoding.TryGetValue(modelName, out encoder))
{
- foreach (KeyValuePair kvp in _modelPrefixToEncoding)
+ foreach ((string Prefix, ModelEncoding Encoding) in _modelPrefixToEncoding)
{
- if (modelName.StartsWith(kvp.Key, StringComparison.OrdinalIgnoreCase))
+ if (modelName.StartsWith(Prefix, StringComparison.OrdinalIgnoreCase))
{
- encoder = kvp.Value;
+ encoder = Encoding;
break;
}
}
@@ -372,16 +371,30 @@ public static async Task CreateByModelNameAsync(
throw new NotImplementedException($"Doesn't support this model [{modelName}]");
}
- return await CreateByEncoderNameAsync(encoder, extraSpecialTokens, normalizer);
+ return await CreateByEncoderNameAsync(encoder, extraSpecialTokens, normalizer).ConfigureAwait(false);
}
private const string Cl100kBaseRegexPattern = @"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+";
private const string P50kBaseRegexPattern = @"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+";
- const string Cl100kBaseVocabUrl = @"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken";
- const string P50RegexUrl = @"https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken";
- const string R50RegexUrl = @"https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken";
- const string GPT2Url = @"https://pythia.blob.core.windows.net/public/encoding/gpt2.tiktoken";
+ private const string Cl100kBaseVocabUrl = @"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken";
+ private const string P50RanksUrl = @"https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken";
+ private const string R50RanksUrl = @"https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken";
+ private const string GPT2Url = @"https://pythia.blob.core.windows.net/public/encoding/gpt2.tiktoken";
+
+#if NET7_0_OR_GREATER
+ [GeneratedRegex(Cl100kBaseRegexPattern)]
+ private static partial Regex Cl100kBaseRegex();
+
+ [GeneratedRegex(P50kBaseRegexPattern)]
+ private static partial Regex P50kBaseRegex();
+#else
+ private static Regex? _cl100kBaseRegex;
+ private static Regex Cl100kBaseRegex() => _cl100kBaseRegex ??= new Regex(Cl100kBaseRegexPattern, RegexOptions.Compiled);
+
+ private static Regex? _p50kBaseRegex;
+ private static Regex P50kBaseRegex() => _p50kBaseRegex ??= new Regex(P50kBaseRegexPattern, RegexOptions.Compiled);
+#endif
///
/// Create tokenizer based on encoder name and extra special tokens
@@ -401,24 +414,24 @@ private static async Task CreateByEncoderNameAsync(
case ModelEncoding.Cl100kBase:
var specialTokens = new Dictionary
{ { EndOfText, 100257}, { FimPrefix, 100258}, { FimMiddle, 100259}, { FimSuffix, 100260}, { EndOfPrompt, 100276} };
- return await CreateTikTokenTokenizerAsync(Cl100kBaseRegexPattern, Cl100kBaseVocabUrl, specialTokens, extraSpecialTokens, normalizer);
+ return await CreateTikTokenTokenizerAsync(Cl100kBaseRegex(), Cl100kBaseVocabUrl, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false);
case ModelEncoding.P50kBase:
specialTokens = new Dictionary { { EndOfText, 50256 } };
- return await CreateTikTokenTokenizerAsync(P50kBaseRegexPattern, P50RegexUrl, specialTokens, extraSpecialTokens, normalizer);
+ return await CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false);
case ModelEncoding.P50kEdit:
specialTokens = new Dictionary
{ { EndOfText, 50256 }, { FimPrefix, 50281 }, { FimMiddle, 50282 }, { FimSuffix, 50283 } };
- return await CreateTikTokenTokenizerAsync(P50kBaseRegexPattern, P50RegexUrl, specialTokens, extraSpecialTokens, normalizer);
+ return await CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false);
case ModelEncoding.R50kBase:
specialTokens = new Dictionary { { EndOfText, 50256 } };
- return await CreateTikTokenTokenizerAsync(P50kBaseRegexPattern, R50RegexUrl, specialTokens, extraSpecialTokens, normalizer);
+ return await CreateTikTokenTokenizerAsync(P50kBaseRegex(), R50RanksUrl, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false);
case ModelEncoding.GPT2:
specialTokens = new Dictionary { { EndOfText, 50256 }, };
- return await CreateTikTokenTokenizerAsync(P50kBaseRegexPattern, GPT2Url, specialTokens, extraSpecialTokens, normalizer);
+ return await CreateTikTokenTokenizerAsync(P50kBaseRegex(), GPT2Url, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false);
default:
Debug.Assert(false, $"Unexpected encoder [{modelEncoding}]");
@@ -426,39 +439,43 @@ private static async Task CreateByEncoderNameAsync(
}
}
- private static readonly Dictionary, Dictionary, IReadOnlyDictionary)> _tiktokenCache = new(StringComparer.OrdinalIgnoreCase);
+ private static readonly ConcurrentDictionary, int>, Dictionary, IReadOnlyDictionary)> _tiktokenCache = new(StringComparer.OrdinalIgnoreCase);
///
/// Create tokenizer based on regex pattern, BPE rank file and special tokens
///
- /// Regex pattern to break a long string
+ /// Regex to break a long string
/// BPE rank file
- /// Special tokens mapping
+ /// Special tokens mapping. This may be mutated by the method.
/// Extra special tokens other than the built-in ones for the encoder
/// To normalize the text before tokenization
/// The tokenizer
private static async Task CreateTikTokenTokenizerAsync(
- string regexPatternStr,
- string mergeableRanksFileUrl,
- Dictionary specialTokens,
- IReadOnlyDictionary? extraSpecialTokens,
- Normalizer? normalizer)
+ Regex regex,
+ string mergeableRanksFileUrl,
+ Dictionary specialTokens,
+ IReadOnlyDictionary? extraSpecialTokens,
+ Normalizer? normalizer)
{
if (extraSpecialTokens is not null)
{
- specialTokens = specialTokens.Concat(extraSpecialTokens).ToDictionary(pair => pair.Key, pair => pair.Value);
+ foreach (var extraSpecialToken in extraSpecialTokens)
+ {
+ specialTokens.Add(extraSpecialToken.Key, extraSpecialToken.Value);
+ }
}
- if (!_tiktokenCache.TryGetValue(mergeableRanksFileUrl, out (Dictionary encoder, Dictionary vocab, IReadOnlyDictionary decoder) cache))
+ if (!_tiktokenCache.TryGetValue(mergeableRanksFileUrl, out (Dictionary, int> encoder, Dictionary vocab, IReadOnlyDictionary decoder) cache))
{
- using (Stream stream = await _httpClient.GetStreamAsync(mergeableRanksFileUrl))
+ using (Stream stream = await _httpClient.GetStreamAsync(mergeableRanksFileUrl).ConfigureAwait(false))
{
- cache = Tiktoken.LoadTikTokenBpe(stream);
+ cache = await Tiktoken.LoadTikTokenBpeAsync(stream, useAsync: true).ConfigureAwait(false);
}
- _tiktokenCache.Add(mergeableRanksFileUrl, cache);
+
+ _tiktokenCache.TryAdd(mergeableRanksFileUrl, cache);
}
- return new Tokenizer(new Tiktoken(cache.encoder, cache.decoder, cache.vocab, specialTokens), new TikTokenPreTokenizer(new Regex(regexPatternStr, RegexOptions.Compiled), specialTokens), normalizer);
+ return new Tokenizer(new Tiktoken(cache.encoder, cache.decoder, cache.vocab, specialTokens), new TikTokenPreTokenizer(regex, specialTokens), normalizer);
}
}
}
diff --git a/src/Microsoft.ML.Tokenizers/TokenizerResult.cs b/src/Microsoft.ML.Tokenizers/TokenizerResult.cs
index 192c215eec..6b8e434878 100644
--- a/src/Microsoft.ML.Tokenizers/TokenizerResult.cs
+++ b/src/Microsoft.ML.Tokenizers/TokenizerResult.cs
@@ -63,8 +63,6 @@ internal void AddTokens(IReadOnlyList addedTokens)
}
}
- private static readonly IReadOnlyList _emptyIds = new List();
-
///
/// Gets list of the tokens Ids.
/// The Ids are the main input to a Language Model. They are the token indices, the numerical representations that a LM understands.
@@ -80,7 +78,7 @@ public IReadOnlyList Ids
if (_tokens is null)
{
- return _emptyIds;
+ return Array.Empty();
}
_ids = new List(_tokens.Count);
@@ -94,8 +92,6 @@ public IReadOnlyList Ids
}
}
- private static readonly IReadOnlyList _emptyTokens = new List();
-
///
/// Gets the generated tokens. They are the string representation of the Ids.
///
@@ -110,7 +106,7 @@ public IReadOnlyList Tokens
if (_tokens is null)
{
- return _emptyTokens;
+ return Array.Empty();
}
_tokensWords = new List(_tokens.Count);
@@ -124,8 +120,6 @@ public IReadOnlyList Tokens
}
}
- private static readonly IReadOnlyList<(int, int)> _emptyOffsets = new List<(int, int)>();
-
///
/// Gets The list of offsets. These offsets let’s you slice the input string, and thus retrieve
/// the original part that led to producing the corresponding token.
@@ -141,7 +135,7 @@ public IReadOnlyList Tokens
if (_tokens is null)
{
- return _emptyOffsets;
+ return Array.Empty<(int, int)>();
}
_offsets = new List<(int Index, int End)>(_tokens.Count);
diff --git a/src/Microsoft.ML.Tokenizers/Utils/ByteArrayComparer.cs b/src/Microsoft.ML.Tokenizers/Utils/ByteArrayComparer.cs
index 9ccca49b89..a3f418317d 100644
--- a/src/Microsoft.ML.Tokenizers/Utils/ByteArrayComparer.cs
+++ b/src/Microsoft.ML.Tokenizers/Utils/ByteArrayComparer.cs
@@ -1,4 +1,4 @@
-// Licensed to the .NET Foundation under one or more agreements.
+// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
@@ -8,27 +8,17 @@
namespace Microsoft.ML.Tokenizers
{
- internal class ByteArrayComparer : IEqualityComparer
+ internal sealed class ReadOnlyMemoryByteComparer : IEqualityComparer>
{
- public bool Equals(byte[]? x, byte[]? y)
- {
- if (x is null || y is null)
- {
- return x == y;
- }
+ public static ReadOnlyMemoryByteComparer Instance { get; } = new();
- return x.SequenceEqual(y);
- }
+ public bool Equals(ReadOnlyMemory x, ReadOnlyMemory y) =>
+ x.Span.SequenceEqual(y.Span);
- public int GetHashCode(byte[] bytes)
+ public int GetHashCode(ReadOnlyMemory x)
{
- if (bytes == null)
- {
- throw new ArgumentNullException(nameof(bytes));
- }
-
int hash = 17;
- foreach (byte b in bytes)
+ foreach (byte b in x.Span)
{
hash = hash * 31 + b;
}
@@ -36,4 +26,4 @@ public int GetHashCode(byte[] bytes)
return hash;
}
}
-}
\ No newline at end of file
+}
diff --git a/src/Microsoft.ML.Tokenizers/Utils/BytePairEncoder.cs b/src/Microsoft.ML.Tokenizers/Utils/BytePairEncoder.cs
index 07a4c7db3b..523db677ee 100644
--- a/src/Microsoft.ML.Tokenizers/Utils/BytePairEncoder.cs
+++ b/src/Microsoft.ML.Tokenizers/Utils/BytePairEncoder.cs
@@ -1,10 +1,9 @@
-// Licensed to the .NET Foundation under one or more agreements.
+// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
-using System.Runtime.CompilerServices;
namespace Microsoft.ML.Tokenizers
{
@@ -13,11 +12,11 @@ namespace Microsoft.ML.Tokenizers
///
internal static class BytePairEncoder
{
- public static int[] BytePairEncode(byte[] mergingBytes, IReadOnlyDictionary ranks)
+ public static int[] BytePairEncode(ReadOnlyMemory mergingBytes, Dictionary, int> ranks)
{
if (mergingBytes.Length == 1)
{
- return new int[] { ranks[mergingBytes] };
+ return [ranks[mergingBytes]];
}
var byteIndicesAndRanks = new List<(int Index, int Rank)>();
@@ -29,7 +28,7 @@ int GetRank(int startIndex, int skip = 0)
{
if (startIndex + skip + 2 < byteIndicesAndRanks.Count)
{
- var slice = mergingBytes.Slice(byteIndicesAndRanks[startIndex].Index, byteIndicesAndRanks[startIndex + skip + 2].Index);
+ var slice = mergingBytes.SliceStartEnd(byteIndicesAndRanks[startIndex].Index, byteIndicesAndRanks[startIndex + skip + 2].Index);
if (ranks.TryGetValue(slice, out var rank))
{
return rank;
@@ -74,17 +73,11 @@ int GetRank(int startIndex, int skip = 0)
var outList = new int[byteIndicesAndRanks.Count - 1];
for (int i = 0; i < byteIndicesAndRanks.Count - 1; i++)
{
- outList[i] = ranks[mergingBytes.Slice(byteIndicesAndRanks[i].Index, byteIndicesAndRanks[i + 1].Index)];
+ outList[i] = ranks[mergingBytes.SliceStartEnd(byteIndicesAndRanks[i].Index, byteIndicesAndRanks[i + 1].Index)];
}
return outList;
}
- private static T[] Slice(this T[] array, int start, int end)
- {
- var length = end - start;
- var result = new T[length];
- Array.Copy(array, start, result, 0, length);
- return result;
- }
+ private static ReadOnlyMemory SliceStartEnd(this ReadOnlyMemory memory, int start, int end) => memory.Slice(start, end - start);
}
-}
\ No newline at end of file
+}
diff --git a/src/Microsoft.ML.Tokenizers/Utils/IListExtensions.cs b/src/Microsoft.ML.Tokenizers/Utils/IListExtensions.cs
index 061f4cc876..feb913158f 100644
--- a/src/Microsoft.ML.Tokenizers/Utils/IListExtensions.cs
+++ b/src/Microsoft.ML.Tokenizers/Utils/IListExtensions.cs
@@ -1,4 +1,4 @@
-// Licensed to the .NET Foundation under one or more agreements.
+// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
@@ -14,6 +14,14 @@ public static void AddRange(this IList list, IEnumerable items)
{
concreteList.AddRange(items);
}
+ else if (items is IList listToAdd)
+ {
+ int count = listToAdd.Count;
+ for (int i = 0; i < count; i++)
+ {
+ list.Add(listToAdd[i]);
+ }
+ }
else
{
foreach (var item in items)
@@ -23,4 +31,4 @@ public static void AddRange(this IList list, IEnumerable items)
}
}
}
-}
\ No newline at end of file
+}