diff --git a/THIRD-PARTY-NOTICES.TXT b/THIRD-PARTY-NOTICES.TXT index 3bc1463084..47f9d3cd1d 100644 --- a/THIRD-PARTY-NOTICES.TXT +++ b/THIRD-PARTY-NOTICES.TXT @@ -133,6 +133,27 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.! +License notice for DoubleArrayTrie (DART) +-------------------------------------------- +The BSD 2-clause license + +https://github.com/s-yata/darts-clone/blob/master/COPYING.md + +Copyright (c) 2008-2014, Susumu Yata All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, +OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF +THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + License notice for CodeGen Tokenizer -------------------------------------------- diff --git a/eng/Versions.props b/eng/Versions.props index 0cc944a243..920c9a766e 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -100,7 +100,7 @@ 0.0.13-test 0.0.6-test 0.0.7-test - 2.0.0-beta.24455.2 + 2.0.0-beta.25110.1 4.9.0 1.0.118 1.6.24 diff --git a/src/Microsoft.ML.Tokenizers/Model/CodeGenTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/CodeGenTokenizer.cs index b3ee022ad3..ac7ea9e6f2 100644 --- a/src/Microsoft.ML.Tokenizers/Model/CodeGenTokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Model/CodeGenTokenizer.cs @@ -22,6 +22,9 @@ namespace Microsoft.ML.Tokenizers /// public class CodeGenTokenizer : Tokenizer { + // The CodeGen tokenizer implementation is primarily adapted from + // https://github.com/huggingface/transformers/blob/main/src/transformers/models/codegen/tokenization_codegen.py, + // with modifications to align with C# code style, the API, and the tokenizer library design. private readonly Dictionary _vocab; private IReadOnlyDictionary? _vocabOriginal; private readonly IReadOnlyDictionary _vocabReverse; diff --git a/src/Microsoft.ML.Tokenizers/Model/LlamaTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/LlamaTokenizer.cs index e5c5ca4e70..57ab57b13a 100644 --- a/src/Microsoft.ML.Tokenizers/Model/LlamaTokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Model/LlamaTokenizer.cs @@ -31,7 +31,7 @@ internal LlamaTokenizer(ModelProto modelProto, bool addBos, bool addEos, IReadOn /// /// When creating the tokenizer, ensure that the vocabulary stream is sourced from a trusted provider. /// - public static LlamaTokenizer Create( + public static new LlamaTokenizer Create( Stream modelStream, bool addBeginOfSentence = true, bool addEndOfSentence = false, @@ -54,13 +54,6 @@ public static LlamaTokenizer Create( throw new ArgumentException($"Normalization '{modelProto.NormalizerSpec.Name}' is not supported.", nameof(modelProto)); } - SentencePieceNormalizer normalizer = new( - modelProto.NormalizerSpec.RemoveExtraWhitespaces, - modelProto.NormalizerSpec.AddDummyPrefix, - modelProto.NormalizerSpec.EscapeWhitespaces, - modelProto.TrainerSpec.TreatWhitespaceAsSuffix, - specialTokens); - return new LlamaTokenizer(modelProto, addBeginOfSentence, addEndOfSentence, specialTokens); } } diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceBaseModel.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceBaseModel.cs new file mode 100644 index 0000000000..a1553ec4cd --- /dev/null +++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceBaseModel.cs @@ -0,0 +1,754 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Sentencepiece; +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Text; +using System.Text.RegularExpressions; + +namespace Microsoft.ML.Tokenizers +{ + internal abstract class SentencePieceBaseModel + { + internal SentencePieceBaseModel(ModelProto modelProto, bool addBos = false, bool addEos = false, IReadOnlyDictionary? specialTokens = null) + { + if (modelProto is null) + { + throw new ArgumentNullException(nameof(modelProto)); + } + + AddBeginningOfSentence = addBos; + AddEndOfSentence = addEos; + BeginningOfSentenceToken = modelProto.TrainerSpec.BosPiece ?? ""; + BeginningOfSentenceId = modelProto.TrainerSpec.BosId <= 0 ? 1 : modelProto.TrainerSpec.BosId; + EndOfSentenceToken = modelProto.TrainerSpec.EosPiece ?? ""; + EndOfSentenceId = modelProto.TrainerSpec.EosId <= 0 ? 1 : modelProto.TrainerSpec.EosId; + UnknownToken = modelProto.TrainerSpec.UnkPiece ?? ""; + UnknownId = modelProto.TrainerSpec.UnkId < 0 ? 0 : modelProto.TrainerSpec.UnkId; + AddDummyPrefix = modelProto.NormalizerSpec.AddDummyPrefix; + EscapeWhiteSpaces = modelProto.NormalizerSpec.EscapeWhitespaces; + TreatWhitespaceAsSuffix = modelProto.TrainerSpec.TreatWhitespaceAsSuffix; + ByteFallback = modelProto.TrainerSpec.ByteFallback; + SpecialTokens = specialTokens; + + if (specialTokens is not null && specialTokens.Count > 0) + { + InternalSpecialTokens = new Dictionary(); + SpecialTokensReverse = new Dictionary(); + + foreach (var item in specialTokens) + { + InternalSpecialTokens.Add(new StringSpanOrdinalKey(item.Key), item.Value); + SpecialTokensReverse.Add(item.Value, item.Key); + } + + // We create this Regex object without a timeout, as we expect the match operation to complete in O(N) time complexity. Note that `specialTokens` are treated as constants after the tokenizer is created. + SpecialTokensRegex = new Regex(string.Join("|", specialTokens.Keys.Select(s => Regex.Escape(s))), RegexOptions.Compiled); + } + + Normalizer = new SentencePieceNormalizer( + modelProto.NormalizerSpec.PrecompiledCharsmap.Span, + modelProto.NormalizerSpec.RemoveExtraWhitespaces, + AddDummyPrefix, EscapeWhiteSpaces, + modelProto.TrainerSpec.TreatWhitespaceAsSuffix, + specialTokens); + } + + internal Regex? SpecialTokensRegex { get; } + + internal Dictionary? InternalSpecialTokens { get; } + + internal Dictionary? SpecialTokensReverse { get; } + + internal int MaxByteId { get; set; } // the maximum value of the byte id.; + + internal int ByteCodeToIdOffset { get; set; } // offset of mapping byte code to the to the Ids. + + internal int OneByteUtf8EncodingMaxId { get; set; } // the maximum value of the one byte UTF-8 character. + + public IReadOnlyDictionary? SpecialTokens { get; } + + public bool ByteFallback { get; } + + public bool AddDummyPrefix { get; } + + public bool EscapeWhiteSpaces { get; } + + public bool TreatWhitespaceAsSuffix { get; internal set; } + + public bool AddBeginningOfSentence { get; } + + public bool AddEndOfSentence { get; } + + public string BeginningOfSentenceToken { get; } + + public string EndOfSentenceToken { get; } + + public string UnknownToken { get; } + + public int BeginningOfSentenceId { get; } + + public int EndOfSentenceId { get; } + + public int UnknownId { get; } + + public SentencePieceNormalizer? Normalizer { get; } + + public abstract IReadOnlyDictionary Vocabulary { get; } + + public abstract IReadOnlyList EncodeToTokens( + string? text, + ReadOnlySpan textSpan, + out string? normalizedText, + bool addBeginningOfSentence, + bool addEndOfSentence, + bool considerNormalization); + + public abstract IReadOnlyList EncodeToIds( + string? text, + ReadOnlySpan textSpan, + bool addBeginningOfSentence, + bool addEndOfSentence, + bool considerNormalization, + out string? normalizedText, + out int charsConsumed, + int maxTokenCount = int.MaxValue); + + public abstract int CountTokens( + string? text, + ReadOnlySpan textSpan, + bool addBeginningOfSentence, + bool addEndOfSentence, + bool considerNormalization, + out string? normalizedText, + out int charsConsumed, + int maxTokenCount = int.MaxValue); + + public abstract int GetIndexByTokenCountFromEnd( + string? text, + ReadOnlySpan textSpan, + bool addBeginningOfSentence, + bool addEndOfSentence, + int maxTokenCount, + bool considerNormalization, + out string? normalizedText, + out int tokenCount); + + public abstract bool TryMapIdToToken(int id, out string? token); + + private const int ApproximatedMaxEncodedBytesCount = 50; + + public virtual string Decode(IEnumerable ids, bool considerSpecialTokens) + { + if (ids is null) + { + throw new ArgumentNullException(nameof(ids)); + } + + using IEnumerator enumerator = ids.GetEnumerator(); + if (!enumerator.MoveNext()) + { + return string.Empty; + } + + ValueStringBuilder sb = new(stackalloc char[256]); + + int bytesCount = -1; + byte[]? bytesPoolArray = null; + bool prefixRemoved = false; + int suffixIndex = -1; + char prefixSuffixChar = EscapeWhiteSpaces ? SentencePieceNormalizer.DummyPrefix : ' '; + + int current = enumerator.Current; + if (current <= MaxByteId && ByteFallback) + { + // First token is a byte token. + + while (current < ByteCodeToIdOffset) + { + // It is possible listing some special tokens before the byte tokens in the tokenizer's data. + TryDecodeAsSpecialToken(this, current, considerSpecialTokens, ref sb); + + // Skip control tokens. + if (!enumerator.MoveNext()) + { + return sb.ToString(); + } + + current = enumerator.Current; + } + + if (current <= MaxByteId && ByteFallback) + { + EncodeByte(current, OneByteUtf8EncodingMaxId, ByteCodeToIdOffset, ref bytesCount, ref bytesPoolArray, ref sb); + } + else if (!TryDecodeAsSpecialToken(this, current, considerSpecialTokens, ref sb) && TryMapIdToToken(current, out string? token)) + { + AppendTokenWithCheckingPrefix(AddDummyPrefix, TreatWhitespaceAsSuffix, token!, prefixSuffixChar, ref sb, ref prefixRemoved, ref suffixIndex); + } + } + else if (!TryDecodeAsSpecialToken(this, current, considerSpecialTokens, ref sb) && TryMapIdToToken(current, out string? token)) + { + AppendTokenWithCheckingPrefix(AddDummyPrefix, TreatWhitespaceAsSuffix, token!, prefixSuffixChar, ref sb, ref prefixRemoved, ref suffixIndex); + } + + char[]? charPoolArray = null; + + while (enumerator.MoveNext()) + { + current = enumerator.Current; + if (current < ByteCodeToIdOffset) + { + if (bytesCount >= 1) + { + FlushBytes(ref bytesCount, ref bytesPoolArray, ref charPoolArray, ref sb); + } + + // It is possible listing some special tokens before the byte tokens in the tokenizer's data. + TryDecodeAsSpecialToken(this, current, considerSpecialTokens, ref sb); + + continue; + } + + if (current <= MaxByteId && ByteFallback) + { + if (bytesCount >= 1) + { + Debug.Assert(bytesPoolArray is not null); + + if (bytesCount >= bytesPoolArray!.Length) + { + Helpers.ArrayPoolGrow(ref bytesPoolArray, bytesCount * 2); + } + + bytesPoolArray![bytesCount++] = (byte)(current - ByteCodeToIdOffset); + } + else + { + EncodeByte(current, OneByteUtf8EncodingMaxId, ByteCodeToIdOffset, ref bytesCount, ref bytesPoolArray, ref sb); + } + } + else + { + if (bytesCount >= 1) + { + FlushBytes(ref bytesCount, ref bytesPoolArray, ref charPoolArray, ref sb); + } + + if (!TryDecodeAsSpecialToken(this, current, considerSpecialTokens, ref sb) && TryMapIdToToken(current, out string? token)) + { + AppendTokenWithCheckingPrefix(AddDummyPrefix, TreatWhitespaceAsSuffix, token!, prefixSuffixChar, ref sb, ref prefixRemoved, ref suffixIndex); + } + } + } + + if (bytesCount >= 1) + { + FlushBytes(ref bytesCount, ref bytesPoolArray, ref charPoolArray, ref sb); + } + + if (AddDummyPrefix && TreatWhitespaceAsSuffix && suffixIndex >= 0 && sb.Length > 0) + { + Debug.Assert(sb[suffixIndex] == SentencePieceNormalizer.DummyPrefix); + Debug.Assert(sb.Length > suffixIndex); + + sb.Remove(suffixIndex, 1); + } + + if (bytesPoolArray is not null) + { + ArrayPool.Shared.Return(bytesPoolArray); + } + + if (charPoolArray is not null) + { + ArrayPool.Shared.Return(charPoolArray); + } + + return EscapeWhiteSpaces ? sb.ToString(SentencePieceNormalizer.DummyPrefix, ' ') : sb.ToString(); + + static void FlushBytes(ref int bytesCount, ref byte[]? bytesPoolArray, ref char[]? charPoolArray, ref ValueStringBuilder sb) + { + Debug.Assert(bytesCount >= 1); + Debug.Assert(bytesPoolArray is not null); + + int len = Encoding.UTF8.GetMaxCharCount(bytesCount); + + charPoolArray ??= ArrayPool.Shared.Rent(Math.Max(len, ApproximatedMaxEncodedBytesCount >> 1)); + + if (len > charPoolArray.Length) + { + Helpers.ArrayPoolGrow(ref charPoolArray, len); + } + + int charCount = Helpers.GetChars(bytesPoolArray.AsSpan(0, bytesCount), charPoolArray); + + sb.Append(charPoolArray.AsSpan(0, charCount)); + bytesCount = -1; + } + + static void EncodeByte(int id, int oneByteUtf8EncodingMaxId, int byteCodeToIdOffset, ref int bytesCount, ref byte[]? bytesPoolArray, ref ValueStringBuilder sb) + { + if (id <= oneByteUtf8EncodingMaxId) + { + sb.Append((char)(id - byteCodeToIdOffset)); + } + else + { + bytesCount = 1; + bytesPoolArray ??= ArrayPool.Shared.Rent(ApproximatedMaxEncodedBytesCount); + bytesPoolArray[0] = (byte)(id - byteCodeToIdOffset); + } + } + + static void AppendTokenWithCheckingPrefix(bool addDummyPrefix, bool treatWhitespaceAsSuffix, string token, char prefixSuffixChar, ref ValueStringBuilder sb, ref bool prefixRemoved, ref int suffixIndex) + { + if (token.Length == 0) + { + return; + } + + if (!addDummyPrefix) + { + sb.Append(token); + return; + } + + if (treatWhitespaceAsSuffix) + { + sb.Append(token); + if (token[token.Length - 1] == prefixSuffixChar) + { + suffixIndex = sb.Length - 1; + } + } + else + { + sb.Append(!prefixRemoved && token[0] == prefixSuffixChar ? token.AsSpan(1) : token.AsSpan()); + } + + prefixRemoved = true; + } + + static bool TryDecodeAsSpecialToken(SentencePieceBaseModel model, int id, bool considerSpecialTokens, ref ValueStringBuilder sb) + { + string? token = null; + + if (id == model.BeginningOfSentenceId) + { + token = model.BeginningOfSentenceToken; + } + else if (id == model.EndOfSentenceId) + { + token = model.EndOfSentenceToken; + } + else if (id == model.UnknownId) + { + token = model.UnknownToken; + } + else if (model.SpecialTokensReverse?.TryGetValue(id, out string? specialToken) is true) + { + token = specialToken; + } + + if (token is not null && considerSpecialTokens) + { + sb.Append(token); + } + + return token is not null; + } + } + + public virtual OperationStatus Decode(IEnumerable ids, Span destination, bool considerSpecialTokens, out int idsConsumed, out int charsWritten) + { + idsConsumed = 0; + charsWritten = 0; + + if (ids is null) + { + throw new ArgumentNullException(nameof(ids)); + } + + using IEnumerator enumerator = ids.GetEnumerator(); + if (!enumerator.MoveNext()) + { + return OperationStatus.Done; + } + + Span buffer = destination; + + int bytesCount = -1; + byte[]? bytesPoolArray = null; + bool prefixRemoved = false; + int suffixIndex = -1; + char prefixSuffixChar = EscapeWhiteSpaces ? SentencePieceNormalizer.DummyPrefix : ' '; + + int current = enumerator.Current; + if (current <= MaxByteId && ByteFallback) + { + // First token is a byte token. + while (current < ByteCodeToIdOffset) + { + OperationStatus status = TryDecodeAsSpecialToken(this, current, considerSpecialTokens, buffer, ref charsWritten, out bool isSpecialToken); + if (status != OperationStatus.Done) + { + return status; + } + buffer = destination.Slice(charsWritten); + + // Skip control tokens. + idsConsumed++; + if (!enumerator.MoveNext()) + { + return OperationStatus.Done; + } + + current = enumerator.Current; + } + + if (current <= MaxByteId && ByteFallback) + { + if (!EncodeByte(enumerator.Current, OneByteUtf8EncodingMaxId, ByteCodeToIdOffset, ref bytesCount, buffer, ref charsWritten, ref idsConsumed, ref bytesPoolArray)) + { + return OperationStatus.DestinationTooSmall; + } + } + else + { + OperationStatus status = TryDecodeAsSpecialToken(this, current, considerSpecialTokens, buffer, ref charsWritten, out bool isSpecialToken); + if (status != OperationStatus.Done) + { + return status; + } + + if (!isSpecialToken && TryMapIdToToken(current, out string? token)) + { + if (!AppendTokenWithCheckingPrefix(AddDummyPrefix, TreatWhitespaceAsSuffix, token!, prefixSuffixChar, destination, ref prefixRemoved, ref suffixIndex, ref idsConsumed, ref charsWritten)) + { + return OperationStatus.DestinationTooSmall; + } + } + else + { + idsConsumed++; + } + } + } + else + { + OperationStatus status = TryDecodeAsSpecialToken(this, current, considerSpecialTokens, buffer, ref charsWritten, out bool isSpecialToken); + if (status != OperationStatus.Done) + { + return status; + } + + if (!isSpecialToken && TryMapIdToToken(current, out string? token)) + { + if (!AppendTokenWithCheckingPrefix(AddDummyPrefix, TreatWhitespaceAsSuffix, token!, prefixSuffixChar, destination, ref prefixRemoved, ref suffixIndex, ref idsConsumed, ref charsWritten)) + { + return OperationStatus.DestinationTooSmall; + } + } + else + { + idsConsumed++; + } + } + + char[]? charPoolArray = null; + + while (enumerator.MoveNext()) + { + current = enumerator.Current; + buffer = destination.Slice(charsWritten); + + if (current < ByteCodeToIdOffset) + { + if (bytesCount >= 1) + { + if (!FlushBytes(ref bytesCount, ref bytesPoolArray, ref charPoolArray, buffer, ref charsWritten, ref idsConsumed)) + { + return OperationStatus.DestinationTooSmall; + } + } + + OperationStatus status = TryDecodeAsSpecialToken(this, current, considerSpecialTokens, buffer, ref charsWritten, out bool isSpecialToken); + if (status != OperationStatus.Done) + { + return status; + } + + idsConsumed++; + continue; + } + + if (current <= MaxByteId && ByteFallback) + { + if (bytesCount >= 1) + { + Debug.Assert(bytesPoolArray is not null); + + if (bytesCount >= bytesPoolArray!.Length) + { + Helpers.ArrayPoolGrow(ref bytesPoolArray, bytesCount * 2); + } + + bytesPoolArray![bytesCount++] = (byte)(current - ByteCodeToIdOffset); + } + else + { + if (!EncodeByte(current, OneByteUtf8EncodingMaxId, ByteCodeToIdOffset, ref bytesCount, buffer, ref charsWritten, ref idsConsumed, ref bytesPoolArray)) + { + return OperationStatus.DestinationTooSmall; + } + } + } + else + { + if (bytesCount >= 1) + { + if (!FlushBytes(ref bytesCount, ref bytesPoolArray, ref charPoolArray, buffer, ref charsWritten, ref idsConsumed)) + { + return OperationStatus.DestinationTooSmall; + } + } + + OperationStatus status = TryDecodeAsSpecialToken(this, current, considerSpecialTokens, buffer, ref charsWritten, out bool isSpecialToken); + if (status != OperationStatus.Done) + { + return status; + } + + if (!isSpecialToken && TryMapIdToToken(current, out string? token)) + { + if (!AppendTokenWithCheckingPrefix(AddDummyPrefix, TreatWhitespaceAsSuffix, token!, prefixSuffixChar, destination, ref prefixRemoved, ref suffixIndex, ref idsConsumed, ref charsWritten)) + { + return OperationStatus.DestinationTooSmall; + } + } + else + { + idsConsumed++; + } + } + } + + buffer = destination.Slice(charsWritten); + + if (bytesCount >= 1) + { + if (!FlushBytes(ref bytesCount, ref bytesPoolArray, ref charPoolArray, buffer, ref charsWritten, ref idsConsumed)) + { + return OperationStatus.DestinationTooSmall; + } + } + + if (suffixIndex >= 0) + { + Debug.Assert(destination[suffixIndex] == ' '); + + if (suffixIndex < charsWritten - 1) + { + destination.Slice(suffixIndex + 1, charsWritten - suffixIndex - 1).CopyTo(destination.Slice(suffixIndex)); + } + + charsWritten--; + } + + if (bytesPoolArray is not null) + { + ArrayPool.Shared.Return(bytesPoolArray); + } + + if (charPoolArray is not null) + { + ArrayPool.Shared.Return(charPoolArray); + } + + return OperationStatus.Done; + + static OperationStatus TryDecodeAsSpecialToken(SentencePieceBaseModel model, int id, bool considerSpecialTokens, Span buffer, ref int charsWritten, out bool isSpecialToken) + { + string? specialToken = null; + + if (id == model.BeginningOfSentenceId) + { + specialToken = model.BeginningOfSentenceToken; + } + else if (id == model.EndOfSentenceId) + { + specialToken = model.EndOfSentenceToken; + } + else if (id == model.UnknownId) + { + specialToken = model.UnknownToken; + } + else + { + model.SpecialTokensReverse?.TryGetValue(id, out specialToken); + } + + isSpecialToken = specialToken is not null; + + if (considerSpecialTokens && isSpecialToken) + { + if (buffer.Length < specialToken!.Length) + { + return OperationStatus.DestinationTooSmall; + } + + specialToken.AsSpan().CopyTo(buffer); + charsWritten += specialToken.Length; + } + + return OperationStatus.Done; + } + + static bool FlushBytes(ref int bytesCount, ref byte[]? bytesPoolArray, ref char[]? charPoolArray, Span buffer, ref int charsWritten, ref int idsConsumed) + { + Debug.Assert(bytesCount >= 1); + Debug.Assert(bytesPoolArray is not null); + + int len = Encoding.UTF8.GetMaxCharCount(bytesCount); + + charPoolArray ??= ArrayPool.Shared.Rent(Math.Max(len, ApproximatedMaxEncodedBytesCount >> 1)); + + if (len > charPoolArray.Length) + { + Helpers.ArrayPoolGrow(ref charPoolArray, len); + } + + int charCount = Helpers.GetChars(bytesPoolArray.AsSpan(0, bytesCount), charPoolArray); + + if (charCount > buffer.Length) + { + return false; + } + + charPoolArray.AsSpan(0, charCount).CopyTo(buffer); + charsWritten += charCount; + idsConsumed += bytesCount; + bytesCount = -1; + + return true; + } + + static bool EncodeByte(int id, int oneByteUtf8EncodingMaxId, int byteCodeToIdOffset, ref int bytesCount, Span buffer, ref int charsWritten, ref int idsConsumed, ref byte[]? bytesPoolArray) + { + if (id <= oneByteUtf8EncodingMaxId) + { + if (buffer.Length < 1) + { + return false; + } + + buffer[0] = (char)(id - byteCodeToIdOffset); + charsWritten++; + idsConsumed++; + } + else + { + bytesCount = 1; + bytesPoolArray ??= ArrayPool.Shared.Rent(ApproximatedMaxEncodedBytesCount); + bytesPoolArray[0] = (byte)(id - byteCodeToIdOffset); + } + + return true; + } + + static bool AppendTokenWithCheckingPrefix(bool addDummyPrefix, bool treatWhitespaceAsSuffix, string token, char prefixSuffixChar, Span destination, ref bool prefixRemoved, ref int suffixIndex, ref int idsConsumed, ref int charsConsumed) + { + if (token.Length == 0) + { + return true; + } + + Span buffer = destination.Slice(charsConsumed); + + ReadOnlySpan tokenSpan = token.AsSpan(); + + if (!addDummyPrefix) + { + if (tokenSpan.Length > buffer.Length) + { + return false; + } + + if (prefixSuffixChar != ' ') + { + Helpers.Replace(tokenSpan, buffer, prefixSuffixChar, ' '); + } + else + { + tokenSpan.CopyTo(buffer); + } + + buffer = buffer.Slice(tokenSpan.Length); + charsConsumed += tokenSpan.Length; + idsConsumed++; + return true; + } + + if (treatWhitespaceAsSuffix) + { + if (tokenSpan[tokenSpan.Length - 1] == prefixSuffixChar) + { + suffixIndex = charsConsumed + tokenSpan.Length - 1; + } + + if (tokenSpan.Length > buffer.Length) + { + return false; + } + + if (prefixSuffixChar != ' ') + { + Helpers.Replace(tokenSpan, buffer, prefixSuffixChar, ' '); + } + else + { + tokenSpan.CopyTo(buffer); + } + + charsConsumed += tokenSpan.Length; + + idsConsumed++; + } + else + { + int delta = !prefixRemoved && token[0] == prefixSuffixChar ? 1 : 0; + if (buffer.Length < token.Length - delta) + { + return false; + } + + tokenSpan = tokenSpan.Slice(delta); + if (prefixSuffixChar != ' ') + { + Helpers.Replace(tokenSpan, buffer, prefixSuffixChar, ' '); + } + else + { + tokenSpan.CopyTo(buffer); + } + + charsConsumed += tokenSpan.Length; + idsConsumed++; + + if (!prefixRemoved && delta == 1) + { + prefixRemoved = true; + } + } + + return true; + } + } + } +} \ No newline at end of file diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceBpeModel.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceBpeModel.cs new file mode 100644 index 0000000000..85c85c1677 --- /dev/null +++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceBpeModel.cs @@ -0,0 +1,1260 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Sentencepiece; +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Text; +using System.Text.RegularExpressions; +using System.Threading; + +namespace Microsoft.ML.Tokenizers +{ + internal sealed class SentencePieceBpeModel : SentencePieceBaseModel + { + private const int UninitializedId = -2; // indicate if the symbol contains uninitialized id. + private readonly Dictionary _vocab = new(); + private readonly Dictionary _vocabReverse = new(); + private IReadOnlyDictionary? _publicVocab; + + internal SentencePieceBpeModel(ModelProto modelProto, bool addBos, bool addEos, IReadOnlyDictionary? specialTokens = null) : base(modelProto, addBos, addEos, specialTokens) + { + for (int i = 0; i < modelProto.Pieces.Count; i++) + { + var piece = modelProto.Pieces[i]; + _vocab.Add(new StringSpanOrdinalKey(piece.Piece), (i, piece.Score, (byte)piece.Type)); + _vocabReverse.Add(i, piece.Piece); + + if (piece.Type == ModelProto.Types.SentencePiece.Types.Type.Byte) + { + MaxByteId = i; + } + } + + ByteCodeToIdOffset = _vocab.TryGetValue("<0x00>", out (int Id, float Score, byte Type) value) ? value.Id : MaxByteId; + OneByteUtf8EncodingMaxId = ByteCodeToIdOffset + 0x7F; // 0x7F is the maximum value of the one byte UTF-8 character. + } + + public override IReadOnlyDictionary Vocabulary + { + get + { + IReadOnlyDictionary? publicVocab = Volatile.Read(ref _publicVocab); + if (publicVocab is null) + { + var vocab = new Dictionary(); + foreach (var item in _vocab) + { + vocab.Add(item.Key.ToString(), item.Value.Id); + } + + Interlocked.CompareExchange(ref _publicVocab, new ReadOnlyDictionary(vocab), null); + publicVocab = _publicVocab; + } + + return publicVocab; + } + } + + public override bool TryMapIdToToken(int id, out string? token) => _vocabReverse.TryGetValue(id, out token); + + public override IReadOnlyList EncodeToTokens(string? text, ReadOnlySpan textSpan, out string? normalizedText, bool addBeginningOfSentence, bool addEndOfSentence, bool considerNormalization) + { + if (string.IsNullOrEmpty(text) && textSpan.IsEmpty) + { + normalizedText = null; + return []; + } + + ReadOnlySpan textToEncode = text is null ? textSpan : text.AsSpan(); + if (considerNormalization && Normalizer is not null) + { + normalizedText = text is not null ? Normalizer.Normalize(text) : Normalizer.Normalize(textSpan); + textToEncode = normalizedText.AsSpan(); + } + else + { + normalizedText = null; + } + + if (textToEncode.Length == 0) + { + return []; + } + + List tokens = new(); + + if (SpecialTokensRegex is not null) + { + EncodeWithSpecialTokens(textToEncode, addBeginningOfSentence, addEndOfSentence, tokens); + } + else + { + EncodeInternal(textToEncode, addBeginningOfSentence, addEndOfSentence, tokens); + } + + return tokens; + } + + private void EncodeWithSpecialTokens(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, List tokens) + { + Debug.Assert(SpecialTokensRegex is not null); + + if (addBeginOfSentence) + { + tokens.Add(new EncodedToken(BeginningOfSentenceId, BeginningOfSentenceToken, new Range(0, 0))); + } + + int currentOffset = 0; + + foreach ((int Offset, int Length) in PreTokenizer.SplitText(text, SpecialTokensRegex!)) + { + if (Offset > currentOffset) + { + EncodeInternal(text.Slice(currentOffset, Offset - currentOffset), addBeginOfSentence: false, addEndOfSentence: false, tokens); + } + + if (InternalSpecialTokens!.TryGetValue(text.Slice(Offset, Length), out int id)) + { + tokens.Add(new EncodedToken(id, SpecialTokensReverse![id], new Range(Offset, Offset + Length))); + } + + currentOffset = Offset + Length; + } + + if (currentOffset < text.Length) + { + EncodeInternal(text.Slice(currentOffset), addBeginOfSentence: false, addEndOfSentence: false, tokens); + } + + if (addEndOfSentence) + { + tokens.Add(new EncodedToken(EndOfSentenceId, EndOfSentenceToken, new Range(text.Length, text.Length))); + } + } + + /// + /// Encode a text to a list of tokens. + /// + /// The text to encode. + /// Indicate emitting the beginning of sentence token during the encoding. + /// Indicate emitting the end of sentence token during the encoding. + /// A collection to store the encoded tokens. + /// The input text has to be normalized before calling this method. + private void EncodeInternal(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, List tokens) + { + BpeSymbol[] symbols = ArrayPool.Shared.Rent(text.Length); + + Dictionary<(int Index, int Len), (int LeftIndex, int LeftLen, int RightIndex, int RightLen)>? revMerge = Encode(text, symbols); + + if (addBeginOfSentence) + { + tokens.Add(new EncodedToken(BeginningOfSentenceId, BeginningOfSentenceToken, new Range(0, 0))); + } + + for (int index = 0; (uint)index < (uint)symbols.Length; index = symbols[index].next) + { + int id = symbols[index].id; + byte type = symbols[index].type; + + if (id == UninitializedId) + { + if (_vocab.TryGetValue(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), out (int Id, float Score, byte Type) tokenInfo)) + { + id = tokenInfo.Id; + type = tokenInfo.Type; + } + else + { + id = UnknownId; + type = 0; + } + } + + if (type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused) + { + if (id == UnknownId && ByteFallback) + { + EncodeAsBytes(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index); + } + else + { + tokens.Add(new EncodedToken( + id, + GetTokenString(id, symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length, text), + new Range(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Index + symbols[index].pieceSpan.Length))); + } + continue; + } + + Segment(symbols[index].pieceSpan, text); + } + + ArrayPool.Shared.Return(symbols); + + if (addEndOfSentence) + { + tokens.Add(new EncodedToken(EndOfSentenceId, EndOfSentenceToken, new Range(text.Length, text.Length))); + } + + return; + + // Encode the Unknown token to bytes. + void EncodeAsBytes(ReadOnlySpan text, int index) + { + for (int i = 0; i < text.Length; i++) + { + char c = text[i]; + if (c <= 0x7F) + { + int id = (int)c + ByteCodeToIdOffset; // byte code is mapped to the to the Ids starting from 4. + + if (_vocabReverse.TryGetValue(id, out string? token)) + { + tokens.Add(new EncodedToken(id, token, new Range(index + i, index + i + 1))); + } + } + else + { + Span utf8Bytes = stackalloc byte[256]; + byte[]? arrayPoolArray = null; + + int len = Encoding.UTF8.GetMaxByteCount(text.Length - i); + if (len > utf8Bytes.Length) + { + arrayPoolArray = ArrayPool.Shared.Rent(len); + utf8Bytes = arrayPoolArray; + } + + // Need to convert the text into UTF-8 bytes and then encode the bytes. + int bytesWritten = Helpers.GetUtf8Bytes(text.Slice(i), utf8Bytes); + int length = text.Length - i; + for (int j = 0; j < bytesWritten; j++) + { + int id = (int)utf8Bytes[j] + ByteCodeToIdOffset; // byte code is mapped to the to the Ids starting from 4. + + if (_vocabReverse.TryGetValue(id, out string? token)) + { + tokens.Add(new EncodedToken(id, token, new Range(index + i, index + i + length))); + } + + length = 0; + } + + if (arrayPoolArray is not null) + { + ArrayPool.Shared.Return(arrayPoolArray); + } + + break; + } + } + } + + void Segment((int Index, int Length) pieceSpan, ReadOnlySpan text) + { + if (!_vocab.TryGetValue(text.Slice(pieceSpan.Index, pieceSpan.Length), out (int Id, float Score, byte Type) id)) + { + EncodeAsBytes(text.Slice(pieceSpan.Index, pieceSpan.Length), pieceSpan.Index); + return; + } + + if (id.Type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused || + revMerge is null || + !revMerge.TryGetValue((pieceSpan.Index, pieceSpan.Length), out (int LeftIndex, int LeftLen, int RightIndex, int RightLen) merge)) + { + tokens.Add(new EncodedToken(id.Id, text.Slice(pieceSpan.Index, pieceSpan.Length).ToString(), new Range(pieceSpan.Index, pieceSpan.Index + pieceSpan.Length))); + return; + } + + Segment((merge.LeftIndex, merge.LeftLen), text); + Segment((merge.RightIndex, merge.RightLen), text); + } + } + + public override IReadOnlyList EncodeToIds(string? text, ReadOnlySpan textSpan, bool addBeginningOfSentence, bool addEndOfSentence, bool considerNormalization, + out string? normalizedText, out int charsConsumed, int maxTokenCount = int.MaxValue) + { + if (maxTokenCount <= 0) + { + throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero."); + } + + if (string.IsNullOrEmpty(text) && textSpan.IsEmpty) + { + normalizedText = null; + charsConsumed = 0; + return []; + } + + return EncodeToIds(text is null ? textSpan : text.AsSpan(), addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedText, out charsConsumed, maxTokenCount); + } + + /// + /// Encodes input text to token Ids up to maximum number of tokens. + /// + /// The text to encode. + /// Indicate emitting the beginning of sentence token during the encoding. + /// Indicate emitting the end of sentence token during the encoding. + /// Indicate whether to consider normalization before tokenization. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// The length of the text that encompasses the maximum encoded tokens. + /// The maximum number of tokens to encode. + /// The list of encoded Ids. + private IReadOnlyList EncodeToIds(ReadOnlySpan text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerNormalization, + out string? normalizedText, out int charsConsumed, int maxTokenCount = int.MaxValue) + { + if (maxTokenCount <= 0) + { + throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero."); + } + + if (text.IsEmpty) + { + normalizedText = null; + charsConsumed = 0; + return []; + } + + ReadOnlySpan textToEncode; + + if (considerNormalization && Normalizer is not null) + { + normalizedText = Normalizer.Normalize(text); + textToEncode = normalizedText.AsSpan(); + } + else + { + normalizedText = null; + textToEncode = text; + } + + if (maxTokenCount <= 0) + { + throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than 0."); + } + + List ids = new(); + + if (SpecialTokensRegex is not null) + { + EncodeToIdsWithAddedToken(textToEncode, addBeginningOfSentence, addEndOfSentence, ids, out charsConsumed, maxTokenCount); + } + else + { + EncodeToIds(textToEncode, addBeginningOfSentence, addEndOfSentence, ids, out charsConsumed, maxTokenCount); + } + + return ids; + } + + private int EncodeToIdsWithAddedToken(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, IList accumulatedIds, out int charsConsumed, int maxTokens = int.MaxValue) + { + Debug.Assert(SpecialTokensRegex is not null); + Debug.Assert(maxTokens > 0); + + charsConsumed = 0; + int idsCount = 0; + + if (addBeginOfSentence) + { + accumulatedIds.Add(BeginningOfSentenceId); + idsCount++; + } + + int currentOffset = 0; + + int charsWritten; + + foreach ((int Offset, int Length) in PreTokenizer.SplitText(text, SpecialTokensRegex!)) + { + if (Offset > currentOffset) + { + idsCount += EncodeToIds(text.Slice(currentOffset, Offset - currentOffset), addBeginOfSentence: false, addEndOfSentence: false, accumulatedIds, out charsWritten, maxTokens - idsCount); + charsConsumed += charsWritten; + } + + if (idsCount < maxTokens && InternalSpecialTokens!.TryGetValue(text.Slice(Offset, Length), out int id)) + { + accumulatedIds.Add(id); + idsCount++; + charsConsumed += Length; + } + + currentOffset = Offset + Length; + } + + if (currentOffset < text.Length && idsCount < maxTokens) + { + idsCount += EncodeToIds(text.Slice(currentOffset), addBeginOfSentence: false, addEndOfSentence: false, accumulatedIds, out charsWritten, maxTokens - idsCount); + charsConsumed += charsWritten; + } + + if (addEndOfSentence && idsCount < maxTokens) + { + accumulatedIds.Add(EndOfSentenceId); + idsCount++; + } + + return idsCount; + } + + /// + /// Encode a text to a list of Ids and add them to the accumulatedIds list. + /// + /// The text to encode. + /// Indicate emitting the beginning of sentence token during the encoding. + /// Indicate emitting the end of sentence token during the encoding. + /// The list of accumulated encoded Ids. + /// The length of the text that encompasses the maximum encoded tokens. + /// The maximum number of tokens to encode. + /// The number of tokens that the input text will be encoded to. + /// The input text has to be normalized before calling this method. + private int EncodeToIds(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, IList accumulatedIds, out int charsConsumed, int maxTokens = int.MaxValue) + { + charsConsumed = 0; + if (text.IsEmpty) + { + return 0; + } + + int idsCount = 0; + + if (addBeginOfSentence) + { + accumulatedIds.Add(BeginningOfSentenceId); + idsCount++; + } + + BpeSymbol[] symbols = ArrayPool.Shared.Rent(text.Length); + + Dictionary<(int Index, int Len), (int LeftIndex, int LeftLen, int RightIndex, int RightLen)>? revMerge = Encode(text, symbols); + + for (int index = 0; index != -1 && index < symbols.Length; index = symbols[index].next) + { + int id = symbols[index].id; + byte type = symbols[index].type; + + if (id == UninitializedId) + { + if (_vocab.TryGetValue(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), out (int Id, float Score, byte Type) tokenInfo)) + { + id = tokenInfo.Id; + type = tokenInfo.Type; + } + else + { + id = UnknownId; + type = 0; + } + } + + if (type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused) + { + if (id == UnknownId && ByteFallback) + { + if (!EncodeAsBytes(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index, ref charsConsumed)) + { + ArrayPool.Shared.Return(symbols); + return idsCount; + } + } + else + { + if (idsCount < maxTokens) + { + accumulatedIds.Add(id); + charsConsumed += symbols[index].pieceSpan.Length; + idsCount++; + } + else + { + ArrayPool.Shared.Return(symbols); + return idsCount; + } + } + continue; + } + + if (!Segment(symbols[index].pieceSpan, text, ref charsConsumed)) + { + break; + } + } + + ArrayPool.Shared.Return(symbols); + + if (addEndOfSentence) + { + if (idsCount < maxTokens) + { + accumulatedIds.Add(EndOfSentenceId); + idsCount++; + } + } + + return idsCount; + + // Encode the Unknown token to bytes. + bool EncodeAsBytes(ReadOnlySpan text, int index, ref int charsConsumed) + { + for (int i = 0; i < text.Length; i++) + { + char c = text[i]; + if (c <= 0x7F) + { + if (idsCount < maxTokens) + { + charsConsumed++; + accumulatedIds.Add((int)c + ByteCodeToIdOffset); // byte code is mapped to the to the Ids starting from 4. + idsCount++; + } + else + { + return false; + } + } + else + { + Span utf8Bytes = stackalloc byte[100]; + byte[]? arrayPoolArray = null; + + int len = Encoding.UTF8.GetMaxByteCount(text.Length - i); + if (len > utf8Bytes.Length) + { + arrayPoolArray = ArrayPool.Shared.Rent(len); + utf8Bytes = arrayPoolArray; + } + + // Need to convert the text into UTF-8 bytes and then encode the bytes. + int bytesWritten = Helpers.GetUtf8Bytes(text.Slice(i), utf8Bytes); + + bool ret; + if (idsCount + bytesWritten <= maxTokens) + { + for (int j = 0; j < bytesWritten; j++) + { + accumulatedIds.Add((int)utf8Bytes[j] + ByteCodeToIdOffset); // byte code is mapped to the to the Ids starting from 4. + } + + charsConsumed += text.Length - i; + ret = true; + } + else + { + ret = false; + } + + if (arrayPoolArray is not null) + { + ArrayPool.Shared.Return(arrayPoolArray); + } + + return ret; + } + } + + return true; + } + + bool Segment((int Index, int Length) pieceSpan, ReadOnlySpan text, ref int charsConsumed) + { + if (!_vocab.TryGetValue(text.Slice(pieceSpan.Index, pieceSpan.Length), out (int Id, float Score, byte Type) id)) + { + return EncodeAsBytes(text.Slice(pieceSpan.Index, pieceSpan.Length), pieceSpan.Index, ref charsConsumed); + } + + if (id.Type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused || + revMerge is null || + !revMerge.TryGetValue((pieceSpan.Index, pieceSpan.Length), out (int LeftIndex, int LeftLen, int RightIndex, int RightLen) merge)) + { + if (idsCount < maxTokens) + { + accumulatedIds.Add(id.Id); + charsConsumed += pieceSpan.Length; + idsCount++; + return true; + } + else + { + return false; + } + } + + return Segment((merge.LeftIndex, merge.LeftLen), text, ref charsConsumed) && Segment((merge.RightIndex, merge.RightLen), text, ref charsConsumed); + } + } + + public override int CountTokens( + string? text, + ReadOnlySpan textSpan, + bool addBeginningOfSentence, + bool addEndOfSentence, + bool considerNormalization, + out string? normalizedText, + out int charsConsumed, + int maxTokenCount = int.MaxValue) + { + if (maxTokenCount <= 0) + { + throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero."); + } + + textSpan = text is null ? textSpan : text.AsSpan(); + + if (textSpan.IsEmpty) + { + normalizedText = null; + charsConsumed = 0; + return 0; + } + + ReadOnlySpan textToEncode; + if (considerNormalization && Normalizer is not null) + { + normalizedText = Normalizer.Normalize(textSpan); + textToEncode = normalizedText.AsSpan(); + } + else + { + normalizedText = null; + textToEncode = textSpan; + } + + return SpecialTokensRegex is not null ? + CountTokensWithSpecialTokens(textToEncode, addBeginningOfSentence, addEndOfSentence, out charsConsumed, maxTokenCount) : + CountTokens(textToEncode, addBeginningOfSentence, addEndOfSentence, out charsConsumed, maxTokenCount); + } + + private int CountTokensWithSpecialTokens(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, out int charsConsumed, int maxTokens = int.MaxValue) + { + Debug.Assert(SpecialTokensRegex is not null); + Debug.Assert(maxTokens > 0); + + charsConsumed = 0; + int idsCount = 0; + + if (addBeginOfSentence) + { + idsCount++; + } + + int currentOffset = 0; + + int charsWritten; + + foreach ((int Offset, int Length) in PreTokenizer.SplitText(text, SpecialTokensRegex!)) + { + if (Offset > currentOffset) + { + idsCount += CountTokens(text.Slice(currentOffset, Offset - currentOffset), addBeginOfSentence: false, addEndOfSentence: false, out charsWritten, maxTokens - idsCount); + charsConsumed += charsWritten; + } + + if (idsCount < maxTokens && InternalSpecialTokens!.TryGetValue(text.Slice(Offset, Length), out int id)) + { + idsCount++; + charsConsumed += Length; + } + + currentOffset = Offset + Length; + } + + if (currentOffset < text.Length && idsCount < maxTokens) + { + idsCount += CountTokens(text.Slice(currentOffset), addBeginOfSentence: false, addEndOfSentence: false, out charsWritten, maxTokens - idsCount); + charsConsumed += charsWritten; + } + + if (addEndOfSentence && idsCount < maxTokens) + { + idsCount++; + } + + return idsCount; + } + + /// + /// Get the number of tokens that the input text will be encoded to. + /// + /// The text to encode. + /// Indicate emitting the beginning of sentence token during the encoding. + /// Indicate emitting the end of sentence token during the encoding. + /// The length of the text that encompasses the maximum encoded tokens. + /// The maximum number of tokens to encode. + /// The number of tokens that the input text will be encoded to. + /// The input text has to be normalized before calling this method. + private int CountTokens(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, out int charsConsumed, int maxTokens = int.MaxValue) + { + charsConsumed = 0; + if (text.IsEmpty) + { + return 0; + } + + int tokenCount = addBeginOfSentence ? 1 : 0; + + BpeSymbol[] symbols = ArrayPool.Shared.Rent(text.Length); + + Dictionary<(int Index, int Len), (int LeftIndex, int LeftLen, int RightIndex, int RightLen)>? revMerge = Encode(text, symbols); + + for (int index = 0; index != -1 && index < symbols.Length; index = symbols[index].next) + { + int id = symbols[index].id; + byte type = symbols[index].type; + + if (id == UninitializedId) + { + if (_vocab.TryGetValue(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), out (int Id, float Score, byte Type) tokenInfo)) + { + id = tokenInfo.Id; + type = tokenInfo.Type; + } + else + { + id = UnknownId; + type = 0; + } + } + + if (type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused) + { + if (id == UnknownId && ByteFallback) + { + if (!EncodeAsBytes(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index, ref charsConsumed)) + { + break; + } + } + else + { + if (tokenCount < maxTokens) + { + tokenCount++; + charsConsumed += symbols[index].pieceSpan.Length; + } + else + { + break; + } + } + continue; + } + + if (!Segment(symbols[index].pieceSpan, text, ref charsConsumed)) + { + break; + } + } + + ArrayPool.Shared.Return(symbols); + + if (addEndOfSentence) + { + if (tokenCount < maxTokens) + { + tokenCount++; + } + } + + return tokenCount; + + // Encode the Unknown token to bytes. + bool EncodeAsBytes(ReadOnlySpan text, int index, ref int charsConsumed) + { + for (int i = 0; i < text.Length; i++) + { + char c = text[i]; + if (c <= 0x7F) + { + if (tokenCount < maxTokens) + { + tokenCount++; + charsConsumed++; + } + else + { + return false; + } + } + else + { + Span utf8Bytes = stackalloc byte[100]; + byte[]? arrayPoolArray = null; + + int len = Encoding.UTF8.GetMaxByteCount(text.Length - i); + if (len > utf8Bytes.Length) + { + arrayPoolArray = ArrayPool.Shared.Rent(len); + utf8Bytes = arrayPoolArray; + } + + // Need to convert the text into UTF-8 bytes and then encode the bytes. + int encodedCount = Helpers.GetUtf8Bytes(text.Slice(i), utf8Bytes); + bool ret; + + if (tokenCount + encodedCount <= maxTokens) + { + tokenCount += encodedCount; + charsConsumed += text.Length - i; + ret = true; + } + else + { + ret = false; + } + + if (arrayPoolArray is not null) + { + ArrayPool.Shared.Return(arrayPoolArray); + } + + return ret; + } + } + + return true; + } + + bool Segment((int Index, int Length) pieceSpan, ReadOnlySpan text, ref int charsConsumed) + { + if (!_vocab.TryGetValue(text.Slice(pieceSpan.Index, pieceSpan.Length), out (int Id, float Score, byte Type) id)) + { + return EncodeAsBytes(text.Slice(pieceSpan.Index, pieceSpan.Length), pieceSpan.Index, ref charsConsumed); + } + + if (id.Type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused || + revMerge is null || + !revMerge.TryGetValue((pieceSpan.Index, pieceSpan.Length), out (int LeftIndex, int LeftLen, int RightIndex, int RightLen) merge)) + { + if (tokenCount < maxTokens) + { + tokenCount++; + charsConsumed += pieceSpan.Length; + return true; + } + else + { + return false; + } + } + + return Segment((merge.LeftIndex, merge.LeftLen), text, ref charsConsumed) && Segment((merge.RightIndex, merge.RightLen), text, ref charsConsumed); + } + } + + public override int GetIndexByTokenCountFromEnd(string? text, ReadOnlySpan textSpan, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, bool considerNormalization, out string? normalizedText, out int tokenCount) + { + if (maxTokenCount <= 0) + { + throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The max token count must be greater than 0."); + } + + textSpan = text is null ? textSpan : text.AsSpan(); + + if (textSpan.IsEmpty) + { + normalizedText = null; + tokenCount = 0; + return 0; + } + + ReadOnlySpan textToEncode; + if (considerNormalization && Normalizer is not null) + { + normalizedText = Normalizer.Normalize(textSpan); + textToEncode = normalizedText.AsSpan(); + } + else + { + normalizedText = null; + textToEncode = textSpan; + } + + int textIndex; + if (SpecialTokensRegex is not null) + { + tokenCount = CountTokensFromEndWithSpecialTokens(textToEncode, addBeginningOfSentence, addEndOfSentence, out textIndex, maxTokenCount); + } + else + { + tokenCount = CountTokensFromEnd(textToEncode, addBeginningOfSentence, addEndOfSentence, out textIndex, maxTokenCount); + } + + return textIndex; + } + + private int CountTokensFromEndWithSpecialTokens(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, out int textIndex, int maxTokens) + { + Debug.Assert(SpecialTokensRegex is not null); + Debug.Assert(maxTokens > 0); + Debug.Assert(text.Length > 0); + + textIndex = text.Length; + int idsCount = 0; + + if (addEndOfSentence) + { + idsCount++; + } + + (int Offset, int Length)[] splits = PreTokenizer.SplitText(text, SpecialTokensRegex!).ToArray(); + + if (splits.Length == 0) + { + return CountTokensFromEnd(text, addBeginOfSentence, addEndOfSentence, out textIndex, maxTokens); + } + + (int Offset, int Length) current = splits[splits.Length - 1]; + + int splitTextIndex; + ReadOnlySpan splitText; + + if (current.Offset + current.Length < text.Length) + { + splitText = text.Slice(current.Offset + current.Length); + idsCount += CountTokensFromEnd(splitText, addBeginOfSentence: false, addEndOfSentence: false, out splitTextIndex, maxTokens - idsCount); + textIndex -= splitText.Length - splitTextIndex; + } + + for (int i = splits.Length - 1; i >= 0 && idsCount < maxTokens; i--) + { + current = splits[i]; + + if (InternalSpecialTokens!.TryGetValue(text.Slice(current.Offset, current.Length), out int id)) + { + idsCount++; + } + textIndex -= current.Length; + + if (current.Offset > 0 && idsCount < maxTokens) + { + int start = i > 0 ? splits[i - 1].Offset + splits[i - 1].Length : 0; + splitText = text.Slice(start, current.Offset - start); + idsCount += CountTokensFromEnd(splitText, addBeginOfSentence: false, addEndOfSentence: false, out splitTextIndex, maxTokens - idsCount); + textIndex -= splitText.Length - splitTextIndex; + } + } + + if (addBeginOfSentence && idsCount < maxTokens) + { + idsCount++; + } + + return idsCount; + } + /// + /// Get the number of tokens that the input text will be encoded to. + /// + /// The text to encode. + /// Indicate emitting the beginning of sentence token during the encoding. + /// Indicate emitting the end of sentence token during the encoding. + /// Starting from this index to the end of the text will encompasses the maximum encoded tokens. + /// The maximum number of tokens to encode. + /// The number of tokens that the input text will be encoded to. + /// The input text has to be normalized before calling this method. + private int CountTokensFromEnd(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, out int textIndex, int maxTokens = int.MaxValue) + { + textIndex = text.Length; + if (text.IsEmpty) + { + return 0; + } + + int tokenCount = addEndOfSentence ? 1 : 0; + + BpeSymbol[] symbols = ArrayPool.Shared.Rent(text.Length); + + Dictionary<(int Index, int Len), (int LeftIndex, int LeftLen, int RightIndex, int RightLen)>? revMerge = Encode(text, symbols); + + // Move to the last symbol. + int lastSymbolIndex = 0; + while (symbols[lastSymbolIndex].next != -1 && lastSymbolIndex < symbols.Length) + { + lastSymbolIndex = symbols[lastSymbolIndex].next; + } + + for (int index = lastSymbolIndex; index >= 0; index = symbols[index].prev) + { + int id = symbols[index].id; + byte type = symbols[index].type; + + if (id == UninitializedId) + { + if (_vocab.TryGetValue(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), out (int Id, float Score, byte Type) tokenInfo)) + { + id = tokenInfo.Id; + type = tokenInfo.Type; + } + else + { + id = UnknownId; + type = 0; + } + } + + if (type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused) + { + if (id == UnknownId && ByteFallback) + { + if (!EncodeAsBytesFromEnd(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index, ref textIndex)) + { + break; + } + } + else + { + if (tokenCount < maxTokens) + { + tokenCount++; + textIndex -= symbols[index].pieceSpan.Length; + } + else + { + break; + } + } + continue; + } + + if (!SegmentFromEnd(symbols[index].pieceSpan, text, ref textIndex)) + { + break; + } + } + + ArrayPool.Shared.Return(symbols); + + if (addBeginOfSentence) + { + if (tokenCount < maxTokens) + { + tokenCount++; + } + } + + return tokenCount; + + // Encode the Unknown token to bytes. + bool EncodeAsBytesFromEnd(ReadOnlySpan text, int index, ref int textIndex) + { + for (int i = text.Length - 1; i >= 0; i--) + { + char c = text[i]; + if (c <= 0x7F) + { + if (tokenCount < maxTokens) + { + tokenCount++; + textIndex--; + } + else + { + return false; + } + } + else + { + Span utf8Bytes = stackalloc byte[100]; + byte[]? arrayPoolArray = null; + + int len = Encoding.UTF8.GetMaxByteCount(text.Length - i); + if (len > utf8Bytes.Length) + { + arrayPoolArray = ArrayPool.Shared.Rent(len); + utf8Bytes = arrayPoolArray; + } + + // Need to convert the text into UTF-8 bytes and then encode the bytes. + int encodedCount = Helpers.GetUtf8Bytes(text.Slice(0, i + 1), utf8Bytes); + bool ret; + + if (tokenCount + encodedCount <= maxTokens) + { + tokenCount += encodedCount; + textIndex -= i + 1; + ret = true; + } + else + { + ret = false; + } + + if (arrayPoolArray is not null) + { + ArrayPool.Shared.Return(arrayPoolArray); + } + + return ret; + } + } + + return true; + } + + bool SegmentFromEnd((int Index, int Length) pieceSpan, ReadOnlySpan text, ref int textIndex) + { + if (!_vocab.TryGetValue(text.Slice(pieceSpan.Index, pieceSpan.Length), out (int Id, float Score, byte Type) id)) + { + return EncodeAsBytesFromEnd(text.Slice(pieceSpan.Index, pieceSpan.Length), pieceSpan.Index, ref textIndex); + } + + if (id.Type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused || + revMerge is null || + !revMerge.TryGetValue((pieceSpan.Index, pieceSpan.Length), out (int LeftIndex, int LeftLen, int RightIndex, int RightLen) merge)) + { + if (tokenCount < maxTokens) + { + tokenCount++; + textIndex -= pieceSpan.Length; + return true; + } + else + { + return false; + } + } + + // Segment the right part first. + return SegmentFromEnd((merge.RightIndex, merge.RightLen), text, ref textIndex) && SegmentFromEnd((merge.LeftIndex, merge.LeftLen), text, ref textIndex); + } + } + + private Dictionary<(int Index, int Len), (int LeftIndex, int LeftLen, int RightIndex, int RightLen)>? Encode(ReadOnlySpan text, BpeSymbol[] symbols) + { + Debug.Assert(text.Length > 0); + Debug.Assert(symbols.Length >= text.Length); + + int symbolIndex = 0; + int spanIndex = 0; + + while (spanIndex < text.Length) + { + int len = (Char.IsHighSurrogate(text[spanIndex]) && spanIndex < text.Length - 1 && Char.IsLowSurrogate(text[spanIndex + 1])) ? 2 : 1; + + BpeSymbol s = new( + prev: symbolIndex == 0 ? -1 : symbolIndex - 1, + next: spanIndex + len >= text.Length ? -1 : symbolIndex + 1, + pieceSpan: (spanIndex, len), + id: UninitializedId, + type: 0); + + symbols[symbolIndex++] = s; + spanIndex += len; + } + + PriorityQueue agenda = new(symbolIndex); + Dictionary<(int Index, int Len), (int LeftIndex, int LeftLen, int RightIndex, int RightLen)>? revMerge = null; + + for (int i = 1; i < symbolIndex; i++) + { + TryMerge(i - 1, i, text); + } + + while (agenda.Count > 0) + { + SymbolPair top = agenda.Dequeue(); + + if (symbols[top.Left].pieceSpan.Length == 0 || symbols[top.Right].pieceSpan.Length == 0 || + symbols[top.Left].pieceSpan.Length + symbols[top.Right].pieceSpan.Length != top.Length) + { + continue; + } + + // Replaces symbols with `top` rule. + symbols[top.Left].pieceSpan = (symbols[top.Left].pieceSpan.Index, symbols[top.Left].pieceSpan.Length + symbols[top.Right].pieceSpan.Length); + symbols[top.Left].id = top.Id; + + // Updates prev/next pointers. + symbols[top.Left].next = symbols[top.Right].next; + + if (symbols[top.Right].next >= 0) + { + symbols[symbols[top.Right].next].prev = top.Left; + } + symbols[top.Right].pieceSpan = (0, 0); + + // Adds new symbol pairs which are newly added after symbol replacement. + TryMerge(symbols[top.Left].prev, top.Left, text); + TryMerge(top.Left, symbols[top.Left].next, text); + } + + return revMerge; + + void TryMerge(int left, int right, ReadOnlySpan textSpan) + { + if (left == -1 || right == -1) + { + return; + } + + int pieceLength = symbols[left].pieceSpan.Length + symbols[right].pieceSpan.Length; + if (!_vocab.TryGetValue(textSpan.Slice(symbols[left].pieceSpan.Index, pieceLength), out (int Id, float Score, byte Type) leftId)) + { + return; + } + + symbols[left].type = leftId.Type; + + SymbolPair pair = new(left, right, leftId.Score, pieceLength, leftId.Id); + agenda.Enqueue(pair); + + if (leftId.Type == (byte)ModelProto.Types.SentencePiece.Types.Type.Unused) + { + revMerge ??= new(); + revMerge.Add((symbols[left].pieceSpan.Index, pieceLength), (symbols[left].pieceSpan.Index, symbols[left].pieceSpan.Length, symbols[right].pieceSpan.Index, symbols[right].pieceSpan.Length)); + } + } + } + + // Tries to avoid string allocations if possible. + private string GetTokenString(int id, int index, int length, ReadOnlySpan text) + => _vocabReverse.TryGetValue(id, out string? token) ? token : text.Slice(index, length).ToString(); + + private struct SymbolPair : IEquatable, IComparable + { + public int Left { get; set; } + public int Right { get; set; } + public int Length { get; set; } + public float Score { get; set; } + public int Id { get; set; } + + public SymbolPair(int left, int right, float score, int length, int id) + { + Left = left; + Right = right; + Score = score; + Length = length; + Id = id; + } + + public int CompareTo(SymbolPair other) + { + if (Score != other.Score) + { + return other.Score.CompareTo(Score); + } + + return other.Left.CompareTo(Left); + } + + public override int GetHashCode() + { + int hashCode = 23; + hashCode = (hashCode * 37) + Score.GetHashCode(); + hashCode = (hashCode * 37) + Left.GetHashCode(); + return hashCode; + } + + public bool Equals(SymbolPair other) => Left == other.Left && Score == other.Score; + } + + private record struct BpeSymbol(int prev, int next, (int Index, int Length) pieceSpan, int id, byte type); + } +} diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs index 873dd0c4f6..f41516e270 100644 --- a/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs @@ -6,13 +6,7 @@ using System; using System.Buffers; using System.Collections.Generic; -using System.Collections.ObjectModel; -using System.Diagnostics; using System.IO; -using System.Linq; -using System.Text; -using System.Text.RegularExpressions; -using System.Threading; namespace Microsoft.ML.Tokenizers { @@ -24,134 +18,82 @@ namespace Microsoft.ML.Tokenizers /// public class SentencePieceTokenizer : Tokenizer { - private const int UninitializedId = -2; // indicate if the symbol contains uninitialized id. - private readonly Dictionary _vocab = new(); - private readonly Dictionary _vocabReverse = new(); - private IReadOnlyDictionary? _publicVocab; - private readonly int _maxByteId; - private readonly int _byteCodeToIdOffset; // offset of mapping byte code to the to the Ids. - private readonly int _oneByteUtf8EncodingMaxId; // the maximum value of the one byte UTF-8 character. - private readonly Normalizer? _normalizer; - private readonly Regex? _specialTokensRegex; - private readonly Dictionary? _specialTokens; - private readonly Dictionary? _specialTokensReverse; + private readonly SentencePieceBaseModel _model; - internal SentencePieceTokenizer(ModelProto modelProto, bool addBos, bool addEos, IReadOnlyDictionary? specialTokens = null) : - this(modelProto is null ? throw new ArgumentNullException(nameof(modelProto)) : modelProto, specialTokens) + internal SentencePieceTokenizer(ModelProto modelProto, bool addBos, bool addEos, IReadOnlyDictionary? specialTokens = null) { - AddBeginningOfSentence = addBos; - AddEndOfSentence = addEos; - } - - private SentencePieceTokenizer(ModelProto modelProto, IReadOnlyDictionary? specialTokens) - { - for (int i = 0; i < modelProto.Pieces.Count; i++) - { - var piece = modelProto.Pieces[i]; - _vocab.Add(new StringSpanOrdinalKey(piece.Piece), (i, piece.Score, (byte)piece.Type)); - _vocabReverse.Add(i, piece.Piece); - - if (piece.Type == ModelProto.Types.SentencePiece.Types.Type.Byte) - { - _maxByteId = i; - } - } - - _byteCodeToIdOffset = _vocab.TryGetValue("<0x00>", out (int Id, float Score, byte Type) value) ? value.Id : _maxByteId; - _oneByteUtf8EncodingMaxId = _byteCodeToIdOffset + 0x7F; // 0x7F is the maximum value of the one byte UTF-8 character. - - BeginningOfSentenceToken = modelProto.TrainerSpec.BosPiece ?? ""; - BeginningOfSentenceId = modelProto.TrainerSpec.BosId <= 0 ? 1 : modelProto.TrainerSpec.BosId; - EndOfSentenceToken = modelProto.TrainerSpec.EosPiece ?? ""; - EndOfSentenceId = modelProto.TrainerSpec.EosId <= 0 ? 1 : modelProto.TrainerSpec.EosId; - UnknownToken = modelProto.TrainerSpec.UnkPiece ?? ""; - UnknownId = modelProto.TrainerSpec.UnkId < 0 ? 0 : modelProto.TrainerSpec.UnkId; - - AddDummyPrefix = modelProto.NormalizerSpec.AddDummyPrefix; - EscapeWhiteSpaces = modelProto.NormalizerSpec.EscapeWhitespaces; - TreatWhitespaceAsSuffix = modelProto.TrainerSpec.TreatWhitespaceAsSuffix; - ByteFallback = modelProto.TrainerSpec.ByteFallback; - - SpecialTokens = specialTokens; - _normalizer = new SentencePieceNormalizer(modelProto.NormalizerSpec.RemoveExtraWhitespaces, AddDummyPrefix, EscapeWhiteSpaces, modelProto.TrainerSpec.TreatWhitespaceAsSuffix, specialTokens); - - if (specialTokens is not null && specialTokens.Count > 0) + _model = modelProto.TrainerSpec.ModelType switch { - _specialTokens = new Dictionary(); - _specialTokensReverse = new Dictionary(); - - foreach (var item in specialTokens) - { - _specialTokens.Add(new StringSpanOrdinalKey(item.Key), item.Value); - _specialTokensReverse.Add(item.Value, item.Key); - } - - // We create this Regex object without a timeout, as we expect the match operation to complete in O(N) time complexity. Note that `specialTokens` are treated as constants after the tokenizer is created. - _specialTokensRegex = new Regex(string.Join("|", specialTokens.Keys.Select(s => Regex.Escape(s))), RegexOptions.Compiled); - } + TrainerSpec.Types.ModelType.Bpe => new SentencePieceBpeModel(modelProto, addBos, addEos, specialTokens), + TrainerSpec.Types.ModelType.Unigram => new SentencePieceUnigramModel(modelProto, addBos, addEos, specialTokens), + _ => throw new ArgumentException($"The model type '{modelProto.TrainerSpec.ModelType}' is not supported.", nameof(modelProto)) + }; } - public IReadOnlyDictionary? SpecialTokens { get; } + /// + /// The special tokens. + /// + public IReadOnlyDictionary? SpecialTokens => _model.SpecialTokens; /// /// Specifies whether the model will do a byte fallback when it encounters unknown tokens during the encoding process. /// - public bool ByteFallback { get; } + public bool ByteFallback => _model.ByteFallback; /// /// Indicate emitting the prefix character U+2581 at the beginning of sentence token during the normalization and encoding. /// - public bool AddDummyPrefix { get; } + public bool AddDummyPrefix => _model.AddDummyPrefix; /// /// Indicate if the spaces should be replaced with character U+2581 during the normalization and encoding. /// - public bool EscapeWhiteSpaces { get; } + public bool EscapeWhiteSpaces => _model.EscapeWhiteSpaces; /// /// Indicate emitting the character U+2581 at the end of the last sentence token instead beginning of sentence token during the normalization and encoding. /// - public bool TreatWhitespaceAsSuffix { get; private set; } + public bool TreatWhitespaceAsSuffix { get => _model.TreatWhitespaceAsSuffix; private set => _model.TreatWhitespaceAsSuffix = value; } /// /// Indicate emitting the beginning of sentence token during the encoding. /// - public bool AddBeginningOfSentence { get; } + public bool AddBeginningOfSentence => _model.AddBeginningOfSentence; /// /// Indicate emitting the end of sentence token during the encoding. /// - public bool AddEndOfSentence { get; } + public bool AddEndOfSentence => _model.AddEndOfSentence; /// /// The beginning of sentence token. /// - public string BeginningOfSentenceToken { get; } + public string BeginningOfSentenceToken => _model.BeginningOfSentenceToken; /// /// The end of sentence token. /// - public string EndOfSentenceToken { get; } + public string EndOfSentenceToken => _model.EndOfSentenceToken; /// /// The unknown token. /// - public string UnknownToken { get; } + public string UnknownToken => _model.UnknownToken; /// /// The id of the beginning of sentence token. /// - public int BeginningOfSentenceId { get; } + public int BeginningOfSentenceId => _model.BeginningOfSentenceId; /// /// The id of the end of sentence token. /// - public int EndOfSentenceId { get; } + public int EndOfSentenceId => _model.EndOfSentenceId; /// /// The id of the unknown token. /// - public int UnknownId { get; } + public int UnknownId => _model.UnknownId; /// /// Gets the PreTokenizer used by the Tokenizer. @@ -161,31 +103,12 @@ private SentencePieceTokenizer(ModelProto modelProto, IReadOnlyDictionary /// Gets the Normalizer in use by the Tokenizer. /// - public override Normalizer? Normalizer => _normalizer; + public override Normalizer? Normalizer => _model.Normalizer; /// /// The vocabulary of the model. /// - public IReadOnlyDictionary Vocabulary - { - get - { - IReadOnlyDictionary? publicVocab = Volatile.Read(ref _publicVocab); - if (publicVocab is null) - { - var vocab = new Dictionary(); - foreach (var item in _vocab) - { - vocab.Add(item.Key.ToString(), item.Value.Id); - } - - Interlocked.CompareExchange(ref _publicVocab, new ReadOnlyDictionary(vocab), null); - publicVocab = _publicVocab; - } - - return publicVocab; - } - } + public IReadOnlyDictionary Vocabulary => _model.Vocabulary; /// /// Encodes input text to a list of s. @@ -197,7 +120,7 @@ protected override EncodeResults EncodeToTokens(string? text, Read { return new EncodeResults { - Tokens = EncodeToTokens(text, textSpan, out string? normalizedText, AddBeginningOfSentence, AddEndOfSentence, settings.ConsiderPreTokenization, settings.ConsiderNormalization), + Tokens = _model.EncodeToTokens(text, textSpan, out string? normalizedText, AddBeginningOfSentence, AddEndOfSentence, settings.ConsiderNormalization), NormalizedText = normalizedText, CharsConsumed = normalizedText?.Length ?? text?.Length ?? textSpan.Length }; @@ -214,7 +137,7 @@ protected override EncodeResults EncodeToTokens(string? text, Read /// Indicate whether to consider normalization before tokenization. /// The tokenization result includes a list of s with string value of the token, id, and offset. public IReadOnlyList EncodeToTokens(string text, out string? normalizedText, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true) - => EncodeToTokens(text, Span.Empty, out normalizedText, addBeginningOfSentence, addEndOfSentence, considerPreTokenization, considerNormalization); + => _model.EncodeToTokens(text, Span.Empty, out normalizedText, addBeginningOfSentence, addEndOfSentence, considerNormalization); /// /// Encodes input text a list of s with string value of the token, id, and offset. @@ -227,221 +150,8 @@ public IReadOnlyList EncodeToTokens(string text, out string? norma /// Indicate whether to consider normalization before tokenization. /// The tokenization result includes a list of s with string value of the token, id, and offset. public IReadOnlyList EncodeToTokens(ReadOnlySpan text, out string? normalizedText, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true) - => EncodeToTokens(null, text, out normalizedText, addBeginningOfSentence, addEndOfSentence, considerPreTokenization, considerNormalization); - - private IReadOnlyList EncodeToTokens(string? text, ReadOnlySpan textSpan, out string? normalizedText, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization, bool considerNormalization) - { - if (string.IsNullOrEmpty(text) && textSpan.IsEmpty) - { - normalizedText = null; - return []; - } - - ReadOnlySpan textToEncode = text is null ? textSpan : text.AsSpan(); - if (considerNormalization && _normalizer is not null) - { - normalizedText = text is not null ? _normalizer.Normalize(text) : _normalizer.Normalize(textSpan); - textToEncode = normalizedText.AsSpan(); - } - else - { - normalizedText = null; - } - - if (textToEncode.Length == 0) - { - return []; - } - - List? tokens = new(); - - if (_specialTokensRegex is not null) - { - EncodeWithSpecialTokens(textToEncode, addBeginningOfSentence, addEndOfSentence, tokens); - } - else - { - EncodeInternal(textToEncode, addBeginningOfSentence, addEndOfSentence, tokens); - } - - return tokens; - } - - private void EncodeWithSpecialTokens(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, List tokens) - { - Debug.Assert(_specialTokensRegex is not null); - - if (addBeginOfSentence) - { - tokens.Add(new EncodedToken(BeginningOfSentenceId, BeginningOfSentenceToken, new Range(0, 0))); - } - - int currentOffset = 0; - - foreach ((int Offset, int Length) in PreTokenizer.SplitText(text, _specialTokensRegex!)) - { - if (Offset > currentOffset) - { - EncodeInternal(text.Slice(currentOffset, Offset - currentOffset), addBeginOfSentence: false, addEndOfSentence: false, tokens); - } - - if (_specialTokens!.TryGetValue(text.Slice(Offset, Length), out int id)) - { - tokens.Add(new EncodedToken(id, _specialTokensReverse![id], new Range(Offset, Offset + Length))); - } - - currentOffset = Offset + Length; - } - - if (currentOffset < text.Length) - { - EncodeInternal(text.Slice(currentOffset), addBeginOfSentence: false, addEndOfSentence: false, tokens); - } - - if (addEndOfSentence) - { - tokens.Add(new EncodedToken(EndOfSentenceId, EndOfSentenceToken, new Range(text.Length, text.Length))); - } - } - - /// - /// Encode a text to a list of tokens. - /// - /// The text to encode. - /// Indicate emitting the beginning of sentence token during the encoding. - /// Indicate emitting the end of sentence token during the encoding. - /// A collection to store the encoded tokens. - /// The input text has to be normalized before calling this method. - private void EncodeInternal(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, List tokens) - { - BpeSymbol[] symbols = ArrayPool.Shared.Rent(text.Length); - - Dictionary<(int Index, int Len), (int LeftIndex, int LeftLen, int RightIndex, int RightLen)>? revMerge = Encode(text, symbols); - - if (addBeginOfSentence) - { - tokens.Add(new EncodedToken(BeginningOfSentenceId, BeginningOfSentenceToken, new Range(0, 0))); - } - - for (int index = 0; (uint)index < (uint)symbols.Length; index = symbols[index].next) - { - int id = symbols[index].id; - byte type = symbols[index].type; - - if (id == UninitializedId) - { - if (_vocab.TryGetValue(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), out (int Id, float Score, byte Type) tokenInfo)) - { - id = tokenInfo.Id; - type = tokenInfo.Type; - } - else - { - id = UnknownId; - type = 0; - } - } - - if (type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused) - { - if (id == UnknownId && ByteFallback) - { - EncodeAsBytes(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index); - } - else - { - tokens.Add(new EncodedToken( - id, - GetTokenString(id, symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length, text), - new Range(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Index + symbols[index].pieceSpan.Length))); - } - continue; - } - - Segment(symbols[index].pieceSpan, text); - } - - ArrayPool.Shared.Return(symbols); - - if (addEndOfSentence) - { - tokens.Add(new EncodedToken(EndOfSentenceId, EndOfSentenceToken, new Range(text.Length, text.Length))); - } - - return; - - // Encode the Unknown token to bytes. - void EncodeAsBytes(ReadOnlySpan text, int index) - { - for (int i = 0; i < text.Length; i++) - { - char c = text[i]; - if (c <= 0x7F) - { - int id = (int)c + _byteCodeToIdOffset; // byte code is mapped to the to the Ids starting from 4. - - if (_vocabReverse.TryGetValue(id, out string? token)) - { - tokens.Add(new EncodedToken(id, token, new Range(index + i, index + i + 1))); - } - } - else - { - Span utf8Bytes = stackalloc byte[256]; - byte[]? arrayPoolArray = null; - - int len = Encoding.UTF8.GetMaxByteCount(text.Length - i); - if (len > utf8Bytes.Length) - { - arrayPoolArray = ArrayPool.Shared.Rent(len); - utf8Bytes = arrayPoolArray; - } - - // Need to convert the text into UTF-8 bytes and then encode the bytes. - int bytesWritten = Helpers.GetUtf8Bytes(text.Slice(i), utf8Bytes); - int length = text.Length - i; - for (int j = 0; j < bytesWritten; j++) - { - int id = (int)utf8Bytes[j] + _byteCodeToIdOffset; // byte code is mapped to the to the Ids starting from 4. - - if (_vocabReverse.TryGetValue(id, out string? token)) - { - tokens.Add(new EncodedToken(id, token, new Range(index + i, index + i + length))); - } - - length = 0; - } - - if (arrayPoolArray is not null) - { - ArrayPool.Shared.Return(arrayPoolArray); - } - - break; - } - } - } - - void Segment((int Index, int Length) pieceSpan, ReadOnlySpan text) - { - if (!_vocab.TryGetValue(text.Slice(pieceSpan.Index, pieceSpan.Length), out (int Id, float Score, byte Type) id)) - { - EncodeAsBytes(text.Slice(pieceSpan.Index, pieceSpan.Length), pieceSpan.Index); - return; - } - - if (id.Type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused || - revMerge is null || - !revMerge.TryGetValue((pieceSpan.Index, pieceSpan.Length), out (int LeftIndex, int LeftLen, int RightIndex, int RightLen) merge)) - { - tokens.Add(new EncodedToken(id.Id, text.Slice(pieceSpan.Index, pieceSpan.Length).ToString(), new Range(pieceSpan.Index, pieceSpan.Index + pieceSpan.Length))); - return; - } + => _model.EncodeToTokens(null, text, out normalizedText, addBeginningOfSentence, addEndOfSentence, considerNormalization); - Segment((merge.LeftIndex, merge.LeftLen), text); - Segment((merge.RightIndex, merge.RightLen), text); - } - } /// /// Encodes input text to token Ids. @@ -454,7 +164,7 @@ protected override EncodeResults EncodeToIds(string? text, ReadOnlySpan { - Tokens = EncodeToIds(text, textSpan, AddBeginningOfSentence, AddEndOfSentence, settings.ConsiderNormalization, out string? normalizedText, out int charsConsumed, settings.MaxTokenCount), + Tokens = _model.EncodeToIds(text, textSpan, AddBeginningOfSentence, AddEndOfSentence, settings.ConsiderNormalization, out string? normalizedText, out int charsConsumed, settings.MaxTokenCount), NormalizedText = normalizedText, CharsConsumed = charsConsumed }; @@ -470,7 +180,7 @@ protected override EncodeResults EncodeToIds(string? text, ReadOnlySpanIndicate whether to consider normalization before tokenization. /// The list of encoded Ids. public IReadOnlyList EncodeToIds(string text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true) - => EncodeToIds(text, Span.Empty, addBeginningOfSentence, addEndOfSentence, considerNormalization, out _, out _); + => _model.EncodeToIds(text, Span.Empty, addBeginningOfSentence, addEndOfSentence, considerNormalization, out _, out _); /// /// Encodes input text to token Ids. @@ -482,7 +192,7 @@ public IReadOnlyList EncodeToIds(string text, bool addBeginningOfSentence, /// Indicate whether to consider normalization before tokenization. /// The list of encoded Ids. public IReadOnlyList EncodeToIds(ReadOnlySpan text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true) - => EncodeToIds(null, text, addBeginningOfSentence, addEndOfSentence, considerNormalization, out _, out _); + => _model.EncodeToIds(null, text, addBeginningOfSentence, addEndOfSentence, considerNormalization, out _, out _); /// /// Encodes input text to token Ids up to maximum number of tokens. @@ -497,7 +207,7 @@ public IReadOnlyList EncodeToIds(ReadOnlySpan text, bool addBeginning /// Indicate whether to consider normalization before tokenization. /// The list of encoded Ids. public IReadOnlyList EncodeToIds(string text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, out string? normalizedText, out int charsConsumed, bool considerPreTokenization = true, bool considerNormalization = true) - => EncodeToIds(text, Span.Empty, addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedText, out charsConsumed, maxTokenCount); + => _model.EncodeToIds(text, Span.Empty, addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedText, out charsConsumed, maxTokenCount); /// /// Encodes input text to token Ids up to maximum number of tokens. @@ -512,320 +222,7 @@ public IReadOnlyList EncodeToIds(string text, bool addBeginningOfSentence, /// Indicate whether to consider normalization before tokenization. /// The list of encoded Ids. public IReadOnlyList EncodeToIds(ReadOnlySpan text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, out string? normalizedText, out int charsConsumed, bool considerPreTokenization = true, bool considerNormalization = true) - => EncodeToIds(null, text, addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedText, out charsConsumed, maxTokenCount); - - - private IReadOnlyList EncodeToIds(string? text, ReadOnlySpan textSpan, bool addBeginningOfSentence, bool addEndOfSentence, bool considerNormalization, out string? normalizedText, out int charsConsumed, int maxTokenCount = int.MaxValue) - { - if (maxTokenCount <= 0) - { - throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero."); - } - - if (string.IsNullOrEmpty(text) && textSpan.IsEmpty) - { - normalizedText = null; - charsConsumed = 0; - return []; - } - - return EncodeToIds(text is null ? textSpan : text.AsSpan(), addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedText, out charsConsumed, maxTokenCount); - } - - /// - /// Encodes input text to token Ids up to maximum number of tokens. - /// - /// The text to encode. - /// Indicate emitting the beginning of sentence token during the encoding. - /// Indicate emitting the end of sentence token during the encoding. - /// Indicate whether to consider normalization before tokenization. - /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. - /// The length of the text that encompasses the maximum encoded tokens. - /// The maximum number of tokens to encode. - /// The list of encoded Ids. - private IReadOnlyList EncodeToIds(ReadOnlySpan text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerNormalization, - out string? normalizedText, out int charsConsumed, int maxTokenCount = int.MaxValue) - { - if (maxTokenCount <= 0) - { - throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero."); - } - - if (text.IsEmpty) - { - normalizedText = null; - charsConsumed = 0; - return []; - } - - ReadOnlySpan textToEncode; - - if (considerNormalization && _normalizer is not null) - { - normalizedText = _normalizer.Normalize(text); - textToEncode = normalizedText.AsSpan(); - } - else - { - normalizedText = null; - textToEncode = text; - } - - if (maxTokenCount <= 0) - { - throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than 0."); - } - - List ids = new(); - - if (_specialTokensRegex is not null) - { - EncodeToIdsWithAddedToken(textToEncode, addBeginningOfSentence, addEndOfSentence, ids, out charsConsumed, maxTokenCount); - } - else - { - EncodeToIds(textToEncode, addBeginningOfSentence, addEndOfSentence, ids, out charsConsumed, maxTokenCount); - } - - return ids; - } - - private int EncodeToIdsWithAddedToken(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, IList accumulatedIds, out int charsConsumed, int maxTokens = int.MaxValue) - { - Debug.Assert(_specialTokensRegex is not null); - Debug.Assert(maxTokens > 0); - - charsConsumed = 0; - int idsCount = 0; - - if (addBeginOfSentence) - { - accumulatedIds.Add(BeginningOfSentenceId); - idsCount++; - } - - int currentOffset = 0; - - int charsWritten; - - foreach ((int Offset, int Length) in PreTokenizer.SplitText(text, _specialTokensRegex!)) - { - if (Offset > currentOffset) - { - idsCount += EncodeToIds(text.Slice(currentOffset, Offset - currentOffset), addBeginOfSentence: false, addEndOfSentence: false, accumulatedIds, out charsWritten, maxTokens - idsCount); - charsConsumed += charsWritten; - } - - if (idsCount < maxTokens && _specialTokens!.TryGetValue(text.Slice(Offset, Length), out int id)) - { - accumulatedIds.Add(id); - idsCount++; - charsConsumed += Length; - } - - currentOffset = Offset + Length; - } - - if (currentOffset < text.Length && idsCount < maxTokens) - { - idsCount += EncodeToIds(text.Slice(currentOffset), addBeginOfSentence: false, addEndOfSentence: false, accumulatedIds, out charsWritten, maxTokens - idsCount); - charsConsumed += charsWritten; - } - - if (addEndOfSentence && idsCount < maxTokens) - { - accumulatedIds.Add(EndOfSentenceId); - idsCount++; - } - - return idsCount; - } - - /// - /// Encode a text to a list of Ids and add them to the accumulatedIds list. - /// - /// The text to encode. - /// Indicate emitting the beginning of sentence token during the encoding. - /// Indicate emitting the end of sentence token during the encoding. - /// The list of accumulated encoded Ids. - /// The length of the text that encompasses the maximum encoded tokens. - /// The maximum number of tokens to encode. - /// The number of tokens that the input text will be encoded to. - /// The input text has to be normalized before calling this method. - private int EncodeToIds(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, IList accumulatedIds, out int charsConsumed, int maxTokens = int.MaxValue) - { - charsConsumed = 0; - if (text.IsEmpty) - { - return 0; - } - - int idsCount = 0; - - if (addBeginOfSentence) - { - accumulatedIds.Add(BeginningOfSentenceId); - idsCount++; - } - - BpeSymbol[] symbols = ArrayPool.Shared.Rent(text.Length); - - Dictionary<(int Index, int Len), (int LeftIndex, int LeftLen, int RightIndex, int RightLen)>? revMerge = Encode(text, symbols); - - for (int index = 0; index != -1 && index < symbols.Length; index = symbols[index].next) - { - int id = symbols[index].id; - byte type = symbols[index].type; - - if (id == UninitializedId) - { - if (_vocab.TryGetValue(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), out (int Id, float Score, byte Type) tokenInfo)) - { - id = tokenInfo.Id; - type = tokenInfo.Type; - } - else - { - id = UnknownId; - type = 0; - } - } - - if (type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused) - { - if (id == UnknownId && ByteFallback) - { - if (!EncodeAsBytes(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index, ref charsConsumed)) - { - ArrayPool.Shared.Return(symbols); - return idsCount; - } - } - else - { - if (idsCount < maxTokens) - { - accumulatedIds.Add(id); - charsConsumed += symbols[index].pieceSpan.Length; - idsCount++; - } - else - { - ArrayPool.Shared.Return(symbols); - return idsCount; - } - } - continue; - } - - if (!Segment(symbols[index].pieceSpan, text, ref charsConsumed)) - { - break; - } - } - - ArrayPool.Shared.Return(symbols); - - if (addEndOfSentence) - { - if (idsCount < maxTokens) - { - accumulatedIds.Add(EndOfSentenceId); - idsCount++; - } - } - - return idsCount; - - // Encode the Unknown token to bytes. - bool EncodeAsBytes(ReadOnlySpan text, int index, ref int charsConsumed) - { - for (int i = 0; i < text.Length; i++) - { - char c = text[i]; - if (c <= 0x7F) - { - if (idsCount < maxTokens) - { - charsConsumed++; - accumulatedIds.Add((int)c + _byteCodeToIdOffset); // byte code is mapped to the to the Ids starting from 4. - idsCount++; - } - else - { - return false; - } - } - else - { - Span utf8Bytes = stackalloc byte[100]; - byte[]? arrayPoolArray = null; - - int len = Encoding.UTF8.GetMaxByteCount(text.Length - i); - if (len > utf8Bytes.Length) - { - arrayPoolArray = ArrayPool.Shared.Rent(len); - utf8Bytes = arrayPoolArray; - } - - // Need to convert the text into UTF-8 bytes and then encode the bytes. - int bytesWritten = Helpers.GetUtf8Bytes(text.Slice(i), utf8Bytes); - - bool ret; - if (idsCount + bytesWritten <= maxTokens) - { - for (int j = 0; j < bytesWritten; j++) - { - accumulatedIds.Add((int)utf8Bytes[j] + _byteCodeToIdOffset); // byte code is mapped to the to the Ids starting from 4. - } - - charsConsumed += text.Length - i; - ret = true; - } - else - { - ret = false; - } - - if (arrayPoolArray is not null) - { - ArrayPool.Shared.Return(arrayPoolArray); - } - - return ret; - } - } - - return true; - } - - bool Segment((int Index, int Length) pieceSpan, ReadOnlySpan text, ref int charsConsumed) - { - if (!_vocab.TryGetValue(text.Slice(pieceSpan.Index, pieceSpan.Length), out (int Id, float Score, byte Type) id)) - { - return EncodeAsBytes(text.Slice(pieceSpan.Index, pieceSpan.Length), pieceSpan.Index, ref charsConsumed); - } - - if (id.Type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused || - revMerge is null || - !revMerge.TryGetValue((pieceSpan.Index, pieceSpan.Length), out (int LeftIndex, int LeftLen, int RightIndex, int RightLen) merge)) - { - if (idsCount < maxTokens) - { - accumulatedIds.Add(id.Id); - charsConsumed += pieceSpan.Length; - idsCount++; - return true; - } - else - { - return false; - } - } - - return Segment((merge.LeftIndex, merge.LeftLen), text, ref charsConsumed) && Segment((merge.RightIndex, merge.RightLen), text, ref charsConsumed); - } - } + => _model.EncodeToIds(null, text, addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedText, out charsConsumed, maxTokenCount); /// /// Get the number of tokens that the input text will be encoded to. @@ -835,12 +232,7 @@ revMerge is null || /// The settings used to encode the text. /// The number of token Ids that the input text will be encoded to. protected override int CountTokens(string? text, ReadOnlySpan textSpan, EncodeSettings settings) - { - return CountTokens(text, textSpan, AddBeginningOfSentence, AddEndOfSentence, settings.ConsiderPreTokenization, settings.ConsiderNormalization, out _, out _, settings.MaxTokenCount); - } - - private int CountTokens(string? text, ReadOnlySpan textSpan, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization, bool considerNormalization, out string? normalizedText, out int charsConsumed, int maxTokenCount = int.MaxValue) - => CountTokens(text is null ? textSpan : text.AsSpan(), addBeginningOfSentence, addEndOfSentence, considerPreTokenization, considerNormalization, out normalizedText, out charsConsumed, maxTokenCount); + => _model.CountTokens(text, textSpan, AddBeginningOfSentence, AddEndOfSentence, settings.ConsiderNormalization, out _, out _, settings.MaxTokenCount); /// /// Get the number of tokens that the input text will be encoded to. @@ -852,7 +244,7 @@ private int CountTokens(string? text, ReadOnlySpan textSpan, bool addBegin /// Indicate whether to consider normalization before tokenization. /// The number of token Ids that the input text will be encoded to. public int CountTokens(string text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true) - => CountTokens(text, Span.Empty, addBeginningOfSentence, addEndOfSentence, considerPreTokenization, considerNormalization, out _, out _); + => _model.CountTokens(text, ReadOnlySpan.Empty, addBeginningOfSentence, addEndOfSentence, considerNormalization, out _, out _, int.MaxValue); /// /// Get the number of tokens that the input text will be encoded to. @@ -864,7 +256,7 @@ public int CountTokens(string text, bool addBeginningOfSentence, bool addEndOfSe /// Indicate whether to consider normalization before tokenization. /// The number of token Ids that the input text will be encoded to. public int CountTokens(ReadOnlySpan text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true) - => CountTokens(null, text, addBeginningOfSentence, addEndOfSentence, considerPreTokenization, considerNormalization, out _, out _); + => _model.CountTokens(null, text, addBeginningOfSentence, addEndOfSentence, considerNormalization, out _, out _, int.MaxValue); /// /// Get the number of tokens that the input text will be encoded to. @@ -879,7 +271,7 @@ public int CountTokens(ReadOnlySpan text, bool addBeginningOfSentence, boo /// The maximum number of tokens to encode. /// The number of tokens that the input text will be encoded to. public int CountTokens(string text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization, bool considerNormalization, out string? normalizedText, out int charsConsumed, int maxTokenCount = int.MaxValue) - => CountTokens(text, Span.Empty, addBeginningOfSentence, addEndOfSentence, considerPreTokenization, considerNormalization, out normalizedText, out charsConsumed, maxTokenCount); + => _model.CountTokens(text, ReadOnlySpan.Empty, addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedText, out charsConsumed, maxTokenCount); /// /// Get the number of tokens that the input text will be encoded to. @@ -894,299 +286,54 @@ public int CountTokens(string text, bool addBeginningOfSentence, bool addEndOfSe /// The maximum number of tokens to encode. /// The number of tokens that the input text will be encoded to. public int CountTokens(ReadOnlySpan text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization, bool considerNormalization, out string? normalizedText, out int charsConsumed, int maxTokenCount = int.MaxValue) - { - if (maxTokenCount <= 0) - { - throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero."); - } - - if (text.IsEmpty) - { - normalizedText = null; - charsConsumed = 0; - return 0; - } + => _model.CountTokens(null, text, addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedText, out charsConsumed, maxTokenCount); - ReadOnlySpan textToEncode; - if (considerNormalization && _normalizer is not null) - { - normalizedText = _normalizer.Normalize(text); - textToEncode = normalizedText.AsSpan(); - } - else + /// + /// Find the index of the maximum encoding capacity without surpassing the token limit. + /// + /// The text to encode. + /// The span of the text to encode which will be used if the is . + /// The settings used to encode the text. + /// Indicate whether to find the index from the end of the text. + /// If the tokenizer's normalization is enabled or has is , this will be set to in its normalized form; otherwise, this value will be set to . + /// The token count can be generated which should be smaller than the maximum token count. + /// + /// The index of the maximum encoding capacity within the processed text without surpassing the token limit. + /// If is , it represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely, + /// if all tokens fit, the result will be length of the input text or the if the normalization is enabled. + /// If is , it represents the index of the first character to be included. In cases where no tokens fit, the result will be the text length; conversely, + /// if all tokens fit, the result will be zero. + /// + protected override int GetIndexByTokenCount(string? text, ReadOnlySpan textSpan, EncodeSettings settings, bool fromEnd, out string? normalizedText, out int tokenCount) + { + if (fromEnd) { - normalizedText = null; - textToEncode = text; + return _model.GetIndexByTokenCountFromEnd(text, textSpan, AddBeginningOfSentence, AddEndOfSentence, settings.MaxTokenCount, settings.ConsiderNormalization, out normalizedText, out tokenCount); } - return _specialTokensRegex is not null ? - CountTokensWithSpecialTokens(textToEncode, addBeginningOfSentence, addEndOfSentence, out charsConsumed, maxTokenCount) : - CountTokens(textToEncode, addBeginningOfSentence, addEndOfSentence, out charsConsumed, maxTokenCount); + tokenCount = _model.CountTokens(text, textSpan, AddBeginningOfSentence, AddEndOfSentence, settings.ConsiderNormalization, out normalizedText, out int charsConsumed, settings.MaxTokenCount); + return charsConsumed; } - private int CountTokensWithSpecialTokens(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, out int charsConsumed, int maxTokens = int.MaxValue) + /// + /// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit. + /// + /// The text to encode. + /// Indicate emitting the beginning of sentence token during the encoding. + /// Indicate emitting the end of sentence token during the encoding. + /// The maximum token count to limit the encoding capacity. + /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. + /// The token count can be generated which should be smaller than the maximum token count. + /// Indicate whether to consider pre-tokenization before tokenization. + /// Indicate whether to consider normalization before tokenization. + /// + /// The index of the maximum encoding capacity within the processed text without surpassing the token limit. + /// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely, + /// if all tokens fit, the result will be length of the text or the if the normalization is enabled. + /// + public int GetIndexByTokenCount(string text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, out string? normalizedText, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { - Debug.Assert(_specialTokensRegex is not null); - Debug.Assert(maxTokens > 0); - - charsConsumed = 0; - int idsCount = 0; - - if (addBeginOfSentence) - { - idsCount++; - } - - int currentOffset = 0; - - int charsWritten; - - foreach ((int Offset, int Length) in PreTokenizer.SplitText(text, _specialTokensRegex!)) - { - if (Offset > currentOffset) - { - idsCount += CountTokens(text.Slice(currentOffset, Offset - currentOffset), addBeginOfSentence: false, addEndOfSentence: false, out charsWritten, maxTokens - idsCount); - charsConsumed += charsWritten; - } - - if (idsCount < maxTokens && _specialTokens!.TryGetValue(text.Slice(Offset, Length), out int id)) - { - idsCount++; - charsConsumed += Length; - } - - currentOffset = Offset + Length; - } - - if (currentOffset < text.Length && idsCount < maxTokens) - { - idsCount += CountTokens(text.Slice(currentOffset), addBeginOfSentence: false, addEndOfSentence: false, out charsWritten, maxTokens - idsCount); - charsConsumed += charsWritten; - } - - if (addEndOfSentence && idsCount < maxTokens) - { - idsCount++; - } - - return idsCount; - } - - /// - /// Get the number of tokens that the input text will be encoded to. - /// - /// The text to encode. - /// Indicate emitting the beginning of sentence token during the encoding. - /// Indicate emitting the end of sentence token during the encoding. - /// The length of the text that encompasses the maximum encoded tokens. - /// The maximum number of tokens to encode. - /// The number of tokens that the input text will be encoded to. - /// The input text has to be normalized before calling this method. - private int CountTokens(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, out int charsConsumed, int maxTokens = int.MaxValue) - { - charsConsumed = 0; - if (text.IsEmpty) - { - return 0; - } - - int tokenCount = addBeginOfSentence ? 1 : 0; - - BpeSymbol[] symbols = ArrayPool.Shared.Rent(text.Length); - - Dictionary<(int Index, int Len), (int LeftIndex, int LeftLen, int RightIndex, int RightLen)>? revMerge = Encode(text, symbols); - - for (int index = 0; index != -1 && index < symbols.Length; index = symbols[index].next) - { - int id = symbols[index].id; - byte type = symbols[index].type; - - if (id == UninitializedId) - { - if (_vocab.TryGetValue(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), out (int Id, float Score, byte Type) tokenInfo)) - { - id = tokenInfo.Id; - type = tokenInfo.Type; - } - else - { - id = UnknownId; - type = 0; - } - } - - if (type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused) - { - if (id == UnknownId && ByteFallback) - { - if (!EncodeAsBytes(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index, ref charsConsumed)) - { - break; - } - } - else - { - if (tokenCount < maxTokens) - { - tokenCount++; - charsConsumed += symbols[index].pieceSpan.Length; - } - else - { - break; - } - } - continue; - } - - if (!Segment(symbols[index].pieceSpan, text, ref charsConsumed)) - { - break; - } - } - - ArrayPool.Shared.Return(symbols); - - if (addEndOfSentence) - { - if (tokenCount < maxTokens) - { - tokenCount++; - } - } - - return tokenCount; - - // Encode the Unknown token to bytes. - bool EncodeAsBytes(ReadOnlySpan text, int index, ref int charsConsumed) - { - for (int i = 0; i < text.Length; i++) - { - char c = text[i]; - if (c <= 0x7F) - { - if (tokenCount < maxTokens) - { - tokenCount++; - charsConsumed++; - } - else - { - return false; - } - } - else - { - Span utf8Bytes = stackalloc byte[100]; - byte[]? arrayPoolArray = null; - - int len = Encoding.UTF8.GetMaxByteCount(text.Length - i); - if (len > utf8Bytes.Length) - { - arrayPoolArray = ArrayPool.Shared.Rent(len); - utf8Bytes = arrayPoolArray; - } - - // Need to convert the text into UTF-8 bytes and then encode the bytes. - int encodedCount = Helpers.GetUtf8Bytes(text.Slice(i), utf8Bytes); - bool ret; - - if (tokenCount + encodedCount <= maxTokens) - { - tokenCount += encodedCount; - charsConsumed += text.Length - i; - ret = true; - } - else - { - ret = false; - } - - if (arrayPoolArray is not null) - { - ArrayPool.Shared.Return(arrayPoolArray); - } - - return ret; - } - } - - return true; - } - - bool Segment((int Index, int Length) pieceSpan, ReadOnlySpan text, ref int charsConsumed) - { - if (!_vocab.TryGetValue(text.Slice(pieceSpan.Index, pieceSpan.Length), out (int Id, float Score, byte Type) id)) - { - return EncodeAsBytes(text.Slice(pieceSpan.Index, pieceSpan.Length), pieceSpan.Index, ref charsConsumed); - } - - if (id.Type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused || - revMerge is null || - !revMerge.TryGetValue((pieceSpan.Index, pieceSpan.Length), out (int LeftIndex, int LeftLen, int RightIndex, int RightLen) merge)) - { - if (tokenCount < maxTokens) - { - tokenCount++; - charsConsumed += pieceSpan.Length; - return true; - } - else - { - return false; - } - } - - return Segment((merge.LeftIndex, merge.LeftLen), text, ref charsConsumed) && Segment((merge.RightIndex, merge.RightLen), text, ref charsConsumed); - } - } - - /// - /// Find the index of the maximum encoding capacity without surpassing the token limit. - /// - /// The text to encode. - /// The span of the text to encode which will be used if the is . - /// The settings used to encode the text. - /// Indicate whether to find the index from the end of the text. - /// If the tokenizer's normalization is enabled or has is , this will be set to in its normalized form; otherwise, this value will be set to . - /// The token count can be generated which should be smaller than the maximum token count. - /// - /// The index of the maximum encoding capacity within the processed text without surpassing the token limit. - /// If is , it represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely, - /// if all tokens fit, the result will be length of the input text or the if the normalization is enabled. - /// If is , it represents the index of the first character to be included. In cases where no tokens fit, the result will be the text length; conversely, - /// if all tokens fit, the result will be zero. - /// - protected override int GetIndexByTokenCount(string? text, ReadOnlySpan textSpan, EncodeSettings settings, bool fromEnd, out string? normalizedText, out int tokenCount) - { - if (fromEnd) - { - return GetIndexByTokenCountFromEnd(text, textSpan, settings.MaxTokenCount, settings.ConsiderNormalization, out normalizedText, out tokenCount); - } - - tokenCount = CountTokens(text, textSpan, AddBeginningOfSentence, AddEndOfSentence, settings.ConsiderPreTokenization, settings.ConsiderNormalization, out normalizedText, out int charsConsumed, settings.MaxTokenCount); - return charsConsumed; - } - - /// - /// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit. - /// - /// The text to encode. - /// Indicate emitting the beginning of sentence token during the encoding. - /// Indicate emitting the end of sentence token during the encoding. - /// The maximum token count to limit the encoding capacity. - /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null. - /// The token count can be generated which should be smaller than the maximum token count. - /// Indicate whether to consider pre-tokenization before tokenization. - /// Indicate whether to consider normalization before tokenization. - /// - /// The index of the maximum encoding capacity within the processed text without surpassing the token limit. - /// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely, - /// if all tokens fit, the result will be length of the text or the if the normalization is enabled. - /// - public int GetIndexByTokenCount(string text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, out string? normalizedText, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) - { - tokenCount = CountTokens(text, Span.Empty, addBeginningOfSentence, addEndOfSentence, considerPreTokenization, considerNormalization, out normalizedText, out int charsConsumed, maxTokenCount); + tokenCount = _model.CountTokens(text, ReadOnlySpan.Empty, addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedText, out int charsConsumed, maxTokenCount); return charsConsumed; } @@ -1208,13 +355,10 @@ public int GetIndexByTokenCount(string text, bool addBeginningOfSentence, bool a /// public int GetIndexByTokenCount(ReadOnlySpan text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, out string? normalizedText, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true) { - tokenCount = CountTokens(null, text, addBeginningOfSentence, addEndOfSentence, considerPreTokenization, considerNormalization, out normalizedText, out int charsConsumed, maxTokenCount); + tokenCount = _model.CountTokens(null, text, addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedText, out int charsConsumed, maxTokenCount); return charsConsumed; } - private int GetIndexByTokenCountFromEnd(string? text, ReadOnlySpan textSpan, int maxTokenCount, bool considerNormalization, out string? normalizedText, out int tokenCount) - => GetIndexByTokenCountFromEnd(text is null ? textSpan : text.AsSpan(), AddBeginningOfSentence, AddEndOfSentence, maxTokenCount, considerNormalization, out normalizedText, out tokenCount); - /// /// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit. /// @@ -1230,7 +374,7 @@ private int GetIndexByTokenCountFromEnd(string? text, ReadOnlySpan textSpa /// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the ; conversely, if all tokens fit, the result will be 0. /// public int GetIndexByTokenCountFromEnd(string text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, bool considerNormalization, out string? normalizedText, out int tokenCount) - => GetIndexByTokenCountFromEnd(text is null ? ReadOnlySpan.Empty : text.AsSpan(), addBeginningOfSentence, addEndOfSentence, maxTokenCount, considerNormalization, out normalizedText, out tokenCount); + => _model.GetIndexByTokenCountFromEnd(text, ReadOnlySpan.Empty, addBeginningOfSentence, addEndOfSentence, maxTokenCount, considerNormalization, out normalizedText, out tokenCount); /// /// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit. @@ -1247,288 +391,14 @@ public int GetIndexByTokenCountFromEnd(string text, bool addBeginningOfSentence, /// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the ; conversely, if all tokens fit, the result will be 0. /// public int GetIndexByTokenCountFromEnd(ReadOnlySpan text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, bool considerNormalization, out string? normalizedText, out int tokenCount) - { - if (maxTokenCount <= 0) - { - throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The max token count must be greater than 0."); - } - - if (text.IsEmpty) - { - normalizedText = null; - tokenCount = 0; - return 0; - } - - ReadOnlySpan textToEncode; - if (considerNormalization && _normalizer is not null) - { - normalizedText = _normalizer.Normalize(text); - textToEncode = normalizedText.AsSpan(); - } - else - { - normalizedText = null; - textToEncode = text; - } - - int textIndex; - if (_specialTokensRegex is not null) - { - tokenCount = CountTokensFromEndWithSpecialTokens(textToEncode, addBeginningOfSentence, addEndOfSentence, out textIndex, maxTokenCount); - } - else - { - tokenCount = CountTokensFromEnd(textToEncode, addBeginningOfSentence, addEndOfSentence, out textIndex, maxTokenCount); - } - - return textIndex; - } - - private int CountTokensFromEndWithSpecialTokens(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, out int textIndex, int maxTokens) - { - Debug.Assert(_specialTokensRegex is not null); - Debug.Assert(maxTokens > 0); - Debug.Assert(text.Length > 0); - - textIndex = text.Length; - int idsCount = 0; - - if (addEndOfSentence) - { - idsCount++; - } - - (int Offset, int Length)[] splits = PreTokenizer.SplitText(text, _specialTokensRegex!).ToArray(); - - if (splits.Length == 0) - { - return CountTokensFromEnd(text, addBeginOfSentence, addEndOfSentence, out textIndex, maxTokens); - } - - (int Offset, int Length) current = splits[splits.Length - 1]; - - int splitTextIndex; - ReadOnlySpan splitText; - - if (current.Offset + current.Length < text.Length) - { - splitText = text.Slice(current.Offset + current.Length); - idsCount += CountTokensFromEnd(splitText, addBeginOfSentence: false, addEndOfSentence: false, out splitTextIndex, maxTokens - idsCount); - textIndex -= splitText.Length - splitTextIndex; - } - - for (int i = splits.Length - 1; i >= 0 && idsCount < maxTokens; i--) - { - current = splits[i]; - - if (_specialTokens!.TryGetValue(text.Slice(current.Offset, current.Length), out int id)) - { - idsCount++; - } - textIndex -= current.Length; - - if (current.Offset > 0 && idsCount < maxTokens) - { - int start = i > 0 ? splits[i - 1].Offset + splits[i - 1].Length : 0; - splitText = text.Slice(start, current.Offset - start); - idsCount += CountTokensFromEnd(splitText, addBeginOfSentence: false, addEndOfSentence: false, out splitTextIndex, maxTokens - idsCount); - textIndex -= splitText.Length - splitTextIndex; - } - } - - if (addBeginOfSentence && idsCount < maxTokens) - { - idsCount++; - } - - return idsCount; - } - - /// - /// Get the number of tokens that the input text will be encoded to. - /// - /// The text to encode. - /// Indicate emitting the beginning of sentence token during the encoding. - /// Indicate emitting the end of sentence token during the encoding. - /// Starting from this index to the end of the text will encompasses the maximum encoded tokens. - /// The maximum number of tokens to encode. - /// The number of tokens that the input text will be encoded to. - /// The input text has to be normalized before calling this method. - private int CountTokensFromEnd(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, out int textIndex, int maxTokens = int.MaxValue) - { - textIndex = text.Length; - if (text.IsEmpty) - { - return 0; - } - - int tokenCount = addEndOfSentence ? 1 : 0; - - BpeSymbol[] symbols = ArrayPool.Shared.Rent(text.Length); - - Dictionary<(int Index, int Len), (int LeftIndex, int LeftLen, int RightIndex, int RightLen)>? revMerge = Encode(text, symbols); - - // Move to the last symbol. - int lastSymbolIndex = 0; - while (symbols[lastSymbolIndex].next != -1 && lastSymbolIndex < symbols.Length) - { - lastSymbolIndex = symbols[lastSymbolIndex].next; - } - - for (int index = lastSymbolIndex; index >= 0; index = symbols[index].prev) - { - int id = symbols[index].id; - byte type = symbols[index].type; - - if (id == UninitializedId) - { - if (_vocab.TryGetValue(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), out (int Id, float Score, byte Type) tokenInfo)) - { - id = tokenInfo.Id; - type = tokenInfo.Type; - } - else - { - id = UnknownId; - type = 0; - } - } - - if (type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused) - { - if (id == UnknownId && ByteFallback) - { - if (!EncodeAsBytesFromEnd(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index, ref textIndex)) - { - break; - } - } - else - { - if (tokenCount < maxTokens) - { - tokenCount++; - textIndex -= symbols[index].pieceSpan.Length; - } - else - { - break; - } - } - continue; - } - - if (!SegmentFromEnd(symbols[index].pieceSpan, text, ref textIndex)) - { - break; - } - } - - ArrayPool.Shared.Return(symbols); - - if (addBeginOfSentence) - { - if (tokenCount < maxTokens) - { - tokenCount++; - } - } - - return tokenCount; - - // Encode the Unknown token to bytes. - bool EncodeAsBytesFromEnd(ReadOnlySpan text, int index, ref int textIndex) - { - for (int i = text.Length - 1; i >= 0; i--) - { - char c = text[i]; - if (c <= 0x7F) - { - if (tokenCount < maxTokens) - { - tokenCount++; - textIndex--; - } - else - { - return false; - } - } - else - { - Span utf8Bytes = stackalloc byte[100]; - byte[]? arrayPoolArray = null; - - int len = Encoding.UTF8.GetMaxByteCount(text.Length - i); - if (len > utf8Bytes.Length) - { - arrayPoolArray = ArrayPool.Shared.Rent(len); - utf8Bytes = arrayPoolArray; - } - - // Need to convert the text into UTF-8 bytes and then encode the bytes. - int encodedCount = Helpers.GetUtf8Bytes(text.Slice(0, i + 1), utf8Bytes); - bool ret; - - if (tokenCount + encodedCount <= maxTokens) - { - tokenCount += encodedCount; - textIndex -= i + 1; - ret = true; - } - else - { - ret = false; - } - - if (arrayPoolArray is not null) - { - ArrayPool.Shared.Return(arrayPoolArray); - } - - return ret; - } - } - - return true; - } - - bool SegmentFromEnd((int Index, int Length) pieceSpan, ReadOnlySpan text, ref int textIndex) - { - if (!_vocab.TryGetValue(text.Slice(pieceSpan.Index, pieceSpan.Length), out (int Id, float Score, byte Type) id)) - { - return EncodeAsBytesFromEnd(text.Slice(pieceSpan.Index, pieceSpan.Length), pieceSpan.Index, ref textIndex); - } - - if (id.Type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused || - revMerge is null || - !revMerge.TryGetValue((pieceSpan.Index, pieceSpan.Length), out (int LeftIndex, int LeftLen, int RightIndex, int RightLen) merge)) - { - if (tokenCount < maxTokens) - { - tokenCount++; - textIndex -= pieceSpan.Length; - return true; - } - else - { - return false; - } - } - - // Segment the right part first. - return SegmentFromEnd((merge.RightIndex, merge.RightLen), text, ref textIndex) && SegmentFromEnd((merge.LeftIndex, merge.LeftLen), text, ref textIndex); - } - } + => _model.GetIndexByTokenCountFromEnd(null, text, addBeginningOfSentence, addEndOfSentence, maxTokenCount, considerNormalization, out normalizedText, out tokenCount); /// /// Decode the given ids, back to a String. /// /// The list of ids that we want to decode. /// The decoded string. - public override string Decode(IEnumerable ids) - => Decode(ids, considerSpecialTokens: false); + public override string Decode(IEnumerable ids) => _model.Decode(ids, considerSpecialTokens: false); /// /// Decode the given ids, back to a String. @@ -1536,231 +406,7 @@ public override string Decode(IEnumerable ids) /// The list of ids that we want to decode. /// Indicate whether to consider special tokens during decoding. /// The decoded string. - public string Decode(IEnumerable ids, bool considerSpecialTokens) - { - if (ids is null) - { - throw new ArgumentNullException(nameof(ids)); - } - - using IEnumerator enumerator = ids.GetEnumerator(); - if (!enumerator.MoveNext()) - { - return string.Empty; - } - - ValueStringBuilder sb = new(stackalloc char[256]); - - int bytesCount = -1; - byte[]? bytesPoolArray = null; - bool prefixRemoved = false; - int suffixIndex = -1; - char prefixSuffixChar = EscapeWhiteSpaces ? SentencePieceNormalizer.DummyPrefix : ' '; - - if (enumerator.Current <= _maxByteId) - { - // First token is a byte token. - - while (enumerator.Current < _byteCodeToIdOffset) - { - // It is possible listing some special tokens before the byte tokens in the tokenizer's data. - TryDecodeAsSpecialToken(this, enumerator.Current, considerSpecialTokens, ref sb); - - // Skip control tokens. - if (!enumerator.MoveNext()) - { - return sb.ToString(); - } - } - - if (enumerator.Current <= _maxByteId) - { - EncodeByte(enumerator.Current, _oneByteUtf8EncodingMaxId, _byteCodeToIdOffset, ref bytesCount, ref bytesPoolArray, ref sb); - } - else if (_vocabReverse.TryGetValue(enumerator.Current, out string? token)) - { - AppendTokenWithCheckingPrefix(AddDummyPrefix, TreatWhitespaceAsSuffix, token, prefixSuffixChar, ref sb, ref prefixRemoved, ref suffixIndex); - } - else - { - TryDecodeAsSpecialToken(this, enumerator.Current, considerSpecialTokens, ref sb); - } - } - else if (_vocabReverse.TryGetValue(enumerator.Current, out string? token)) - { - AppendTokenWithCheckingPrefix(AddDummyPrefix, TreatWhitespaceAsSuffix, token, prefixSuffixChar, ref sb, ref prefixRemoved, ref suffixIndex); - } - else - { - TryDecodeAsSpecialToken(this, enumerator.Current, considerSpecialTokens, ref sb); - } - - char[]? charPoolArray = null; - - while (enumerator.MoveNext()) - { - if (enumerator.Current < _byteCodeToIdOffset) - { - if (bytesCount >= 1) - { - FlushBytes(ref bytesCount, ref bytesPoolArray, ref charPoolArray, ref sb); - } - - // It is possible listing some special tokens before the byte tokens in the tokenizer's data. - TryDecodeAsSpecialToken(this, enumerator.Current, considerSpecialTokens, ref sb); - - continue; - } - - if (enumerator.Current <= _maxByteId) - { - if (bytesCount >= 1) - { - Debug.Assert(bytesPoolArray is not null); - - if (bytesCount >= bytesPoolArray!.Length) - { - Helpers.ArrayPoolGrow(ref bytesPoolArray, bytesCount * 2); - } - - bytesPoolArray![bytesCount++] = (byte)(enumerator.Current - _byteCodeToIdOffset); - } - else - { - EncodeByte(enumerator.Current, _oneByteUtf8EncodingMaxId, _byteCodeToIdOffset, ref bytesCount, ref bytesPoolArray, ref sb); - } - } - else - { - if (bytesCount >= 1) - { - FlushBytes(ref bytesCount, ref bytesPoolArray, ref charPoolArray, ref sb); - } - - if (_vocabReverse.TryGetValue(enumerator.Current, out string? token)) - { - AppendTokenWithCheckingPrefix(AddDummyPrefix, TreatWhitespaceAsSuffix, token, prefixSuffixChar, ref sb, ref prefixRemoved, ref suffixIndex); - } - else - { - TryDecodeAsSpecialToken(this, enumerator.Current, considerSpecialTokens, ref sb); - } - } - } - - if (bytesCount >= 1) - { - FlushBytes(ref bytesCount, ref bytesPoolArray, ref charPoolArray, ref sb); - } - - if (AddDummyPrefix && TreatWhitespaceAsSuffix && suffixIndex >= 0 && sb.Length > 0) - { - Debug.Assert(sb[suffixIndex] == SentencePieceNormalizer.DummyPrefix); - Debug.Assert(sb.Length > suffixIndex); - - sb.Remove(suffixIndex, 1); - } - - if (bytesPoolArray is not null) - { - ArrayPool.Shared.Return(bytesPoolArray); - } - - if (charPoolArray is not null) - { - ArrayPool.Shared.Return(charPoolArray); - } - - return EscapeWhiteSpaces ? sb.ToString(SentencePieceNormalizer.DummyPrefix, ' ') : sb.ToString(); - - static void FlushBytes(ref int bytesCount, ref byte[]? bytesPoolArray, ref char[]? charPoolArray, ref ValueStringBuilder sb) - { - Debug.Assert(bytesCount >= 1); - Debug.Assert(bytesPoolArray is not null); - - int len = Encoding.UTF8.GetMaxCharCount(bytesCount); - - charPoolArray ??= ArrayPool.Shared.Rent(Math.Max(len, 50)); - - if (len > charPoolArray.Length) - { - Helpers.ArrayPoolGrow(ref charPoolArray, len); - } - - int charCount = Helpers.GetChars(bytesPoolArray.AsSpan(0, bytesCount), charPoolArray); - - sb.Append(charPoolArray.AsSpan(0, charCount)); - bytesCount = -1; - } - - static void EncodeByte(int id, int oneByteUtf8EncodingMaxId, int byteCodeToIdOffset, ref int bytesCount, ref byte[]? bytesPoolArray, ref ValueStringBuilder sb) - { - if (id <= oneByteUtf8EncodingMaxId) - { - sb.Append((char)(id - byteCodeToIdOffset)); - } - else - { - bytesCount = 1; - bytesPoolArray ??= ArrayPool.Shared.Rent(50); - bytesPoolArray[0] = (byte)(id - byteCodeToIdOffset); - } - } - - static void AppendTokenWithCheckingPrefix(bool addDummyPrefix, bool treatWhitespaceAsSuffix, string token, char prefixSuffixChar, ref ValueStringBuilder sb, ref bool prefixRemoved, ref int suffixIndex) - { - if (token.Length == 0) - { - return; - } - - if (!addDummyPrefix) - { - sb.Append(token); - return; - } - - if (treatWhitespaceAsSuffix) - { - sb.Append(token); - if (token[token.Length - 1] == prefixSuffixChar) - { - suffixIndex = sb.Length - 1; - } - } - else - { - sb.Append(!prefixRemoved && token[0] == prefixSuffixChar ? token.AsSpan(1) : token.AsSpan()); - } - - prefixRemoved = true; - } - - static void TryDecodeAsSpecialToken(SentencePieceTokenizer tokenizer, int id, bool considerSpecialTokens, ref ValueStringBuilder sb) - { - if (!considerSpecialTokens) - { - return; - } - - if (id == tokenizer.BeginningOfSentenceId) - { - sb.Append(tokenizer.BeginningOfSentenceToken); - } - else if (id == tokenizer.EndOfSentenceId) - { - sb.Append(tokenizer.EndOfSentenceToken); - } - else if (id == tokenizer.UnknownId) - { - sb.Append(tokenizer.UnknownToken); - } - else if (tokenizer._specialTokensReverse?.TryGetValue(id, out string? specialToken) is true) - { - sb.Append(specialToken); - } - } - } + public string Decode(IEnumerable ids, bool considerSpecialTokens) => _model.Decode(ids, considerSpecialTokens); /// /// Decode the given ids back to text and store the result in the span. @@ -1771,7 +417,7 @@ static void TryDecodeAsSpecialToken(SentencePieceTokenizer tokenizer, int id, bo /// The number of characters written to the destination span. /// The operation status indicates whether all IDs were successfully decoded or if the is too small to contain the entire decoded result. public override OperationStatus Decode(IEnumerable ids, Span destination, out int idsConsumed, out int charsWritten) - => Decode(ids, destination, considerSpecialTokens: false, out idsConsumed, out charsWritten); + => _model.Decode(ids, destination, considerSpecialTokens: false, out idsConsumed, out charsWritten); /// /// Decode the given ids back to text and store the result in the span. @@ -1783,517 +429,33 @@ public override OperationStatus Decode(IEnumerable ids, Span destinat /// The number of characters written to the destination span. /// The operation status indicates whether all IDs were successfully decoded or if the is too small to contain the entire decoded result. public OperationStatus Decode(IEnumerable ids, Span destination, bool considerSpecialTokens, out int idsConsumed, out int charsWritten) - { - idsConsumed = 0; - charsWritten = 0; - - if (ids is null) - { - throw new ArgumentNullException(nameof(ids)); - } - - using IEnumerator enumerator = ids.GetEnumerator(); - if (!enumerator.MoveNext()) - { - return OperationStatus.Done; - } - - Span buffer = destination; - - int bytesCount = -1; - byte[]? bytesPoolArray = null; - bool prefixRemoved = false; - int suffixIndex = -1; - char prefixSuffixChar = EscapeWhiteSpaces ? SentencePieceNormalizer.DummyPrefix : ' '; - - if (enumerator.Current <= _maxByteId) - { - // First token is a byte token. - while (enumerator.Current < _byteCodeToIdOffset) - { - OperationStatus status = TryDecodeAsSpecialToken(this, enumerator.Current, considerSpecialTokens, buffer, ref charsWritten); - if (status != OperationStatus.Done) - { - return status; - } - buffer = destination.Slice(charsWritten); - - // Skip control tokens. - idsConsumed++; - if (!enumerator.MoveNext()) - { - return OperationStatus.Done; - } - } - - if (enumerator.Current <= _maxByteId) - { - if (!EncodeByte(enumerator.Current, _oneByteUtf8EncodingMaxId, _byteCodeToIdOffset, ref bytesCount, buffer, ref charsWritten, ref idsConsumed, ref bytesPoolArray)) - { - return OperationStatus.DestinationTooSmall; - } - } - else if (_vocabReverse.TryGetValue(enumerator.Current, out string? token)) - { - if (!AppendTokenWithCheckingPrefix(AddDummyPrefix, TreatWhitespaceAsSuffix, token, prefixSuffixChar, destination, ref prefixRemoved, ref suffixIndex, ref idsConsumed, ref charsWritten)) - { - return OperationStatus.DestinationTooSmall; - } - } - else - { - OperationStatus status = TryDecodeAsSpecialToken(this, enumerator.Current, considerSpecialTokens, buffer, ref charsWritten); - if (status != OperationStatus.Done) - { - return status; - } - - idsConsumed++; - } - } - else if (_vocabReverse.TryGetValue(enumerator.Current, out string? token)) - { - if (!AppendTokenWithCheckingPrefix(AddDummyPrefix, TreatWhitespaceAsSuffix, token, prefixSuffixChar, destination, ref prefixRemoved, ref suffixIndex, ref idsConsumed, ref charsWritten)) - { - return OperationStatus.DestinationTooSmall; - } - } - else - { - OperationStatus status = TryDecodeAsSpecialToken(this, enumerator.Current, considerSpecialTokens, buffer, ref charsWritten); - if (status != OperationStatus.Done) - { - return status; - } - - idsConsumed++; - } - - char[]? charPoolArray = null; - - while (enumerator.MoveNext()) - { - buffer = destination.Slice(charsWritten); - - if (enumerator.Current < _byteCodeToIdOffset) - { - if (bytesCount >= 1) - { - if (!FlushBytes(ref bytesCount, ref bytesPoolArray, ref charPoolArray, buffer, ref charsWritten, ref idsConsumed)) - { - return OperationStatus.DestinationTooSmall; - } - } - - OperationStatus status = TryDecodeAsSpecialToken(this, enumerator.Current, considerSpecialTokens, buffer, ref charsWritten); - if (status != OperationStatus.Done) - { - return status; - } - - idsConsumed++; - continue; - } - - if (enumerator.Current <= _maxByteId) - { - if (bytesCount >= 1) - { - Debug.Assert(bytesPoolArray is not null); - - if (bytesCount >= bytesPoolArray!.Length) - { - Helpers.ArrayPoolGrow(ref bytesPoolArray, bytesCount * 2); - } - - bytesPoolArray![bytesCount++] = (byte)(enumerator.Current - _byteCodeToIdOffset); - } - else - { - if (!EncodeByte(enumerator.Current, _oneByteUtf8EncodingMaxId, _byteCodeToIdOffset, ref bytesCount, buffer, ref charsWritten, ref idsConsumed, ref bytesPoolArray)) - { - return OperationStatus.DestinationTooSmall; - } - } - } - else - { - if (bytesCount >= 1) - { - if (!FlushBytes(ref bytesCount, ref bytesPoolArray, ref charPoolArray, buffer, ref charsWritten, ref idsConsumed)) - { - return OperationStatus.DestinationTooSmall; - } - } - - if (_vocabReverse.TryGetValue(enumerator.Current, out string? token)) - { - if (!AppendTokenWithCheckingPrefix(AddDummyPrefix, TreatWhitespaceAsSuffix, token, prefixSuffixChar, destination, ref prefixRemoved, ref suffixIndex, ref idsConsumed, ref charsWritten)) - { - return OperationStatus.DestinationTooSmall; - } - } - else - { - OperationStatus status = TryDecodeAsSpecialToken(this, enumerator.Current, considerSpecialTokens, buffer, ref charsWritten); - if (status != OperationStatus.Done) - { - return status; - } - - idsConsumed++; - } - } - } - - buffer = destination.Slice(charsWritten); - - if (bytesCount >= 1) - { - if (!FlushBytes(ref bytesCount, ref bytesPoolArray, ref charPoolArray, buffer, ref charsWritten, ref idsConsumed)) - { - return OperationStatus.DestinationTooSmall; - } - } - - if (suffixIndex >= 0) - { - Debug.Assert(destination[suffixIndex] == ' '); - - if (suffixIndex < charsWritten - 1) - { - destination.Slice(suffixIndex + 1, charsWritten - suffixIndex - 1).CopyTo(destination.Slice(suffixIndex)); - } - - charsWritten--; - } - - if (bytesPoolArray is not null) - { - ArrayPool.Shared.Return(bytesPoolArray); - } - - if (charPoolArray is not null) - { - ArrayPool.Shared.Return(charPoolArray); - } - - return OperationStatus.Done; - - static OperationStatus TryDecodeAsSpecialToken(SentencePieceTokenizer tokenizer, int id, bool considerSpecialTokens, Span buffer, ref int charsWritten) - { - string? specialToken = null; - - if (id == tokenizer.BeginningOfSentenceId) - { - specialToken = tokenizer.BeginningOfSentenceToken; - } - else if (id == tokenizer.EndOfSentenceId) - { - specialToken = tokenizer.EndOfSentenceToken; - } - else if (id == tokenizer.UnknownId) - { - specialToken = tokenizer.UnknownToken; - } - else if (!tokenizer._specialTokensReverse?.TryGetValue(id, out specialToken) is true) - { - return OperationStatus.InvalidData; - } - - if (considerSpecialTokens && specialToken is not null) - { - if (buffer.Length < specialToken!.Length) - { - return OperationStatus.DestinationTooSmall; - } - - specialToken.AsSpan().CopyTo(buffer); - charsWritten += specialToken.Length; - } - - return OperationStatus.Done; - } - - static bool FlushBytes(ref int bytesCount, ref byte[]? bytesPoolArray, ref char[]? charPoolArray, Span buffer, ref int charsWritten, ref int idsConsumed) - { - Debug.Assert(bytesCount >= 1); - Debug.Assert(bytesPoolArray is not null); - - int len = Encoding.UTF8.GetMaxCharCount(bytesCount); - - charPoolArray ??= ArrayPool.Shared.Rent(Math.Max(len, 50)); - - if (len > charPoolArray.Length) - { - Helpers.ArrayPoolGrow(ref charPoolArray, len); - } - - int charCount = Helpers.GetChars(bytesPoolArray.AsSpan(0, bytesCount), charPoolArray); - - if (charCount > buffer.Length) - { - return false; - } - - charPoolArray.AsSpan(0, charCount).CopyTo(buffer); - charsWritten += charCount; - idsConsumed += bytesCount; - bytesCount = -1; - - return true; - } - - static bool EncodeByte(int id, int oneByteUtf8EncodingMaxId, int byteCodeToIdOffset, ref int bytesCount, Span buffer, ref int charsWritten, ref int idsConsumed, ref byte[]? bytesPoolArray) - { - if (id <= oneByteUtf8EncodingMaxId) - { - if (buffer.Length < 1) - { - return false; - } - - buffer[0] = (char)(id - byteCodeToIdOffset); - charsWritten++; - idsConsumed++; - } - else - { - bytesCount = 1; - bytesPoolArray ??= ArrayPool.Shared.Rent(50); - bytesPoolArray[0] = (byte)(id - byteCodeToIdOffset); - } - - return true; - } - - static bool AppendTokenWithCheckingPrefix(bool addDummyPrefix, bool treatWhitespaceAsSuffix, string token, char prefixSuffixChar, Span destination, ref bool prefixRemoved, ref int suffixIndex, ref int idsConsumed, ref int charsConsumed) - { - if (token.Length == 0) - { - return true; - } - - Span buffer = destination.Slice(charsConsumed); - - ReadOnlySpan tokenSpan = token.AsSpan(); - - if (!addDummyPrefix) - { - if (tokenSpan.Length > buffer.Length) - { - return false; - } - - if (prefixSuffixChar != ' ') - { - for (int i = 0; i < tokenSpan.Length; i++) - { - buffer[i] = tokenSpan[i] == prefixSuffixChar ? ' ' : tokenSpan[i]; - } - } - else - { - tokenSpan.CopyTo(buffer); - } - - buffer = buffer.Slice(tokenSpan.Length); - charsConsumed += tokenSpan.Length; - idsConsumed++; - return true; - } - - if (treatWhitespaceAsSuffix) - { - if (tokenSpan[tokenSpan.Length - 1] == prefixSuffixChar) - { - suffixIndex = charsConsumed + tokenSpan.Length - 1; - } - - if (tokenSpan.Length > buffer.Length) - { - return false; - } - - if (prefixSuffixChar != ' ') - { - for (int i = 0; i < tokenSpan.Length; i++) - { - buffer[i] = tokenSpan[i] == prefixSuffixChar ? ' ' : tokenSpan[i]; - } - } - else - { - tokenSpan.CopyTo(buffer); - } - - charsConsumed += tokenSpan.Length; - - idsConsumed++; - } - else - { - int delta = !prefixRemoved && token[0] == prefixSuffixChar ? 1 : 0; - if (buffer.Length < token.Length - delta) - { - return false; - } + => _model.Decode(ids, destination, considerSpecialTokens, out idsConsumed, out charsWritten); - tokenSpan = tokenSpan.Slice(delta); - if (prefixSuffixChar != ' ') - { - for (int i = 0; i < tokenSpan.Length; i++) - { - buffer[i] = tokenSpan[i] == prefixSuffixChar ? ' ' : tokenSpan[i]; - } - } - else - { - tokenSpan.CopyTo(buffer); - } - - charsConsumed += tokenSpan.Length; - idsConsumed++; - - if (!prefixRemoved && delta == 1) - { - prefixRemoved = true; - } - } - - return true; - } - } - - // Tries to avoid string allocations if possible. - private string GetTokenString(int id, int index, int length, ReadOnlySpan text) - => _vocabReverse.TryGetValue(id, out string? token) ? token : text.Slice(index, length).ToString(); - - private Dictionary<(int Index, int Len), (int LeftIndex, int LeftLen, int RightIndex, int RightLen)>? Encode(ReadOnlySpan text, BpeSymbol[] symbols) - { - Debug.Assert(text.Length > 0); - Debug.Assert(symbols.Length >= text.Length); - - int symbolIndex = 0; - int spanIndex = 0; - - while (spanIndex < text.Length) - { - int len = (Char.IsHighSurrogate(text[spanIndex]) && spanIndex < text.Length - 1 && Char.IsLowSurrogate(text[spanIndex + 1])) ? 2 : 1; - - BpeSymbol s = new( - prev: symbolIndex == 0 ? -1 : symbolIndex - 1, - next: spanIndex + len >= text.Length ? -1 : symbolIndex + 1, - pieceSpan: (spanIndex, len), - id: UninitializedId, - type: 0); - - symbols[symbolIndex++] = s; - spanIndex += len; - } - - PriorityQueue agenda = new(symbolIndex); - Dictionary<(int Index, int Len), (int LeftIndex, int LeftLen, int RightIndex, int RightLen)>? revMerge = null; - - for (int i = 1; i < symbolIndex; i++) - { - TryMerge(i - 1, i, text); - } - - while (agenda.Count > 0) - { - SymbolPair top = agenda.Dequeue(); - - if (symbols[top.Left].pieceSpan.Length == 0 || symbols[top.Right].pieceSpan.Length == 0 || - symbols[top.Left].pieceSpan.Length + symbols[top.Right].pieceSpan.Length != top.Length) - { - continue; - } - - // Replaces symbols with `top` rule. - symbols[top.Left].pieceSpan = (symbols[top.Left].pieceSpan.Index, symbols[top.Left].pieceSpan.Length + symbols[top.Right].pieceSpan.Length); - symbols[top.Left].id = top.Id; - - // Updates prev/next pointers. - symbols[top.Left].next = symbols[top.Right].next; - - if (symbols[top.Right].next >= 0) - { - symbols[symbols[top.Right].next].prev = top.Left; - } - symbols[top.Right].pieceSpan = (0, 0); - - // Adds new symbol pairs which are newly added after symbol replacement. - TryMerge(symbols[top.Left].prev, top.Left, text); - TryMerge(top.Left, symbols[top.Left].next, text); - } - - return revMerge; - - void TryMerge(int left, int right, ReadOnlySpan textSpan) - { - if (left == -1 || right == -1) - { - return; - } - - int pieceLength = symbols[left].pieceSpan.Length + symbols[right].pieceSpan.Length; - if (!_vocab.TryGetValue(textSpan.Slice(symbols[left].pieceSpan.Index, pieceLength), out (int Id, float Score, byte Type) leftId)) - { - return; - } - - symbols[left].type = leftId.Type; - - SymbolPair pair = new(left, right, leftId.Score, pieceLength, leftId.Id); - agenda.Enqueue(pair); - - if (leftId.Type == (byte)ModelProto.Types.SentencePiece.Types.Type.Unused) - { - revMerge ??= new(); - revMerge.Add((symbols[left].pieceSpan.Index, pieceLength), (symbols[left].pieceSpan.Index, symbols[left].pieceSpan.Length, symbols[right].pieceSpan.Index, symbols[right].pieceSpan.Length)); - } - } - } - - private struct SymbolPair : IEquatable, IComparable + /// + /// Creates an instance of SentencePieceTokenizer. The model stream should contain a SentencePiece model as specified in the following documentation: + /// https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto. + /// + /// The stream containing the SentencePiece Bpe or Unigram model. + /// Indicate emitting the beginning of sentence token during the encoding. + /// Indicate emitting the end of sentence token during the encoding. + /// The additional tokens to add to the vocabulary. + /// + /// When creating the tokenizer, ensure that the vocabulary stream is sourced from a trusted provider. + /// + public static SentencePieceTokenizer Create( + Stream modelStream, + bool addBeginOfSentence = true, + bool addEndOfSentence = false, + IReadOnlyDictionary? specialTokens = null) { - public int Left { get; set; } - public int Right { get; set; } - public int Length { get; set; } - public float Score { get; set; } - public int Id { get; set; } - - public SymbolPair(int left, int right, float score, int length, int id) - { - Left = left; - Right = right; - Score = score; - Length = length; - Id = id; - } - - public int CompareTo(SymbolPair other) - { - if (Score != other.Score) - { - return other.Score.CompareTo(Score); - } + ModelProto modelProto = ModelProto.Parser.ParseFrom(modelStream); - return other.Left.CompareTo(Left); - } - - public override int GetHashCode() + if (modelProto is null) { - int hashCode = 23; - hashCode = (hashCode * 37) + Score.GetHashCode(); - hashCode = (hashCode * 37) + Left.GetHashCode(); - return hashCode; + throw new ArgumentNullException(nameof(modelProto)); } - public bool Equals(SymbolPair other) => Left == other.Left && Score == other.Score; + return new SentencePieceTokenizer(modelProto, addBeginOfSentence, addEndOfSentence, specialTokens); } - - private record struct BpeSymbol(int prev, int next, (int Index, int Length) pieceSpan, int id, byte type); } } diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceUnigramModel.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceUnigramModel.cs new file mode 100644 index 0000000000..a578362f73 --- /dev/null +++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceUnigramModel.cs @@ -0,0 +1,1399 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Sentencepiece; +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Diagnostics; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text; +using System.Text.RegularExpressions; + +namespace Microsoft.ML.Tokenizers +{ + internal sealed class SentencePieceUnigramModel : SentencePieceBaseModel + { + private readonly SortedDictionary _vocab; + private readonly (string Piece, float Score, ModelProto.Types.SentencePiece.Types.Type Type)[] _vocabReverse; + private readonly DoubleArrayTrie _trie; + private readonly float _minScore; + private readonly float _maxScore; + private const float UnkPenalty = 10.0f; + + public SentencePieceUnigramModel(ModelProto modelProto, bool addBos, bool addEos, IReadOnlyDictionary? specialTokens = null) : base(modelProto, addBos, addEos, specialTokens) + { + _vocab = new SortedDictionary(OrdinalUtf8StringComparer.Instance); + _vocabReverse = new (string Piece, float Score, ModelProto.Types.SentencePiece.Types.Type Type)[modelProto.Pieces.Count]; + + _minScore = float.MaxValue; + _maxScore = float.MinValue; + + for (int i = 0; i < modelProto.Pieces.Count; i++) + { + if (modelProto.Pieces[i].Type == ModelProto.Types.SentencePiece.Types.Type.Normal || + modelProto.Pieces[i].Type == ModelProto.Types.SentencePiece.Types.Type.UserDefined || + modelProto.Pieces[i].Type == ModelProto.Types.SentencePiece.Types.Type.Unused) + { + string piece = modelProto.Pieces[i].Piece; + float score = modelProto.Pieces[i].Score; + _vocabReverse[i] = (piece, score, modelProto.Pieces[i].Type); + _vocab.Add(piece, i); + _minScore = Math.Min(_minScore, score); + _maxScore = Math.Max(_maxScore, score); + } + else if (modelProto.Pieces[i].Type == ModelProto.Types.SentencePiece.Types.Type.Byte) + { + MaxByteId = i; + } + else if (modelProto.Pieces[i].Type == ModelProto.Types.SentencePiece.Types.Type.Unknown) + { + // Ensure the unknown token is cached + _vocabReverse[i] = (modelProto.Pieces[i].Piece, modelProto.Pieces[i].Score, ModelProto.Types.SentencePiece.Types.Type.Unknown); + } + } + + ByteCodeToIdOffset = _vocab.TryGetValue("<0x00>", out int id) ? id : MaxByteId; + OneByteUtf8EncodingMaxId = ByteCodeToIdOffset + 0x7F; // 0x7F is the maximum value of the one byte UTF-8 character. + MaxIdByteFallbackId = ByteCodeToIdOffset + 0xFF; // from <0x00> to <0xFF>. + + _trie = new DoubleArrayTrie(_vocab); + + _vocabReverse[BeginningOfSentenceId] = (BeginningOfSentenceToken, 0f, 0); + _vocabReverse[EndOfSentenceId] = (EndOfSentenceToken, 0f, 0); + } + + public override IReadOnlyDictionary Vocabulary => new ReadOnlyDictionary(_vocab); + + public int MaxIdByteFallbackId { get; } + + public override IReadOnlyList EncodeToTokens(string? text, ReadOnlySpan textSpan, out string? normalizedText, bool addBeginningOfSentence, bool addEndOfSentence, bool considerNormalization) + { + ReadOnlySpan textToEncode = string.IsNullOrEmpty(text) ? textSpan : text.AsSpan(); + if (textToEncode.IsEmpty) + { + normalizedText = string.Empty; + return Array.Empty(); + } + + List tokens = new(); + + // Rent a buffer that approximately enough to hold the Utf8 encoded bytes, the normalization of the encoded buffer, and some extra memory to for encoding results. + int[] buffer = ArrayPool.Shared.Rent(textToEncode.Length * 3); + + // Hold the Utf16 normalized string. + char[] normalizedString = ArrayPool.Shared.Rent(textToEncode.Length + 2); + + if (SpecialTokensRegex is not null) + { + EncodeToTokensWithSpecialTokens(textToEncode, addBeginningOfSentence, addEndOfSentence, considerNormalization, tokens, buffer, ref normalizedString, out normalizedText); + } + else + { + EncodeToTokensWithoutSpecialTokens(textToEncode, addBeginningOfSentence, addEndOfSentence, considerNormalization, tokens, buffer, ref normalizedString, out normalizedText); + } + + ArrayPool.Shared.Return(normalizedString); + ArrayPool.Shared.Return(buffer); + + return tokens; + } + + public override bool TryMapIdToToken(int id, out string? token) + { + if ((uint)id >= (uint)(_vocabReverse.Length)) + { + token = null; + return false; + } + + token = _vocabReverse[id].Piece; + return true; + } + + private void StoreNormalizedTextFromEnd(ReadOnlySpan text, ref char[] normalizedString, ref int normalizedStringIndexFromEnd) + { + int remainingLength = normalizedString.Length - normalizedStringIndexFromEnd; + if (text.Length > remainingLength) + { + char[] utf16NormalizedString = ArrayPool.Shared.Rent(normalizedString.Length << 1); + normalizedString.AsSpan().Slice(normalizedString.Length - normalizedStringIndexFromEnd).CopyTo(utf16NormalizedString.AsSpan(utf16NormalizedString.Length - normalizedStringIndexFromEnd)); + ArrayPool.Shared.Return(normalizedString); + normalizedString = utf16NormalizedString; + } + + text.CopyTo(normalizedString.AsSpan(normalizedString.Length - normalizedStringIndexFromEnd - text.Length)); + normalizedStringIndexFromEnd += text.Length; + } + + private void StoreNormalizedTextFromEnd(ReadOnlySpan utf8Bytes, ref char[] normalizedString, ref int normalizedStringIndexFromEnd) + { + int remainingLength = normalizedString.Length - normalizedStringIndexFromEnd; + int expectedCount = Helpers.GetUtf16LengthFromUtf8Bytes(utf8Bytes); + + if (expectedCount > remainingLength) + { + char[] utf16NormalizedString = ArrayPool.Shared.Rent(normalizedString.Length << 1); + normalizedString.AsSpan().Slice(normalizedString.Length - normalizedStringIndexFromEnd).CopyTo(utf16NormalizedString.AsSpan(utf16NormalizedString.Length - normalizedStringIndexFromEnd)); + ArrayPool.Shared.Return(normalizedString); + normalizedString = utf16NormalizedString; + } + + bool res = Helpers.ConvertUtf8ToUtf16(utf8Bytes, normalizedString.AsSpan(normalizedString.Length - normalizedStringIndexFromEnd - expectedCount), out int bytesConsumed, out int charsWritten); + Debug.Assert(res); + Debug.Assert(bytesConsumed == utf8Bytes.Length); + Debug.Assert(charsWritten == expectedCount); + normalizedStringIndexFromEnd += expectedCount; + } + + private void StoreNormalizedText(ReadOnlySpan text, ref char[] normalizedString, ref int normalizedStringIndex) + { + Span utf16NormalizedString = normalizedString.AsSpan().Slice(normalizedStringIndex); + + if (text.Length > utf16NormalizedString.Length) + { + Helpers.ArrayPoolGrow(ref normalizedString, normalizedString.Length << 1); + utf16NormalizedString = normalizedString.AsSpan().Slice(normalizedStringIndex); + } + + text.CopyTo(utf16NormalizedString); + normalizedStringIndex += text.Length; + } + + private void StoreNormalizedText(ReadOnlySpan normalizationSpan, ref char[] normalizedString, ref int normalizedStringIndex) + { + Span normalizedUtf16Span = normalizedString.AsSpan().Slice(normalizedStringIndex); + if (Encoding.UTF8.GetMaxCharCount(normalizationSpan.Length) > normalizedUtf16Span.Length) + { + Helpers.ArrayPoolGrow(ref normalizedString, normalizedString.Length << 1); + normalizedUtf16Span = normalizedString.AsSpan().Slice(normalizedStringIndex); + } + + bool res = Helpers.ConvertUtf8ToUtf16(normalizationSpan, normalizedUtf16Span, out int bytesConsumed, out int charsWritten); + Debug.Assert(res); + normalizedStringIndex += charsWritten; + } + + private void EncodeToTokensWithSpecialTokens( + ReadOnlySpan text, + bool addBeginningOfSentence, + bool addEndOfSentence, + bool considerNormalization, + List tokens, + int[] buffer, + ref char[] normalizedString, + out string? normalizedText) + { + Debug.Assert(SpecialTokensRegex is not null); + + if (addBeginningOfSentence) + { + tokens.Add(new EncodedToken(BeginningOfSentenceId, BeginningOfSentenceToken, new Range(0, 0))); + } + + int currentOffset = 0; + int progressOffset = 0; + int normalizedStringIndex = 0; + + foreach ((int Offset, int Length) in PreTokenizer.SplitText(text, SpecialTokensRegex!)) + { + if (Offset > currentOffset) + { + EncodeToTokensInternal(text.Slice(currentOffset, Offset - currentOffset), considerNormalization, ref progressOffset, tokens, buffer, ref normalizedString, ref normalizedStringIndex); + } + + if (InternalSpecialTokens!.TryGetValue(text.Slice(Offset, Length), out int id)) + { + tokens.Add(new EncodedToken(id, SpecialTokensReverse![id], new Range(progressOffset, progressOffset + Length))); + progressOffset += Length; + + StoreNormalizedText(text.Slice(Offset, Length), ref normalizedString, ref normalizedStringIndex); + } + + currentOffset = Offset + Length; + } + + if (currentOffset < text.Length) + { + EncodeToTokensInternal(text.Slice(currentOffset), considerNormalization, ref progressOffset, tokens, buffer, ref normalizedString, ref normalizedStringIndex); + } + + if (addEndOfSentence) + { + tokens.Add(new EncodedToken(EndOfSentenceId, EndOfSentenceToken, new Range(progressOffset, progressOffset))); + } + + normalizedText = normalizedString.AsSpan().Slice(0, normalizedStringIndex).ToString(); + } + + private void EncodeToTokensWithoutSpecialTokens( + ReadOnlySpan text, + bool addBeginningOfSentence, + bool addEndOfSentence, + bool considerNormalization, + List tokens, + int[] buffer, + ref char[] normalizedString, + out string? normalizedText) + { + if (addBeginningOfSentence) + { + tokens.Add(new EncodedToken(BeginningOfSentenceId, BeginningOfSentenceToken, new Range(0, 0))); + } + + int progressOffset = 0; + int normalizedStringIndex = 0; + + EncodeToTokensInternal(text, considerNormalization, ref progressOffset, tokens, buffer, ref normalizedString, ref normalizedStringIndex); + + if (addEndOfSentence) + { + tokens.Add(new EncodedToken(EndOfSentenceId, EndOfSentenceToken, new Range(progressOffset, progressOffset))); + } + + normalizedText = normalizedString.AsSpan().Slice(0, normalizedStringIndex).ToString(); + } + + private void NormalizeText( + ReadOnlySpan text, + bool considerNormalization, + int[] buffer, + out byte[]? normalizedArrayPool, + out Span normalizationSpan) + { + Debug.Assert(Encoding.UTF8.GetMaxByteCount(text.Length) * 3 <= buffer.Length * sizeof(int)); + Span byteSpan = MemoryMarshal.AsBytes(buffer.AsSpan()); + + // Unigram is currently working with Utf8 encoded bytes. + // if considerNormalization is true, the utf8 encoded bytes will be normalized to utf8 bytes too. + int byteCount = Helpers.GetUtf8Bytes(text, byteSpan); + normalizationSpan = byteSpan.Slice(byteCount); + + Debug.Assert(normalizationSpan.Length >= (byteCount << 1)); + normalizedArrayPool = null; + + if (considerNormalization) + { + int normalizationCount = Normalizer!.Normalize(byteSpan.Slice(0, byteCount), ref normalizationSpan, ref normalizedArrayPool); + normalizationSpan = normalizationSpan.Slice(0, normalizationCount); + if (normalizationCount == 0) + { + if (normalizedArrayPool is not null) + { + ArrayPool.Shared.Return(normalizedArrayPool); + normalizedArrayPool = null; + } + + return; + } + } + else + { + normalizationSpan = byteSpan.Slice(0, byteCount); + } + } + + private void EncodeToTokensInternal( + ReadOnlySpan text, + bool considerNormalization, + ref int tokensOffset, + List tokens, + int[] buffer, + ref char[] normalizedString, + ref int normalizedStringIndex) + { + // + // Normalize text + // + + NormalizeText(text, considerNormalization, buffer, out byte[]? normalizedArrayPool, out Span normalizationSpan); + + // + // Encode using Unigram algorithm + // + + BestPathNode[] bestPathEndsAt = ArrayPool.Shared.Rent(normalizationSpan.Length + 1); + + Encode(normalizationSpan, bestPathEndsAt); + + // + // Fill the results + // + + // Backtrack to identify the best path. + int insertionStartPosition = tokens.Count; + int endsAt = normalizationSpan.Length; + bool unknownEncountered = false; + + while (endsAt > 0) + { + ref BestPathNode node = ref bestPathEndsAt[endsAt]; + + string stringToken = node.Id == UnknownId ? Helpers.GetString(normalizationSpan.Slice(node.StartsAt, endsAt - node.StartsAt)) : _vocabReverse[node.Id].Piece; + int tokenLength = stringToken.Length; + + tokens.Add(new EncodedToken(node.Id, stringToken, new Range(0, tokenLength))); // we will update the range later. + endsAt = node.StartsAt; + unknownEncountered = unknownEncountered || node.Id == UnknownId; + } + + int start = insertionStartPosition; + int end = tokens.Count - 1; + + // Reverse the stored tokens and fix the encoded tokens offset. + while (start < end) + { + EncodedToken temp = tokens[start]; + tokens[start] = tokens[end]; + tokens[end] = temp; + + int tokenLength = tokens[start].Offset.End.Value; + // Fix the offsets + tokens[start] = new EncodedToken(tokens[start].Id, tokens[start].Value, new Range(tokensOffset, tokensOffset + tokenLength)); + tokensOffset += tokenLength; + + start++; + end--; + } + + while (start < tokens.Count) + { + int tokenLength = tokens[start].Offset.End.Value; + // Fix the offsets + tokens[start] = new EncodedToken(tokens[start].Id, tokens[start].Value, new Range(tokensOffset, tokensOffset + tokenLength)); + tokensOffset += tokenLength; + start++; + } + + StoreNormalizedText(normalizationSpan, ref normalizedString, ref normalizedStringIndex); + + if (ByteFallback && unknownEncountered) + { + FallbackToByteEncoding(normalizedString, tokens, insertionStartPosition); + } + + ArrayPool.Shared.Return(bestPathEndsAt); + if (normalizedArrayPool is not null) + { + ArrayPool.Shared.Return(normalizedArrayPool); + } + } + + private void FallbackToByteEncoding(ReadOnlySpan normalizationSpan, List tokens, int insertionStartPosition) + { + Span destination = stackalloc byte[4]; + + while (insertionStartPosition < tokens.Count) + { + if (tokens[insertionStartPosition].Id == UnknownId) + { + int offsetStart = tokens[insertionStartPosition].Offset.Start.Value; + int tokenLength = tokens[insertionStartPosition].Offset.End.Value - offsetStart; + + tokens.RemoveAt(insertionStartPosition); + + int charLength = 0; + for (int i = 0; i < tokenLength; i += charLength) + { + int codepointLength = Helpers.EncodeNextUtf8(normalizationSpan.Slice(offsetStart), destination); + charLength = codepointLength == 4 ? 2 : 1; + + Debug.Assert(codepointLength > 0); + + int id = ByteCodeToIdOffset + destination[0]; + tokens.Insert(insertionStartPosition++, new EncodedToken(id, _vocabReverse[id].Piece, new Range(offsetStart, offsetStart + charLength))); + + for (int j = 1; j < codepointLength; j++) + { + id = ByteCodeToIdOffset + destination[j]; + tokens.Insert(insertionStartPosition++, new EncodedToken(id, _vocabReverse[id].Piece, new Range(offsetStart + charLength, offsetStart + charLength))); + } + + offsetStart += charLength; + } + + continue; + } + + insertionStartPosition++; + } + } + + private struct BestPathNode + { + public BestPathNode() + { + Id = -1; + BestPathScore = 0f; + StartsAt = -1; + } + + // The vocab id. (maybe -1 for UNK) + public int Id { get; set; } + + // The total score of the best path ending at this node. + public float BestPathScore { get; set; } + + // The starting position (in utf-8) of this node. The entire best path can be constructed by backtracking along this link. + public int StartsAt { get; set; } + }; + + private void Encode(ReadOnlySpan normalized, Span bestPathEndsAt) + { + Debug.Assert(normalized.Length > 0); + + int size = normalized.Length; + float unkScore = _minScore - UnkPenalty; + + Debug.Assert(bestPathEndsAt.Length >= size + 1); + + // The ends are exclusive. + for (int i = 0; i < size + 1; i++) + { + bestPathEndsAt[i] = new BestPathNode(); + } + + // Generate lattice on-the-fly (not stored) and update best_path_ends_at. + int startsAt = 0; + + while (startsAt < size) + { + int nodePos = 0; + int keyPos = startsAt; + float bestPathScoreTillHere = bestPathEndsAt[startsAt].BestPathScore; + bool hasSingleNode = false; + int mbLen = Helpers.OneCharLen(normalized[startsAt]); + while (keyPos < size) + { + int ret = _trie.Traverse(normalized, ref nodePos, ref keyPos, keyPos + 1); + if (ret == -2) + { + break; + } + + if (ret >= 0) + { + if (_vocabReverse[ret].Type == ModelProto.Types.SentencePiece.Types.Type.Unused) + { + continue; + } + + // Update the best path node. + ref BestPathNode targetNode = ref bestPathEndsAt[keyPos]; + int length = keyPos - startsAt; + + // User defined symbol receives extra bonus to always be selected. + float score = _vocabReverse[ret].Type == ModelProto.Types.SentencePiece.Types.Type.UserDefined ? length * _maxScore - 0.1f : _vocabReverse[ret].Score; + float candidateBestPathScore = score + bestPathScoreTillHere; + + if (targetNode.StartsAt == -1 || candidateBestPathScore > targetNode.BestPathScore) + { + targetNode.BestPathScore = candidateBestPathScore; + targetNode.StartsAt = startsAt; + targetNode.Id = ret; + } + + if (!hasSingleNode && length == mbLen) + { + hasSingleNode = true; + } + } + } + + if (!hasSingleNode) + { + ref BestPathNode targetNode = ref bestPathEndsAt[startsAt + mbLen]; + float candidateBestPathScore = unkScore + bestPathScoreTillHere; + + if (targetNode.StartsAt == -1 || candidateBestPathScore > targetNode.BestPathScore) + { + targetNode.BestPathScore = candidateBestPathScore; + targetNode.StartsAt = startsAt; + targetNode.Id = UnknownId; + } + } + + // Move by one unicode character. + startsAt += mbLen; + } + } + + public override IReadOnlyList EncodeToIds( + string? text, + ReadOnlySpan textSpan, + bool addBeginningOfSentence, + bool addEndOfSentence, + bool considerNormalization, + out string? normalizedText, + out int charsConsumed, + int maxTokenCount = int.MaxValue) + { + ReadOnlySpan textToEncode = string.IsNullOrEmpty(text) ? textSpan : text.AsSpan(); + + if (textToEncode.IsEmpty || maxTokenCount <= 0) + { + normalizedText = null; + charsConsumed = 0; + return Array.Empty(); + } + + List? ids = new(); + + if (addBeginningOfSentence) + { + ids.Add(BeginningOfSentenceId); + if (maxTokenCount == 1) + { + normalizedText = null; + charsConsumed = 0; + return ids; // done. no more space for anything else. + } + } + + // Rent a buffer that approximately enough to hold the Utf8 encoded bytes, the normalization of the encoded buffer, and some extra memory to for encoding results. + int[] buffer = ArrayPool.Shared.Rent(textToEncode.Length * 3); + + // when maxTokenCount == int.MaxValue we don't need to return the normalized string as most likely we can handle the whole input text without need to continuation. + char[]? normalizedString = maxTokenCount == int.MaxValue ? null : ArrayPool.Shared.Rent(textToEncode.Length + 2); + + if (SpecialTokensRegex is not null) + { + EncodeToIdsWithSpecialTokens(textToEncode, considerNormalization, ids, buffer, ref normalizedString, out normalizedText, out charsConsumed, maxTokenCount); + } + else + { + EncodeToIdsWithoutSpecialTokens(textToEncode, considerNormalization, ids, buffer, ref normalizedString, out normalizedText, out charsConsumed, maxTokenCount); + } + + if (addEndOfSentence && ids.Count < maxTokenCount) + { + ids.Add(EndOfSentenceId); + } + + if (normalizedString is not null) + { + ArrayPool.Shared.Return(normalizedString); + } + + ArrayPool.Shared.Return(buffer); + + return ids; + } + + private void StoreNormalizedText(ReadOnlySpan text, bool considerNormalization, int[] buffer, ref char[]? normalizedString, ref int normalizedStringIndex) + { + Debug.Assert(normalizedString is not null); + + if (!considerNormalization) + { + StoreNormalizedText(text, ref normalizedString!, ref normalizedStringIndex); + } + else + { + NormalizeText(text, considerNormalization, buffer, out byte[]? normalizedArrayPool, out Span normalizationSpan); + StoreNormalizedText(normalizationSpan, ref normalizedString!, ref normalizedStringIndex); + if (normalizedArrayPool is not null) + { + ArrayPool.Shared.Return(normalizedArrayPool); + } + } + } + + private void EncodeToIdsWithSpecialTokens( + ReadOnlySpan text, + bool considerNormalization, + List ids, + int[] buffer, + ref char[]? normalizedString, + out string? normalizedText, + out int charsConsumed, + int maxTokenCount) + { + Debug.Assert(SpecialTokensRegex is not null); + Debug.Assert(maxTokenCount > 0); + + charsConsumed = 0; + normalizedText = null; + + int currentOffset = 0; + int normalizedStringIndex = 0; + + foreach ((int Offset, int Length) in PreTokenizer.SplitText(text, SpecialTokensRegex!)) + { + if (Offset > currentOffset) + { + if (ids.Count >= maxTokenCount) + { + if (normalizedString is not null) + { + StoreNormalizedText(text.Slice(currentOffset, Offset - currentOffset), considerNormalization, buffer, ref normalizedString, ref normalizedStringIndex); + } + } + else + { + EncodeToIdsInternal(text.Slice(currentOffset, Offset - currentOffset), considerNormalization, ids, buffer, ref normalizedString, ref normalizedStringIndex, ref charsConsumed, maxTokenCount); + } + } + + if (InternalSpecialTokens!.TryGetValue(text.Slice(Offset, Length), out int id)) + { + if (normalizedString is not null) + { + StoreNormalizedText(text.Slice(Offset, Length), ref normalizedString, ref normalizedStringIndex); + } + + if (ids.Count < maxTokenCount) + { + ids.Add(id); // special token id + + charsConsumed += Length; + } + } + + currentOffset = Offset + Length; + } + + if (currentOffset < text.Length) + { + if (ids.Count < maxTokenCount) + { + EncodeToIdsInternal(text.Slice(currentOffset), considerNormalization, ids, buffer, ref normalizedString, ref normalizedStringIndex, ref charsConsumed, maxTokenCount); + } + else if (normalizedString is not null) + { + StoreNormalizedText(text.Slice(currentOffset), considerNormalization, buffer, ref normalizedString, ref normalizedStringIndex); + } + } + + if (normalizedString is not null) + { + normalizedText = normalizedString.AsSpan().Slice(0, normalizedStringIndex).ToString(); + } + } + + private void EncodeToIdsWithoutSpecialTokens( + ReadOnlySpan text, + bool considerNormalization, + List ids, + int[] buffer, + ref char[]? normalizedString, + out string? normalizedText, + out int charsConsumed, + int maxTokenCount) + { + charsConsumed = 0; + normalizedText = null; + int normalizedStringIndex = 0; + + EncodeToIdsInternal(text, considerNormalization, ids, buffer, ref normalizedString, ref normalizedStringIndex, ref charsConsumed, maxTokenCount); + + if (normalizedString is not null) + { + normalizedText = normalizedString.AsSpan().Slice(0, normalizedStringIndex).ToString(); + } + } + + private void FallbackToByteEncoding(List ids, ReadOnlySpan normalizationSpan, (int IdsIndex, int Utf8Index, int Utf8Length)[] unknownTokensTracking, int unknownTokensCount) + { + Debug.Assert(unknownTokensCount > 0); + Debug.Assert(unknownTokensTracking is not null && unknownTokensTracking.Length >= unknownTokensCount); + + // validate reverse ordered. + Debug.Assert(unknownTokensCount == 1 || unknownTokensTracking![0].IdsIndex > unknownTokensTracking![1].IdsIndex); + + int accumulatedOffsets = 0; + for (int i = unknownTokensCount - 1; i >= 0; i--) + { + unknownTokensTracking![i].IdsIndex += accumulatedOffsets; + (int IdsIndex, int Utf8Index, int Utf8Length) = unknownTokensTracking![i]; + + if (IdsIndex >= ids.Count) + { + continue; // already removed. + } + + Debug.Assert(ids[IdsIndex] == UnknownId); + + // Replace the Unknown id entry with the byte encoding. + ids.RemoveAt(IdsIndex); + + for (int j = Utf8Length - 1; j >= 0; j--) + { + ids.Insert(IdsIndex, ByteCodeToIdOffset + normalizationSpan[Utf8Index + j]); + } + + // -1 because we removed the Unknown id entry. + accumulatedOffsets += Utf8Length - 1; + } + } + + private void EncodeToIdsInternal( + ReadOnlySpan text, + bool considerNormalization, + List ids, + int[] buffer, + ref char[]? normalizedString, + ref int normalizedStringIndex, + ref int charsConsumed, + int maxTokenCount) + { + if (ids.Count >= maxTokenCount) + { + return; + } + + // + // Normalize the input text. + // + + NormalizeText(text, considerNormalization, buffer, out byte[]? normalizedArrayPool, out Span normalizationSpan); + + // + // Do the actual encoding + // + + BestPathNode[] bestPathEndsAt = ArrayPool.Shared.Rent(normalizationSpan.Length + 1); + + Encode(normalizationSpan, bestPathEndsAt); + + // Backtrack to identify the best path. + int insertionStartPosition = ids.Count; + int endsAt = normalizationSpan.Length; + + int unknownTokensCount = 0; + (int IdsIndex, int Utf8Index, int Utf8Length)[]? unknownTokensTracking = null; + bool needToTrackUnknown = ByteFallback || maxTokenCount != int.MaxValue; + + while (endsAt > 0) + { + ref BestPathNode node = ref bestPathEndsAt[endsAt]; + + ids.Add(node.Id); + + if (node.Id == UnknownId && needToTrackUnknown) + { + unknownTokensCount++; + if (unknownTokensTracking is null) + { + unknownTokensTracking = ArrayPool<(int IdsIndex, int Utf8Index, int Utf8Length)>.Shared.Rent(10); + } + else if (unknownTokensTracking.Length == unknownTokensCount) + { + Helpers.ArrayPoolGrow(ref unknownTokensTracking, unknownTokensCount << 1); + } + + unknownTokensTracking[unknownTokensCount - 1] = (ids.Count - 1, node.StartsAt, endsAt - node.StartsAt); + } + + endsAt = node.StartsAt; + } + + ArrayPool.Shared.Return(bestPathEndsAt); + + ids.Reverse(insertionStartPosition, ids.Count - insertionStartPosition); + + if (unknownTokensCount > 0) + { + Debug.Assert(unknownTokensTracking is not null && unknownTokensTracking.Length >= unknownTokensCount); + + int end = ids.Count - 1; + + // Fix the id indexes after swapping + for (int i = 0; i < unknownTokensCount; i++) + { + unknownTokensTracking![i].IdsIndex = insertionStartPosition + (end - unknownTokensTracking![i].IdsIndex); + } + } + + // + // Handle maxTokenCount + // + + if (maxTokenCount == int.MaxValue) + { + Debug.Assert(unknownTokensCount == 0 && unknownTokensTracking is null); + + if (ByteFallback && unknownTokensCount > 0) + { + Debug.Assert(unknownTokensTracking is not null && unknownTokensTracking.Length >= unknownTokensCount); + FallbackToByteEncoding(ids, normalizationSpan, unknownTokensTracking!, unknownTokensCount); + } + + // sure we should be consumed the whole text. + charsConsumed += text.Length; + + if (normalizedArrayPool is not null) + { + ArrayPool.Shared.Return(normalizedArrayPool); + } + + // done't bother storing the normalized string as we return null when we can handle the whole input text. + Debug.Assert(normalizedString is null); + + return; + } + + // Check if we need to truncate the tokens. and calculate the accurate consumed characters count. + int index = insertionStartPosition; + int addedTokensCount = 0; + + while (index < ids.Count && index + addedTokensCount < maxTokenCount) + { + if (ids[index] == UnknownId) + { + Debug.Assert(unknownTokensCount > 0 && unknownTokensTracking is not null && unknownTokensTracking.Length >= unknownTokensCount); + + int j = 0; + for (; j < unknownTokensCount; j++) + { + if (unknownTokensTracking![j].IdsIndex == index) + { + break; + } + } + + Debug.Assert(j < unknownTokensCount); + + ReadOnlySpan utf8UnknownBytes = normalizationSpan.Slice(unknownTokensTracking![j].Utf8Index, unknownTokensTracking![j].Utf8Length); + + if (ByteFallback) + { + if (index + utf8UnknownBytes.Length > maxTokenCount) + { + break; // not enough space + } + + addedTokensCount += utf8UnknownBytes.Length - 1; + } + + charsConsumed += Helpers.GetUtf16LengthFromUtf8Bytes(utf8UnknownBytes); + } + else + { + charsConsumed += _vocabReverse[ids[index]].Piece.Length; + } + + index++; + } + + if (index < ids.Count) + { + ids.RemoveRange(index, ids.Count - index); + } + + if (unknownTokensCount > 0 && ByteFallback) + { + Debug.Assert(unknownTokensTracking is not null && unknownTokensTracking.Length >= unknownTokensCount); + FallbackToByteEncoding(ids, normalizationSpan, unknownTokensTracking!, unknownTokensCount); + } + + // + // Create the normalized string. + // + + if (normalizedString is not null) + { + StoreNormalizedText(normalizationSpan, ref normalizedString, ref normalizedStringIndex); + } + + if (unknownTokensTracking is not null) + { + ArrayPool<(int IdsIndex, int Utf8Index, int Utf8Length)>.Shared.Return(unknownTokensTracking); + } + + if (normalizedArrayPool is not null) + { + ArrayPool.Shared.Return(normalizedArrayPool); + } + } + + public override int CountTokens( + string? text, + ReadOnlySpan textSpan, + bool addBeginningOfSentence, + bool addEndOfSentence, + bool considerNormalization, + out string? normalizedText, + out int charsConsumed, + int maxTokenCount = int.MaxValue) + { + ReadOnlySpan textToEncode = string.IsNullOrEmpty(text) ? textSpan : text.AsSpan(); + + if (textToEncode.IsEmpty || maxTokenCount <= 0) + { + normalizedText = null; + charsConsumed = 0; + return 0; + } + + int tokenCount = 0; + + if (addBeginningOfSentence) + { + tokenCount++; + + if (maxTokenCount == 1) + { + normalizedText = null; + charsConsumed = 0; + return tokenCount; + } + } + + // Rent a buffer that approximately enough to hold the Utf8 encoded bytes, the normalization of the encoded buffer, and some extra memory to for encoding results. + int[] buffer = ArrayPool.Shared.Rent(textToEncode.Length * 3); + + // when maxTokenCount == int.MaxValue we don't need to return the normalized string as most likely we can handle the whole input text without need to continuation. + char[]? normalizedString = maxTokenCount == int.MaxValue ? null : ArrayPool.Shared.Rent(textToEncode.Length + 2); + + if (SpecialTokensRegex is not null) + { + CountTokensWithSpecialTokens(textToEncode, considerNormalization, ref tokenCount, buffer, ref normalizedString, out normalizedText, out charsConsumed, maxTokenCount); + } + else + { + CountTokensWithoutSpecialTokens(textToEncode, considerNormalization, ref tokenCount, buffer, ref normalizedString, out normalizedText, out charsConsumed, maxTokenCount); + } + + if (addEndOfSentence && tokenCount < maxTokenCount) + { + tokenCount++; + } + + if (normalizedString is not null) + { + ArrayPool.Shared.Return(normalizedString); + } + + ArrayPool.Shared.Return(buffer); + + return tokenCount; + } + + private void CountTokensWithSpecialTokens( + ReadOnlySpan text, + bool considerNormalization, + ref int tokenCount, + int[] buffer, + ref char[]? normalizedString, + out string? normalizedText, + out int charsConsumed, + int maxTokenCount) + { + Debug.Assert(SpecialTokensRegex is not null); + Debug.Assert(maxTokenCount > 0); + + charsConsumed = 0; + normalizedText = null; + + int currentOffset = 0; + int normalizedStringIndex = 0; + + foreach ((int Offset, int Length) in PreTokenizer.SplitText(text, SpecialTokensRegex!)) + { + if (Offset > currentOffset) + { + if (tokenCount >= maxTokenCount) + { + if (normalizedString is not null) + { + StoreNormalizedText(text.Slice(currentOffset, Offset - currentOffset), considerNormalization, buffer, ref normalizedString, ref normalizedStringIndex); + } + } + else + { + CountTokensInternal(text.Slice(currentOffset, Offset - currentOffset), considerNormalization, ref tokenCount, buffer, ref normalizedString, ref normalizedStringIndex, ref charsConsumed, maxTokenCount); + } + } + + if (InternalSpecialTokens!.TryGetValue(text.Slice(Offset, Length), out int id)) + { + if (normalizedString is not null) + { + StoreNormalizedText(text.Slice(Offset, Length), ref normalizedString, ref normalizedStringIndex); + } + + if (tokenCount < maxTokenCount) + { + tokenCount++; // special token id + charsConsumed += Length; + } + } + + currentOffset = Offset + Length; + } + + if (currentOffset < text.Length && tokenCount < maxTokenCount) + { + if (tokenCount < maxTokenCount) + { + CountTokensInternal(text.Slice(currentOffset), considerNormalization, ref tokenCount, buffer, ref normalizedString, ref normalizedStringIndex, ref charsConsumed, maxTokenCount); + } + else if (normalizedString is not null) + { + StoreNormalizedText(text.Slice(currentOffset), considerNormalization, buffer, ref normalizedString, ref normalizedStringIndex); + } + } + + if (normalizedString is not null) + { + normalizedText = normalizedString.AsSpan().Slice(0, normalizedStringIndex).ToString(); + } + } + + private void CountTokensWithoutSpecialTokens( + ReadOnlySpan text, + bool considerNormalization, + ref int tokenCount, + int[] buffer, + ref char[]? normalizedString, + out string? normalizedText, + out int charsConsumed, + int maxTokenCount) + { + charsConsumed = 0; + normalizedText = null; + int normalizedStringIndex = 0; + + CountTokensInternal(text, considerNormalization, ref tokenCount, buffer, ref normalizedString, ref normalizedStringIndex, ref charsConsumed, maxTokenCount); + + if (normalizedString is not null) + { + normalizedText = normalizedString.AsSpan().Slice(0, normalizedStringIndex).ToString(); + } + } + + private void CountTokensInternal( + ReadOnlySpan text, + bool considerNormalization, + ref int tokenCount, + int[] buffer, + ref char[]? normalizedString, + ref int normalizedStringIndex, + ref int charsConsumed, + int maxTokenCount) + { + // + // Normalize the input text. + // + + NormalizeText(text, considerNormalization, buffer, out byte[]? normalizedArrayPool, out Span normalizationSpan); + + // + // Do the actual encoding + // + + BestPathNode[] bestPathEndsAt = ArrayPool.Shared.Rent(normalizationSpan.Length + 1); + + Encode(normalizationSpan, bestPathEndsAt); + + // Need to check for unknown tokens and update the charsConsumed. + + (int Id, int UtfStartOffset, int Utf8Length)[] ids = ArrayPool<(int Id, int UtfStartOffset, int Utf8Length)>.Shared.Rent(bestPathEndsAt.Length); + + // Backtrack to identify the best path. + int idsIndex = ids.Length - 1; + int endsAt = normalizationSpan.Length; + + bool unknownEncountered = false; + while (endsAt > 0) + { + ref BestPathNode node = ref bestPathEndsAt[endsAt]; + + ids[idsIndex--] = (node.Id, node.StartsAt, endsAt - node.StartsAt); + + unknownEncountered = unknownEncountered || node.Id == UnknownId; + + endsAt = node.StartsAt; + } + + idsIndex++; // Index starting the collected tokens. + + ArrayPool.Shared.Return(bestPathEndsAt); + + if ((!ByteFallback || !unknownEncountered) && (maxTokenCount == int.MaxValue || (tokenCount + ids.Length - idsIndex <= maxTokenCount))) + { + // sure we should be consumed the whole text. + charsConsumed += Helpers.GetUtf16LengthFromUtf8Bytes(normalizationSpan); + tokenCount += ids.Length - idsIndex; + + if (normalizedString is not null) + { + StoreNormalizedText(normalizationSpan, ref normalizedString, ref normalizedStringIndex); + } + + ArrayPool<(int Id, int UtfStartOffset, int Utf8Length)>.Shared.Return(ids); + + if (normalizedArrayPool is not null) + { + ArrayPool.Shared.Return(normalizedArrayPool); + } + + return; + } + + // Manually count the tokens up to the max. + for (int i = idsIndex; tokenCount < maxTokenCount && i < ids.Length; i++) + { + if (ids[i].Id == UnknownId) + { + if (ByteFallback) + { + if (tokenCount + ids[i].Utf8Length > maxTokenCount) + { + break; + } + + tokenCount += ids[i].Utf8Length; + } + else + { + tokenCount++; + } + + charsConsumed += Helpers.GetUtf16LengthFromUtf8Bytes(normalizationSpan.Slice(ids[i].UtfStartOffset, ids[i].Utf8Length)); + } + else + { + charsConsumed += _vocabReverse[ids[i].Id].Piece.Length; + tokenCount++; + } + } + + // + // Create the normalized string. + // + + ArrayPool<(int Id, int UtfStartOffset, int Utf8Length)>.Shared.Return(ids); + + if (normalizedString is not null) + { + StoreNormalizedText(normalizationSpan, ref normalizedString, ref normalizedStringIndex); + } + + if (normalizedArrayPool is not null) + { + ArrayPool.Shared.Return(normalizedArrayPool); + } + } + + public override int GetIndexByTokenCountFromEnd( + string? text, + ReadOnlySpan textSpan, + bool addBeginningOfSentence, + bool addEndOfSentence, + int maxTokenCount, + bool considerNormalization, + out string? normalizedText, + out int tokenCount) + { + ReadOnlySpan textToEncode = string.IsNullOrEmpty(text) ? textSpan : text.AsSpan(); + + tokenCount = 0; + if (textToEncode.IsEmpty || maxTokenCount <= 0) + { + normalizedText = null; + return textToEncode.Length; + } + + if (addEndOfSentence) + { + tokenCount++; + + if (maxTokenCount == 1) + { + normalizedText = null; + return textToEncode.Length; + } + } + + // Rent a buffer that approximately enough to hold the Utf8 encoded bytes, the normalization of the encoded buffer, and some extra memory to for encoding results. + int[] buffer = ArrayPool.Shared.Rent(textToEncode.Length * 3); + + // when maxTokenCount == int.MaxValue we don't need to return the normalized string as most likely we can handle the whole input text without need to continuation. + char[]? normalizedString = maxTokenCount == int.MaxValue ? null : ArrayPool.Shared.Rent(textToEncode.Length + 2); + + int charConsumedFromEnd; + + if (SpecialTokensRegex is not null) + { + GetIndexByTokenCountFromEndWithSpecialTokens(textToEncode, considerNormalization, ref tokenCount, buffer, ref normalizedString, out charConsumedFromEnd, out normalizedText, maxTokenCount); + } + else + { + GetIndexByTokenCountFromEndWithoutSpecialTokens(textToEncode, considerNormalization, ref tokenCount, buffer, ref normalizedString, out charConsumedFromEnd, out normalizedText, maxTokenCount); + } + + if (addBeginningOfSentence && tokenCount < maxTokenCount) + { + tokenCount++; + } + + ArrayPool.Shared.Return(buffer); + + return normalizedText is not null ? normalizedText.Length - charConsumedFromEnd : 0; + } + + private void GetIndexByTokenCountFromEndWithSpecialTokens( + ReadOnlySpan text, + bool considerNormalization, + ref int tokenCount, + int[] buffer, + ref char[]? normalizedString, + out int charConsumedFromEnd, + out string? normalizedText, + int maxTokenCount) + { + Debug.Assert(SpecialTokensRegex is not null); + Debug.Assert(maxTokenCount > 0); + + charConsumedFromEnd = 0; + int normalizedStringIndexFromEnd = 0; + + (int Offset, int Length)[] splits = PreTokenizer.SplitText(text, SpecialTokensRegex!).ToArray(); + + if (splits.Length == 0) + { + GetIndexByTokenCountFromEndInternal(text, considerNormalization, ref tokenCount, buffer, ref normalizedString, ref normalizedStringIndexFromEnd, ref charConsumedFromEnd, maxTokenCount); + normalizedText = normalizedString is not null ? normalizedString.AsSpan().Slice(normalizedString.Length - charConsumedFromEnd).ToString() : null; + } + + (int Offset, int Length) current = splits[splits.Length - 1]; + + // Last part is not a special token + if (current.Offset + current.Length < text.Length) + { + GetIndexByTokenCountFromEndInternal(text.Slice(current.Offset + current.Length), considerNormalization, ref tokenCount, buffer, ref normalizedString, ref normalizedStringIndexFromEnd, ref charConsumedFromEnd, maxTokenCount); + } + + for (int i = splits.Length - 1; i >= 0; i--) + { + current = splits[i]; // special token + + if (tokenCount < maxTokenCount) + { + if (InternalSpecialTokens!.TryGetValue(text.Slice(current.Offset, current.Length), out int id)) + { + tokenCount++; + } + + charConsumedFromEnd += current.Length; + } + + if (normalizedString is not null) + { + StoreNormalizedTextFromEnd(text.Slice(current.Offset, current.Length), ref normalizedString, ref normalizedStringIndexFromEnd); + } + + if (current.Offset > 0) + { + int start = i > 0 ? splits[i - 1].Offset + splits[i - 1].Length : 0; + GetIndexByTokenCountFromEndInternal(text.Slice(start, current.Offset - start), considerNormalization, ref tokenCount, buffer, ref normalizedString, ref normalizedStringIndexFromEnd, ref charConsumedFromEnd, maxTokenCount); + } + } + + normalizedText = normalizedString is not null ? normalizedString.AsSpan().Slice(normalizedString.Length - normalizedStringIndexFromEnd).ToString() : null; + } + + private void GetIndexByTokenCountFromEndWithoutSpecialTokens( + ReadOnlySpan text, + bool considerNormalization, + ref int tokenCount, + int[] buffer, + ref char[]? normalizedString, + out int charConsumedFromEnd, + out string? normalizedText, + int maxTokenCount) + { + charConsumedFromEnd = 0; + int normalizedStringIndexFromEnd = 0; + + GetIndexByTokenCountFromEndInternal(text, considerNormalization, ref tokenCount, buffer, ref normalizedString, ref normalizedStringIndexFromEnd, ref charConsumedFromEnd, maxTokenCount); + + normalizedText = normalizedString is not null ? normalizedString.AsSpan().Slice(normalizedString.Length - normalizedStringIndexFromEnd).ToString() : null; + } + + private void GetIndexByTokenCountFromEndInternal( + ReadOnlySpan text, + bool considerNormalization, + ref int tokenCount, + int[] buffer, + ref char[]? normalizedString, + ref int normalizedStringIndexFromEnd, + ref int charConsumedFromEnd, + int maxTokenCount) + { + // + // Normalize the input text. + // + + NormalizeText(text, considerNormalization, buffer, out byte[]? normalizedArrayPool, out Span normalizationSpan); + + // + // Do the actual encoding + // + + BestPathNode[] bestPathEndsAt = ArrayPool.Shared.Rent(normalizationSpan.Length + 1); + + Encode(normalizationSpan, bestPathEndsAt); + + int consumedCharacters = 0; + int endsAt = normalizationSpan.Length; + + while (endsAt > 0 && tokenCount < maxTokenCount) + { + ref BestPathNode node = ref bestPathEndsAt[endsAt]; + + if (node.Id == UnknownId) + { + int length = endsAt - node.StartsAt; + if (ByteFallback) + { + if (tokenCount + length > maxTokenCount) + { + break; + } + + tokenCount += length; + } + else + { + tokenCount++; + } + + consumedCharacters += Helpers.GetUtf16LengthFromUtf8Bytes(normalizationSpan.Slice(node.StartsAt, length)); + } + else + { + consumedCharacters += _vocabReverse[node.Id].Piece.Length; + tokenCount++; + } + + endsAt = node.StartsAt; + } + + charConsumedFromEnd += consumedCharacters; + + if (normalizedString is not null) + { + if (considerNormalization) + { + StoreNormalizedTextFromEnd(normalizationSpan, ref normalizedString, ref normalizedStringIndexFromEnd); + } + else + { + StoreNormalizedTextFromEnd(text, ref normalizedString, ref normalizedStringIndexFromEnd); + } + } + + ArrayPool.Shared.Return(bestPathEndsAt); + if (normalizedArrayPool is not null) + { + ArrayPool.Shared.Return(normalizedArrayPool); + } + } + } +} diff --git a/src/Microsoft.ML.Tokenizers/Normalizer/SentencePieceNormalizer.cs b/src/Microsoft.ML.Tokenizers/Normalizer/SentencePieceNormalizer.cs index f1e0ebaad4..97b3f4a7a4 100644 --- a/src/Microsoft.ML.Tokenizers/Normalizer/SentencePieceNormalizer.cs +++ b/src/Microsoft.ML.Tokenizers/Normalizer/SentencePieceNormalizer.cs @@ -6,18 +6,27 @@ using System.Buffers; using System.Collections.Generic; using System.Diagnostics; +using System.Text; namespace Microsoft.ML.Tokenizers { /// - /// Normalize the string to lowercase form before processing it with the tokenizer. + /// Normalize the string according to SentencePiece normalization. /// public sealed class SentencePieceNormalizer : Normalizer { + // Maximum size of the return value of Trie, which corresponds to the maximum size of shared common prefix in the chars map. + private const int MaxTrieResultsSize = 32; internal const char DummyPrefix = '\u2581'; // '▁' (LOWER ONE EIGHT BLOCK) + private static readonly byte[] _spaceSymbol = { 0xe2, 0x96, 0x81 }; // Utf8 of DummyPrefix; Null terminated. + private static readonly byte[] _space = { (byte)' ' }; + private static readonly byte[] _replacementBytes = { 0xEF, 0xBF, 0xBD, 0 }; // Utf8 of 0xFFFD; Null terminated. + + private readonly DoubleArrayTrie? _trie; + private readonly byte[]? _normalized; /// - /// Creates a LowerCaseNormalizer object. + /// Creates a SentencePieceNormalizer object. /// public SentencePieceNormalizer(bool removeExtraWhiteSpaces, bool addDummyPrefix, bool escapeWhiteSpaces, bool treatWhitespaceAsSuffix, IReadOnlyDictionary? specialTokens) { @@ -28,6 +37,25 @@ public SentencePieceNormalizer(bool removeExtraWhiteSpaces, bool addDummyPrefix, SpecialTokens = specialTokens; } + internal SentencePieceNormalizer( + ReadOnlySpan precompiledCharsMap, + bool removeExtraWhiteSpaces, + bool addDummyPrefix, + bool escapeWhiteSpaces, + bool treatWhitespaceAsSuffix, + IReadOnlyDictionary? specialTokens) : this(removeExtraWhiteSpaces, addDummyPrefix, escapeWhiteSpaces, treatWhitespaceAsSuffix, specialTokens) + { + if (precompiledCharsMap.IsEmpty) + { + return; + } + + DecodePrecompiledCharsMap(precompiledCharsMap, out DoubleArrayUnit[]? trieBlob, out _normalized); + + Debug.Assert(trieBlob is not null); + _trie = new DoubleArrayTrie(trieBlob!); + } + /// /// Indicate removing extra white spaces from the original string during the normalization. /// @@ -239,5 +267,269 @@ private void InsertDummyPrefixAtEnd(Span span, ref int bufferIndex) bufferIndex++; } } + + // Returns the longest consumed prefix of |input| that can be normalized. + // if we return normalizedPrefix == default, means no normalization and the original input span should be used. + private int NormalizePrefix(ReadOnlySpan input, out Memory normalizedPrefix) + { + Debug.Assert(!input.IsEmpty); + + int longestLength = 0; + int longestValue = 0; + + if (_trie is not null) + { + // Allocates trie_results in stack, which makes the encoding speed 36% faster. (38k sentences/sec => 60k sentences/sec). + // Builder checks that the result size never exceeds kMaxTrieResultsSize. This array consumes 0.5kByte in stack, + // which is less than default stack frames (16kByte). + Span trieResults = stackalloc DoubleArrayResultPair[MaxTrieResultsSize]; + + int numNodes = _trie.CommonPrefixSearch(input, trieResults); + + // Finds the longest rule. + for (int k = 0; k < numNodes; ++k) + { + if (longestLength == 0 || trieResults[k].Length > longestLength) + { + longestLength = trieResults[k].Length; // length of prefix + longestValue = trieResults[k].Value; // pointer to |_normalized|. + } + } + } + + int result; + + if (longestLength == 0) + { + if (!Helpers.IsValidDecodeUtf8(input, out int length)) + { + // Found a malformed utf8. + // The rune is set to be 0xFFFD (REPLACEMENT CHARACTER), which is a valid Unicode of three bytes in utf8, but here we only consume one byte. + result = 1; + normalizedPrefix = new Memory(_replacementBytes, 0, 3); + } + else + { + result = length; + normalizedPrefix = default; + } + } + else + { + Debug.Assert(_normalized is not null); + + result = longestLength; + + // Calculate the length of the normalized prefix. + int normalizedLength = longestValue; + while (normalizedLength < _normalized!.Length && _normalized[normalizedLength] != 0) + { + normalizedLength++; + } + normalizedPrefix = new Memory(_normalized, longestValue, normalizedLength - longestValue); + } + + return result; + } + + internal int Normalize(ReadOnlySpan input, ref Span normalized, ref byte[]? poolArray) + { + if (input.IsEmpty) + { + return 0; + } + + int consumed = 0; + + // Ignores heading space. + if (RemoveExtraWhiteSpaces) + { + while (!input.IsEmpty) + { + int p = NormalizePrefix(input, out Memory normalizedPrefix); + + Debug.Assert(p > 0); + + if (p != 1) + { + break; + } + + ReadOnlySpan normalizedByte = normalizedPrefix.Equals(default(Memory)) ? input.Slice(0, p) : normalizedPrefix.Span; + if (normalizedByte[0] != (byte)' ') + { + break; + } + + input = input.Slice(p); + consumed += p; + } + } + + // all chars are whitespace. + if (input.IsEmpty) + { + return 0; + } + + int normalizedIndex = 0; + + // Adds a space symbol as a prefix (default is true) With this prefix, "world" and "hello world" are converted into + // "_world" and "_hello_world", which help the trainer to extract "_world" as one symbol. + if (!TreatWhitespaceAsSuffix && AddDummyPrefix) + { + AddWhiteSpace(this, normalized, ref normalizedIndex, ref poolArray); + } + + bool isPrevSpace = RemoveExtraWhiteSpaces; + + while (!input.IsEmpty) + { + int p = NormalizePrefix(input, out Memory normalizedPrefix); + ReadOnlySpan sp = normalizedPrefix.Equals(default(Memory)) ? input.Slice(0, p) : normalizedPrefix.Span; + + // Removes heading spaces in sentence piece, if the previous sentence piece ends with whitespace. + while (isPrevSpace && sp.Length > 0 && sp[0] == (byte)' ') + { + sp = sp.Slice(1); + } + + if (!sp.IsEmpty) + { + for (int n = 0; n < sp.Length; ++n) + { + if (EscapeWhiteSpaces && sp[n] == ' ') + { + if (normalized.Length <= normalizedIndex + _spaceSymbol.Length) + { + Helpers.ArrayPoolGrow(ref normalized, ref poolArray, (normalizedIndex + _spaceSymbol.Length) << 1); + } + + // replace ' ' with _spaceSymbol. + _spaceSymbol.AsSpan().CopyTo(normalized.Slice(normalizedIndex)); + normalizedIndex += _spaceSymbol.Length; + + } + else + { + if (normalized.Length <= normalizedIndex + 1) + { + Helpers.ArrayPoolGrow(ref normalized, ref poolArray, (normalizedIndex + 1) << 1); + } + + normalized[normalizedIndex++] = sp[n]; + + } + } + + // Checks whether the last character of sp is whitespace. + isPrevSpace = sp[sp.Length - 1] == (byte)' '; + } + + input = input.Slice(p); + + if (!RemoveExtraWhiteSpaces) + { + isPrevSpace = false; + } + } + + // Ignores trailing space. + if (RemoveExtraWhiteSpaces) + { + Span space = EscapeWhiteSpaces ? _spaceSymbol : _space; + while (normalized.Slice(0, normalizedIndex).EndsWith(space)) + { + int length = normalizedIndex - space.Length; + if (length < 0) + { + return normalizedIndex; + } + + normalizedIndex = length; // cut spaces + + } + } + + // Adds a space symbol as a suffix (default is false) + if (TreatWhitespaceAsSuffix && AddDummyPrefix) + { + AddWhiteSpace(this, normalized, ref normalizedIndex, ref poolArray); + } + + return normalizedIndex; + + // adds _spaceSymbol to the current context. + static void AddWhiteSpace(SentencePieceNormalizer normalizer, Span normalized, ref int normalizedIndex, ref byte[]? poolArray) + { + if (normalizer.EscapeWhiteSpaces) + { + if (normalized.Length <= normalizedIndex + _spaceSymbol.Length) + { + Helpers.ArrayPoolGrow(ref normalized, ref poolArray, (normalizedIndex + _spaceSymbol.Length) << 1); + } + _spaceSymbol.AsSpan().CopyTo(normalized.Slice(normalizedIndex)); + normalizedIndex += _spaceSymbol.Length; + } + else + { + if (normalized.Length <= normalizedIndex + 1) + { + Helpers.ArrayPoolGrow(ref normalized, ref poolArray, (normalizedIndex + 1) << 1); + } + normalized[normalizedIndex] = (byte)' '; + normalizedIndex++; + } + } + } + + private unsafe void DecodePrecompiledCharsMap(ReadOnlySpan blob, out DoubleArrayUnit[]? trieBlob, out byte[]? normalized) + { + uint trieBlobSize = 0; + + if (blob.Length <= sizeof(uint)) + { + throw new ArgumentException("Blob for normalization rule is broken."); + } + + fixed (byte* pBlob = blob) + { + trieBlobSize = *(uint*)pBlob; + } + + if (!BitConverter.IsLittleEndian) + { + trieBlobSize = Helpers.Swap32(trieBlobSize); + } + + if (trieBlobSize >= blob.Length) + { + throw new ArgumentException("Trie data size exceeds the input blob size."); + } + + blob = blob.Slice(sizeof(uint)); + + if (!BitConverter.IsLittleEndian) + { + fixed (byte* pBlob = blob) + { + uint* data = (uint*)pBlob; + + // Perform necessary operations for Big Endian + for (int i = 0; i < trieBlobSize / 4; ++i) + { + data[i] = Helpers.Swap32(data[i]); + } + } + } + + fixed (byte* pBlob = blob.Slice(0, (int)trieBlobSize)) + { + + trieBlob = new Span((DoubleArrayUnit*)pBlob, (int)trieBlobSize / 4).ToArray(); + } + + normalized = blob.Slice((int)trieBlobSize).ToArray(); + } } } diff --git a/src/Microsoft.ML.Tokenizers/Utils/DoubleArrayTrie.cs b/src/Microsoft.ML.Tokenizers/Utils/DoubleArrayTrie.cs new file mode 100644 index 0000000000..8e1fafbd74 --- /dev/null +++ b/src/Microsoft.ML.Tokenizers/Utils/DoubleArrayTrie.cs @@ -0,0 +1,1143 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +// The implementation of the DoubleArrayBuilder class is based on the following C# port of the C++ implementation of the Double-Array Trie (DART) data structure. +// The original C++ implementation is available at https://github.com/s-yata/darts-clone/blob/master/include/darts.h and used under BSD 2-clause license. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; + +#if Test +namespace Microsoft.ML.Tokenizers.Tests +#else +namespace Microsoft.ML.Tokenizers +#endif // Test +{ + // + // Succinct bit vector. + // + public class BitVector + { + private const int UnitSize = sizeof(uint) * 8; + private readonly List _units = new(); + private uint[]? _ranks; + private uint _numOnes; + private uint _size; + + public BitVector() { } + + public bool this[int id] + { + get => (_units[id / UnitSize] >> (id % UnitSize) & 1) == 1; + } + + private static uint PopCount(uint unit) + { + unit = ((unit & 0xAAAAAAAA) >> 1) + (unit & 0x55555555); + unit = ((unit & 0xCCCCCCCC) >> 2) + (unit & 0x33333333); + unit = ((unit >> 4) + unit) & 0x0F0F0F0F; + unit += unit >> 8; + unit += unit >> 16; + return unit & 0xFF; + } + + public uint Rank(uint id) + { + uint unitId = id / UnitSize; + return (uint)(_ranks![(int)unitId] + PopCount((uint)(_units[(int)unitId] & (~0U >> (int)(UnitSize - (id % UnitSize) - 1))))); + } + + public void Set(uint id, bool bit) + { + if (bit) + { + _units[(int)(id / UnitSize)] |= 1U << (int)(id % UnitSize); + } + else + { + _units[(int)(id / UnitSize)] &= ~(1U << (int)(id % UnitSize)); + } + } + + public bool IsEmpty => _units.Count == 0; + + public uint NumOnes => _numOnes; + + public uint Size => _size; + + public void Append() + { + if ((_size % UnitSize) == 0) + { + _units.Add(0); + } + ++_size; + } + + public void Build() + { + _ranks = new uint[_units.Count]; + + _numOnes = 0; + for (int i = 0; i < _units.Count; ++i) + { + _ranks[i] = _numOnes; + _numOnes += PopCount(_units[i]); + } + } + } + + internal class AutoPool + { + private T[] _buf = Array.Empty(); + private int _size; + private int _capacity; + + public AutoPool() { } + + public ref T this[int id] + { + get => ref _buf[id]; + } + + public T[] Buffer => _buf; + + public bool Empty => _size == 0; + + public int Size => _size; + + public void Clear() + { + _buf = Array.Empty(); + _size = 0; + _capacity = 0; + } + + public void ResizeBuf(int size) + { + if (size <= _capacity) + { + return; + } + + int capacity; + if (size >= _capacity * 2) + { + capacity = size; + } + else + { + capacity = 1; + while (capacity < size) + { + capacity <<= 1; + } + } + + T[] buf = new T[capacity]; + + if (_buf is not null) + { + Array.Copy(_buf, buf, _size); + } + + _buf = buf; + _capacity = capacity; + } + + public void Append(T value) + { + if (_size == _capacity) + { + ResizeBuf(_size + 1); + } + + _buf[_size++] = value!; + } + + public void Append() + { + if (_size == _capacity) + { + ResizeBuf(_size + 1); + } + _buf[_size++] = default!; + } + + public void PushBack(T value) => Append(value); + + public void PopBack() + { + if (Empty) + { + return; + } + + _buf[--_size] = default!; + } + + public void Resize(int size) + { + while (_size > size) + { + _buf[--_size] = default!; + } + + if (size > _capacity) + { + ResizeBuf(size); + } + + while (_size < size) + { + _buf[_size++] = default!; + } + } + + public void Resize(int size, T value) + { + while (_size > size) + { + _buf[--_size] = default!; + } + + if (size > _capacity) + { + ResizeBuf(size); + } + + while (_size < size) + { + _buf[_size++] = value; + } + } + + public void Reserve(int size) + { + if (size > _capacity) + { + ResizeBuf(size); + } + } + } + + // + // Fixed unit of Directed Acyclic Word Graph (DAWG). + // + + // + // Node of Directed Acyclic Word Graph (DAWG). + // + internal struct DawgNode + { + public DawgNode() { } + + public uint Child { get; set; } + public uint Sibling { get; set; } + public bool IsState { get; set; } + public bool HasSibling { get; set; } + public byte Label { get; set; } + public int Value { get => (int)Child; set => Child = (uint)value; } + + public uint Unit + { + get + { + if (Label == 0) + { + return (Child << 1) | (uint)(HasSibling ? 1 : 0); + } + return (Child << 2) | (uint)(IsState ? 2 : 0) | (uint)(HasSibling ? 1 : 0); + } + } + } + + internal struct DawgUnit + { + private readonly uint _unit; + public DawgUnit(uint unit = 0) => _unit = unit; + public DawgUnit(DawgUnit unit) => _unit = unit._unit; + + public static implicit operator DawgUnit(uint unit) => new DawgUnit(unit); + + public uint Unit => _unit; + + public uint Child => _unit >> 2; + + public bool HasSibling => (_unit & 1) == 1; + + public int Value => (int)(_unit >> 1); + + public bool IsState => (_unit & 2) == 2; + } + + // + // Directed Acyclic Word Graph (DAWG) builder. + // + + public class DawgBuilder + { + private const int InitialTableSize = 1 << 10; + + private readonly AutoPool _nodes = new(); + private readonly AutoPool _units = new(); + private readonly AutoPool _labels = new(); + private readonly AutoPool _table = new(); + private readonly BitVector _isIntersections = new(); + private readonly Stack _nodeStack = new(); + private readonly Stack _recycleBin = new(); + private int _numStates; + + public DawgBuilder() { } + + public uint Root => 0; + + public uint Child(uint id) => _units[(int)id].Child; + + public uint Sibling(uint id) => _units[(int)id].HasSibling ? id + 1 : 0; + + public int Value(uint id) => _units[(int)id].Value; + public byte Label(uint id) => _labels[(int)id]; + + public bool IsLeaf(uint id) => Label(id) == 0; + + public bool IsIntersection(uint id) => _isIntersections[(int)id]; + + public uint IntersectionId(uint id) => _isIntersections.Rank(id) - 1; + + public int NumIntersections => (int)_isIntersections.NumOnes; + + public int Size => _units.Size; + + private static uint Hash(uint key) + { + key = ~key + (key << 15); // key = (key << 15) - key - 1; + key = key ^ (key >> 12); + key = key + (key << 2); + key = key ^ (key >> 4); + key = key * 2057; // key = (key + (key << 3)) + (key << 11); + key = key ^ (key >> 16); + return key; + } + + private void FreeNode(uint id) => _recycleBin.Push(id); + + public void Finish() + { + Flush(0); + + _units[0] = _nodes[0].Unit; + _labels[0] = _nodes[0].Label; + _isIntersections.Build(); + } + + public void Insert(ReadOnlySpan key, int length, int value) + { + if (value < 0) + { + throw new ArgumentException("failed to insert key: negative value"); + } + else if (length == 0) + { + throw new ArgumentException("failed to insert key: zero-length key"); + } + + uint id = 0; + int keyPos = 0; + + for (; keyPos <= length; ++keyPos) + { + uint childId = _nodes[(int)id].Child; + if (childId == 0) + { + break; + } + + byte keyLabel = key[keyPos]; + if (keyPos < length && keyLabel == 0) + { + throw new InvalidOperationException("failed to insert key: invalid null character"); + } + + byte unitLabel = _nodes[(int)childId].Label; + if (keyLabel < unitLabel) + { + throw new InvalidOperationException("failed to insert key: wrong key order"); + } + else if (keyLabel > unitLabel) + { + _nodes[(int)childId].HasSibling = true; + Flush(childId); + break; + } + + id = childId; + } + + if (keyPos > length) + { + return; + } + + for (; keyPos <= length; ++keyPos) + { + byte keyLabel = keyPos < length ? key[keyPos] : (byte)0; + uint childId = AppendNode(); + + if (_nodes[(int)id].Child == 0) + { + _nodes[(int)childId].IsState = true; + } + + _nodes[(int)childId].Sibling = _nodes[(int)id].Child; + _nodes[(int)childId].Label = keyLabel; + _nodes[(int)id].Child = childId; + _nodeStack.Push(childId); + id = childId; + } + _nodes[(int)id].Value = value; + } + + private uint AppendNode() + { + uint id; + if (_recycleBin.Count == 0) + { + id = (uint)_nodes.Size; + _nodes.Append(); + } + else + { + id = _recycleBin.Pop(); + _nodes[(int)id] = new DawgNode(); + } + return id; + } + + private uint AppendUnit() + { + _isIntersections.Append(); + _units.Append(); + _labels.Append(); + return _isIntersections.Size - 1; + } + + public void Init() + { + _table.Resize(InitialTableSize, 0); + + AppendNode(); + AppendUnit(); + + _numStates = 1; + + _nodes[0].Label = 0xFF; + _nodeStack.Push(0); + } + + private void ExpandTable() + { + int tableSize = _table.Size << 1; + _table.Clear(); + _table.Resize(tableSize, 0); + + for (int i = 1; i < _units.Size; ++i) + { + uint id = (uint)i; + if (_labels[i] == 0 || _units[i].IsState) + { + FindUnit(id, out uint hashId); + _table[(int)hashId] = id; + } + } + } + + private uint HashNode(uint id) + { + uint hashValue = 0; + for (; id != 0; id = _nodes[(int)id].Sibling) + { + uint unit = _nodes[(int)id].Unit; + byte label = _nodes[(int)id].Label; + hashValue ^= Hash((uint)((label << 24) ^ unit)); + } + + return hashValue; + } + + private bool AreEqual(uint nodeId, uint unitId) + { + for (uint i = _nodes[(int)nodeId].Sibling; i != 0; i = _nodes[(int)i].Sibling) + { + if (!_units[(int)unitId].HasSibling) + { + return false; + } + + ++unitId; + } + + if (_units[(int)unitId].HasSibling) + { + return false; + } + + for (uint i = nodeId; i != 0; i = _nodes[(int)i].Sibling, --unitId) + { + if (_nodes[(int)i].Unit != _units[(int)unitId].Unit || _nodes[(int)i].Label != _labels[(int)unitId]) + { + return false; + } + } + + return true; + } + + private uint FindNode(uint nodeId, out uint hashId) + { + hashId = (uint)(HashNode(nodeId) % _table.Size); + for (; ; hashId = (uint)((hashId + 1) % _table.Size)) + { + uint unitId = _table[(int)hashId]; + if (unitId == 0) + { + break; + } + + if (AreEqual(nodeId, unitId)) + { + return unitId; + } + } + + return 0; + } + + private uint HashUnit(uint id) + { + uint hashValue = 0; + for (; id != 0; ++id) + { + uint unit = _units[(int)id].Unit; + byte label = _labels[(int)id]; + hashValue ^= Hash((uint)((label << 24) ^ unit)); + + if (!_units[(int)id].HasSibling) + { + break; + } + } + return hashValue; + } + private uint FindUnit(uint id, out uint hashId) + { + hashId = (uint)(HashUnit(id) % _table.Size); + for (; ; hashId = (uint)((hashId + 1) % _table.Size)) + { + uint unitId = _table[(int)hashId]; + if (unitId == 0) + { + break; + } + + // There must not be the same unit. + } + return 0; + } + + private void Flush(uint id) + { + while (_nodeStack.Peek() != id) + { + uint nodeId = _nodeStack.Pop(); + + if (_numStates >= _table.Size - (_table.Size >> 2)) + { + ExpandTable(); + } + + uint numSiblings = 0; + for (uint i = nodeId; i != 0; i = _nodes[(int)i].Sibling) + { + ++numSiblings; + } + + uint matchId = FindNode(nodeId, out uint hashId); + if (matchId != 0) + { + _isIntersections.Set(matchId, true); + } + else + { + uint unitId = 0; + for (uint i = 0; i < numSiblings; ++i) + { + unitId = AppendUnit(); + } + + for (uint i = nodeId; i != 0; i = _nodes[(int)i].Sibling) + { + _units[(int)unitId] = _nodes[(int)i].Unit; + _labels[(int)unitId] = _nodes[(int)i].Label; + --unitId; + } + + matchId = unitId + 1; + _table[(int)hashId] = matchId; + ++_numStates; + } + + for (uint i = nodeId, next; i != 0; i = next) + { + next = _nodes[(int)i].Sibling; + FreeNode(i); + } + + _nodes[(int)_nodeStack!.Peek()].Child = matchId; + } + + _nodeStack.Pop(); + } + } + + internal struct DoubleArrayUnit + { + private uint _unit; + public DoubleArrayUnit() { } + + // returns whether a leaf unit is immediately derived from the unit (true) or not (false). + public bool HasLeaf + { + get => ((_unit >> 8) & 1) == 1; + set + { + if (value) + { + _unit |= 1U << 8; + } + else + { + _unit &= ~(1U << 8); + } + } + } + + // value() returns the value stored in the unit, and thus value() is + // available when and only when the unit is a leaf unit. + public uint Value + { + get => _unit & ((1U << 31) - 1); + set => _unit = value | (1U << 31); + } + + // returns the label associated with the unit. Note that a leaf unit always returns an invalid label. + // For this feature, leaf unit's label returns an id that has the MSB of 1. + public uint Label + { + get => _unit & ((1U << 31) | 0xFF); + set + { + _unit = (_unit & ~0xFFU) | value; + } + } + + // offset() returns the offset from the unit to its derived units. + public uint Offset + { + get => (_unit >> 10) << (int)((_unit & (1U << 9)) >> 6); + set + { + if (value >= 1U << 29) + { + throw new InvalidOperationException("failed to modify unit: too large offset"); + } + + _unit &= (1U << 31) | (1U << 8) | 0xFF; + + if (value < 1U << 21) + { + _unit |= value << 10; + } + else + { + _unit |= (value << 2) | (1U << 9); + } + } + } + + } + + // + // Extra unit of double-array builder. + // + + internal struct DoubleArrayBuilderExtraUnit + { + + private uint _prev; + private uint _next; + private bool _isFixed; + private bool _isUsed; + + public DoubleArrayBuilderExtraUnit() { } + + public uint Prev + { + get => _prev; + set => _prev = value; + } + + public uint Next + { + get => _next; + set => _next = value; + } + + public bool IsFixed + { + get => _isFixed; + set => _isFixed = value; + } + + public bool IsUsed + { + get => _isUsed; + set => _isUsed = value; + } + } + + internal class DoubleArrayBuilder + { + private const int BlockSize = 256; + private const int NumExtraBlock = 16; + private const int NumExtras = BlockSize * NumExtraBlock; + private const int UpperMask = 0xFF << 21; + private const int LowerMask = 0xFF; + + private readonly AutoPool _units = new(); + private readonly DoubleArrayBuilderExtraUnit[] _extras = new DoubleArrayBuilderExtraUnit[NumExtras]; + private readonly AutoPool _labels = new(); + private uint[]? _table; + private uint _extrasHead; + + private int NumBlocks() => _units.Size / BlockSize; + + public DoubleArrayUnit[] Units => _units.Buffer; + public int UnitsSize => _units.Size; + + private ref DoubleArrayBuilderExtraUnit this[uint id] + { + get => ref _extras[id % NumExtras]; + } + + internal unsafe void BuildDawg(SortedDictionary dictionary, DawgBuilder dawgBuilder) + { + dawgBuilder.Init(); + + Span bytes = stackalloc byte[512]; + byte[]? array = null; + + foreach (KeyValuePair pair in dictionary) + { + int encodingLength = Encoding.UTF8.GetByteCount(pair.Key); + if (encodingLength > bytes.Length) + { + if (array is not null) + { + ArrayPool.Shared.Return(array); + } + + array = ArrayPool.Shared.Rent(encodingLength * 2); + bytes = array; + } + + encodingLength = Helpers.GetUtf8Bytes(pair.Key.AsSpan(), bytes); + + dawgBuilder.Insert(bytes, encodingLength, pair.Value); + } + + if (array is not null) + { + ArrayPool.Shared.Return(array); + } + + dawgBuilder.Finish(); + } + + internal void FixBlock(uint blockId) + { + uint begin = blockId * BlockSize; + uint end = begin + BlockSize; + + uint unusedOffset = 0; + for (uint offset = begin; offset != end; ++offset) + { + if (!this[offset].IsUsed) + { + unusedOffset = offset; + break; + } + } + + for (uint id = begin; id != end; ++id) + { + if (!this[id].IsFixed) + { + ReserveId(id); + _units[(int)id].Label = (byte)(id ^ unusedOffset); + } + } + } + + internal void ExpandUnits() + { + uint srcNumUnits = (uint)_units.Size; + uint srcNumBlocks = (uint)NumBlocks(); + + uint destNumUnits = srcNumUnits + BlockSize; + uint destNumBlocks = srcNumBlocks + 1; + + if (destNumBlocks > NumExtraBlock) + { + FixBlock(srcNumBlocks - NumExtraBlock); + } + + _units.Resize((int)destNumUnits); + + if (destNumBlocks > NumExtraBlock) + { + for (uint id = srcNumUnits; id < destNumUnits; ++id) + { + this[id].IsUsed = false; + this[id].IsFixed = false; + } + } + + for (uint i = srcNumUnits + 1; i < destNumUnits; ++i) + { + this[i - 1].Next = i; + this[i].Prev = i - 1; + } + + this[srcNumUnits].Prev = destNumUnits - 1; + this[destNumUnits - 1].Next = srcNumUnits; + this[srcNumUnits].Prev = this[_extrasHead].Prev; + this[destNumUnits - 1].Next = _extrasHead; + this[this[_extrasHead].Prev].Next = srcNumUnits; + this[_extrasHead].Prev = destNumUnits - 1; + } + + internal void ReserveId(uint id) + { + if (id >= _units.Size) + { + ExpandUnits(); + } + + if (id == _extrasHead) + { + _extrasHead = this[id].Next; + if (_extrasHead == id) + { + _extrasHead = (uint)_units.Size; + } + } + + this[this[id].Prev].Next = this[id].Next; + this[this[id].Next].Prev = this[id].Prev; + this[id].IsFixed = true; + } + + internal bool IsValidOffset(uint id, uint offset) + { + if (this[offset].IsUsed) + { + return false; + } + + uint relOffset = id ^ offset; + if ((relOffset & LowerMask) != 0 && (relOffset & UpperMask) != 0) + { + return false; + } + + for (int i = 1; i < _labels.Size; ++i) + { + if (this[offset ^ _labels[i]].IsFixed) + { + return false; + } + } + + return true; + } + + internal uint FindValidOffset(uint id) + { + if (_extrasHead >= _units.Size) + { + return (uint)_units.Size | (id & LowerMask); + } + + uint unfixedId = _extrasHead; + do + { + uint offset = unfixedId ^ _labels[0]; + if (IsValidOffset(id, offset)) + { + return offset; + } + + unfixedId = this[unfixedId].Next; + } while (unfixedId != _extrasHead); + + return (uint)_units.Size | (id & LowerMask); + } + + internal uint ArrangeFromDawg(DawgBuilder dawg, uint dawgId, uint dicId) + { + _labels.Resize(0); + + uint dawgChildId = dawg.Child(dawgId); + while (dawgChildId != 0) + { + _labels.Append(dawg.Label(dawgChildId)); + dawgChildId = dawg.Sibling(dawgChildId); + } + + uint offset = FindValidOffset(dicId); + _units[(int)dicId].Offset = dicId ^ offset; + + dawgChildId = dawg.Child(dawgId); + for (int i = 0; i < _labels.Size; ++i) + { + uint dicChildId = offset ^ _labels[i]; + ReserveId(dicChildId); + + if (dawg.IsLeaf(dawgChildId)) + { + _units[(int)dicId].HasLeaf = true; + _units[(int)dicChildId].Value = (uint)dawg.Value(dawgChildId); + } + else + { + _units[(int)dicChildId].Label = _labels[i]; + } + + dawgChildId = dawg.Sibling(dawgChildId); + } + + this[offset].IsUsed = true; + + return offset; + } + + internal void BuildFromDawg(DawgBuilder dawg, uint dawgId, uint dicId) + { + uint dawgChildId = dawg.Child(dawgId); + uint offset; + if (dawg.IsIntersection(dawgChildId)) + { + uint intersectionId = dawg.IntersectionId(dawgChildId); + offset = _table![intersectionId]; + if (offset != 0) + { + offset ^= dicId; + if ((offset & UpperMask) == 0 || (offset & LowerMask) == 0) + { + if (dawg.IsLeaf(dawgChildId)) + { + _units[(int)dicId].HasLeaf = true; + } + _units[(int)dicId].Offset = offset; + return; + } + } + } + + offset = ArrangeFromDawg(dawg, dawgId, dicId); + if (dawg.IsIntersection(dawgChildId)) + { + _table![dawg.IntersectionId(dawgChildId)] = offset; + } + + do + { + byte childLabel = dawg.Label(dawgChildId); + uint dicChildId = offset ^ childLabel; + if (childLabel != 0) + { + BuildFromDawg(dawg, dawgChildId, dicChildId); + } + + dawgChildId = dawg.Sibling(dawgChildId); + } while (dawgChildId != 0); + } + + internal void FixAllBlocks() + { + uint begin = 0; + if (NumBlocks() > NumExtraBlock) + { + begin = (uint)NumBlocks() - NumExtraBlock; + } + + uint end = (uint)NumBlocks(); + + for (uint blockId = begin; blockId != end; ++blockId) + { + FixBlock(blockId); + } + } + + internal void BuildFromDawg(DawgBuilder dawg) + { + int numUnits = 1; + while (numUnits < dawg.Size) + { + numUnits <<= 1; + } + + _units.Reserve(numUnits); + _table = new uint[dawg.NumIntersections]; + + ReserveId(0); + + this[0].IsUsed = true; + _units[0].Offset = 1; + _units[0].Label = 0; + + if (dawg.Child(dawg.Root) != 0) + { + BuildFromDawg(dawg, dawg.Root, 0); + } + + FixAllBlocks(); + } + + public void Build(SortedDictionary dictionary) + { + DawgBuilder dawgBuilder = new(); + BuildDawg(dictionary, dawgBuilder); + BuildFromDawg(dawgBuilder); + } + } + + internal struct DoubleArrayResultPair + { + public int Value { get; set; } + public int Length { get; set; } + }; + + internal class DoubleArrayTrie + { + private readonly int _size; + private readonly DoubleArrayUnit[] _array; + + internal DoubleArrayUnit[] ArrayUnits => _array; + internal int Size => _size; + + // Sorted Dictionary to store the key value pairs + public DoubleArrayTrie(SortedDictionary dictionary) + { + DoubleArrayBuilder builder = new DoubleArrayBuilder(); + builder.Build(dictionary); + + _size = builder.UnitsSize; + _array = builder.Units; + } + + public DoubleArrayTrie(DoubleArrayUnit[] preCompiledData) + { + if (preCompiledData is null) + { + throw new ArgumentNullException(nameof(preCompiledData)); + } + + _size = preCompiledData.Length; + _array = preCompiledData; + } + + public int CommonPrefixSearch(ReadOnlySpan key, Span results, int nodePos = 0) + { + int numResults = 0; + + DoubleArrayUnit unit = _array[nodePos]; + nodePos ^= (int)unit.Offset; + + for (int i = 0; i < key.Length; ++i) + { + nodePos ^= key[i]; + unit = _array[nodePos]; + + if (unit.Label != key[i]) + { + return numResults; + } + + nodePos ^= (int)unit.Offset; + + if (unit.HasLeaf) + { + if (numResults < results.Length) + { + results[numResults].Value = (int)_array[nodePos].Value; + results[numResults].Length = i + 1; + } + + ++numResults; + } + } + + return numResults; + } + + public int Traverse(ReadOnlySpan key, ref int nodePos, ref int keyPos, int length) + { + uint id = (uint)nodePos; + DoubleArrayUnit unit = _array[id]; + + if (length != 0) + { + for (; keyPos < length; ++keyPos) + { + id ^= unit.Offset ^ key[keyPos]; + unit = _array[id]; + if (unit.Label != key[keyPos]) + { + return -2; + } + + nodePos = (int)id; + } + } + else + { + for (; key[keyPos] != 0; ++keyPos) + { + id ^= unit.Offset ^ key[keyPos]; + unit = _array[id]; + if (unit.Label != key[keyPos]) + { + return -2; + } + + nodePos = (int)id; + } + } + + if (!unit.HasLeaf) + { + return -1; + } + + unit = _array[id ^ unit.Offset]; + return (int)unit.Value; + } + } +} + diff --git a/src/Microsoft.ML.Tokenizers/Utils/Helpers.cs b/src/Microsoft.ML.Tokenizers/Utils/Helpers.cs index 4517eaa615..8ae0dcd862 100644 --- a/src/Microsoft.ML.Tokenizers/Utils/Helpers.cs +++ b/src/Microsoft.ML.Tokenizers/Utils/Helpers.cs @@ -5,12 +5,19 @@ using System; using System.Buffers; using System.Diagnostics; +using System.Runtime.CompilerServices; using System.Text; +#if Test +namespace Microsoft.ML.Tokenizers.Tests +#else namespace Microsoft.ML.Tokenizers +#endif // Test { internal static partial class Helpers { + private const int UnicodeError = 0xFFFD; + internal static void ArrayPoolGrow(ref T[] arrayPoolArray, int requiredCapacity) { T[] tmp = ArrayPool.Shared.Rent(Math.Max(arrayPoolArray.Length * 2, requiredCapacity)); @@ -19,6 +26,43 @@ internal static void ArrayPoolGrow(ref T[] arrayPoolArray, int requiredCapaci arrayPoolArray = tmp; } + internal static void ArrayPoolGrow(ref Span span, ref T[]? poolArray, int newSize) + { + Debug.Assert(span.Length <= newSize); + + T[] newPoolArray = ArrayPool.Shared.Rent(newSize); + span.CopyTo(newPoolArray); + + if (poolArray is not null) + { + ArrayPool.Shared.Return(poolArray); + } + + poolArray = newPoolArray; + span = poolArray; + } + + private static readonly int[] _oneCharLen = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; + + // Return length of a single UTF-8 source character + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static int OneCharLen(byte src) => _oneCharLen[(src & 0xFF) >> 4]; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static int GetUtf16LengthFromUtf8Bytes(ReadOnlySpan utf8Bytes) + { + int length = 0; + + while (utf8Bytes.Length > 0) + { + int bytesLength = OneCharLen(utf8Bytes[0]); + length += bytesLength == 4 ? 2 : 1; + utf8Bytes = utf8Bytes.Slice(Math.Min(bytesLength, utf8Bytes.Length)); + } + + return length; + } + internal static int EncodeToUtf8(ReadOnlySpan text, Span destination, Span indexMapping) { Debug.Assert(!text.IsEmpty); @@ -73,6 +117,44 @@ internal static int EncodeToUtf8(ReadOnlySpan text, Span destination return targetIndex; } + internal static int EncodeNextUtf8(ReadOnlySpan text, Span destination) + { + Debug.Assert(!text.IsEmpty); + Debug.Assert(destination.Length >= 4); + + uint c = (uint)text[0]; + if (c <= 0x7Fu) + { + destination[0] = (byte)c; + return 1; + } + + if (c <= 0x7FFu) + { + // Scalar 00000yyy yyxxxxxx -> bytes [ 110yyyyy 10xxxxxx ] + destination[0] = (byte)((c + (0b110u << 11)) >> 6); + destination[1] = (byte)((c & 0x3Fu) + 0x80u); + return 2; + } + + if (text.Length > 1 && char.IsSurrogatePair((char)c, text[1])) + { + // Scalar 000uuuuu zzzzyyyy yyxxxxxx -> bytes [ 11110uuu 10uuzzzz 10yyyyyy 10xxxxxx ] + uint value = (uint)char.ConvertToUtf32((char)c, text[1]); + destination[0] = (byte)((value + (0b11110 << 21)) >> 18); + destination[1] = (byte)(((value & (0x3Fu << 12)) >> 12) + 0x80u); + destination[2] = (byte)(((value & (0x3Fu << 6)) >> 6) + 0x80u); + destination[3] = (byte)((value & 0x3Fu) + 0x80u); + return 4; + } + + // Scalar zzzzyyyy yyxxxxxx -> bytes [ 1110zzzz 10yyyyyy 10xxxxxx ] + destination[0] = (byte)((c + (0b1110 << 16)) >> 12); + destination[1] = (byte)(((c & (0x3Fu << 6)) >> 6) + 0x80u); + destination[2] = (byte)((c & 0x3Fu) + 0x80u); + return 3; + } + internal static int EncodeToUtf8AndTransform(ReadOnlySpan text, Span destination, Span indexMapping) { Debug.Assert(!text.IsEmpty); @@ -130,7 +212,7 @@ internal static int EncodeToUtf8AndTransform(ReadOnlySpan text, Span public static bool ConvertUtf8ToUtf16(ReadOnlySpan utf8Bytes, Span utf16Chars, out int bytesConsumed, out int charsWritten) { - Debug.Assert(utf16Chars.Length >= Encoding.UTF8.GetMaxCharCount(utf8Bytes.Length)); + Debug.Assert(utf16Chars.Length >= GetUtf16LengthFromUtf8Bytes(utf8Bytes)); int byteIndex = 0; int charIndex = 0; @@ -210,5 +292,64 @@ public static bool ConvertUtf8ToUtf16(ReadOnlySpan utf8Bytes, Span u return true; } + + // encodedLength stores the number of bytes consumed after decoding. + internal static int DecodeUtf8(ReadOnlySpan input, out int encodedLength) + { + Debug.Assert(input.Length > 0); + + if (input[0] < 0x80) + { + encodedLength = 1; + return input[0]; + } + else if (input.Length >= 2 && (input[0] & 0xE0) == 0xC0) + { + int cp = (((input[0] & 0x1F) << 6) | ((input[1] & 0x3F))); + if (IsTrailByte(input[1]) && cp >= 0x0080 && IsValidCodepoint(cp)) + { + encodedLength = 2; + return cp; + } + } + else if (input.Length >= 3 && (input[0] & 0xF0) == 0xE0) + { + int cp = (((input[0] & 0x0F) << 12) | ((input[1] & 0x3F) << 6) | ((input[2] & 0x3F))); + if (IsTrailByte(input[1]) && IsTrailByte(input[2]) && cp >= 0x0800 && IsValidCodepoint(cp)) + { + encodedLength = 3; + return cp; + } + } + else if (input.Length >= 4 && (input[0] & 0xf8) == 0xF0) + { + int cp = (((input[0] & 0x07) << 18) | ((input[1] & 0x3F) << 12) | ((input[2] & 0x3F) << 6) | ((input[3] & 0x3F))); + if (IsTrailByte(input[1]) && IsTrailByte(input[2]) && IsTrailByte(input[3]) && cp >= 0x10000 && IsValidCodepoint(cp)) + { + encodedLength = 4; + return cp; + } + } + + // Invalid UTF-8. + encodedLength = 1; + return UnicodeError; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static bool IsTrailByte(byte x) => (sbyte)x < -0x40; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static bool IsValidCodepoint(int c) => ((uint)c) < 0xD800 || (c >= 0xE000 && c <= 0x10FFFF); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static bool IsValidDecodeUtf8(ReadOnlySpan input, out int encodedLength) + { + int c = DecodeUtf8(input, out encodedLength); + return c != UnicodeError || encodedLength == 3; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static uint Swap32(uint x) => ((x & 0x000000FF) << 24) | ((x & 0x0000FF00) << 8) | ((x & 0x00FF0000) >> 8) | ((x & 0xFF000000) >> 24); } } diff --git a/src/Microsoft.ML.Tokenizers/Utils/Helpers.netcoreapp.cs b/src/Microsoft.ML.Tokenizers/Utils/Helpers.netcoreapp.cs index dcff2e6d80..a7ce495033 100644 --- a/src/Microsoft.ML.Tokenizers/Utils/Helpers.netcoreapp.cs +++ b/src/Microsoft.ML.Tokenizers/Utils/Helpers.netcoreapp.cs @@ -12,7 +12,11 @@ using System.Threading; using System.Net.Http; +#if Test +namespace Microsoft.ML.Tokenizers.Tests +#else namespace Microsoft.ML.Tokenizers +#endif // Test { internal static partial class Helpers { @@ -62,6 +66,8 @@ internal static int GetChars(ReadOnlySpan bytes, Span chars) internal static void Replace(Span span, char oldValue, char newValue) => span.Replace(oldValue, newValue); + internal static void Replace(ReadOnlySpan source, Span destination, char oldValue, char newValue) => source.Replace(destination, oldValue, newValue); + /// /// Encode the next code point in the text to UTF-8. /// diff --git a/src/Microsoft.ML.Tokenizers/Utils/Helpers.netstandard.cs b/src/Microsoft.ML.Tokenizers/Utils/Helpers.netstandard.cs index 4824ccd67d..bafa8bf09f 100644 --- a/src/Microsoft.ML.Tokenizers/Utils/Helpers.netstandard.cs +++ b/src/Microsoft.ML.Tokenizers/Utils/Helpers.netstandard.cs @@ -10,7 +10,11 @@ using System.Threading; using System.Threading.Tasks; +#if Test +namespace Microsoft.ML.Tokenizers.Tests +#else namespace Microsoft.ML.Tokenizers +#endif // Test { internal static partial class Helpers { @@ -113,6 +117,16 @@ internal static void Replace(Span span, char oldValue, char newValue) span[i] = newValue; } + internal static void Replace(ReadOnlySpan source, Span destination, char oldValue, char newValue) + { + Debug.Assert(source.Length <= destination.Length); + + for (int i = 0; i < source.Length; i++) + { + destination[i] = source[i] == oldValue ? newValue : source[i]; + } + } + /// /// Encode the next code point in the text to UTF-8. /// diff --git a/src/Microsoft.ML.Tokenizers/Utils/OrdinalUtf8StringComparer.cs b/src/Microsoft.ML.Tokenizers/Utils/OrdinalUtf8StringComparer.cs new file mode 100644 index 0000000000..6b03eaf2b7 --- /dev/null +++ b/src/Microsoft.ML.Tokenizers/Utils/OrdinalUtf8StringComparer.cs @@ -0,0 +1,94 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Text; + +#if Test +namespace Microsoft.ML.Tokenizers.Tests +#else +namespace Microsoft.ML.Tokenizers +#endif // Test +{ + internal class OrdinalUtf8StringComparer : IComparer + { + internal static readonly OrdinalUtf8StringComparer Instance = new OrdinalUtf8StringComparer(); + public int Compare(string? x, string? y) + { + if (x == null || y == null) + { + return x == y ? 0 : (x == null ? -1 : 1); + } + + Span buffer1 = stackalloc byte[520]; + Span buffer2 = stackalloc byte[520]; + + int minLength = Math.Min(x.Length, y.Length); + for (int i = 0; i < minLength; i++) + { + char c = x[i]; + char d = y[i]; + + if (c == d) + { + continue; + } + + if (!Char.IsSurrogate(c) && !Char.IsSurrogate(d)) + { + return (int)x[i] - (int)y[i]; + } + + // Need to consider surrogate conversions to UTF-8 before comparing. + + while (i > 0 && (Char.IsSurrogate(x[i - 1]) || Char.IsSurrogate(y[i - 1]))) + { + i--; + } + + int xLen = x.Length - i; + int yLen = y.Length - i; + + byte[]? bytes1 = null; + byte[]? bytes2 = null; + + int requiredLength1 = Encoding.UTF8.GetMaxByteCount(xLen); + int requiredLength2 = Encoding.UTF8.GetMaxByteCount(yLen); + + if (requiredLength1 > buffer1.Length) + { + bytes1 = ArrayPool.Shared.Rent(requiredLength1); + buffer1 = bytes1; + } + + if (requiredLength2 > buffer2.Length) + { + bytes2 = ArrayPool.Shared.Rent(requiredLength2); + buffer2 = bytes2; + } + + xLen = Helpers.GetUtf8Bytes(x.AsSpan(i), buffer1); + yLen = Helpers.GetUtf8Bytes(y.AsSpan(i), buffer2); + + int result = buffer1.Slice(0, xLen).SequenceCompareTo(buffer2.Slice(0, yLen)); + + if (bytes1 != null) + { + ArrayPool.Shared.Return(bytes1); + } + + if (bytes2 != null) + { + ArrayPool.Shared.Return(bytes2); + } + + return result; + } + + return x.Length - y.Length; + } + } +} diff --git a/test/Microsoft.ML.Tokenizers.Tests/DoubleArrayTrieTest.cs b/test/Microsoft.ML.Tokenizers.Tests/DoubleArrayTrieTest.cs new file mode 100644 index 0000000000..f54c9461c8 --- /dev/null +++ b/test/Microsoft.ML.Tokenizers.Tests/DoubleArrayTrieTest.cs @@ -0,0 +1,107 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Text; +using Xunit; + +namespace Microsoft.ML.Tokenizers.Tests +{ + public class DoubleArrayTrieTests + { + private static readonly (string Key, int Value)[] _entries = new[] + { + // Higher ranges and surrogates + ("\uD83D\uDE00ab", 26), // \uF0B7 ordered before \uD83D\uDE00 in utf-8 + ("\uF0B7ab", 25), // \uF0B7 ordered before \uD83D\uDE00 in utf-8 + + // Different scripts + ("\u0627\u0644", 24), // Arabic + ("\u0391\u0393", 23), // Greek + + // Higher ranges and surrogates proceeded by Latin + ("a\uD83D\uDE00b", 22), // \uF0B7 ordered before \uD83D\uDE00 in utf-8 + ("a\uF0B7b", 21), // \uF0B7 ordered before \uD83D\uDE00 in utf-8 + + ("abcdefghijklmnopqrstu", 20), + ("abcdefghijklmnopqrst", 19), + ("abcdefghijklmnopqrs", 18), + ("abcdefghijklmnopqr", 17), + ("abcdefghijklmnopq", 16), + ("abcdefghijklmnop", 15), + ("abcdefghijklmno", 14), + ("abcdefghijklmn", 13), + ("abcdefghijklm", 12), + ("abcdefghijkl", 11), + ("abcdefghij", 10), + ("abcdefghi", 9), + ("abcdefgh", 8), + ("abcdefg", 7), + ("abcdef", 6), + ("abcde", 5), + ("abcd", 4), + ("abc", 3), + ("ab", 2), + ("a", 1) + }; + + [Fact] + public void DoubleArrayTrieTest() + { + SortedDictionary dict = new SortedDictionary(OrdinalUtf8StringComparer.Instance); + foreach (var (key, value) in _entries) + { + dict.Add(key, value); + } + + // + // Ensure expected order by OrdinalUtf8StringComparer + // + + int i = 1; + foreach (var kvp in dict) + { + Assert.Equal(i, kvp.Value); // Validate the sort order + i++; + } + + // + // test DoubleArrayTrie with prefix search + // + + DoubleArrayTrie trie = new DoubleArrayTrie(dict); + DoubleArrayResultPair[] doubleArrayResultPairs = new DoubleArrayResultPair[_entries.Length]; + + foreach (var (key, value) in _entries) + { + byte[] utf8Bytes = Encoding.UTF8.GetBytes(key); + int resultCount = trie.CommonPrefixSearch(utf8Bytes, doubleArrayResultPairs); + for (i = 0; i < resultCount; i++) + { + Assert.True(doubleArrayResultPairs[i].Value <= value); + Assert.StartsWith(Helpers.GetString(utf8Bytes.AsSpan(0, doubleArrayResultPairs[i].Length)), key, StringComparison.Ordinal); + } + } + + // + // test DoubleArrayTrie with travers search + // + + foreach (var (key, value) in _entries) + { + byte[] utf8Bytes = Encoding.UTF8.GetBytes(key); + + int nodePos = 0; + int keyPos = 0; + + int result = trie.Traverse(utf8Bytes, ref nodePos, ref keyPos, utf8Bytes.Length); + + Assert.True(trie.ArrayUnits[nodePos].HasLeaf); + Assert.Equal(utf8Bytes.Length, keyPos); + Assert.Equal(value, result); + } + } + } +} diff --git a/test/Microsoft.ML.Tokenizers.Tests/Microsoft.ML.Tokenizers.Tests.csproj b/test/Microsoft.ML.Tokenizers.Tests/Microsoft.ML.Tokenizers.Tests.csproj index 8e1f741552..3297c8a70f 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/Microsoft.ML.Tokenizers.Tests.csproj +++ b/test/Microsoft.ML.Tokenizers.Tests/Microsoft.ML.Tokenizers.Tests.csproj @@ -5,6 +5,8 @@ Test $(NoWarn);MSML_ExtendBaseTestClass enable + true + $(DefineConstants);Test @@ -25,6 +27,18 @@ + + + + + + + + + + + + diff --git a/test/Microsoft.ML.Tokenizers.Tests/UnigramTests.cs b/test/Microsoft.ML.Tokenizers.Tests/UnigramTests.cs new file mode 100644 index 0000000000..b90ab7a414 --- /dev/null +++ b/test/Microsoft.ML.Tokenizers.Tests/UnigramTests.cs @@ -0,0 +1,514 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Reflection.Metadata; +using Microsoft.ML.Tokenizers; +using Xunit; + +namespace Microsoft.ML.Tokenizers.Tests +{ + public class UnigramTests + { + private static SentencePieceTokenizer _unigramTokenizer = CreateUnigramTokenizer(); + private static SentencePieceTokenizer _unigramTokenizerWithSpecialTokens = CreateUnigramTokenizerWithSpecialTokens(); + + private static SentencePieceTokenizer CreateUnigramTokenizer() + { + // @"https://huggingface.co/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2/resolve/main/sentencepiece.bpe.model?download=true"; + using Stream remoteStream = File.OpenRead(Path.Combine(@"Paraphrase-multilingual-MiniLM-L12-v2", "sentencepiece.bpe.model")); + return SentencePieceTokenizer.Create(remoteStream); + } + + private static SentencePieceTokenizer CreateUnigramTokenizerWithSpecialTokens() + { + // @"https://huggingface.co/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2/resolve/main/sentencepiece.bpe.model?download=true"; + using Stream remoteStream = File.OpenRead(Path.Combine(@"Paraphrase-multilingual-MiniLM-L12-v2", "sentencepiece.bpe.model")); + return SentencePieceTokenizer.Create(remoteStream, specialTokens: + new Dictionary + { + { "", 0 }, + { "", 1 }, + { "", 2 }, + { "", 7 }, + { "", 8 }, + }); + } + + public static IEnumerable UnigramTestData() + { + // tokenizer, input text, normalized text, decoded text, ids, tokens, offsets + yield return new object[] + { + "Hello, world!", + "▁Hello,▁world!", + "Hello, world!", + new int[] { 35377, 3, 8998, 37 }, + new string[] { "▁Hello", ",", "▁world", "!" }, + new Range[] { new Range(0, 6), new Range(6, 7), new Range(7, 13), new Range(13, 14) } + }; + + yield return new object[] + { + "Hello, ①カタカナff⁰⅓Ⅳ \U00010200 world! \uD800\uDE00", // include normalization and unknown characters + "▁Hello,▁1カタカナff01⁄3IV▁\U00010200▁world!▁\U00010200", + "Hello, 1カタカナff01⁄3IV world! ", + new int[] { 35377, 3, 105, 10044, 10792, 10044, 17455, 4901, 6745, 244258, 362, 15582, 5, 0, 8998, 37, 5, 0 }, // Unknown Id is 0 + new string[] { "▁Hello", ",", "▁1", "カ", "タ", "カ", "ナ", "ff", "01", "⁄", "3", "IV", "▁", "\U00010200", "▁world", "!", "▁", "\U00010200" }, + new Range[] + { + new Range(0, 6), new Range(6, 7), new Range(7, 9), new Range(9, 10), new Range(10, 11), new Range(11, 12), + new Range(12, 13), new Range(13, 15), new Range(15, 17), new Range(17, 18), new Range(18, 19), new Range(19, 21), + new Range(21, 22), new Range(22, 24), new Range(24, 30), new Range(30, 31), new Range(31, 32), new Range(32, 34) + } + }; + + yield return new object[] + { + "", + "", + "", + new int[0], + new string[0], + new Range[0] + }; + + yield return new object[] + { + @"The sun dipped below the horizon, casting a warm golden hue across the tranquil meadow. Birds fluttered from " + + "tree to tree, their melodic songs filling the air. A gentle breeze rustled the leaves, carrying with it the scent of " + + "blooming flowers. In the distance, the silhouette of a lone figure stood atop a hill, gazing out at the vast expanse " + + "before them. It was a moment frozen in time, where nature and solitude merged into something magical.", + + "▁The▁sun▁dipped▁below▁the▁horizon,▁casting▁a▁warm▁golden▁hue▁across▁the▁tranquil▁meadow.▁Birds▁fluttered▁from▁tree▁to▁tree,▁their" + + "▁melodic▁songs▁filling▁the▁air.▁A▁gentle▁breeze▁rustled▁the▁leaves,▁carrying▁with▁it▁the▁scent▁of▁blooming▁flowers.▁In▁the▁distance" + + ",▁the▁silhouette▁of▁a▁lone▁figure▁stood▁atop▁a▁hill,▁gazing▁out▁at▁the▁vast▁expanse▁before▁them.▁It▁was▁a▁moment▁frozen▁in▁time,▁" + + "where▁nature▁and▁solitude▁merged▁into▁something▁magical.", + + @"The sun dipped below the horizon, casting a warm golden hue across the tranquil meadow. Birds fluttered from " + + "tree to tree, their melodic songs filling the air. A gentle breeze rustled the leaves, carrying with it the scent of " + + "blooming flowers. In the distance, the silhouette of a lone figure stood atop a hill, gazing out at the vast expanse " + + "before them. It was a moment frozen in time, where nature and solitude merged into something magical.", + + new int[] + { + 580, 4261, 44, 48397, 35063, 69, 5, 156633, 3, 176049, 9, 24813, 158043, 78023, 36879, 69, 46193, 10547, 24292, 4, + 72606, 6, 139099, 55, 296, 1294, 53200, 46, 53200, 3, 2362, 43670, 237, 52335, 26291, 213, 69, 1830, 4, 61, 21506, + 132, 12561, 6658, 52647, 6258, 69, 31357, 6, 3, 85357, 213, 677, 441, 69, 25453, 17, 110, 29694, 305, 213, 189066, + 4, 359, 69, 62487, 3, 69, 5794, 13884, 8675, 110, 9, 458, 85, 26365, 192941, 9, 13783, 9, 130472, 3, 13958, 213, + 1809, 98, 69, 18409, 14699, 20539, 8107, 2855, 4, 1649, 508, 9, 3094, 1237, 70462, 22, 1732, 3, 7439, 31424, 135, + 3114, 21752, 12, 42563, 70, 3933, 9843, 49845, 288, 4 + }, + + new string[] + { + "▁The", "▁sun", "▁di", "pped", "▁below", "▁the", "▁", "horizon", ",", "▁casting", "▁a", "▁warm", "▁golden", "▁hue", + "▁across", "▁the", "▁tranquil", "▁mea", "dow", ".", "▁Bird", "s", "▁flutt", "er", "ed", "▁from", "▁tree", "▁to", "▁tree", + ",", "▁their", "▁melodi", "c", "▁songs", "▁fill", "ing", "▁the", "▁air", ".", "▁A", "▁gent", "le", "▁bre", "eze", "▁rust", + "led", "▁the", "▁leave", "s", ",", "▁carry", "ing", "▁with", "▁it", "▁the", "▁scen", "t", "▁of", "▁blo", "om", "ing", + "▁flowers", ".", "▁In", "▁the", "▁distance", ",", "▁the", "▁sil", "hou", "ette", "▁of", "▁a", "▁lo", "ne", "▁figure", + "▁stood", "▁a", "top", "▁a", "▁hill", ",", "▁gaz", "ing", "▁out", "▁at", "▁the", "▁vast", "▁exp", "anse", "▁before", + "▁them", ".", "▁It", "▁was", "▁a", "▁moment", "▁f", "rozen", "▁in", "▁time", ",", "▁where", "▁nature", "▁and", "▁sol", + "itud", "e", "▁merge", "d", "▁into", "▁something", "▁magic", "al", "." + }, + + new Range[] + { + new Range(0, 4), new Range(4, 8), new Range(8, 11), new Range(11, 15), new Range(15, 21), new Range(21, 25), + new Range(25, 26), new Range(26, 33), new Range(33, 34), new Range(34, 42), new Range(42, 44), new Range(44, 49), new Range(49, 56), + new Range(56, 60), new Range(60, 67), new Range(67, 71), new Range(71, 80), new Range(80, 84), new Range(84, 87), new Range(87, 88), + new Range(88, 93), new Range(93, 94), new Range(94, 100), new Range(100, 102), new Range(102, 104), new Range(104, 109), new Range(109, 114), + new Range(114, 117), new Range(117, 122), new Range(122, 123), new Range(123, 129), new Range(129, 136), new Range(136, 137), + new Range(137, 143), new Range(143, 148), new Range(148, 151), new Range(151, 155), new Range(155, 159), new Range(159, 160), + new Range(160, 162), new Range(162, 167), new Range(167, 169), new Range(169, 173), new Range(173, 176), new Range(176, 181), + new Range(181, 184), new Range(184, 188), new Range(188, 194), new Range(194, 195), new Range(195, 196), new Range(196, 202), + new Range(202, 205), new Range(205, 210), new Range(210, 213), new Range(213, 217), new Range(217, 222), new Range(222, 223), + new Range(223, 226), new Range(226, 230), new Range(230, 232), new Range(232, 235), new Range(235, 243), new Range(243, 244), + new Range(244, 247), new Range(247, 251), new Range(251, 260), new Range(260, 261), new Range(261, 265), new Range(265, 269), + new Range(269, 272), new Range(272, 276), new Range(276, 279), new Range(279, 281), new Range(281, 284), new Range(284, 286), + new Range(286, 293), new Range(293, 299), new Range(299, 301), new Range(301, 304), new Range(304, 306), new Range(306, 311), + new Range(311, 312), new Range(312, 316), new Range(316, 319), new Range(319, 323), new Range(323, 326), new Range(326, 330), + new Range(330, 335), new Range(335, 339), new Range(339, 343), new Range(343, 350), new Range(350, 355), new Range(355, 356), + new Range(356, 359), new Range(359, 363), new Range(363, 365), new Range(365, 372), new Range(372, 374), new Range(374, 379), + new Range(379, 382), new Range(382, 387), new Range(387, 388), new Range(388, 394), new Range(394, 401), new Range(401, 405), + new Range(405, 409), new Range(409, 413), new Range(413, 414), new Range(414, 420), new Range(420, 421), new Range(421, 426), + new Range(426, 436), new Range(436, 442), new Range(442, 444), new Range(444, 445) + } + }; + + yield return new object[] + { + "This is 👍, an emoji.", + "▁This▁is▁👍,▁an▁emoji.", + "This is 👍, an emoji.", + new int[] { 3292, 82, 5, 118279, 3, 141, 27, 121504, 4 }, + new string[] { "▁This", "▁is", "▁", "👍", ",", "▁an", "▁e", "moji", "." }, + new Range[] { new Range(0, 5), new Range(5, 8), new Range(8, 9), new Range(9, 11), new Range(11, 12), new Range(12, 15), new Range(15, 17), new Range(17, 21), new Range(21, 22) } + }; + + yield return new object[] + { + "清水寺は京都にある。", // Japanese + "▁清水寺は京都にある。", + "清水寺は京都にある。", + new int[] { 5, 177585, 32566, 341, 60423, 24432, 29 }, + new string[] { "▁", "清水", "寺", "は", "京都", "にある", "。" }, + new Range[] { new Range(0, 1), new Range(1, 3), new Range(3, 4), new Range(4, 5), new Range(5, 7), new Range(7, 10), new Range(10, 11) } + }; + + yield return new object[] + { + "xyz東京", // Latin-Japanese + "▁xyz東京", + "xyz東京", + new int[] { 1021, 32188, 22887 }, + new string[] { "▁x", "yz", "東京" }, + new Range[] { new Range(0, 2), new Range(2, 4), new Range(4, 6) } + }; + + yield return new object[] + { + "㍻", // Japanese with normalization + "▁平成", + "平成", + new int[] { 5, 44405 }, + new string[] { "▁", "平成" }, + new Range[] { new Range(0, 1), new Range(1, 3) } + }; + + yield return new object[] + { + "KADOKAWAABC", // Full-width Latin to normalize to normal width + "▁KADOKAWAABC", + "KADOKAWAABC", + new int[] { 340, 41387, 218268, 186943 }, + new string[] { "▁K", "ADO", "KAWA", "ABC" }, + new Range[] { new Range(0, 2), new Range(2, 5), new Range(5, 9), new Range(9, 12) } + }; + + yield return new object[] + { + "ℌ𝔢𝔩𝔩𝔬 𝔚𝔬𝔯𝔩𝔡!", // Gothic script + "▁Hello▁World!", + "Hello World!", + new int[] { 35377, 6660, 37 }, + new string[] { "▁Hello", "▁World", "!" }, + new Range[] { new Range(0, 6), new Range(6, 12), new Range(12, 13) } + }; + + yield return new object[] + { + "𝛢𝛷𝛢𝛪𝛯𝛪", // Greek script + "▁ΑΦΑΙΞΙ", + "ΑΦΑΙΞΙ", + new int[] { 3866, 203768, 15470, 72125, 15470 }, + new string[] { "▁Α", "ΦΑ", "Ι", "Ξ", "Ι" }, + new Range[] { new Range(0, 2), new Range(2, 4), new Range(4, 5), new Range(5, 6), new Range(6, 7) } + }; + + yield return new object[] + { + "𝖘𝖙𝖗𝖆𝖓𝖎𝖈𝖆", // Russian script + "▁stranica", + "stranica", + new int[] { 60133 }, + new string[] { "▁stranica" }, + new Range[] { new Range(0, 9) } + }; + + yield return new object[] + { + "老師", // Chinese + "▁老師", + "老師", + new int[] { 5, 25924 }, + new string[] { "▁", "老師" }, + new Range[] { new Range(0, 1), new Range(1, 3) } + }; + } + + private (IEnumerable Ids, IEnumerable Tokens, IEnumerable Offsets) ExtractedIds( + SentencePieceTokenizer tokenizer, + IReadOnlyList tokens, + string? normalized, + bool addBeginningOfSentence, + bool addEndOfSentence) + { + List writableTokens = tokens.ToList(); + if (addBeginningOfSentence && writableTokens.Count > 0) + { + Assert.True(writableTokens[0].Id == tokenizer.BeginningOfSentenceId); + Assert.True(writableTokens[0].Value == tokenizer.BeginningOfSentenceToken); + Assert.True(writableTokens[0].Offset.Equals(new Range(0, 0))); + writableTokens.RemoveAt(0); + } + + if (addEndOfSentence && writableTokens.Count > 0) + { + Assert.True(writableTokens[writableTokens.Count - 1].Id == tokenizer.EndOfSentenceId); + Assert.True(writableTokens[writableTokens.Count - 1].Value == tokenizer.EndOfSentenceToken); + + if (normalized is not null) + { + Assert.True(writableTokens[writableTokens.Count - 1].Offset.Equals(new Range(normalized.Length, normalized.Length))); + } + writableTokens.RemoveAt(writableTokens.Count - 1); + } + + return ( + writableTokens.Select(t => t.Id), + writableTokens.Select(t => t.Value), + writableTokens.Select(t => t.Offset) + ); + } + + private void Validate((IEnumerable Ids, IEnumerable Tokens, IEnumerable Offsets) extracted, int[] ids, string[] tokens, Range[] offsets) + { + Assert.Equal(ids, extracted.Ids); + Assert.Equal(tokens, extracted.Tokens); + Assert.Equal(offsets, extracted.Offsets); + } + + [Theory] + [MemberData(nameof(UnigramTestData))] + public void EncodeToTokensTest(string inputText, string normalizedText, string decodedString, int[] ids, string[] tokens, Range[] offsets) + { + Assert.True(decodedString is not null); // to make the compiler happy + IReadOnlyList result = _unigramTokenizer.EncodeToTokens(inputText, out string? normalized); + (IEnumerable Ids, IEnumerable Tokens, IEnumerable Offsets) extracted = ExtractedIds(_unigramTokenizer, result, normalizedText, _unigramTokenizer.AddBeginningOfSentence, _unigramTokenizer.AddEndOfSentence); + Validate(extracted, ids, tokens, offsets); + + result = _unigramTokenizer.EncodeToTokens(inputText.AsSpan(), out normalized); + extracted = ExtractedIds(_unigramTokenizer, result, normalizedText, _unigramTokenizer.AddBeginningOfSentence, _unigramTokenizer.AddEndOfSentence); + Validate(extracted, ids, tokens, offsets); + + result = _unigramTokenizer.EncodeToTokens(inputText, out normalized, addBeginningOfSentence: true, addEndOfSentence: false); + extracted = ExtractedIds(_unigramTokenizer, result, normalizedText, true, false); + Validate(extracted, ids, tokens, offsets); + + result = _unigramTokenizer.EncodeToTokens(inputText.AsSpan(), out normalized, addBeginningOfSentence: true, addEndOfSentence: false); + extracted = ExtractedIds(_unigramTokenizer, result, normalizedText, true, false); + Validate(extracted, ids, tokens, offsets); + + result = _unigramTokenizer.EncodeToTokens(inputText, out normalized, addBeginningOfSentence: true, addEndOfSentence: true); + extracted = ExtractedIds(_unigramTokenizer, result, normalizedText, true, true); + Validate(extracted, ids, tokens, offsets); + + result = _unigramTokenizer.EncodeToTokens(inputText.AsSpan(), out normalized, addBeginningOfSentence: true, addEndOfSentence: true); + extracted = ExtractedIds(_unigramTokenizer, result, normalizedText, true, true); + Validate(extracted, ids, tokens, offsets); + + string newString = $"{_unigramTokenizer.BeginningOfSentenceToken}{inputText}{inputText}{_unigramTokenizer.EndOfSentenceToken}"; + result = _unigramTokenizerWithSpecialTokens.EncodeToTokens(newString, out normalized, addBeginningOfSentence: false, addEndOfSentence: false); + extracted = ExtractedIds(_unigramTokenizerWithSpecialTokens, result, normalizedText, false, false); + + int[] expectedIds = new int[ids.Length * 2 + 3]; + expectedIds[0] = _unigramTokenizerWithSpecialTokens.BeginningOfSentenceId; + Array.Copy(ids, 0, expectedIds, 1, ids.Length); + expectedIds[ids.Length + 1] = _unigramTokenizerWithSpecialTokens.SpecialTokens![""]; + Array.Copy(ids, 0, expectedIds, ids.Length + 2, ids.Length); + expectedIds[ids.Length * 2 + 2] = _unigramTokenizerWithSpecialTokens.EndOfSentenceId; + Assert.Equal(expectedIds, extracted.Ids); + + string[] expectedTokens = new string[tokens.Length * 2 + 3]; + expectedTokens[0] = _unigramTokenizerWithSpecialTokens.BeginningOfSentenceToken; + Array.Copy(tokens, 0, expectedTokens, 1, tokens.Length); + expectedTokens[tokens.Length + 1] = ""; + Array.Copy(tokens, 0, expectedTokens, tokens.Length + 2, tokens.Length); + expectedTokens[tokens.Length * 2 + 2] = _unigramTokenizerWithSpecialTokens.EndOfSentenceToken; + Assert.Equal(expectedTokens, extracted.Tokens); + } + + [Theory] + [MemberData(nameof(UnigramTestData))] + public void EncodeToIdsTest(string inputText, string normalizedText, string decodedString, int[] ids, string[] tokens, Range[] offsets) + { + Assert.True(decodedString is not null); // to make the compiler happy + Assert.True(tokens is not null); // to make the compiler happy + Assert.True(offsets is not null); // to make the compiler happy + + IReadOnlyList result = _unigramTokenizer.EncodeToIds(inputText, addBeginningOfSentence: false, addEndOfSentence: false); + Assert.Equal(ids, result); + result = _unigramTokenizer.EncodeToIds(inputText.AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false); + Assert.Equal(ids, result); + + result = _unigramTokenizer.EncodeToIds(inputText, addBeginningOfSentence: true, addEndOfSentence: false); + List ints = result is List list ? list : result.ToList(); + if (ints.Count > 0) + { + ints.RemoveAt(0); + } + Assert.Equal(ids, ints); + + result = _unigramTokenizer.EncodeToIds(inputText.AsSpan(), addBeginningOfSentence: true, addEndOfSentence: false); + ints = result is List ? (List)result : result.ToList(); + if (ints.Count > 0) + { + ints.RemoveAt(0); + } + Assert.Equal(ids, ints); + + result = _unigramTokenizer.EncodeToIds(inputText, addBeginningOfSentence: true, addEndOfSentence: true); + ints = result is List ? (List)result : result.ToList(); + if (ints.Count > 0) + { + ints.RemoveAt(0); + ints.RemoveAt(ints.Count - 1); + } + Assert.Equal(ids, ints); + + result = _unigramTokenizer.EncodeToIds(inputText.AsSpan(), addBeginningOfSentence: true, addEndOfSentence: true); + ints = result is List ? (List)result : result.ToList(); + if (ints.Count > 0) + { + ints.RemoveAt(0); + ints.RemoveAt(ints.Count - 1); + } + Assert.Equal(ids, ints); + + for (int i = 1; i <= ids.Length; i++) + { + result = _unigramTokenizer.EncodeToIds(inputText, addBeginningOfSentence: false, addEndOfSentence: false, maxTokenCount: i, out string? normalized, out int charConsumed); + Assert.Equal(ids.Take(i), result); + Assert.Equal(normalizedText, normalized); + + result = _unigramTokenizer.EncodeToIds(inputText.AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false, maxTokenCount: i, out normalized, out charConsumed); + Assert.Equal(ids.Take(i), result); + Assert.Equal(normalizedText, normalized); + + result = _unigramTokenizer.EncodeToIds(inputText, addBeginningOfSentence: true, addEndOfSentence: true, maxTokenCount: i, out normalized, out charConsumed); + ints = result is List ? (List)result : result.ToList(); + if (ints.Count > 0) + { + ints.RemoveAt(0); + } + if (ints.Count > ids.Length) + { + ints.RemoveAt(ints.Count - 1); + } + Assert.Equal(ids.Take(i - 1), ints); // Exclude the counted BoS token + if (normalized is not null) + { + Assert.Equal(normalizedText, normalized); + } + + result = _unigramTokenizer.EncodeToIds(inputText.AsSpan(), addBeginningOfSentence: true, addEndOfSentence: true, maxTokenCount: i, out normalized, out charConsumed); + ints = result is List ? (List)result : result.ToList(); + if (ints.Count > 0) + { + ints.RemoveAt(0); + } + if (ints.Count > ids.Length) + { + ints.RemoveAt(ints.Count - 1); + } + Assert.Equal(ids.Take(i - 1), ints); // Exclude the counted BoS token + if (normalized is not null) + { + Assert.Equal(normalizedText, normalized); + } + } + + inputText = $"{_unigramTokenizerWithSpecialTokens.BeginningOfSentenceToken}{inputText}{inputText}{_unigramTokenizerWithSpecialTokens.EndOfSentenceToken}"; + int[] expectedIds = new int[ids.Length * 2 + 3]; + expectedIds[0] = _unigramTokenizerWithSpecialTokens.BeginningOfSentenceId; + Array.Copy(ids, 0, expectedIds, 1, ids.Length); + expectedIds[ids.Length + 1] = _unigramTokenizerWithSpecialTokens.SpecialTokens![""]; + Array.Copy(ids, 0, expectedIds, ids.Length + 2, ids.Length); + expectedIds[ids.Length * 2 + 2] = _unigramTokenizerWithSpecialTokens.EndOfSentenceId; + string expectedNormalized = $"{_unigramTokenizerWithSpecialTokens.BeginningOfSentenceToken}{normalizedText}{normalizedText}{_unigramTokenizerWithSpecialTokens.EndOfSentenceToken}"; + + for (int i = 1; i <= expectedIds.Length; i++) + { + result = _unigramTokenizerWithSpecialTokens.EncodeToIds(inputText, addBeginningOfSentence: false, addEndOfSentence: false, maxTokenCount: i, out string? normalized, out int charConsumed); + Assert.Equal(expectedIds.Take(i), result); + Assert.Equal(expectedNormalized, normalized); + + result = _unigramTokenizerWithSpecialTokens.EncodeToIds(inputText.AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false, maxTokenCount: i, out normalized, out charConsumed); + Assert.Equal(expectedIds.Take(i), result); + Assert.Equal(expectedNormalized, normalized); + } + } + + [Theory] + [MemberData(nameof(UnigramTestData))] + public void GetIndexByTokenCountTest(string inputText, string normalizedText, string decodedString, int[] ids, string[] tokens, Range[] offsets) + { + Assert.True(decodedString is not null); // to make the compiler happy + Assert.True(tokens is not null); // to make the compiler happy + Assert.True(offsets is not null); // to make the compiler happy + + int totalTokens = ids.Length; + + for (int i = 1; i <= totalTokens; i++) + { + int index = _unigramTokenizer.GetIndexByTokenCount(inputText, addBeginningOfSentence: false, addEndOfSentence: false, maxTokenCount: 1, out string? normalized, out int charConsumed); + Assert.Equal(normalizedText, normalized); + IReadOnlyList ids1 = _unigramTokenizer.EncodeToIds(normalized!.Substring(0, index), addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false); + IReadOnlyList ids2 = index < normalized.Length ? _unigramTokenizer.EncodeToIds(normalized!.Substring(index), addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false) : new List(); + Assert.Equal(ids, ids1.Concat(ids2).ToList()); + + index = _unigramTokenizer.GetIndexByTokenCount(inputText.AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false, maxTokenCount: 1, out normalized, out charConsumed); + Assert.Equal(normalizedText, normalized); + ids1 = _unigramTokenizer.EncodeToIds(normalized!.Substring(0, index).AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false); + ids2 = index < normalized.Length ? _unigramTokenizer.EncodeToIds(normalized!.Substring(index).AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false) : new List(); + Assert.Equal(ids, ids1.Concat(ids2).ToList()); + + index = _unigramTokenizer.GetIndexByTokenCountFromEnd(inputText, addBeginningOfSentence: false, addEndOfSentence: false, maxTokenCount: 1, considerNormalization: true, out normalized, out charConsumed); + Assert.Equal(normalizedText, normalized); + ids1 = _unigramTokenizer.EncodeToIds(normalized!.Substring(0, index), addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false); + ids2 = index < normalized.Length ? _unigramTokenizer.EncodeToIds(normalized!.Substring(index), addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false) : new List(); + Assert.Equal(ids, ids1.Concat(ids2).ToList()); + + index = _unigramTokenizer.GetIndexByTokenCountFromEnd(inputText.AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false, maxTokenCount: 1, considerNormalization: true, out normalized, out charConsumed); + Assert.Equal(normalizedText, normalized); + ids1 = _unigramTokenizer.EncodeToIds(normalized!.Substring(0, index).AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false); + ids2 = index < normalized.Length ? _unigramTokenizer.EncodeToIds(normalized!.Substring(index).AsSpan(), addBeginningOfSentence: false, addEndOfSentence: false, considerNormalization: false) : new List(); + Assert.Equal(ids, ids1.Concat(ids2).ToList()); + } + } + + [Theory] + [MemberData(nameof(UnigramTestData))] + public void DecodeTest(string inputText, string normalizedText, string decodedString, int[] ids, string[] tokens, Range[] offsets) + { + Assert.True(tokens is not null); // to make the compiler happy + Assert.True(offsets is not null); // to make the compiler happy + Assert.True(inputText is not null); // to make the compiler happy + Assert.True(normalizedText is not null); // to make the compiler happy + + string result = _unigramTokenizer.Decode(ids, considerSpecialTokens: false); + Assert.Equal(decodedString, result); + + char[] buffer = new char[decodedString.Length]; + + OperationStatus status = _unigramTokenizer.Decode(ids, buffer, considerSpecialTokens: false, out int idsConsumed, out int charsWritten); + Assert.Equal(OperationStatus.Done, status); + Assert.Equal(ids.Length, idsConsumed); + Assert.Equal(decodedString, buffer.AsSpan().Slice(0, charsWritten).ToString()); + + for (int i = 0; i < decodedString.Length - 1; i++) + { + status = _unigramTokenizer.Decode(ids, buffer.AsSpan().Slice(0, i), considerSpecialTokens: false, out idsConsumed, out charsWritten); + Assert.Equal(OperationStatus.DestinationTooSmall, status); + Assert.Equal(decodedString.AsSpan().Slice(0, charsWritten).ToString(), buffer.AsSpan().Slice(0, charsWritten).ToString()); + } + } + + [Fact] + public void SpecialTokensTest() + { + Assert.Equal("", _unigramTokenizer.UnknownToken); + Assert.Equal(0, _unigramTokenizer.UnknownId); + Assert.Equal("", _unigramTokenizer.BeginningOfSentenceToken); + Assert.Equal(1, _unigramTokenizer.BeginningOfSentenceId); + Assert.Equal("", _unigramTokenizer.EndOfSentenceToken); + Assert.Equal(2, _unigramTokenizer.EndOfSentenceId); + } + } +}