From 10748948aa0406668edc379b33a05a171e3fe6d8 Mon Sep 17 00:00:00 2001 From: Tarek Mahmoud Sayed Date: Tue, 27 Feb 2024 18:23:38 -0800 Subject: [PATCH 1/3] Add Span support in tokenizer's Model abstraction --- src/Microsoft.ML.Tokenizers/Model/BPE.cs | 230 ++++++++++-------- src/Microsoft.ML.Tokenizers/Model/Cache.cs | 97 ++------ .../Model/EnglishRoberta.cs | 89 ++++--- src/Microsoft.ML.Tokenizers/Model/Model.cs | 9 +- src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs | 142 ++++++----- src/Microsoft.ML.Tokenizers/Tokenizer.cs | 8 +- .../Utils/Helpers.netcoreapp.cs | 2 + .../Utils/Helpers.netstandard.cs | 11 + src/Microsoft.ML.Tokenizers/Utils/LruCache.cs | 132 +++++----- .../Utils/StringSpanOrdinalKey.cs | 131 ++++++++++ .../Microsoft.ML.Tokenizers.Tests/BpeTests.cs | 2 +- .../EnglishRobertaTests.cs | 2 +- 12 files changed, 492 insertions(+), 363 deletions(-) create mode 100644 src/Microsoft.ML.Tokenizers/Utils/StringSpanOrdinalKey.cs diff --git a/src/Microsoft.ML.Tokenizers/Model/BPE.cs b/src/Microsoft.ML.Tokenizers/Model/BPE.cs index d799d45a39..a98da728b0 100644 --- a/src/Microsoft.ML.Tokenizers/Model/BPE.cs +++ b/src/Microsoft.ML.Tokenizers/Model/BPE.cs @@ -3,8 +3,10 @@ // See the LICENSE file in the project root for more information. using System; +using System.Buffers; using System.Collections.Generic; using System.IO; +using System.Linq; using System.Runtime.CompilerServices; using System.Text.Json; using System.Text.Json.Serialization; @@ -34,20 +36,21 @@ private set { _unknownToken = value; - if (value is null) + if (VocabReverse.TryGetValue(0, out string? v)) { - if (VocabReverse.TryGetValue(0, out string? v)) + if (v == value) { - VocabReverse.Remove(0); - if (_vocab.TryGetValue(v, out int id)) - { - _vocab.Remove(v); - } + return; } + + VocabReverse.Remove(0); + _vocab.Remove(new StringSpanOrdinalKey(v)); } - else + + + if (value is not null) { - _vocab[value] = 0; + _vocab[new StringSpanOrdinalKey(value)] = 0; VocabReverse[0] = value; } } @@ -68,7 +71,6 @@ private set /// public bool FuseUnknownTokens { get; } - /// /// Construct a new Bpe model object to use for text encoding. /// @@ -111,23 +113,19 @@ private Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken, stri ContinuingSubwordPrefix = continuingSubwordPrefix; EndOfWordSuffix = endOfWordSuffix; - (Dictionary? vocab1, Vec<(string, string)> merges) = ReadModelData(vocabStream, mergesStream); - _vocab = vocab1 ?? new Dictionary(); - Cache = new Cache(); + (Dictionary? vocab1, Vec<(string, string)> merges) = ReadModelData(vocabStream, mergesStream); + _vocab = vocab1 ?? new Dictionary(); + Cache = new StringSpanOrdinalKeyCache(); VocabReverse = new(); - foreach (KeyValuePair kvp in Vocab) + foreach (KeyValuePair kvp in _vocab) { - VocabReverse.Add(kvp.Value, kvp.Key); + VocabReverse.Add(kvp.Value, kvp.Key.Data!); } - if (unknownToken is null && VocabReverse.TryGetValue(0, out string? unkToken)) - { - unknownToken = unkToken; - } - UnknownToken = unknownToken; + UnknownToken = unknownToken ?? (VocabReverse.TryGetValue(0, out string? unkToken) ? unkToken : null); int prefixLen = ContinuingSubwordPrefix is null ? 0 : ContinuingSubwordPrefix.Length; @@ -136,12 +134,12 @@ private Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken, stri { (string a, string b) mergeValues = merges[i]; - if (!_vocab.TryGetValue(mergeValues.a, out int aId)) + if (!_vocab.TryGetValueUnsafe(mergeValues.a, out int aId)) { throw new InvalidOperationException($"Trying to merge a token '{mergeValues.a}' which not exist in the vocabulary."); } - if (!_vocab.TryGetValue(mergeValues.b, out int bId)) + if (!_vocab.TryGetValueUnsafe(mergeValues.b, out int bId)) { throw new InvalidOperationException($"Trying to merge a token '{mergeValues.b}' which not exist in the vocabulary."); } @@ -152,7 +150,7 @@ private Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken, stri } string newToken = $"{mergeValues.a}{mergeValues.b.Substring(prefixLen)}"; - if (!_vocab.TryGetValue(newToken, out int newId)) + if (!_vocab.TryGetValueUnsafe(newToken, out int newId)) { throw new InvalidOperationException($"Trying to merge a token '{newToken}' which not exist in the vocabulary."); } @@ -197,7 +195,7 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f /// The text to split. /// Indicate if the token is a special token. /// The list of accumulated encoded Ids. - public override void EncodeToIds(string text, bool isSpecialToken, IList accumulatedIds) => EncodeToIdsWithCache(text, accumulatedIds); + public override void EncodeToIds(ReadOnlySpan text, bool isSpecialToken, IList accumulatedIds) => EncodeToIdsWithCache(text, accumulatedIds); /// /// Get the number of tokens that the input text will be encoded to. @@ -205,7 +203,7 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f /// The text to encode. /// Indicate if the token is special token. /// The number of tokens that the input text will be encoded to. - public override int CountTokens(string text, bool isSpecialToken) => EncodeToIdsWithCache(text, null); + public override int CountTokens(ReadOnlySpan text, bool isSpecialToken) => EncodeToIdsWithCache(text, null); /// /// Map the token to encoded Id. @@ -213,15 +211,7 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f /// The token to map to the Id. /// Indicate if want to consider the special tokens during the encoding. /// The mapped Id of the token. - public override int? MapTokenToId(string token, bool considerSpecialTokens = true) - { - if (_vocab.TryGetValue(token, out int value)) - { - return value; - } - - return null; - } + public override int? MapTokenToId(ReadOnlySpan token, bool considerSpecialTokens = true) => _vocab.TryGetValueUnsafe(token, out int value) ? value : null; /// /// Map the encoded Id to the token. @@ -242,24 +232,27 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f /// /// Gets the dictionary mapping tokens to Ids. /// - public IReadOnlyDictionary Vocab => _vocab; + public IReadOnlyDictionary Vocab => _vocabOriginal ?? (_vocabOriginal = _vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value)); /// Read the given files to extract the vocab and merges - internal static (Dictionary?, Vec<(string, string)>) ReadModelData(Stream vocab, Stream? merges) + internal static (Dictionary?, Vec<(string, string)>) ReadModelData(Stream vocab, Stream? merges) { - Dictionary? dic = JsonSerializer.Deserialize>(vocab) as Dictionary; + JsonSerializerOptions options = new() { Converters = { new StringSpanOrdinalKeyConverter() } }; + Dictionary? dic = JsonSerializer.Deserialize>(vocab, options) as Dictionary; return (dic, ConvertMergesToHashmap(merges)); } /// The vocabulary assigns a number to each token. - private readonly Dictionary _vocab; + private readonly Dictionary _vocab; + + private Dictionary? _vocabOriginal; /// Contains the mapping between Pairs and their (rank, newId). internal Dictionary, (int, int)> Merges { get; } /// Contains the cache for optimizing the encoding step. - internal Cache? Cache { get; } + internal StringSpanOrdinalKeyCache? Cache { get; } internal static readonly int DefaultCacheCapacity = 10_000; @@ -309,9 +302,6 @@ internal static (Dictionary?, Vec<(string, string)>) ReadModelData( return merges; } - /// Reset the cache. - internal void ClearCache() => Cache?.Clear(); - private readonly Dictionary _charToString = new Dictionary(); [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -327,80 +317,124 @@ internal string CharToString(char c) return s; } - internal Word MergeWord(string w) + internal Word MergeWord(ReadOnlySpan w) { Word word = Word.WithCapacity(w.Length); (int Id, int Len)? unk = null; int i = 0; - while (i < w.Length) - { - int length; - string s; - - if (Char.IsHighSurrogate(w[i]) && i < w.Length - 1 && Char.IsLowSurrogate(w[i + 1])) - { - length = 2; - s = w.Substring(i, length); - } - else - { - length = 1; - s = CharToString(w[i]); - } + char[]? buffer = null; - // Add the `continuing_subword_prefix` if relevant - if (i > 0 && ContinuingSubwordPrefix is not null) + try + { + while (i < w.Length) { - s = $"{ContinuingSubwordPrefix}{s}"; - } + int length; + ReadOnlySpan s; - // Add the `end_of_word_suffix` if relevant - if (i + length >= w.Length && EndOfWordSuffix is not null) - { - s = $"{s}{EndOfWordSuffix}"; - } + if (Char.IsHighSurrogate(w[i]) && i < w.Length - 1 && Char.IsLowSurrogate(w[i + 1])) + { + length = 2; + s = w.Slice(i, 2); + } + else + { + length = 1; + s = w.Slice(i, 1); + } - if (_vocab.TryGetValue(s, out int id)) - { - if (unk.HasValue) + // Add the `continuing_subword_prefix` if relevant + if (i > 0 && ContinuingSubwordPrefix is not null) { - word.Add(unk.Value.Id, unk.Value.Len); - unk = null; + if (buffer is null) + { + // 60 should be enough for most cases + buffer = ArrayPool.Shared.Rent(60); + } + + if (ContinuingSubwordPrefix.Length + s.Length <= buffer.Length) + { + ContinuingSubwordPrefix.AsSpan().CopyTo(buffer.AsSpan()); + s.CopyTo(buffer.AsSpan().Slice(ContinuingSubwordPrefix.Length)); + s = buffer.AsSpan().Slice(0, ContinuingSubwordPrefix.Length + s.Length); + } + else + { + string s1 = s.Length == 1 ? CharToString(s[0]) : s.ToString(); + s = $"{ContinuingSubwordPrefix}{s1}".AsSpan(); + } } - word.Add(id, length); - } - else if (UnknownToken is not null) - { - if (unk.HasValue) + + // Add the `end_of_word_suffix` if relevant + if (i + length >= w.Length && EndOfWordSuffix is not null) { - if (FuseUnknownTokens) + if (buffer is null) { - // Fuse unk - unk = (unk.Value.Id, unk.Value.Len + length); + // 60 should be enough for most cases + buffer = ArrayPool.Shared.Rent(60); + } + + if (s.Length + EndOfWordSuffix.Length <= buffer.Length) + { + s.CopyTo(buffer.AsSpan()); + EndOfWordSuffix.AsSpan().CopyTo(buffer.AsSpan().Slice(s.Length)); + s = buffer.AsSpan().Slice(0, s.Length + EndOfWordSuffix.Length); } else { - // Do not fuse unk, add the previous one + string s1 = s.Length == 1 ? CharToString(s[0]) : s.ToString(); + s = $"{s1}{EndOfWordSuffix}".AsSpan(); + } + } + + if (_vocab.TryGetValueUnsafe(s, out int id)) + { + if (unk.HasValue) + { word.Add(unk.Value.Id, unk.Value.Len); - if (!_vocab.TryGetValue(UnknownToken, out int value)) - { - throw new InvalidOperationException($"Unknown Token Out Of Vocabulary."); - } - unk = (value, length); + unk = null; } + word.Add(id, length); } - else + else if (UnknownToken is not null) { - if (!_vocab.TryGetValue(UnknownToken, out int value)) + if (unk.HasValue) { - throw new InvalidOperationException($"Unknown Token Out Of Vocabulary."); + if (FuseUnknownTokens) + { + // Fuse unk + unk = (unk.Value.Id, unk.Value.Len + length); + } + else + { + // Do not fuse unk, add the previous one + word.Add(unk.Value.Id, unk.Value.Len); + if (!_vocab.TryGetValueUnsafe(UnknownToken, out int value)) + { + throw new InvalidOperationException($"Unknown Token Out Of Vocabulary."); + } + unk = (value, length); + } + } + else + { + if (!_vocab.TryGetValueUnsafe(UnknownToken, out int value)) + { + throw new InvalidOperationException($"Unknown Token Out Of Vocabulary."); + } + unk = (value, length); } - unk = (value, length); } - } - i += length; + i += length; + } + } + finally + { + if (buffer is not null) + { + ArrayPool.Shared.Return(buffer); + } } if (unk.HasValue) @@ -419,17 +453,17 @@ internal List EncodeWithCache(string text) Word word; if (Cache is not null) { - if (Cache.TryGet(text, out word)) + if (Cache.TryGetValue(text, out word)) { return WordToTokens(ref word); } - word = MergeWord(text); + word = MergeWord(text.AsSpan()); Cache.Set(text, word); } else { - word = MergeWord(text); + word = MergeWord(text.AsSpan()); } return WordToTokens(ref word); @@ -445,19 +479,19 @@ internal int WordToIds(ref Word word, IList? accumulatedIds) return word.SymbolsCount; } - internal int EncodeToIdsWithCache(string text, IList? accumulatedIds) + internal int EncodeToIdsWithCache(ReadOnlySpan text, IList? accumulatedIds) { Word word; if (Cache is not null) { - if (Cache.TryGet(text, out Word hit)) + if (Cache.TryGetValue(text, out Word hit)) { return WordToIds(ref hit, accumulatedIds); } word = MergeWord(text); - Cache.Set(text, word); + Cache.Set(text.ToString(), word); } else { diff --git a/src/Microsoft.ML.Tokenizers/Model/Cache.cs b/src/Microsoft.ML.Tokenizers/Model/Cache.cs index b10d211ea6..065676621e 100644 --- a/src/Microsoft.ML.Tokenizers/Model/Cache.cs +++ b/src/Microsoft.ML.Tokenizers/Model/Cache.cs @@ -4,112 +4,53 @@ using System; using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading; namespace Microsoft.ML.Tokenizers { internal sealed class Cache where TKey : notnull where TValue : notnull { + private readonly int _capacity; + private readonly Dictionary _map; + private object SyncObj => _map; + internal Cache() : this(Bpe.DefaultCacheCapacity) { } internal Cache(int capacity) { - Capacity = capacity; - Map = new Dictionary(Capacity); + _capacity = capacity; + _map = new Dictionary(capacity); } - private readonly object _lock = new(); - - internal Dictionary Map { get; set; } - - internal int Capacity { get; set; } - - internal void Fresh() => Map = new Dictionary(Capacity); - - internal void Clear() + internal bool TryGetValue(TKey key, out TValue value) { - lock (_lock) + lock (SyncObj) { - Map.Clear(); + return _map.TryGetValue(key, out value!); } } - internal List GetValues(IEnumerable keys) - { - List values = new(); - lock (_lock) - { - foreach (TKey key in keys) - { - if (Map.TryGetValue(key, out TValue? value)) - { - values.Add(value); - } - } - } - - return values; - } - - internal bool TryGet(TKey key, out TValue value) - { - lock (_lock) - { - return Map.TryGetValue(key, out value!); - } - } - - internal void SetValues(IEnumerable<(TKey, TValue)> entries) - { - lock (_lock) - { - foreach ((TKey, TValue) entry in entries) - { - if (Capacity <= Map.Count) - { - break; - } - Map[entry.Item1] = entry.Item2; - } - } - } - - internal void Set(TKey k, TValue v) + internal TValue GetOrAdd(TKey key, TValue value) { - lock (_lock) + lock (SyncObj) { - if (Capacity > Map.Count) + if (_map.TryGetValue(key, out TValue? v)) { - Map[k] = v; + return v!; } - } - } - internal KeyValuePair[] ToArray() - { - lock (_lock) - { - return Map.ToArray(); + _map[key] = value; + return value; } } - internal TValue GetOrAdd(TKey key, TValue value) + internal void Set(TKey key, TValue value) { - lock (_lock) + lock (SyncObj) { - if (Map.TryGetValue(key, out TValue? v)) + if (_map.Count < _capacity) { - return v; + _map[key] = value; } - - if (Capacity > Map.Count) - { - Map[key] = value; - } - - return value; } } } diff --git a/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs b/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs index ea9fa884a8..ae978eb0fe 100644 --- a/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs +++ b/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs @@ -18,13 +18,14 @@ namespace Microsoft.ML.Tokenizers public sealed class EnglishRoberta : Model { private readonly HighestOccurrenceMapping _vocabIdToHighestOccurrence; - private readonly IReadOnlyDictionary _vocab; - private readonly SortedDictionary _vocabReverse; + private readonly IReadOnlyDictionary _vocab; + private Dictionary? _vocabOriginal; + private readonly SortedDictionary _vocabReverse; private readonly Cache<(string, string), int> _mergeRanks; private readonly IReadOnlyDictionary _byteToUnicode; private readonly IReadOnlyDictionary _unicodeToByte; private readonly string[] _charToString; - private readonly Cache> _cache; + private readonly StringSpanOrdinalKeyCache> _cache; /// /// Indicate if want to filter the unsupported characters during the decoding. @@ -77,7 +78,7 @@ public EnglishRoberta(string vocabularyPath, string mergePath, string highestOcc } _unicodeToByte = _byteToUnicode.Reverse(); - _cache = new Cache>(); + _cache = new StringSpanOrdinalKeyCache>(); } /// @@ -118,13 +119,13 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes } _unicodeToByte = _byteToUnicode.Reverse(); - _cache = new Cache>(); + _cache = new StringSpanOrdinalKeyCache>(); } /// /// Gets the dictionary mapping tokens to Ids. /// - public IReadOnlyDictionary Vocab => _vocab; + public IReadOnlyDictionary Vocab => _vocabOriginal ??= (_vocabOriginal = _vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value)); // // Public Model interfaces implementation @@ -145,14 +146,15 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes if (_vocabReverse.TryGetValue(id, out var value)) { + string v = value.Data!; if (FilterUnsupportedChars) { - char[] buffer = ArrayPool.Shared.Rent(value.Length); + char[] buffer = ArrayPool.Shared.Rent(v.Length); int i = 0; - for (int j = 0; j < value.Length; j++) + for (int j = 0; j < v.Length; j++) { - if (_unicodeToByte.TryGetValue(value[j], out var c)) + if (_unicodeToByte.TryGetValue(v[j], out var c)) { buffer[i++] = c; } @@ -164,7 +166,7 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes } else { - return value; + return v; } } @@ -205,7 +207,7 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f return Array.Empty(); } - if (_cache.TryGet(text, out List? hit)) + if (_cache.TryGetValue(text, out List? hit)) { ArrayPool.Shared.Return(token); ArrayPool.Shared.Return(indexMapping); @@ -225,7 +227,7 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f /// The text to split. /// Indicate if the token is a special token. /// The list of accumulated encoded Ids. - public override void EncodeToIds(string text, bool isSpecialToken, IList accumulatedIds) => EncodeToIds(text, accumulatedIds); + public override void EncodeToIds(ReadOnlySpan text, bool isSpecialToken, IList accumulatedIds) => EncodeToIds(text, accumulatedIds); /// /// Get the number of tokens that the input text will be encoded to. @@ -233,16 +235,16 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f /// The text to encode. /// Indicate if the token is special token. /// The number of tokens that the input text will be encoded to. - public override int CountTokens(string text, bool isSpecialToken) => EncodeToIds(text, null); + public override int CountTokens(ReadOnlySpan text, bool isSpecialToken) => EncodeToIds(text, null); - private int EncodeToIds(string text, IList? accumulatedIds) + private int EncodeToIds(ReadOnlySpan text, IList? accumulatedIds) { - if (string.IsNullOrEmpty(text)) + if (text.IsEmpty) { return 0; } - if (_cache.TryGet(text, out List? hit)) + if (_cache.TryGetValue(text, out List? hit)) { if (accumulatedIds is not null) { @@ -255,17 +257,41 @@ private int EncodeToIds(string text, IList? accumulatedIds) return hit.Count; } - // If the cache doesn't have the text, then encode it and add it to the cache - IReadOnlyList tokens = Encode(text); + char[] token = ArrayPool.Shared.Rent(text.Length); + int[] indexMapping = ArrayPool.Shared.Rent(text.Length); + + int newTokenIndex = 0; + for (int i = 0; i < text.Length; i++) + { + if (_byteToUnicode.TryGetValue(text[i], out var value)) + { + token[newTokenIndex] = value; + indexMapping[newTokenIndex] = i; + newTokenIndex++; + } + } + + if (newTokenIndex == 0) + { + ArrayPool.Shared.Return(token); + ArrayPool.Shared.Return(indexMapping); + return 0; + } + + List result = EncodeToTokens(token.AsSpan().Slice(0, newTokenIndex), indexMapping); + _cache.Set(text.ToString(), result); + ArrayPool.Shared.Return(token); + ArrayPool.Shared.Return(indexMapping); + if (accumulatedIds is not null) { - foreach (var t in tokens) + foreach (var t in result) { accumulatedIds.Add(t.Id); } } - return tokens.Count; + return result.Count; } /// @@ -274,7 +300,7 @@ private int EncodeToIds(string text, IList? accumulatedIds) /// The token to map to the Id. /// Indicate if want to consider the special tokens during the encoding. /// The mapped Id of the token. - public override int? MapTokenToId(string token, bool considerSpecialTokens = true) => _vocab.TryGetValue(token, out var value) ? value : null; + public override int? MapTokenToId(ReadOnlySpan token, bool considerSpecialTokens = true) => _vocab.TryGetValueUnsafe(token, out int value) ? value : null; /// /// Convert a list of tokens Ids to highest occurrence rankings. @@ -397,12 +423,13 @@ private IReadOnlyList ModifyTokenListOffsets(IReadOnlyList tokens, private static HighestOccurrenceMapping GetHighestOccurrenceMapping(Stream highestOccurrenceMappingStream) => HighestOccurrenceMapping.Load(highestOccurrenceMappingStream); - private Dictionary GetVocabulary(Stream vocabularyStream) + private Dictionary GetVocabulary(Stream vocabularyStream) { - Dictionary? vocab; + Dictionary? vocab; try { - vocab = JsonSerializer.Deserialize>(vocabularyStream) as Dictionary; + JsonSerializerOptions options = new() { Converters = { new StringSpanOrdinalKeyConverter() } }; + vocab = JsonSerializer.Deserialize>(vocabularyStream, options) as Dictionary; } catch (Exception e) { @@ -416,22 +443,22 @@ private Dictionary GetVocabulary(Stream vocabularyStream) if (_vocabIdToHighestOccurrence.BosWord is not null) { - vocab[_vocabIdToHighestOccurrence.BosWord] = -_vocabIdToHighestOccurrence.BosIndex; + vocab[new StringSpanOrdinalKey(_vocabIdToHighestOccurrence.BosWord)] = -_vocabIdToHighestOccurrence.BosIndex; } if (_vocabIdToHighestOccurrence.EosWord is not null) { - vocab[_vocabIdToHighestOccurrence.EosWord] = -_vocabIdToHighestOccurrence.EosIndex; + vocab[new StringSpanOrdinalKey(_vocabIdToHighestOccurrence.EosWord)] = -_vocabIdToHighestOccurrence.EosIndex; } if (_vocabIdToHighestOccurrence.UnkWord is not null) { - vocab[_vocabIdToHighestOccurrence.UnkWord] = -_vocabIdToHighestOccurrence.UnkIndex; + vocab[new StringSpanOrdinalKey(_vocabIdToHighestOccurrence.UnkWord)] = -_vocabIdToHighestOccurrence.UnkIndex; } if (_vocabIdToHighestOccurrence.PadWord is not null) { - vocab[_vocabIdToHighestOccurrence.PadWord] = -_vocabIdToHighestOccurrence.PadIndex; + vocab[new StringSpanOrdinalKey(_vocabIdToHighestOccurrence.PadWord)] = -_vocabIdToHighestOccurrence.PadIndex; } return vocab; @@ -510,7 +537,7 @@ private List EncodeToTokens(Span token, Span indexMapping) if (token.Length == 1) { string tokenValue = _charToString[token[0]]; - return new List { new Token(_vocab[tokenValue], tokenValue, (indexMapping[0], 1)) }; + return new List { new Token(_vocab[new StringSpanOrdinalKey(tokenValue)], tokenValue, (indexMapping[0], 1)) }; } List word = new(token.Length); @@ -539,7 +566,7 @@ private List EncodeToTokens(Span token, Span indexMapping) // get the most frequent bi-gram pair var (first, second) = pairs.ArgMin(pair => _mergeRanks.GetOrAdd(pair, int.MaxValue)); - if (!_mergeRanks.TryGet((first, second), out int _)) + if (!_mergeRanks.TryGetValue((first, second), out int _)) { break; } @@ -599,7 +626,7 @@ private List EncodeToTokens(Span token, Span indexMapping) foreach (string w in word) { - tokens.Add(new Token(_vocab[w], w, (indexMapping[index], w.Length))); + tokens.Add(new Token(_vocab[new StringSpanOrdinalKey(w)], w, (indexMapping[index], w.Length))); index += w.Length; } diff --git a/src/Microsoft.ML.Tokenizers/Model/Model.cs b/src/Microsoft.ML.Tokenizers/Model/Model.cs index 16eecc4aa4..815bd04a0b 100644 --- a/src/Microsoft.ML.Tokenizers/Model/Model.cs +++ b/src/Microsoft.ML.Tokenizers/Model/Model.cs @@ -31,14 +31,15 @@ public abstract class Model /// This method does the default implementation that uses the Encode method to get the token's Ids. /// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation. /// - public virtual void EncodeToIds(string text, bool isSpecialToken, IList accumulatedIds) + public virtual void EncodeToIds(ReadOnlySpan text, bool isSpecialToken, IList accumulatedIds) { if (accumulatedIds is null) { throw new ArgumentNullException(nameof(accumulatedIds)); } - var tokens = Encode(text); + // Default implementation is not optimized for memory allocation. It is recommended to override this method for the sake of the performance. + var tokens = Encode(text.ToString()); foreach (var token in tokens) { accumulatedIds.Add(token.Id); @@ -55,7 +56,7 @@ public virtual void EncodeToIds(string text, bool isSpecialToken, IList acc /// This method does the default implementation that uses the EncodeToIds method to get the number of token's Ids. /// Tokenizer's models which care about performance may choose to override this method to provide a more efficient implementation. /// - public virtual int CountTokens(string text, bool isSpecialToken) + public virtual int CountTokens(ReadOnlySpan text, bool isSpecialToken) { var ids = new List(); EncodeToIds(text, isSpecialToken, ids); @@ -68,7 +69,7 @@ public virtual int CountTokens(string text, bool isSpecialToken) /// The token to map to Id /// Indicate if want to consider the special tokens during the encoding. /// The mapped Id of the token. - public abstract int? MapTokenToId(string token, bool considerSpecialTokens = true); + public abstract int? MapTokenToId(ReadOnlySpan token, bool considerSpecialTokens = true); /// /// Map the encoded Id to the token. diff --git a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs index 0696efd9b0..afa405f9e7 100644 --- a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs +++ b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs @@ -19,12 +19,14 @@ namespace Microsoft.ML.Tokenizers /// public sealed class Tiktoken : Model { - private readonly Dictionary, int> _encoder = null!; - private readonly Dictionary> _decoder = null!; - private readonly LruCache? _cache; - private readonly IReadOnlyDictionary? _specialTokensEncoder; + private readonly Dictionary, int> _encoder; + private readonly Dictionary> _decoder; + private readonly LruCache _cache; + private readonly Dictionary? _specialTokensEncoder; + private Dictionary? _specialTokensEncoderOriginal; private readonly Dictionary? _specialTokensDecoder; - private readonly Dictionary _vocab = null!; + private readonly Dictionary _vocab; + private IReadOnlyDictionary? _vocabOriginal; /// /// Create a new Tiktoken tokenizer's model object. @@ -34,7 +36,7 @@ public sealed class Tiktoken : Model /// The size of the cache to use. /// Thrown when is null or empty. /// Thrown when failed to load the BPE vocab file. - public Tiktoken(string vocabFilePath, IReadOnlyDictionary? specialTokens = null, int cacheSize = LruCache.DefaultCacheSize) : + public Tiktoken(string vocabFilePath, IReadOnlyDictionary? specialTokens = null, int cacheSize = LruCache.DefaultCacheSize) : this(string.IsNullOrEmpty(vocabFilePath) ? throw new ArgumentNullException(nameof(vocabFilePath)) : File.OpenRead(vocabFilePath), specialTokens, cacheSize, disposeStream: true) { } @@ -47,7 +49,7 @@ public Tiktoken(string vocabFilePath, IReadOnlyDictionary? specialT /// The size of the cache to use. /// Thrown when is null or empty. /// Thrown when failed to load the BPE vocab file. - public Tiktoken(Stream vocabStream, IReadOnlyDictionary? specialTokens = null, int cacheSize = LruCache.DefaultCacheSize) : + public Tiktoken(Stream vocabStream, IReadOnlyDictionary? specialTokens = null, int cacheSize = LruCache.DefaultCacheSize) : this(vocabStream ?? throw new ArgumentNullException(nameof(vocabStream)), specialTokens, cacheSize, disposeStream: false) { } @@ -63,9 +65,9 @@ public Tiktoken(Stream vocabStream, IReadOnlyDictionary? specialTok internal Tiktoken( Dictionary, int> encoder, Dictionary> decoder, - Dictionary vocab, + Dictionary vocab, IReadOnlyDictionary? specialTokens, - int cacheSize = LruCache.DefaultCacheSize) : this(cacheSize) + int cacheSize = LruCache.DefaultCacheSize) { _encoder = encoder ?? throw new ArgumentNullException(nameof(encoder)); _decoder = decoder ?? throw new ArgumentNullException(nameof(decoder)); @@ -73,24 +75,21 @@ internal Tiktoken( Debug.Assert(encoder.Count == decoder.Count); - _specialTokensEncoder = specialTokens; - if (_specialTokensEncoder is not null) - { - _specialTokensDecoder = _specialTokensEncoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key); - } + _encoder = encoder!; + _decoder = decoder!; + _vocab = vocab!; + _cache = new LruCache(cacheSize); + + (_specialTokensEncoder, _specialTokensDecoder) = CreateEncoderDecoder(specialTokens); } - private Tiktoken(Stream vocabStream, IReadOnlyDictionary? specialTokens, int cacheSize, bool disposeStream) : this(cacheSize) + private Tiktoken(Stream vocabStream, IReadOnlyDictionary? specialTokens, int cacheSize, bool disposeStream) { try { + _cache = new LruCache(cacheSize); (_encoder, _vocab, _decoder) = LoadTikTokenBpeAsync(vocabStream, useAsync: false).GetAwaiter().GetResult(); - - _specialTokensEncoder = specialTokens; - if (_specialTokensEncoder is not null) - { - _specialTokensDecoder = _specialTokensEncoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key); - } + (_specialTokensEncoder, _specialTokensDecoder) = CreateEncoderDecoder(specialTokens); } finally { @@ -101,17 +100,15 @@ private Tiktoken(Stream vocabStream, IReadOnlyDictionary? specialTo } } - private Tiktoken(int cacheSize) + private static (Dictionary?, Dictionary?) CreateEncoderDecoder(IReadOnlyDictionary? specialTokens) { - if (cacheSize < 0) + if (specialTokens is not null) { - throw new ArgumentOutOfRangeException(nameof(cacheSize)); + var encoder = specialTokens.ToDictionary(e => new StringSpanOrdinalKey(e.Key), e => e.Value); + return (encoder, encoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key.Data!)); } - if (cacheSize > 0) - { - _cache = new LruCache(cacheSize); - } + return (null, null); } /// @@ -125,7 +122,7 @@ private Tiktoken(int cacheSize) public static async Task CreateAsync( Stream vocabStream, IReadOnlyDictionary? specialTokens = null, - int cacheSize = LruCache.DefaultCacheSize, + int cacheSize = LruCache.DefaultCacheSize, CancellationToken cancellationToken = default) { if (vocabStream is null) @@ -133,7 +130,7 @@ public static async Task CreateAsync( throw new ArgumentNullException(nameof(vocabStream)); } - (Dictionary, int> encoder, Dictionary vocab, Dictionary> decoder) = + (Dictionary, int> encoder, Dictionary vocab, Dictionary> decoder) = await LoadTikTokenBpeAsync(vocabStream, useAsync: true, cancellationToken).ConfigureAwait(false); return new Tiktoken(encoder, decoder, vocab, specialTokens, cacheSize); @@ -150,7 +147,7 @@ public static async Task CreateAsync( public static async Task CreateAsync( string vocabFilePath, IReadOnlyDictionary? specialTokensEncoder = null, - int cacheSize = LruCache.DefaultCacheSize, + int cacheSize = LruCache.DefaultCacheSize, CancellationToken cancellationToken = default) { if (vocabFilePath is null) @@ -170,11 +167,11 @@ public static async Task CreateAsync( /// used to request cancellation of the operation. /// Map of byte[] to integer token id /// - internal static async ValueTask<(Dictionary, int>, Dictionary, Dictionary>)> LoadTikTokenBpeAsync( + internal static async ValueTask<(Dictionary, int>, Dictionary, Dictionary>)> LoadTikTokenBpeAsync( Stream vocabStream, bool useAsync, CancellationToken cancellationToken = default) { var encoder = new Dictionary, int>(ReadOnlyMemoryByteComparer.Instance); - var vocab = new Dictionary(); + var vocab = new Dictionary(); var decoder = new Dictionary>(); try @@ -212,7 +209,7 @@ await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) : if (decodedToken.IndexOf('\uFFFD') < 0) { - vocab[decodedToken] = rank; + vocab[new StringSpanOrdinalKey(decodedToken)] = rank; } } else @@ -230,12 +227,6 @@ await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) : return (encoder, vocab, decoder); } - /// - /// Gets the dictionary mapping special tokens to Ids. - /// - /// The dictionary mapping special tokens to Ids. - public IReadOnlyDictionary? SpecialTokensEncoder => _specialTokensEncoder; - /// /// Encode a split text string to a list of tokens. /// @@ -253,12 +244,7 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken) if (isSpecialToken) { - if (_specialTokensEncoder is null) - { - throw new InvalidOperationException($"The tokenizer doesn't have special tokens"); - } - - if (_specialTokensEncoder.TryGetValue(text, out int id)) + if (_specialTokensEncoder?.TryGetValueUnsafe(text, out int id) is true) { return new List { new(id, text, (0, text.Length)) }; } @@ -266,7 +252,7 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken) throw new InvalidOperationException($"The special token {text} doesn't exist in the tokenizer"); } - if (_cache?.Lookup(text, out int[] ids) is true) + if (_cache.TryGetValue(text, out int[]? ids)) { tokens = new Token[ids.Length]; tokens[0] = new Token(ids[0], text, (0, text.Length)); @@ -280,7 +266,7 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken) } // cache miss - if (_vocab.TryGetValue(text, out int mappedId)) + if (_vocab.TryGetValueUnsafe(text, out int mappedId)) { return new Token[1] { new(mappedId, text, (0, text.Length)) }; } @@ -290,7 +276,7 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken) int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder); Debug.Assert(encodedIds.Length > 0); - _cache?.Add(text, encodedIds); + _cache.Add(text, encodedIds); tokens = new Token[encodedIds.Length]; tokens[0] = new Token(encodedIds[0], text, (0, text.Length)); @@ -305,21 +291,21 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken) } /// - /// Encode a split text string to a list of Ids. + /// Encode text to a list of Ids. /// /// The text to encode. /// Indicate if the token is a special token. /// The list of accumulated Ids. - public override void EncodeToIds(string text, bool isSpecialToken, IList accumulatedIds) + public override void EncodeToIds(ReadOnlySpan text, bool isSpecialToken, IList accumulatedIds) { - if (string.IsNullOrEmpty(text)) + if (text.IsEmpty) { return; } if (isSpecialToken) { - if (_specialTokensEncoder is not null && _specialTokensEncoder.TryGetValue(text, out int id)) + if (_specialTokensEncoder?.TryGetValueUnsafe(text, out int id) is true) { accumulatedIds.Add(id); } @@ -327,23 +313,23 @@ public override void EncodeToIds(string text, bool isSpecialToken, IList ac return; } - if (_cache?.Lookup(text, out int[] tokenIds) is true) + if (_cache.TryGetValue(text, out int[]? tokenIds)) { accumulatedIds.AddRange(tokenIds); return; } - if (_vocab.TryGetValue(text, out int mappedId)) + if (_vocab.TryGetValueUnsafe(text, out int mappedId)) { accumulatedIds.Add(mappedId); return; } byte[] arrayPoolArray = ArrayPool.Shared.Rent(Encoding.UTF8.GetMaxByteCount(text.Length)); - int encodedLength = GetUtf8Bytes(text.AsSpan(), arrayPoolArray); + int encodedLength = GetUtf8Bytes(text, arrayPoolArray); int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder); - _cache?.Add(text, encodedIds); + _cache.Add(text.ToString(), encodedIds); accumulatedIds.AddRange(encodedIds); @@ -354,36 +340,36 @@ public override void EncodeToIds(string text, bool isSpecialToken, IList ac /// /// Get the number of tokens that the input text will be encoded to. /// - /// The text to encode. + /// The text to tokenize. /// Indicate if the token is special token. /// The number of tokens that the input text will be encoded to. - public override int CountTokens(string text, bool isSpecialToken) + public override int CountTokens(ReadOnlySpan text, bool isSpecialToken) { - if (string.IsNullOrEmpty(text)) + if (text.IsEmpty) { return 0; } if (isSpecialToken && _specialTokensEncoder is not null) { - return _specialTokensEncoder.TryGetValue(text, out _) ? 1 : 0; + return _specialTokensEncoder.TryGetValueUnsafe(text, out _) ? 1 : 0; } - if (_cache?.Lookup(text, out int[] ids) is true) + if (_cache.TryGetValue(text, out int[] ids)) { return ids.Length; } - if (_vocab.TryGetValue(text, out _)) + if (_vocab.TryGetValueUnsafe(text, out _)) { return 1; } byte[] arrayPoolArray = ArrayPool.Shared.Rent(Encoding.UTF8.GetMaxByteCount(text.Length)); - int encodedLength = GetUtf8Bytes(text.AsSpan(), arrayPoolArray); + int encodedLength = GetUtf8Bytes(text, arrayPoolArray); int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder); - _cache?.Add(text, encodedIds); + _cache.Add(text.ToString(), encodedIds); ArrayPool.Shared.Return(arrayPoolArray); return encodedIds.Length; @@ -395,19 +381,22 @@ public override int CountTokens(string text, bool isSpecialToken) /// The token to map to the Id. /// Indicate if want to consider the special tokens during the encoding. /// The mapped Id of the token. - public override int? MapTokenToId(string token, bool considerSpecialTokens = true) + public override int? MapTokenToId(ReadOnlySpan token, bool considerSpecialTokens = true) { - if (string.IsNullOrEmpty(token)) + if (token.IsEmpty) { return 0; } - if (considerSpecialTokens && _specialTokensEncoder is not null && _specialTokensEncoder.TryGetValue(token, out int specialTokenId)) + if (considerSpecialTokens && _specialTokensEncoder is not null) { - return specialTokenId; + if (_specialTokensEncoder.TryGetValueUnsafe(token, out int specialTokenId)) + { + return specialTokenId; + } } - if (_cache?.Lookup(token, out int[] ids) is true) + if (_cache.TryGetValue(token, out int[] ids)) { if (ids.Length == 1) { @@ -417,7 +406,7 @@ public override int CountTokens(string text, bool isSpecialToken) return null; } - if (_vocab.TryGetValue(token, out int id)) + if (_vocab.TryGetValueUnsafe(token, out int id)) { return id; } @@ -425,10 +414,10 @@ public override int CountTokens(string text, bool isSpecialToken) byte[] arrayPoolArray = ArrayPool.Shared.Rent(Encoding.UTF8.GetMaxByteCount(token.Length)); try { - int encodedLength = GetUtf8Bytes(token.AsSpan(), arrayPoolArray); + int encodedLength = GetUtf8Bytes(token, arrayPoolArray); int[] idsToCache = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder); - _cache?.Add(token, idsToCache); + _cache.Add(token.ToString(), idsToCache); if (idsToCache.Length == 1) { @@ -550,7 +539,12 @@ static void ArrayPoolGrow(ref Span utf8Bytes, ref byte[]? arrayPoolArray, /// Gets the dictionary mapping tokens to Ids. /// /// This may not contain the full set of vocabulary tokens, use Encoder to get the full set of vocabulary. - public IReadOnlyDictionary Vocab => _vocab; + public IReadOnlyDictionary Vocab => _vocabOriginal ?? (_vocabOriginal = _vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value)); + + /// + /// Gets the dictionary mapping special tokens to Ids. + /// + public IReadOnlyDictionary? SpecialTokensEncoder => _specialTokensEncoderOriginal ?? (_specialTokensEncoderOriginal = _specialTokensEncoder?.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value)); /// /// Gets the dictionary mapping token bytes to Ids. diff --git a/src/Microsoft.ML.Tokenizers/Tokenizer.cs b/src/Microsoft.ML.Tokenizers/Tokenizer.cs index 0826a8b68e..c64ebf256e 100644 --- a/src/Microsoft.ML.Tokenizers/Tokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Tokenizer.cs @@ -104,7 +104,7 @@ public IReadOnlyList EncodeToIds(string text, bool considerSpecialTokens = foreach (Split split in PreTokenizer.PreTokenize(normalized, considerSpecialTokens)) { - Model.EncodeToIds(split.TokenString, split.IsSpecialToken, idsList); + Model.EncodeToIds(split.TokenSpan, split.IsSpecialToken, idsList); } return idsList; @@ -130,7 +130,7 @@ public int CountTokens(string text, bool considerSpecialTokens = true) int idsCount = 0; foreach (Split split in PreTokenizer.PreTokenize(normalized, considerSpecialTokens)) { - idsCount += Model.CountTokens(split.TokenString, split.IsSpecialToken); + idsCount += Model.CountTokens(split.TokenSpan, split.IsSpecialToken); } return idsCount; @@ -343,7 +343,7 @@ private static Task CreateByEncoderNameAsync( } } - private static readonly ConcurrentDictionary, int>, Dictionary, Dictionary>)> _tiktokenCache = new(StringComparer.OrdinalIgnoreCase); + private static readonly ConcurrentDictionary, int> encoder, Dictionary vocab, Dictionary> decoder)> _tiktokenCache = new(StringComparer.OrdinalIgnoreCase); /// /// Create tokenizer based on regex pattern, BPE rank file and special tokens @@ -371,7 +371,7 @@ private static async Task CreateTikTokenTokenizerAsync( } } - if (!_tiktokenCache.TryGetValue(mergeableRanksFileUrl, out (Dictionary, int> encoder, Dictionary vocab, Dictionary> decoder) cache)) + if (!_tiktokenCache.TryGetValue(mergeableRanksFileUrl, out (Dictionary, int> encoder, Dictionary vocab, Dictionary> decoder) cache)) { using (Stream stream = await Helpers.GetStreamAsync(_httpClient, mergeableRanksFileUrl, cancellationToken).ConfigureAwait(false)) { diff --git a/src/Microsoft.ML.Tokenizers/Utils/Helpers.netcoreapp.cs b/src/Microsoft.ML.Tokenizers/Utils/Helpers.netcoreapp.cs index b64531431f..0050c63f3d 100644 --- a/src/Microsoft.ML.Tokenizers/Utils/Helpers.netcoreapp.cs +++ b/src/Microsoft.ML.Tokenizers/Utils/Helpers.netcoreapp.cs @@ -37,5 +37,7 @@ public static byte[] FromBase64String(string base64String, int offset, int lengt internal static bool TryParseInt32(string s, int offset, out int result) => int.TryParse(s.AsSpan().Slice(offset), NumberStyles.None, CultureInfo.InvariantCulture, out result); + + internal static int GetHashCode(ReadOnlySpan span) => string.GetHashCode(span); } } diff --git a/src/Microsoft.ML.Tokenizers/Utils/Helpers.netstandard.cs b/src/Microsoft.ML.Tokenizers/Utils/Helpers.netstandard.cs index 2979c99b6e..2d739e52e4 100644 --- a/src/Microsoft.ML.Tokenizers/Utils/Helpers.netstandard.cs +++ b/src/Microsoft.ML.Tokenizers/Utils/Helpers.netstandard.cs @@ -48,6 +48,17 @@ internal static bool TryParseInt32(string s, int offset, out int result) return true; } + + internal static int GetHashCode(ReadOnlySpan span) + { + int hash = 17; + foreach (char c in span) + { + hash = hash * 31 + c; + } + + return hash; + } } } diff --git a/src/Microsoft.ML.Tokenizers/Utils/LruCache.cs b/src/Microsoft.ML.Tokenizers/Utils/LruCache.cs index 9ad88e2f35..c11d79e1f5 100644 --- a/src/Microsoft.ML.Tokenizers/Utils/LruCache.cs +++ b/src/Microsoft.ML.Tokenizers/Utils/LruCache.cs @@ -2,47 +2,37 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Collections.Generic; namespace Microsoft.ML.Tokenizers { - internal class LruCache where TKey : notnull where TValue : notnull + internal sealed class LruCache { /// /// The default LRU cache size. /// - public const int DefaultCacheSize = 8192; // 4096; + public const int DefaultCacheSize = 8192; - private readonly object _lockObject = new object(); - - private class CacheItem - { - public readonly TKey Key; - public TValue Value; - - public CacheItem(TKey key, TValue value) - { - Key = key; - Value = value; - } - } - - private readonly Dictionary> _cache; - private readonly LinkedList _lruList; + private readonly Dictionary>> _cache = new(); + private readonly LinkedList> _lruList = new(); private readonly int _cacheSize; + private object SyncObj => _cache; + /// - /// Constructs an object. + /// Constructs an object. /// /// - /// The maximum number of to mappings - /// that can be cached. This defaults to , which is set to - /// 4096. + /// The maximum number of mappings that can be cached. This defaults to , which is set to 8192. /// public LruCache(int cacheSize = DefaultCacheSize) { - _cache = new Dictionary>(); - _lruList = new LinkedList(); + if (cacheSize <= 0) + { + throw new ArgumentOutOfRangeException(nameof(cacheSize), "Cache size must be a positive number."); + } + _cacheSize = cacheSize; } @@ -54,11 +44,11 @@ public LruCache(int cacheSize = DefaultCacheSize) /// /// true if the cache contains a mapping for key, false otherwise. /// - public bool Lookup(TKey key, out TValue value) + public bool TryGetValue(string key, out TValue value) { - lock (_lockObject) + lock (SyncObj) { - if (_cache.TryGetValue(key, out LinkedListNode? cached)) + if (_cache.TryGetValue(new StringSpanOrdinalKey(key), out LinkedListNode>? cached)) { _lruList.Remove(cached); _lruList.AddFirst(cached); @@ -71,16 +61,31 @@ public bool Lookup(TKey key, out TValue value) } } - protected virtual void OnEviction(TValue evictedValue) { } - - private void EvictIfNeeded() + /// + /// Retrieves the value associated with the specified key /> object. + /// + /// The object to be used as a key. + /// An out parameter that is set to the value of the key if key contains a mapping in the cache. + /// + /// true if the cache contains a mapping for key, false otherwise. + /// + public unsafe bool TryGetValue(ReadOnlySpan key, out TValue value) { - while (_cache.Count >= _cacheSize) + lock (SyncObj) { - LinkedListNode? nodeToEvict = _lruList.Last; - _lruList.RemoveLast(); - _cache.Remove(nodeToEvict!.Value.Key); - OnEviction(nodeToEvict.Value.Value); + fixed (char* ptr = key) + { + if (_cache.TryGetValue(new StringSpanOrdinalKey(ptr, key.Length), out LinkedListNode>? cached)) + { + _lruList.Remove(cached); + _lruList.AddFirst(cached); + value = cached.Value.Value; + return true; + } + } + + value = default!; + return false; } } @@ -89,46 +94,29 @@ private void EvictIfNeeded() /// /// The key whose mapped is to be created or replaced. /// The new value to be mapped to the . - public void Add(TKey key, TValue value) => Replace(key, value, out _); - - public bool Replace(TKey key, TValue value, out TValue oldValue) + public void Add(string key, TValue value) { - lock (_lockObject) + lock (SyncObj) { - return ReplaceInternal(key, value, out oldValue); - } - } - - private bool ReplaceInternal(TKey key, TValue value, out TValue oldValue) - { - if (_cache.TryGetValue(key, out LinkedListNode? cached)) - { - oldValue = cached.Value.Value; - cached.Value.Value = value; - _lruList.Remove(cached); - _lruList.AddFirst(cached); - return true; - } - EvictIfNeeded(); - var node = new LinkedListNode(new CacheItem(key, value)); - _cache[key] = node; - _lruList.AddFirst(node); - oldValue = default!; - return false; - } + if (_cache.TryGetValue(new StringSpanOrdinalKey(key), out LinkedListNode>? cached)) + { + cached.Value = new KeyValuePair(key, value); + _lruList.Remove(cached); + _lruList.AddFirst(cached); + return; + } - /// - /// The number of entries currently present in the cache. - /// - public int Count => _cache.Count; + while (_cache.Count >= _cacheSize) + { + LinkedListNode>? nodeToEvict = _lruList.Last; + _lruList.RemoveLast(); + _cache.Remove(new StringSpanOrdinalKey(nodeToEvict!.Value.Key)); + } - /// - /// Clears the contents of this cache. - /// - public void Clear() - { - _cache.Clear(); - _lruList.Clear(); + var node = new LinkedListNode>(new KeyValuePair(key, value)); + _cache[new StringSpanOrdinalKey(key)] = node; + _lruList.AddFirst(node); + } } } -} \ No newline at end of file +} diff --git a/src/Microsoft.ML.Tokenizers/Utils/StringSpanOrdinalKey.cs b/src/Microsoft.ML.Tokenizers/Utils/StringSpanOrdinalKey.cs new file mode 100644 index 0000000000..646cd53bc7 --- /dev/null +++ b/src/Microsoft.ML.Tokenizers/Utils/StringSpanOrdinalKey.cs @@ -0,0 +1,131 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace Microsoft.ML.Tokenizers +{ + /// Used as a key in a dictionary to enable querying with either a string or a span. + /// + /// This should only be used with a Ptr/Length for querying. For storing in a dictionary, this should + /// always be used with a string. + /// + internal unsafe readonly struct StringSpanOrdinalKey : IEquatable + { + public readonly char* Ptr; + public readonly int Length; + public readonly string? Data; + + public StringSpanOrdinalKey(char* ptr, int length) + { + Ptr = ptr; + Length = length; + } + + public StringSpanOrdinalKey(string data) => + Data = data; + + private ReadOnlySpan Span => Ptr is not null ? + new ReadOnlySpan(Ptr, Length) : + Data.AsSpan(); + + public override bool Equals(object? obj) => + obj is StringSpanOrdinalKey wrapper && Equals(wrapper); + + public bool Equals(StringSpanOrdinalKey other) => + Span.SequenceEqual(other.Span); + + public override int GetHashCode() => Helpers.GetHashCode(Span); + } + + internal sealed class StringSpanOrdinalKeyCache + { + private readonly int _capacity; + private readonly Dictionary _map; + + private object SyncObj => _map; + + internal StringSpanOrdinalKeyCache() : this(Bpe.DefaultCacheCapacity) { } + + internal StringSpanOrdinalKeyCache(int capacity) + { + _capacity = capacity; + _map = new Dictionary(capacity); + } + + internal bool TryGetValue(string key, out TValue value) + { + lock (SyncObj) + { + return _map.TryGetValue(new StringSpanOrdinalKey(key), out value!); + } + } + + internal unsafe bool TryGetValue(ReadOnlySpan key, out TValue value) + { + lock (SyncObj) + { + fixed (char* ptr = key) + { + return _map.TryGetValue(new StringSpanOrdinalKey(ptr, key.Length), out value!); + } + } + } + + internal void Remove(string key) + { + lock (SyncObj) + { + _map.Remove(new StringSpanOrdinalKey(key)); + } + } + + internal void Set(string k, TValue v) + { + lock (SyncObj) + { + if (_map.Count < _capacity) + { + _map[new StringSpanOrdinalKey(k)] = v; + } + } + } + } + + /// + /// Custom JSON converter for . + /// + internal sealed class StringSpanOrdinalKeyConverter : JsonConverter + { + public override StringSpanOrdinalKey ReadAsPropertyName(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) => + new StringSpanOrdinalKey(reader.GetString()!); + + public override void WriteAsPropertyName(Utf8JsonWriter writer, StringSpanOrdinalKey value, JsonSerializerOptions options) => + writer.WriteStringValue(value.Data!); + + public override StringSpanOrdinalKey Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) => new StringSpanOrdinalKey(reader.GetString()!); + public override void Write(Utf8JsonWriter writer, StringSpanOrdinalKey value, JsonSerializerOptions options) => writer.WriteStringValue(value.Data!); + } + + /// + /// Extension methods for . + /// + internal static class StringSpanOrdinalKeyExtensions + { + public unsafe static bool TryGetValueUnsafe(this IReadOnlyDictionary map, ReadOnlySpan key, out TValue value) + { + fixed (char* ptr = key) + { + return map.TryGetValue(new StringSpanOrdinalKey(ptr, key.Length), out value!); + } + } + + public static bool TryGetValueUnsafe(this IReadOnlyDictionary map, string key, out TValue value) => + map.TryGetValue(new StringSpanOrdinalKey(key), out value!); + } +} diff --git a/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs b/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs index 810862322b..2959184b5d 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs @@ -156,7 +156,7 @@ public void SimpleTestWithUnknownToken(Dictionary vocab, (string, s Assert.Equal(ids[i], encoding.Ids[i]); Assert.Equal(ids[i], idsList[i]); Assert.Equal(encoding.Tokens[i], tokenizer.Model.MapIdToToken(encoding.Ids[i])); - Assert.Equal(encoding.Ids[i], tokenizer.Model.MapTokenToId(encoding.Tokens[i])); + Assert.Equal(encoding.Ids[i], tokenizer.Model.MapTokenToId(encoding.Tokens[i].AsSpan())); Assert.Equal(encoding.Tokens[i], tokenizer.Decode(encoding.Ids[i])); } } diff --git a/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs b/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs index d23f241319..ccf0e66ef9 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs @@ -201,7 +201,7 @@ private void TestTokenizer(Tokenizer tokenizer, CallingOrder callingOrder = Call Assert.Equal(unfilteredToken![i], tokenizer.Model.MapIdToToken(encoding.Ids[i], considerSpecialTokens: false)); } - Assert.Equal(encoding.Ids[i], tokenizer.Model.MapTokenToId(encoding.Tokens[i])); + Assert.Equal(encoding.Ids[i], tokenizer.Model.MapTokenToId(encoding.Tokens[i].AsSpan())); } } } From 39f36a126a491da30dafc652b274f0dadd6b9d4e Mon Sep 17 00:00:00 2001 From: Tarek Mahmoud Sayed Date: Wed, 28 Feb 2024 09:56:58 -0800 Subject: [PATCH 2/3] Address the feedback --- src/Microsoft.ML.Tokenizers/Model/BPE.cs | 38 +++++++++---------- .../Model/EnglishRoberta.cs | 8 ++-- src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs | 20 +++++----- .../Utils/StringSpanOrdinalKey.cs | 5 ++- 4 files changed, 36 insertions(+), 35 deletions(-) diff --git a/src/Microsoft.ML.Tokenizers/Model/BPE.cs b/src/Microsoft.ML.Tokenizers/Model/BPE.cs index a98da728b0..9da31a301c 100644 --- a/src/Microsoft.ML.Tokenizers/Model/BPE.cs +++ b/src/Microsoft.ML.Tokenizers/Model/BPE.cs @@ -134,12 +134,12 @@ private Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken, stri { (string a, string b) mergeValues = merges[i]; - if (!_vocab.TryGetValueUnsafe(mergeValues.a, out int aId)) + if (!_vocab.TryGetValue(mergeValues.a, out int aId)) { throw new InvalidOperationException($"Trying to merge a token '{mergeValues.a}' which not exist in the vocabulary."); } - if (!_vocab.TryGetValueUnsafe(mergeValues.b, out int bId)) + if (!_vocab.TryGetValue(mergeValues.b, out int bId)) { throw new InvalidOperationException($"Trying to merge a token '{mergeValues.b}' which not exist in the vocabulary."); } @@ -150,7 +150,7 @@ private Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken, stri } string newToken = $"{mergeValues.a}{mergeValues.b.Substring(prefixLen)}"; - if (!_vocab.TryGetValueUnsafe(newToken, out int newId)) + if (!_vocab.TryGetValue(newToken, out int newId)) { throw new InvalidOperationException($"Trying to merge a token '{newToken}' which not exist in the vocabulary."); } @@ -211,7 +211,7 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f /// The token to map to the Id. /// Indicate if want to consider the special tokens during the encoding. /// The mapped Id of the token. - public override int? MapTokenToId(ReadOnlySpan token, bool considerSpecialTokens = true) => _vocab.TryGetValueUnsafe(token, out int value) ? value : null; + public override int? MapTokenToId(ReadOnlySpan token, bool considerSpecialTokens = true) => _vocab.TryGetValue(token, out int value) ? value : null; /// /// Map the encoded Id to the token. @@ -232,12 +232,12 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f /// /// Gets the dictionary mapping tokens to Ids. /// - public IReadOnlyDictionary Vocab => _vocabOriginal ?? (_vocabOriginal = _vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value)); + public IReadOnlyDictionary Vocab => _vocabOriginal ??= _vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value); /// Read the given files to extract the vocab and merges internal static (Dictionary?, Vec<(string, string)>) ReadModelData(Stream vocab, Stream? merges) { - JsonSerializerOptions options = new() { Converters = { new StringSpanOrdinalKeyConverter() } }; + JsonSerializerOptions options = new() { Converters = { StringSpanOrdinalKeyConverter.Instance } }; Dictionary? dic = JsonSerializer.Deserialize>(vocab, options) as Dictionary; return (dic, ConvertMergesToHashmap(merges)); @@ -346,11 +346,7 @@ internal Word MergeWord(ReadOnlySpan w) // Add the `continuing_subword_prefix` if relevant if (i > 0 && ContinuingSubwordPrefix is not null) { - if (buffer is null) - { - // 60 should be enough for most cases - buffer = ArrayPool.Shared.Rent(60); - } + buffer ??= ArrayPool.Shared.Rent(60); // 60 should be enough for most cases if (ContinuingSubwordPrefix.Length + s.Length <= buffer.Length) { @@ -360,19 +356,19 @@ internal Word MergeWord(ReadOnlySpan w) } else { +#if NETCOREAPP + s = $"{ContinuingSubwordPrefix}{s}".AsSpan(); +#else string s1 = s.Length == 1 ? CharToString(s[0]) : s.ToString(); s = $"{ContinuingSubwordPrefix}{s1}".AsSpan(); +#endif } } // Add the `end_of_word_suffix` if relevant if (i + length >= w.Length && EndOfWordSuffix is not null) { - if (buffer is null) - { - // 60 should be enough for most cases - buffer = ArrayPool.Shared.Rent(60); - } + buffer ??= ArrayPool.Shared.Rent(60); // 60 should be enough for most cases if (s.Length + EndOfWordSuffix.Length <= buffer.Length) { @@ -382,12 +378,16 @@ internal Word MergeWord(ReadOnlySpan w) } else { +#if NETCOREAPP + s = $"{s}{EndOfWordSuffix}".AsSpan(); +#else string s1 = s.Length == 1 ? CharToString(s[0]) : s.ToString(); s = $"{s1}{EndOfWordSuffix}".AsSpan(); +#endif } } - if (_vocab.TryGetValueUnsafe(s, out int id)) + if (_vocab.TryGetValue(s, out int id)) { if (unk.HasValue) { @@ -409,7 +409,7 @@ internal Word MergeWord(ReadOnlySpan w) { // Do not fuse unk, add the previous one word.Add(unk.Value.Id, unk.Value.Len); - if (!_vocab.TryGetValueUnsafe(UnknownToken, out int value)) + if (!_vocab.TryGetValue(UnknownToken, out int value)) { throw new InvalidOperationException($"Unknown Token Out Of Vocabulary."); } @@ -418,7 +418,7 @@ internal Word MergeWord(ReadOnlySpan w) } else { - if (!_vocab.TryGetValueUnsafe(UnknownToken, out int value)) + if (!_vocab.TryGetValue(UnknownToken, out int value)) { throw new InvalidOperationException($"Unknown Token Out Of Vocabulary."); } diff --git a/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs b/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs index ae978eb0fe..3155c778ec 100644 --- a/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs +++ b/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs @@ -18,7 +18,7 @@ namespace Microsoft.ML.Tokenizers public sealed class EnglishRoberta : Model { private readonly HighestOccurrenceMapping _vocabIdToHighestOccurrence; - private readonly IReadOnlyDictionary _vocab; + private readonly Dictionary _vocab; private Dictionary? _vocabOriginal; private readonly SortedDictionary _vocabReverse; private readonly Cache<(string, string), int> _mergeRanks; @@ -125,7 +125,7 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes /// /// Gets the dictionary mapping tokens to Ids. /// - public IReadOnlyDictionary Vocab => _vocabOriginal ??= (_vocabOriginal = _vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value)); + public IReadOnlyDictionary Vocab => _vocabOriginal ??= _vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value); // // Public Model interfaces implementation @@ -300,7 +300,7 @@ private int EncodeToIds(ReadOnlySpan text, IList? accumulatedIds) /// The token to map to the Id. /// Indicate if want to consider the special tokens during the encoding. /// The mapped Id of the token. - public override int? MapTokenToId(ReadOnlySpan token, bool considerSpecialTokens = true) => _vocab.TryGetValueUnsafe(token, out int value) ? value : null; + public override int? MapTokenToId(ReadOnlySpan token, bool considerSpecialTokens = true) => _vocab.TryGetValue(token, out int value) ? value : null; /// /// Convert a list of tokens Ids to highest occurrence rankings. @@ -428,7 +428,7 @@ private Dictionary GetVocabulary(Stream vocabularyStr Dictionary? vocab; try { - JsonSerializerOptions options = new() { Converters = { new StringSpanOrdinalKeyConverter() } }; + JsonSerializerOptions options = new() { Converters = { StringSpanOrdinalKeyConverter.Instance } }; vocab = JsonSerializer.Deserialize>(vocabularyStream, options) as Dictionary; } catch (Exception e) diff --git a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs index afa405f9e7..60e9282a81 100644 --- a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs +++ b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs @@ -244,7 +244,7 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken) if (isSpecialToken) { - if (_specialTokensEncoder?.TryGetValueUnsafe(text, out int id) is true) + if (_specialTokensEncoder?.TryGetValue(text, out int id) is true) { return new List { new(id, text, (0, text.Length)) }; } @@ -266,7 +266,7 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken) } // cache miss - if (_vocab.TryGetValueUnsafe(text, out int mappedId)) + if (_vocab.TryGetValue(text, out int mappedId)) { return new Token[1] { new(mappedId, text, (0, text.Length)) }; } @@ -305,7 +305,7 @@ public override void EncodeToIds(ReadOnlySpan text, bool isSpecialToken, I if (isSpecialToken) { - if (_specialTokensEncoder?.TryGetValueUnsafe(text, out int id) is true) + if (_specialTokensEncoder?.TryGetValue(text, out int id) is true) { accumulatedIds.Add(id); } @@ -319,7 +319,7 @@ public override void EncodeToIds(ReadOnlySpan text, bool isSpecialToken, I return; } - if (_vocab.TryGetValueUnsafe(text, out int mappedId)) + if (_vocab.TryGetValue(text, out int mappedId)) { accumulatedIds.Add(mappedId); return; @@ -352,7 +352,7 @@ public override int CountTokens(ReadOnlySpan text, bool isSpecialToken) if (isSpecialToken && _specialTokensEncoder is not null) { - return _specialTokensEncoder.TryGetValueUnsafe(text, out _) ? 1 : 0; + return _specialTokensEncoder.TryGetValue(text, out _) ? 1 : 0; } if (_cache.TryGetValue(text, out int[] ids)) @@ -360,7 +360,7 @@ public override int CountTokens(ReadOnlySpan text, bool isSpecialToken) return ids.Length; } - if (_vocab.TryGetValueUnsafe(text, out _)) + if (_vocab.TryGetValue(text, out _)) { return 1; } @@ -390,7 +390,7 @@ public override int CountTokens(ReadOnlySpan text, bool isSpecialToken) if (considerSpecialTokens && _specialTokensEncoder is not null) { - if (_specialTokensEncoder.TryGetValueUnsafe(token, out int specialTokenId)) + if (_specialTokensEncoder.TryGetValue(token, out int specialTokenId)) { return specialTokenId; } @@ -406,7 +406,7 @@ public override int CountTokens(ReadOnlySpan text, bool isSpecialToken) return null; } - if (_vocab.TryGetValueUnsafe(token, out int id)) + if (_vocab.TryGetValue(token, out int id)) { return id; } @@ -539,12 +539,12 @@ static void ArrayPoolGrow(ref Span utf8Bytes, ref byte[]? arrayPoolArray, /// Gets the dictionary mapping tokens to Ids. /// /// This may not contain the full set of vocabulary tokens, use Encoder to get the full set of vocabulary. - public IReadOnlyDictionary Vocab => _vocabOriginal ?? (_vocabOriginal = _vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value)); + public IReadOnlyDictionary Vocab => _vocabOriginal ??= _vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value); /// /// Gets the dictionary mapping special tokens to Ids. /// - public IReadOnlyDictionary? SpecialTokensEncoder => _specialTokensEncoderOriginal ?? (_specialTokensEncoderOriginal = _specialTokensEncoder?.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value)); + public IReadOnlyDictionary? SpecialTokensEncoder => _specialTokensEncoderOriginal ??= _specialTokensEncoder?.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value); /// /// Gets the dictionary mapping token bytes to Ids. diff --git a/src/Microsoft.ML.Tokenizers/Utils/StringSpanOrdinalKey.cs b/src/Microsoft.ML.Tokenizers/Utils/StringSpanOrdinalKey.cs index 646cd53bc7..3cee62e318 100644 --- a/src/Microsoft.ML.Tokenizers/Utils/StringSpanOrdinalKey.cs +++ b/src/Microsoft.ML.Tokenizers/Utils/StringSpanOrdinalKey.cs @@ -102,6 +102,7 @@ internal void Set(string k, TValue v) /// internal sealed class StringSpanOrdinalKeyConverter : JsonConverter { + public static StringSpanOrdinalKeyConverter Instance { get; } = new StringSpanOrdinalKeyConverter(); public override StringSpanOrdinalKey ReadAsPropertyName(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) => new StringSpanOrdinalKey(reader.GetString()!); @@ -117,7 +118,7 @@ public override void WriteAsPropertyName(Utf8JsonWriter writer, StringSpanOrdina /// internal static class StringSpanOrdinalKeyExtensions { - public unsafe static bool TryGetValueUnsafe(this IReadOnlyDictionary map, ReadOnlySpan key, out TValue value) + public unsafe static bool TryGetValue(this Dictionary map, ReadOnlySpan key, out TValue value) { fixed (char* ptr = key) { @@ -125,7 +126,7 @@ public unsafe static bool TryGetValueUnsafe(this IReadOnlyDictionary(this IReadOnlyDictionary map, string key, out TValue value) => + public static bool TryGetValue(this Dictionary map, string key, out TValue value) => map.TryGetValue(new StringSpanOrdinalKey(key), out value!); } } From 40a669f10a8d56a300d30cfb2e2d85a4c48c1346 Mon Sep 17 00:00:00 2001 From: Tarek Mahmoud Sayed Date: Wed, 28 Feb 2024 12:02:09 -0800 Subject: [PATCH 3/3] Use stackalloc instead of the ArrayPool --- src/Microsoft.ML.Tokenizers/Model/BPE.cs | 142 ++++++++++------------- 1 file changed, 64 insertions(+), 78 deletions(-) diff --git a/src/Microsoft.ML.Tokenizers/Model/BPE.cs b/src/Microsoft.ML.Tokenizers/Model/BPE.cs index 9da31a301c..20cfe7f38b 100644 --- a/src/Microsoft.ML.Tokenizers/Model/BPE.cs +++ b/src/Microsoft.ML.Tokenizers/Model/BPE.cs @@ -323,101 +323,86 @@ internal Word MergeWord(ReadOnlySpan w) (int Id, int Len)? unk = null; int i = 0; - char[]? buffer = null; + Span buffer = stackalloc char[256]; + scoped ReadOnlySpan s; - try + while (i < w.Length) { - while (i < w.Length) + int length; + + if (Char.IsHighSurrogate(w[i]) && i < w.Length - 1 && Char.IsLowSurrogate(w[i + 1])) + { + length = 2; + s = w.Slice(i, 2); + } + else { - int length; - ReadOnlySpan s; + length = 1; + s = w.Slice(i, 1); + } - if (Char.IsHighSurrogate(w[i]) && i < w.Length - 1 && Char.IsLowSurrogate(w[i + 1])) + // Add the `continuing_subword_prefix` if relevant + if (i > 0 && ContinuingSubwordPrefix is not null) + { + if (ContinuingSubwordPrefix.Length + s.Length <= buffer.Length) { - length = 2; - s = w.Slice(i, 2); + ContinuingSubwordPrefix.AsSpan().CopyTo(buffer); + s.CopyTo(buffer.Slice(ContinuingSubwordPrefix.Length)); + s = buffer.Slice(0, ContinuingSubwordPrefix.Length + s.Length); } else { - length = 1; - s = w.Slice(i, 1); - } - - // Add the `continuing_subword_prefix` if relevant - if (i > 0 && ContinuingSubwordPrefix is not null) - { - buffer ??= ArrayPool.Shared.Rent(60); // 60 should be enough for most cases - - if (ContinuingSubwordPrefix.Length + s.Length <= buffer.Length) - { - ContinuingSubwordPrefix.AsSpan().CopyTo(buffer.AsSpan()); - s.CopyTo(buffer.AsSpan().Slice(ContinuingSubwordPrefix.Length)); - s = buffer.AsSpan().Slice(0, ContinuingSubwordPrefix.Length + s.Length); - } - else - { #if NETCOREAPP - s = $"{ContinuingSubwordPrefix}{s}".AsSpan(); + s = $"{ContinuingSubwordPrefix}{s}".AsSpan(); #else - string s1 = s.Length == 1 ? CharToString(s[0]) : s.ToString(); - s = $"{ContinuingSubwordPrefix}{s1}".AsSpan(); + string s1 = s.Length == 1 ? CharToString(s[0]) : s.ToString(); + s = $"{ContinuingSubwordPrefix}{s1}".AsSpan(); #endif - } } + } - // Add the `end_of_word_suffix` if relevant - if (i + length >= w.Length && EndOfWordSuffix is not null) + // Add the `end_of_word_suffix` if relevant + if (i + length >= w.Length && EndOfWordSuffix is not null) + { + if (s.Length + EndOfWordSuffix.Length <= buffer.Length) + { + s.CopyTo(buffer); + EndOfWordSuffix.AsSpan().CopyTo(buffer.Slice(s.Length)); + s = buffer.Slice(0, s.Length + EndOfWordSuffix.Length); + } + else { - buffer ??= ArrayPool.Shared.Rent(60); // 60 should be enough for most cases - - if (s.Length + EndOfWordSuffix.Length <= buffer.Length) - { - s.CopyTo(buffer.AsSpan()); - EndOfWordSuffix.AsSpan().CopyTo(buffer.AsSpan().Slice(s.Length)); - s = buffer.AsSpan().Slice(0, s.Length + EndOfWordSuffix.Length); - } - else - { #if NETCOREAPP - s = $"{s}{EndOfWordSuffix}".AsSpan(); + s = $"{s}{EndOfWordSuffix}".AsSpan(); #else - string s1 = s.Length == 1 ? CharToString(s[0]) : s.ToString(); - s = $"{s1}{EndOfWordSuffix}".AsSpan(); + string s1 = s.Length == 1 ? CharToString(s[0]) : s.ToString(); + s = $"{s1}{EndOfWordSuffix}".AsSpan(); #endif - } } + } - if (_vocab.TryGetValue(s, out int id)) + if (_vocab.TryGetValue(s, out int id)) + { + if (unk.HasValue) { - if (unk.HasValue) - { - word.Add(unk.Value.Id, unk.Value.Len); - unk = null; - } - word.Add(id, length); + word.Add(unk.Value.Id, unk.Value.Len); + unk = null; } - else if (UnknownToken is not null) + word.Add(id, length); + } + else if (UnknownToken is not null) + { + if (unk.HasValue) { - if (unk.HasValue) + if (FuseUnknownTokens) { - if (FuseUnknownTokens) - { - // Fuse unk - unk = (unk.Value.Id, unk.Value.Len + length); - } - else - { - // Do not fuse unk, add the previous one - word.Add(unk.Value.Id, unk.Value.Len); - if (!_vocab.TryGetValue(UnknownToken, out int value)) - { - throw new InvalidOperationException($"Unknown Token Out Of Vocabulary."); - } - unk = (value, length); - } + // Fuse unk + unk = (unk.Value.Id, unk.Value.Len + length); } else { + // Do not fuse unk, add the previous one + word.Add(unk.Value.Id, unk.Value.Len); if (!_vocab.TryGetValue(UnknownToken, out int value)) { throw new InvalidOperationException($"Unknown Token Out Of Vocabulary."); @@ -425,16 +410,17 @@ internal Word MergeWord(ReadOnlySpan w) unk = (value, length); } } - - i += length; - } - } - finally - { - if (buffer is not null) - { - ArrayPool.Shared.Return(buffer); + else + { + if (!_vocab.TryGetValue(UnknownToken, out int value)) + { + throw new InvalidOperationException($"Unknown Token Out Of Vocabulary."); + } + unk = (value, length); + } } + + i += length; } if (unk.HasValue)