diff --git a/THIRD-PARTY-NOTICES.TXT b/THIRD-PARTY-NOTICES.TXT
index 3bc1463084..47f9d3cd1d 100644
--- a/THIRD-PARTY-NOTICES.TXT
+++ b/THIRD-PARTY-NOTICES.TXT
@@ -133,6 +133,27 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.!
+License notice for DoubleArrayTrie (DART)
+--------------------------------------------
+The BSD 2-clause license
+
+https://github.com/s-yata/darts-clone/blob/master/COPYING.md
+
+Copyright (c) 2008-2014, Susumu Yata All rights reserved.
+
+Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
+
+Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
+Redistributions in binary form must reproduce the above copyright notice,
+this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
+INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
+IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
+OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
+WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
+THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
License notice for CodeGen Tokenizer
--------------------------------------------
diff --git a/eng/Versions.props b/eng/Versions.props
index 0cc944a243..920c9a766e 100644
--- a/eng/Versions.props
+++ b/eng/Versions.props
@@ -100,7 +100,7 @@
0.0.13-test
0.0.6-test
0.0.7-test
- 2.0.0-beta.24455.2
+ 2.0.0-beta.25110.1
4.9.0
1.0.118
1.6.24
diff --git a/src/Microsoft.ML.Tokenizers/Model/CodeGenTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/CodeGenTokenizer.cs
index b3ee022ad3..ac7ea9e6f2 100644
--- a/src/Microsoft.ML.Tokenizers/Model/CodeGenTokenizer.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/CodeGenTokenizer.cs
@@ -22,6 +22,9 @@ namespace Microsoft.ML.Tokenizers
///
public class CodeGenTokenizer : Tokenizer
{
+ // The CodeGen tokenizer implementation is primarily adapted from
+ // https://github.com/huggingface/transformers/blob/main/src/transformers/models/codegen/tokenization_codegen.py,
+ // with modifications to align with C# code style, the API, and the tokenizer library design.
private readonly Dictionary _vocab;
private IReadOnlyDictionary? _vocabOriginal;
private readonly IReadOnlyDictionary _vocabReverse;
diff --git a/src/Microsoft.ML.Tokenizers/Model/LlamaTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/LlamaTokenizer.cs
index e5c5ca4e70..57ab57b13a 100644
--- a/src/Microsoft.ML.Tokenizers/Model/LlamaTokenizer.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/LlamaTokenizer.cs
@@ -31,7 +31,7 @@ internal LlamaTokenizer(ModelProto modelProto, bool addBos, bool addEos, IReadOn
///
/// When creating the tokenizer, ensure that the vocabulary stream is sourced from a trusted provider.
///
- public static LlamaTokenizer Create(
+ public static new LlamaTokenizer Create(
Stream modelStream,
bool addBeginOfSentence = true,
bool addEndOfSentence = false,
@@ -54,13 +54,6 @@ public static LlamaTokenizer Create(
throw new ArgumentException($"Normalization '{modelProto.NormalizerSpec.Name}' is not supported.", nameof(modelProto));
}
- SentencePieceNormalizer normalizer = new(
- modelProto.NormalizerSpec.RemoveExtraWhitespaces,
- modelProto.NormalizerSpec.AddDummyPrefix,
- modelProto.NormalizerSpec.EscapeWhitespaces,
- modelProto.TrainerSpec.TreatWhitespaceAsSuffix,
- specialTokens);
-
return new LlamaTokenizer(modelProto, addBeginOfSentence, addEndOfSentence, specialTokens);
}
}
diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceBaseModel.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceBaseModel.cs
new file mode 100644
index 0000000000..a1553ec4cd
--- /dev/null
+++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceBaseModel.cs
@@ -0,0 +1,754 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Sentencepiece;
+using System;
+using System.Buffers;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+using System.Text;
+using System.Text.RegularExpressions;
+
+namespace Microsoft.ML.Tokenizers
+{
+ internal abstract class SentencePieceBaseModel
+ {
+ internal SentencePieceBaseModel(ModelProto modelProto, bool addBos = false, bool addEos = false, IReadOnlyDictionary? specialTokens = null)
+ {
+ if (modelProto is null)
+ {
+ throw new ArgumentNullException(nameof(modelProto));
+ }
+
+ AddBeginningOfSentence = addBos;
+ AddEndOfSentence = addEos;
+ BeginningOfSentenceToken = modelProto.TrainerSpec.BosPiece ?? "";
+ BeginningOfSentenceId = modelProto.TrainerSpec.BosId <= 0 ? 1 : modelProto.TrainerSpec.BosId;
+ EndOfSentenceToken = modelProto.TrainerSpec.EosPiece ?? "";
+ EndOfSentenceId = modelProto.TrainerSpec.EosId <= 0 ? 1 : modelProto.TrainerSpec.EosId;
+ UnknownToken = modelProto.TrainerSpec.UnkPiece ?? "";
+ UnknownId = modelProto.TrainerSpec.UnkId < 0 ? 0 : modelProto.TrainerSpec.UnkId;
+ AddDummyPrefix = modelProto.NormalizerSpec.AddDummyPrefix;
+ EscapeWhiteSpaces = modelProto.NormalizerSpec.EscapeWhitespaces;
+ TreatWhitespaceAsSuffix = modelProto.TrainerSpec.TreatWhitespaceAsSuffix;
+ ByteFallback = modelProto.TrainerSpec.ByteFallback;
+ SpecialTokens = specialTokens;
+
+ if (specialTokens is not null && specialTokens.Count > 0)
+ {
+ InternalSpecialTokens = new Dictionary();
+ SpecialTokensReverse = new Dictionary();
+
+ foreach (var item in specialTokens)
+ {
+ InternalSpecialTokens.Add(new StringSpanOrdinalKey(item.Key), item.Value);
+ SpecialTokensReverse.Add(item.Value, item.Key);
+ }
+
+ // We create this Regex object without a timeout, as we expect the match operation to complete in O(N) time complexity. Note that `specialTokens` are treated as constants after the tokenizer is created.
+ SpecialTokensRegex = new Regex(string.Join("|", specialTokens.Keys.Select(s => Regex.Escape(s))), RegexOptions.Compiled);
+ }
+
+ Normalizer = new SentencePieceNormalizer(
+ modelProto.NormalizerSpec.PrecompiledCharsmap.Span,
+ modelProto.NormalizerSpec.RemoveExtraWhitespaces,
+ AddDummyPrefix, EscapeWhiteSpaces,
+ modelProto.TrainerSpec.TreatWhitespaceAsSuffix,
+ specialTokens);
+ }
+
+ internal Regex? SpecialTokensRegex { get; }
+
+ internal Dictionary? InternalSpecialTokens { get; }
+
+ internal Dictionary? SpecialTokensReverse { get; }
+
+ internal int MaxByteId { get; set; } // the maximum value of the byte id.;
+
+ internal int ByteCodeToIdOffset { get; set; } // offset of mapping byte code to the to the Ids.
+
+ internal int OneByteUtf8EncodingMaxId { get; set; } // the maximum value of the one byte UTF-8 character.
+
+ public IReadOnlyDictionary? SpecialTokens { get; }
+
+ public bool ByteFallback { get; }
+
+ public bool AddDummyPrefix { get; }
+
+ public bool EscapeWhiteSpaces { get; }
+
+ public bool TreatWhitespaceAsSuffix { get; internal set; }
+
+ public bool AddBeginningOfSentence { get; }
+
+ public bool AddEndOfSentence { get; }
+
+ public string BeginningOfSentenceToken { get; }
+
+ public string EndOfSentenceToken { get; }
+
+ public string UnknownToken { get; }
+
+ public int BeginningOfSentenceId { get; }
+
+ public int EndOfSentenceId { get; }
+
+ public int UnknownId { get; }
+
+ public SentencePieceNormalizer? Normalizer { get; }
+
+ public abstract IReadOnlyDictionary Vocabulary { get; }
+
+ public abstract IReadOnlyList EncodeToTokens(
+ string? text,
+ ReadOnlySpan textSpan,
+ out string? normalizedText,
+ bool addBeginningOfSentence,
+ bool addEndOfSentence,
+ bool considerNormalization);
+
+ public abstract IReadOnlyList EncodeToIds(
+ string? text,
+ ReadOnlySpan textSpan,
+ bool addBeginningOfSentence,
+ bool addEndOfSentence,
+ bool considerNormalization,
+ out string? normalizedText,
+ out int charsConsumed,
+ int maxTokenCount = int.MaxValue);
+
+ public abstract int CountTokens(
+ string? text,
+ ReadOnlySpan textSpan,
+ bool addBeginningOfSentence,
+ bool addEndOfSentence,
+ bool considerNormalization,
+ out string? normalizedText,
+ out int charsConsumed,
+ int maxTokenCount = int.MaxValue);
+
+ public abstract int GetIndexByTokenCountFromEnd(
+ string? text,
+ ReadOnlySpan textSpan,
+ bool addBeginningOfSentence,
+ bool addEndOfSentence,
+ int maxTokenCount,
+ bool considerNormalization,
+ out string? normalizedText,
+ out int tokenCount);
+
+ public abstract bool TryMapIdToToken(int id, out string? token);
+
+ private const int ApproximatedMaxEncodedBytesCount = 50;
+
+ public virtual string Decode(IEnumerable ids, bool considerSpecialTokens)
+ {
+ if (ids is null)
+ {
+ throw new ArgumentNullException(nameof(ids));
+ }
+
+ using IEnumerator enumerator = ids.GetEnumerator();
+ if (!enumerator.MoveNext())
+ {
+ return string.Empty;
+ }
+
+ ValueStringBuilder sb = new(stackalloc char[256]);
+
+ int bytesCount = -1;
+ byte[]? bytesPoolArray = null;
+ bool prefixRemoved = false;
+ int suffixIndex = -1;
+ char prefixSuffixChar = EscapeWhiteSpaces ? SentencePieceNormalizer.DummyPrefix : ' ';
+
+ int current = enumerator.Current;
+ if (current <= MaxByteId && ByteFallback)
+ {
+ // First token is a byte token.
+
+ while (current < ByteCodeToIdOffset)
+ {
+ // It is possible listing some special tokens before the byte tokens in the tokenizer's data.
+ TryDecodeAsSpecialToken(this, current, considerSpecialTokens, ref sb);
+
+ // Skip control tokens.
+ if (!enumerator.MoveNext())
+ {
+ return sb.ToString();
+ }
+
+ current = enumerator.Current;
+ }
+
+ if (current <= MaxByteId && ByteFallback)
+ {
+ EncodeByte(current, OneByteUtf8EncodingMaxId, ByteCodeToIdOffset, ref bytesCount, ref bytesPoolArray, ref sb);
+ }
+ else if (!TryDecodeAsSpecialToken(this, current, considerSpecialTokens, ref sb) && TryMapIdToToken(current, out string? token))
+ {
+ AppendTokenWithCheckingPrefix(AddDummyPrefix, TreatWhitespaceAsSuffix, token!, prefixSuffixChar, ref sb, ref prefixRemoved, ref suffixIndex);
+ }
+ }
+ else if (!TryDecodeAsSpecialToken(this, current, considerSpecialTokens, ref sb) && TryMapIdToToken(current, out string? token))
+ {
+ AppendTokenWithCheckingPrefix(AddDummyPrefix, TreatWhitespaceAsSuffix, token!, prefixSuffixChar, ref sb, ref prefixRemoved, ref suffixIndex);
+ }
+
+ char[]? charPoolArray = null;
+
+ while (enumerator.MoveNext())
+ {
+ current = enumerator.Current;
+ if (current < ByteCodeToIdOffset)
+ {
+ if (bytesCount >= 1)
+ {
+ FlushBytes(ref bytesCount, ref bytesPoolArray, ref charPoolArray, ref sb);
+ }
+
+ // It is possible listing some special tokens before the byte tokens in the tokenizer's data.
+ TryDecodeAsSpecialToken(this, current, considerSpecialTokens, ref sb);
+
+ continue;
+ }
+
+ if (current <= MaxByteId && ByteFallback)
+ {
+ if (bytesCount >= 1)
+ {
+ Debug.Assert(bytesPoolArray is not null);
+
+ if (bytesCount >= bytesPoolArray!.Length)
+ {
+ Helpers.ArrayPoolGrow(ref bytesPoolArray, bytesCount * 2);
+ }
+
+ bytesPoolArray![bytesCount++] = (byte)(current - ByteCodeToIdOffset);
+ }
+ else
+ {
+ EncodeByte(current, OneByteUtf8EncodingMaxId, ByteCodeToIdOffset, ref bytesCount, ref bytesPoolArray, ref sb);
+ }
+ }
+ else
+ {
+ if (bytesCount >= 1)
+ {
+ FlushBytes(ref bytesCount, ref bytesPoolArray, ref charPoolArray, ref sb);
+ }
+
+ if (!TryDecodeAsSpecialToken(this, current, considerSpecialTokens, ref sb) && TryMapIdToToken(current, out string? token))
+ {
+ AppendTokenWithCheckingPrefix(AddDummyPrefix, TreatWhitespaceAsSuffix, token!, prefixSuffixChar, ref sb, ref prefixRemoved, ref suffixIndex);
+ }
+ }
+ }
+
+ if (bytesCount >= 1)
+ {
+ FlushBytes(ref bytesCount, ref bytesPoolArray, ref charPoolArray, ref sb);
+ }
+
+ if (AddDummyPrefix && TreatWhitespaceAsSuffix && suffixIndex >= 0 && sb.Length > 0)
+ {
+ Debug.Assert(sb[suffixIndex] == SentencePieceNormalizer.DummyPrefix);
+ Debug.Assert(sb.Length > suffixIndex);
+
+ sb.Remove(suffixIndex, 1);
+ }
+
+ if (bytesPoolArray is not null)
+ {
+ ArrayPool.Shared.Return(bytesPoolArray);
+ }
+
+ if (charPoolArray is not null)
+ {
+ ArrayPool.Shared.Return(charPoolArray);
+ }
+
+ return EscapeWhiteSpaces ? sb.ToString(SentencePieceNormalizer.DummyPrefix, ' ') : sb.ToString();
+
+ static void FlushBytes(ref int bytesCount, ref byte[]? bytesPoolArray, ref char[]? charPoolArray, ref ValueStringBuilder sb)
+ {
+ Debug.Assert(bytesCount >= 1);
+ Debug.Assert(bytesPoolArray is not null);
+
+ int len = Encoding.UTF8.GetMaxCharCount(bytesCount);
+
+ charPoolArray ??= ArrayPool.Shared.Rent(Math.Max(len, ApproximatedMaxEncodedBytesCount >> 1));
+
+ if (len > charPoolArray.Length)
+ {
+ Helpers.ArrayPoolGrow(ref charPoolArray, len);
+ }
+
+ int charCount = Helpers.GetChars(bytesPoolArray.AsSpan(0, bytesCount), charPoolArray);
+
+ sb.Append(charPoolArray.AsSpan(0, charCount));
+ bytesCount = -1;
+ }
+
+ static void EncodeByte(int id, int oneByteUtf8EncodingMaxId, int byteCodeToIdOffset, ref int bytesCount, ref byte[]? bytesPoolArray, ref ValueStringBuilder sb)
+ {
+ if (id <= oneByteUtf8EncodingMaxId)
+ {
+ sb.Append((char)(id - byteCodeToIdOffset));
+ }
+ else
+ {
+ bytesCount = 1;
+ bytesPoolArray ??= ArrayPool.Shared.Rent(ApproximatedMaxEncodedBytesCount);
+ bytesPoolArray[0] = (byte)(id - byteCodeToIdOffset);
+ }
+ }
+
+ static void AppendTokenWithCheckingPrefix(bool addDummyPrefix, bool treatWhitespaceAsSuffix, string token, char prefixSuffixChar, ref ValueStringBuilder sb, ref bool prefixRemoved, ref int suffixIndex)
+ {
+ if (token.Length == 0)
+ {
+ return;
+ }
+
+ if (!addDummyPrefix)
+ {
+ sb.Append(token);
+ return;
+ }
+
+ if (treatWhitespaceAsSuffix)
+ {
+ sb.Append(token);
+ if (token[token.Length - 1] == prefixSuffixChar)
+ {
+ suffixIndex = sb.Length - 1;
+ }
+ }
+ else
+ {
+ sb.Append(!prefixRemoved && token[0] == prefixSuffixChar ? token.AsSpan(1) : token.AsSpan());
+ }
+
+ prefixRemoved = true;
+ }
+
+ static bool TryDecodeAsSpecialToken(SentencePieceBaseModel model, int id, bool considerSpecialTokens, ref ValueStringBuilder sb)
+ {
+ string? token = null;
+
+ if (id == model.BeginningOfSentenceId)
+ {
+ token = model.BeginningOfSentenceToken;
+ }
+ else if (id == model.EndOfSentenceId)
+ {
+ token = model.EndOfSentenceToken;
+ }
+ else if (id == model.UnknownId)
+ {
+ token = model.UnknownToken;
+ }
+ else if (model.SpecialTokensReverse?.TryGetValue(id, out string? specialToken) is true)
+ {
+ token = specialToken;
+ }
+
+ if (token is not null && considerSpecialTokens)
+ {
+ sb.Append(token);
+ }
+
+ return token is not null;
+ }
+ }
+
+ public virtual OperationStatus Decode(IEnumerable ids, Span destination, bool considerSpecialTokens, out int idsConsumed, out int charsWritten)
+ {
+ idsConsumed = 0;
+ charsWritten = 0;
+
+ if (ids is null)
+ {
+ throw new ArgumentNullException(nameof(ids));
+ }
+
+ using IEnumerator enumerator = ids.GetEnumerator();
+ if (!enumerator.MoveNext())
+ {
+ return OperationStatus.Done;
+ }
+
+ Span buffer = destination;
+
+ int bytesCount = -1;
+ byte[]? bytesPoolArray = null;
+ bool prefixRemoved = false;
+ int suffixIndex = -1;
+ char prefixSuffixChar = EscapeWhiteSpaces ? SentencePieceNormalizer.DummyPrefix : ' ';
+
+ int current = enumerator.Current;
+ if (current <= MaxByteId && ByteFallback)
+ {
+ // First token is a byte token.
+ while (current < ByteCodeToIdOffset)
+ {
+ OperationStatus status = TryDecodeAsSpecialToken(this, current, considerSpecialTokens, buffer, ref charsWritten, out bool isSpecialToken);
+ if (status != OperationStatus.Done)
+ {
+ return status;
+ }
+ buffer = destination.Slice(charsWritten);
+
+ // Skip control tokens.
+ idsConsumed++;
+ if (!enumerator.MoveNext())
+ {
+ return OperationStatus.Done;
+ }
+
+ current = enumerator.Current;
+ }
+
+ if (current <= MaxByteId && ByteFallback)
+ {
+ if (!EncodeByte(enumerator.Current, OneByteUtf8EncodingMaxId, ByteCodeToIdOffset, ref bytesCount, buffer, ref charsWritten, ref idsConsumed, ref bytesPoolArray))
+ {
+ return OperationStatus.DestinationTooSmall;
+ }
+ }
+ else
+ {
+ OperationStatus status = TryDecodeAsSpecialToken(this, current, considerSpecialTokens, buffer, ref charsWritten, out bool isSpecialToken);
+ if (status != OperationStatus.Done)
+ {
+ return status;
+ }
+
+ if (!isSpecialToken && TryMapIdToToken(current, out string? token))
+ {
+ if (!AppendTokenWithCheckingPrefix(AddDummyPrefix, TreatWhitespaceAsSuffix, token!, prefixSuffixChar, destination, ref prefixRemoved, ref suffixIndex, ref idsConsumed, ref charsWritten))
+ {
+ return OperationStatus.DestinationTooSmall;
+ }
+ }
+ else
+ {
+ idsConsumed++;
+ }
+ }
+ }
+ else
+ {
+ OperationStatus status = TryDecodeAsSpecialToken(this, current, considerSpecialTokens, buffer, ref charsWritten, out bool isSpecialToken);
+ if (status != OperationStatus.Done)
+ {
+ return status;
+ }
+
+ if (!isSpecialToken && TryMapIdToToken(current, out string? token))
+ {
+ if (!AppendTokenWithCheckingPrefix(AddDummyPrefix, TreatWhitespaceAsSuffix, token!, prefixSuffixChar, destination, ref prefixRemoved, ref suffixIndex, ref idsConsumed, ref charsWritten))
+ {
+ return OperationStatus.DestinationTooSmall;
+ }
+ }
+ else
+ {
+ idsConsumed++;
+ }
+ }
+
+ char[]? charPoolArray = null;
+
+ while (enumerator.MoveNext())
+ {
+ current = enumerator.Current;
+ buffer = destination.Slice(charsWritten);
+
+ if (current < ByteCodeToIdOffset)
+ {
+ if (bytesCount >= 1)
+ {
+ if (!FlushBytes(ref bytesCount, ref bytesPoolArray, ref charPoolArray, buffer, ref charsWritten, ref idsConsumed))
+ {
+ return OperationStatus.DestinationTooSmall;
+ }
+ }
+
+ OperationStatus status = TryDecodeAsSpecialToken(this, current, considerSpecialTokens, buffer, ref charsWritten, out bool isSpecialToken);
+ if (status != OperationStatus.Done)
+ {
+ return status;
+ }
+
+ idsConsumed++;
+ continue;
+ }
+
+ if (current <= MaxByteId && ByteFallback)
+ {
+ if (bytesCount >= 1)
+ {
+ Debug.Assert(bytesPoolArray is not null);
+
+ if (bytesCount >= bytesPoolArray!.Length)
+ {
+ Helpers.ArrayPoolGrow(ref bytesPoolArray, bytesCount * 2);
+ }
+
+ bytesPoolArray![bytesCount++] = (byte)(current - ByteCodeToIdOffset);
+ }
+ else
+ {
+ if (!EncodeByte(current, OneByteUtf8EncodingMaxId, ByteCodeToIdOffset, ref bytesCount, buffer, ref charsWritten, ref idsConsumed, ref bytesPoolArray))
+ {
+ return OperationStatus.DestinationTooSmall;
+ }
+ }
+ }
+ else
+ {
+ if (bytesCount >= 1)
+ {
+ if (!FlushBytes(ref bytesCount, ref bytesPoolArray, ref charPoolArray, buffer, ref charsWritten, ref idsConsumed))
+ {
+ return OperationStatus.DestinationTooSmall;
+ }
+ }
+
+ OperationStatus status = TryDecodeAsSpecialToken(this, current, considerSpecialTokens, buffer, ref charsWritten, out bool isSpecialToken);
+ if (status != OperationStatus.Done)
+ {
+ return status;
+ }
+
+ if (!isSpecialToken && TryMapIdToToken(current, out string? token))
+ {
+ if (!AppendTokenWithCheckingPrefix(AddDummyPrefix, TreatWhitespaceAsSuffix, token!, prefixSuffixChar, destination, ref prefixRemoved, ref suffixIndex, ref idsConsumed, ref charsWritten))
+ {
+ return OperationStatus.DestinationTooSmall;
+ }
+ }
+ else
+ {
+ idsConsumed++;
+ }
+ }
+ }
+
+ buffer = destination.Slice(charsWritten);
+
+ if (bytesCount >= 1)
+ {
+ if (!FlushBytes(ref bytesCount, ref bytesPoolArray, ref charPoolArray, buffer, ref charsWritten, ref idsConsumed))
+ {
+ return OperationStatus.DestinationTooSmall;
+ }
+ }
+
+ if (suffixIndex >= 0)
+ {
+ Debug.Assert(destination[suffixIndex] == ' ');
+
+ if (suffixIndex < charsWritten - 1)
+ {
+ destination.Slice(suffixIndex + 1, charsWritten - suffixIndex - 1).CopyTo(destination.Slice(suffixIndex));
+ }
+
+ charsWritten--;
+ }
+
+ if (bytesPoolArray is not null)
+ {
+ ArrayPool.Shared.Return(bytesPoolArray);
+ }
+
+ if (charPoolArray is not null)
+ {
+ ArrayPool.Shared.Return(charPoolArray);
+ }
+
+ return OperationStatus.Done;
+
+ static OperationStatus TryDecodeAsSpecialToken(SentencePieceBaseModel model, int id, bool considerSpecialTokens, Span buffer, ref int charsWritten, out bool isSpecialToken)
+ {
+ string? specialToken = null;
+
+ if (id == model.BeginningOfSentenceId)
+ {
+ specialToken = model.BeginningOfSentenceToken;
+ }
+ else if (id == model.EndOfSentenceId)
+ {
+ specialToken = model.EndOfSentenceToken;
+ }
+ else if (id == model.UnknownId)
+ {
+ specialToken = model.UnknownToken;
+ }
+ else
+ {
+ model.SpecialTokensReverse?.TryGetValue(id, out specialToken);
+ }
+
+ isSpecialToken = specialToken is not null;
+
+ if (considerSpecialTokens && isSpecialToken)
+ {
+ if (buffer.Length < specialToken!.Length)
+ {
+ return OperationStatus.DestinationTooSmall;
+ }
+
+ specialToken.AsSpan().CopyTo(buffer);
+ charsWritten += specialToken.Length;
+ }
+
+ return OperationStatus.Done;
+ }
+
+ static bool FlushBytes(ref int bytesCount, ref byte[]? bytesPoolArray, ref char[]? charPoolArray, Span buffer, ref int charsWritten, ref int idsConsumed)
+ {
+ Debug.Assert(bytesCount >= 1);
+ Debug.Assert(bytesPoolArray is not null);
+
+ int len = Encoding.UTF8.GetMaxCharCount(bytesCount);
+
+ charPoolArray ??= ArrayPool.Shared.Rent(Math.Max(len, ApproximatedMaxEncodedBytesCount >> 1));
+
+ if (len > charPoolArray.Length)
+ {
+ Helpers.ArrayPoolGrow(ref charPoolArray, len);
+ }
+
+ int charCount = Helpers.GetChars(bytesPoolArray.AsSpan(0, bytesCount), charPoolArray);
+
+ if (charCount > buffer.Length)
+ {
+ return false;
+ }
+
+ charPoolArray.AsSpan(0, charCount).CopyTo(buffer);
+ charsWritten += charCount;
+ idsConsumed += bytesCount;
+ bytesCount = -1;
+
+ return true;
+ }
+
+ static bool EncodeByte(int id, int oneByteUtf8EncodingMaxId, int byteCodeToIdOffset, ref int bytesCount, Span buffer, ref int charsWritten, ref int idsConsumed, ref byte[]? bytesPoolArray)
+ {
+ if (id <= oneByteUtf8EncodingMaxId)
+ {
+ if (buffer.Length < 1)
+ {
+ return false;
+ }
+
+ buffer[0] = (char)(id - byteCodeToIdOffset);
+ charsWritten++;
+ idsConsumed++;
+ }
+ else
+ {
+ bytesCount = 1;
+ bytesPoolArray ??= ArrayPool.Shared.Rent(ApproximatedMaxEncodedBytesCount);
+ bytesPoolArray[0] = (byte)(id - byteCodeToIdOffset);
+ }
+
+ return true;
+ }
+
+ static bool AppendTokenWithCheckingPrefix(bool addDummyPrefix, bool treatWhitespaceAsSuffix, string token, char prefixSuffixChar, Span destination, ref bool prefixRemoved, ref int suffixIndex, ref int idsConsumed, ref int charsConsumed)
+ {
+ if (token.Length == 0)
+ {
+ return true;
+ }
+
+ Span buffer = destination.Slice(charsConsumed);
+
+ ReadOnlySpan tokenSpan = token.AsSpan();
+
+ if (!addDummyPrefix)
+ {
+ if (tokenSpan.Length > buffer.Length)
+ {
+ return false;
+ }
+
+ if (prefixSuffixChar != ' ')
+ {
+ Helpers.Replace(tokenSpan, buffer, prefixSuffixChar, ' ');
+ }
+ else
+ {
+ tokenSpan.CopyTo(buffer);
+ }
+
+ buffer = buffer.Slice(tokenSpan.Length);
+ charsConsumed += tokenSpan.Length;
+ idsConsumed++;
+ return true;
+ }
+
+ if (treatWhitespaceAsSuffix)
+ {
+ if (tokenSpan[tokenSpan.Length - 1] == prefixSuffixChar)
+ {
+ suffixIndex = charsConsumed + tokenSpan.Length - 1;
+ }
+
+ if (tokenSpan.Length > buffer.Length)
+ {
+ return false;
+ }
+
+ if (prefixSuffixChar != ' ')
+ {
+ Helpers.Replace(tokenSpan, buffer, prefixSuffixChar, ' ');
+ }
+ else
+ {
+ tokenSpan.CopyTo(buffer);
+ }
+
+ charsConsumed += tokenSpan.Length;
+
+ idsConsumed++;
+ }
+ else
+ {
+ int delta = !prefixRemoved && token[0] == prefixSuffixChar ? 1 : 0;
+ if (buffer.Length < token.Length - delta)
+ {
+ return false;
+ }
+
+ tokenSpan = tokenSpan.Slice(delta);
+ if (prefixSuffixChar != ' ')
+ {
+ Helpers.Replace(tokenSpan, buffer, prefixSuffixChar, ' ');
+ }
+ else
+ {
+ tokenSpan.CopyTo(buffer);
+ }
+
+ charsConsumed += tokenSpan.Length;
+ idsConsumed++;
+
+ if (!prefixRemoved && delta == 1)
+ {
+ prefixRemoved = true;
+ }
+ }
+
+ return true;
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceBpeModel.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceBpeModel.cs
new file mode 100644
index 0000000000..85c85c1677
--- /dev/null
+++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceBpeModel.cs
@@ -0,0 +1,1260 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Sentencepiece;
+using System;
+using System.Buffers;
+using System.Collections.Generic;
+using System.Collections.ObjectModel;
+using System.Diagnostics;
+using System.IO;
+using System.Linq;
+using System.Text;
+using System.Text.RegularExpressions;
+using System.Threading;
+
+namespace Microsoft.ML.Tokenizers
+{
+ internal sealed class SentencePieceBpeModel : SentencePieceBaseModel
+ {
+ private const int UninitializedId = -2; // indicate if the symbol contains uninitialized id.
+ private readonly Dictionary _vocab = new();
+ private readonly Dictionary _vocabReverse = new();
+ private IReadOnlyDictionary? _publicVocab;
+
+ internal SentencePieceBpeModel(ModelProto modelProto, bool addBos, bool addEos, IReadOnlyDictionary? specialTokens = null) : base(modelProto, addBos, addEos, specialTokens)
+ {
+ for (int i = 0; i < modelProto.Pieces.Count; i++)
+ {
+ var piece = modelProto.Pieces[i];
+ _vocab.Add(new StringSpanOrdinalKey(piece.Piece), (i, piece.Score, (byte)piece.Type));
+ _vocabReverse.Add(i, piece.Piece);
+
+ if (piece.Type == ModelProto.Types.SentencePiece.Types.Type.Byte)
+ {
+ MaxByteId = i;
+ }
+ }
+
+ ByteCodeToIdOffset = _vocab.TryGetValue("<0x00>", out (int Id, float Score, byte Type) value) ? value.Id : MaxByteId;
+ OneByteUtf8EncodingMaxId = ByteCodeToIdOffset + 0x7F; // 0x7F is the maximum value of the one byte UTF-8 character.
+ }
+
+ public override IReadOnlyDictionary Vocabulary
+ {
+ get
+ {
+ IReadOnlyDictionary? publicVocab = Volatile.Read(ref _publicVocab);
+ if (publicVocab is null)
+ {
+ var vocab = new Dictionary();
+ foreach (var item in _vocab)
+ {
+ vocab.Add(item.Key.ToString(), item.Value.Id);
+ }
+
+ Interlocked.CompareExchange(ref _publicVocab, new ReadOnlyDictionary(vocab), null);
+ publicVocab = _publicVocab;
+ }
+
+ return publicVocab;
+ }
+ }
+
+ public override bool TryMapIdToToken(int id, out string? token) => _vocabReverse.TryGetValue(id, out token);
+
+ public override IReadOnlyList EncodeToTokens(string? text, ReadOnlySpan textSpan, out string? normalizedText, bool addBeginningOfSentence, bool addEndOfSentence, bool considerNormalization)
+ {
+ if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
+ {
+ normalizedText = null;
+ return [];
+ }
+
+ ReadOnlySpan textToEncode = text is null ? textSpan : text.AsSpan();
+ if (considerNormalization && Normalizer is not null)
+ {
+ normalizedText = text is not null ? Normalizer.Normalize(text) : Normalizer.Normalize(textSpan);
+ textToEncode = normalizedText.AsSpan();
+ }
+ else
+ {
+ normalizedText = null;
+ }
+
+ if (textToEncode.Length == 0)
+ {
+ return [];
+ }
+
+ List tokens = new();
+
+ if (SpecialTokensRegex is not null)
+ {
+ EncodeWithSpecialTokens(textToEncode, addBeginningOfSentence, addEndOfSentence, tokens);
+ }
+ else
+ {
+ EncodeInternal(textToEncode, addBeginningOfSentence, addEndOfSentence, tokens);
+ }
+
+ return tokens;
+ }
+
+ private void EncodeWithSpecialTokens(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, List tokens)
+ {
+ Debug.Assert(SpecialTokensRegex is not null);
+
+ if (addBeginOfSentence)
+ {
+ tokens.Add(new EncodedToken(BeginningOfSentenceId, BeginningOfSentenceToken, new Range(0, 0)));
+ }
+
+ int currentOffset = 0;
+
+ foreach ((int Offset, int Length) in PreTokenizer.SplitText(text, SpecialTokensRegex!))
+ {
+ if (Offset > currentOffset)
+ {
+ EncodeInternal(text.Slice(currentOffset, Offset - currentOffset), addBeginOfSentence: false, addEndOfSentence: false, tokens);
+ }
+
+ if (InternalSpecialTokens!.TryGetValue(text.Slice(Offset, Length), out int id))
+ {
+ tokens.Add(new EncodedToken(id, SpecialTokensReverse![id], new Range(Offset, Offset + Length)));
+ }
+
+ currentOffset = Offset + Length;
+ }
+
+ if (currentOffset < text.Length)
+ {
+ EncodeInternal(text.Slice(currentOffset), addBeginOfSentence: false, addEndOfSentence: false, tokens);
+ }
+
+ if (addEndOfSentence)
+ {
+ tokens.Add(new EncodedToken(EndOfSentenceId, EndOfSentenceToken, new Range(text.Length, text.Length)));
+ }
+ }
+
+ ///
+ /// Encode a text to a list of tokens.
+ ///
+ /// The text to encode.
+ /// Indicate emitting the beginning of sentence token during the encoding.
+ /// Indicate emitting the end of sentence token during the encoding.
+ /// A collection to store the encoded tokens.
+ /// The input text has to be normalized before calling this method.
+ private void EncodeInternal(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, List tokens)
+ {
+ BpeSymbol[] symbols = ArrayPool.Shared.Rent(text.Length);
+
+ Dictionary<(int Index, int Len), (int LeftIndex, int LeftLen, int RightIndex, int RightLen)>? revMerge = Encode(text, symbols);
+
+ if (addBeginOfSentence)
+ {
+ tokens.Add(new EncodedToken(BeginningOfSentenceId, BeginningOfSentenceToken, new Range(0, 0)));
+ }
+
+ for (int index = 0; (uint)index < (uint)symbols.Length; index = symbols[index].next)
+ {
+ int id = symbols[index].id;
+ byte type = symbols[index].type;
+
+ if (id == UninitializedId)
+ {
+ if (_vocab.TryGetValue(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), out (int Id, float Score, byte Type) tokenInfo))
+ {
+ id = tokenInfo.Id;
+ type = tokenInfo.Type;
+ }
+ else
+ {
+ id = UnknownId;
+ type = 0;
+ }
+ }
+
+ if (type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused)
+ {
+ if (id == UnknownId && ByteFallback)
+ {
+ EncodeAsBytes(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index);
+ }
+ else
+ {
+ tokens.Add(new EncodedToken(
+ id,
+ GetTokenString(id, symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length, text),
+ new Range(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Index + symbols[index].pieceSpan.Length)));
+ }
+ continue;
+ }
+
+ Segment(symbols[index].pieceSpan, text);
+ }
+
+ ArrayPool.Shared.Return(symbols);
+
+ if (addEndOfSentence)
+ {
+ tokens.Add(new EncodedToken(EndOfSentenceId, EndOfSentenceToken, new Range(text.Length, text.Length)));
+ }
+
+ return;
+
+ // Encode the Unknown token to bytes.
+ void EncodeAsBytes(ReadOnlySpan text, int index)
+ {
+ for (int i = 0; i < text.Length; i++)
+ {
+ char c = text[i];
+ if (c <= 0x7F)
+ {
+ int id = (int)c + ByteCodeToIdOffset; // byte code is mapped to the to the Ids starting from 4.
+
+ if (_vocabReverse.TryGetValue(id, out string? token))
+ {
+ tokens.Add(new EncodedToken(id, token, new Range(index + i, index + i + 1)));
+ }
+ }
+ else
+ {
+ Span utf8Bytes = stackalloc byte[256];
+ byte[]? arrayPoolArray = null;
+
+ int len = Encoding.UTF8.GetMaxByteCount(text.Length - i);
+ if (len > utf8Bytes.Length)
+ {
+ arrayPoolArray = ArrayPool.Shared.Rent(len);
+ utf8Bytes = arrayPoolArray;
+ }
+
+ // Need to convert the text into UTF-8 bytes and then encode the bytes.
+ int bytesWritten = Helpers.GetUtf8Bytes(text.Slice(i), utf8Bytes);
+ int length = text.Length - i;
+ for (int j = 0; j < bytesWritten; j++)
+ {
+ int id = (int)utf8Bytes[j] + ByteCodeToIdOffset; // byte code is mapped to the to the Ids starting from 4.
+
+ if (_vocabReverse.TryGetValue(id, out string? token))
+ {
+ tokens.Add(new EncodedToken(id, token, new Range(index + i, index + i + length)));
+ }
+
+ length = 0;
+ }
+
+ if (arrayPoolArray is not null)
+ {
+ ArrayPool.Shared.Return(arrayPoolArray);
+ }
+
+ break;
+ }
+ }
+ }
+
+ void Segment((int Index, int Length) pieceSpan, ReadOnlySpan text)
+ {
+ if (!_vocab.TryGetValue(text.Slice(pieceSpan.Index, pieceSpan.Length), out (int Id, float Score, byte Type) id))
+ {
+ EncodeAsBytes(text.Slice(pieceSpan.Index, pieceSpan.Length), pieceSpan.Index);
+ return;
+ }
+
+ if (id.Type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused ||
+ revMerge is null ||
+ !revMerge.TryGetValue((pieceSpan.Index, pieceSpan.Length), out (int LeftIndex, int LeftLen, int RightIndex, int RightLen) merge))
+ {
+ tokens.Add(new EncodedToken(id.Id, text.Slice(pieceSpan.Index, pieceSpan.Length).ToString(), new Range(pieceSpan.Index, pieceSpan.Index + pieceSpan.Length)));
+ return;
+ }
+
+ Segment((merge.LeftIndex, merge.LeftLen), text);
+ Segment((merge.RightIndex, merge.RightLen), text);
+ }
+ }
+
+ public override IReadOnlyList EncodeToIds(string? text, ReadOnlySpan textSpan, bool addBeginningOfSentence, bool addEndOfSentence, bool considerNormalization,
+ out string? normalizedText, out int charsConsumed, int maxTokenCount = int.MaxValue)
+ {
+ if (maxTokenCount <= 0)
+ {
+ throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero.");
+ }
+
+ if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
+ {
+ normalizedText = null;
+ charsConsumed = 0;
+ return [];
+ }
+
+ return EncodeToIds(text is null ? textSpan : text.AsSpan(), addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedText, out charsConsumed, maxTokenCount);
+ }
+
+ ///
+ /// Encodes input text to token Ids up to maximum number of tokens.
+ ///
+ /// The text to encode.
+ /// Indicate emitting the beginning of sentence token during the encoding.
+ /// Indicate emitting the end of sentence token during the encoding.
+ /// Indicate whether to consider normalization before tokenization.
+ /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
+ /// The length of the text that encompasses the maximum encoded tokens.
+ /// The maximum number of tokens to encode.
+ /// The list of encoded Ids.
+ private IReadOnlyList EncodeToIds(ReadOnlySpan text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerNormalization,
+ out string? normalizedText, out int charsConsumed, int maxTokenCount = int.MaxValue)
+ {
+ if (maxTokenCount <= 0)
+ {
+ throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero.");
+ }
+
+ if (text.IsEmpty)
+ {
+ normalizedText = null;
+ charsConsumed = 0;
+ return [];
+ }
+
+ ReadOnlySpan textToEncode;
+
+ if (considerNormalization && Normalizer is not null)
+ {
+ normalizedText = Normalizer.Normalize(text);
+ textToEncode = normalizedText.AsSpan();
+ }
+ else
+ {
+ normalizedText = null;
+ textToEncode = text;
+ }
+
+ if (maxTokenCount <= 0)
+ {
+ throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than 0.");
+ }
+
+ List ids = new();
+
+ if (SpecialTokensRegex is not null)
+ {
+ EncodeToIdsWithAddedToken(textToEncode, addBeginningOfSentence, addEndOfSentence, ids, out charsConsumed, maxTokenCount);
+ }
+ else
+ {
+ EncodeToIds(textToEncode, addBeginningOfSentence, addEndOfSentence, ids, out charsConsumed, maxTokenCount);
+ }
+
+ return ids;
+ }
+
+ private int EncodeToIdsWithAddedToken(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, IList accumulatedIds, out int charsConsumed, int maxTokens = int.MaxValue)
+ {
+ Debug.Assert(SpecialTokensRegex is not null);
+ Debug.Assert(maxTokens > 0);
+
+ charsConsumed = 0;
+ int idsCount = 0;
+
+ if (addBeginOfSentence)
+ {
+ accumulatedIds.Add(BeginningOfSentenceId);
+ idsCount++;
+ }
+
+ int currentOffset = 0;
+
+ int charsWritten;
+
+ foreach ((int Offset, int Length) in PreTokenizer.SplitText(text, SpecialTokensRegex!))
+ {
+ if (Offset > currentOffset)
+ {
+ idsCount += EncodeToIds(text.Slice(currentOffset, Offset - currentOffset), addBeginOfSentence: false, addEndOfSentence: false, accumulatedIds, out charsWritten, maxTokens - idsCount);
+ charsConsumed += charsWritten;
+ }
+
+ if (idsCount < maxTokens && InternalSpecialTokens!.TryGetValue(text.Slice(Offset, Length), out int id))
+ {
+ accumulatedIds.Add(id);
+ idsCount++;
+ charsConsumed += Length;
+ }
+
+ currentOffset = Offset + Length;
+ }
+
+ if (currentOffset < text.Length && idsCount < maxTokens)
+ {
+ idsCount += EncodeToIds(text.Slice(currentOffset), addBeginOfSentence: false, addEndOfSentence: false, accumulatedIds, out charsWritten, maxTokens - idsCount);
+ charsConsumed += charsWritten;
+ }
+
+ if (addEndOfSentence && idsCount < maxTokens)
+ {
+ accumulatedIds.Add(EndOfSentenceId);
+ idsCount++;
+ }
+
+ return idsCount;
+ }
+
+ ///
+ /// Encode a text to a list of Ids and add them to the accumulatedIds list.
+ ///
+ /// The text to encode.
+ /// Indicate emitting the beginning of sentence token during the encoding.
+ /// Indicate emitting the end of sentence token during the encoding.
+ /// The list of accumulated encoded Ids.
+ /// The length of the text that encompasses the maximum encoded tokens.
+ /// The maximum number of tokens to encode.
+ /// The number of tokens that the input text will be encoded to.
+ /// The input text has to be normalized before calling this method.
+ private int EncodeToIds(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, IList accumulatedIds, out int charsConsumed, int maxTokens = int.MaxValue)
+ {
+ charsConsumed = 0;
+ if (text.IsEmpty)
+ {
+ return 0;
+ }
+
+ int idsCount = 0;
+
+ if (addBeginOfSentence)
+ {
+ accumulatedIds.Add(BeginningOfSentenceId);
+ idsCount++;
+ }
+
+ BpeSymbol[] symbols = ArrayPool.Shared.Rent(text.Length);
+
+ Dictionary<(int Index, int Len), (int LeftIndex, int LeftLen, int RightIndex, int RightLen)>? revMerge = Encode(text, symbols);
+
+ for (int index = 0; index != -1 && index < symbols.Length; index = symbols[index].next)
+ {
+ int id = symbols[index].id;
+ byte type = symbols[index].type;
+
+ if (id == UninitializedId)
+ {
+ if (_vocab.TryGetValue(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), out (int Id, float Score, byte Type) tokenInfo))
+ {
+ id = tokenInfo.Id;
+ type = tokenInfo.Type;
+ }
+ else
+ {
+ id = UnknownId;
+ type = 0;
+ }
+ }
+
+ if (type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused)
+ {
+ if (id == UnknownId && ByteFallback)
+ {
+ if (!EncodeAsBytes(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index, ref charsConsumed))
+ {
+ ArrayPool.Shared.Return(symbols);
+ return idsCount;
+ }
+ }
+ else
+ {
+ if (idsCount < maxTokens)
+ {
+ accumulatedIds.Add(id);
+ charsConsumed += symbols[index].pieceSpan.Length;
+ idsCount++;
+ }
+ else
+ {
+ ArrayPool.Shared.Return(symbols);
+ return idsCount;
+ }
+ }
+ continue;
+ }
+
+ if (!Segment(symbols[index].pieceSpan, text, ref charsConsumed))
+ {
+ break;
+ }
+ }
+
+ ArrayPool.Shared.Return(symbols);
+
+ if (addEndOfSentence)
+ {
+ if (idsCount < maxTokens)
+ {
+ accumulatedIds.Add(EndOfSentenceId);
+ idsCount++;
+ }
+ }
+
+ return idsCount;
+
+ // Encode the Unknown token to bytes.
+ bool EncodeAsBytes(ReadOnlySpan text, int index, ref int charsConsumed)
+ {
+ for (int i = 0; i < text.Length; i++)
+ {
+ char c = text[i];
+ if (c <= 0x7F)
+ {
+ if (idsCount < maxTokens)
+ {
+ charsConsumed++;
+ accumulatedIds.Add((int)c + ByteCodeToIdOffset); // byte code is mapped to the to the Ids starting from 4.
+ idsCount++;
+ }
+ else
+ {
+ return false;
+ }
+ }
+ else
+ {
+ Span utf8Bytes = stackalloc byte[100];
+ byte[]? arrayPoolArray = null;
+
+ int len = Encoding.UTF8.GetMaxByteCount(text.Length - i);
+ if (len > utf8Bytes.Length)
+ {
+ arrayPoolArray = ArrayPool.Shared.Rent(len);
+ utf8Bytes = arrayPoolArray;
+ }
+
+ // Need to convert the text into UTF-8 bytes and then encode the bytes.
+ int bytesWritten = Helpers.GetUtf8Bytes(text.Slice(i), utf8Bytes);
+
+ bool ret;
+ if (idsCount + bytesWritten <= maxTokens)
+ {
+ for (int j = 0; j < bytesWritten; j++)
+ {
+ accumulatedIds.Add((int)utf8Bytes[j] + ByteCodeToIdOffset); // byte code is mapped to the to the Ids starting from 4.
+ }
+
+ charsConsumed += text.Length - i;
+ ret = true;
+ }
+ else
+ {
+ ret = false;
+ }
+
+ if (arrayPoolArray is not null)
+ {
+ ArrayPool.Shared.Return(arrayPoolArray);
+ }
+
+ return ret;
+ }
+ }
+
+ return true;
+ }
+
+ bool Segment((int Index, int Length) pieceSpan, ReadOnlySpan text, ref int charsConsumed)
+ {
+ if (!_vocab.TryGetValue(text.Slice(pieceSpan.Index, pieceSpan.Length), out (int Id, float Score, byte Type) id))
+ {
+ return EncodeAsBytes(text.Slice(pieceSpan.Index, pieceSpan.Length), pieceSpan.Index, ref charsConsumed);
+ }
+
+ if (id.Type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused ||
+ revMerge is null ||
+ !revMerge.TryGetValue((pieceSpan.Index, pieceSpan.Length), out (int LeftIndex, int LeftLen, int RightIndex, int RightLen) merge))
+ {
+ if (idsCount < maxTokens)
+ {
+ accumulatedIds.Add(id.Id);
+ charsConsumed += pieceSpan.Length;
+ idsCount++;
+ return true;
+ }
+ else
+ {
+ return false;
+ }
+ }
+
+ return Segment((merge.LeftIndex, merge.LeftLen), text, ref charsConsumed) && Segment((merge.RightIndex, merge.RightLen), text, ref charsConsumed);
+ }
+ }
+
+ public override int CountTokens(
+ string? text,
+ ReadOnlySpan textSpan,
+ bool addBeginningOfSentence,
+ bool addEndOfSentence,
+ bool considerNormalization,
+ out string? normalizedText,
+ out int charsConsumed,
+ int maxTokenCount = int.MaxValue)
+ {
+ if (maxTokenCount <= 0)
+ {
+ throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero.");
+ }
+
+ textSpan = text is null ? textSpan : text.AsSpan();
+
+ if (textSpan.IsEmpty)
+ {
+ normalizedText = null;
+ charsConsumed = 0;
+ return 0;
+ }
+
+ ReadOnlySpan textToEncode;
+ if (considerNormalization && Normalizer is not null)
+ {
+ normalizedText = Normalizer.Normalize(textSpan);
+ textToEncode = normalizedText.AsSpan();
+ }
+ else
+ {
+ normalizedText = null;
+ textToEncode = textSpan;
+ }
+
+ return SpecialTokensRegex is not null ?
+ CountTokensWithSpecialTokens(textToEncode, addBeginningOfSentence, addEndOfSentence, out charsConsumed, maxTokenCount) :
+ CountTokens(textToEncode, addBeginningOfSentence, addEndOfSentence, out charsConsumed, maxTokenCount);
+ }
+
+ private int CountTokensWithSpecialTokens(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, out int charsConsumed, int maxTokens = int.MaxValue)
+ {
+ Debug.Assert(SpecialTokensRegex is not null);
+ Debug.Assert(maxTokens > 0);
+
+ charsConsumed = 0;
+ int idsCount = 0;
+
+ if (addBeginOfSentence)
+ {
+ idsCount++;
+ }
+
+ int currentOffset = 0;
+
+ int charsWritten;
+
+ foreach ((int Offset, int Length) in PreTokenizer.SplitText(text, SpecialTokensRegex!))
+ {
+ if (Offset > currentOffset)
+ {
+ idsCount += CountTokens(text.Slice(currentOffset, Offset - currentOffset), addBeginOfSentence: false, addEndOfSentence: false, out charsWritten, maxTokens - idsCount);
+ charsConsumed += charsWritten;
+ }
+
+ if (idsCount < maxTokens && InternalSpecialTokens!.TryGetValue(text.Slice(Offset, Length), out int id))
+ {
+ idsCount++;
+ charsConsumed += Length;
+ }
+
+ currentOffset = Offset + Length;
+ }
+
+ if (currentOffset < text.Length && idsCount < maxTokens)
+ {
+ idsCount += CountTokens(text.Slice(currentOffset), addBeginOfSentence: false, addEndOfSentence: false, out charsWritten, maxTokens - idsCount);
+ charsConsumed += charsWritten;
+ }
+
+ if (addEndOfSentence && idsCount < maxTokens)
+ {
+ idsCount++;
+ }
+
+ return idsCount;
+ }
+
+ ///
+ /// Get the number of tokens that the input text will be encoded to.
+ ///
+ /// The text to encode.
+ /// Indicate emitting the beginning of sentence token during the encoding.
+ /// Indicate emitting the end of sentence token during the encoding.
+ /// The length of the text that encompasses the maximum encoded tokens.
+ /// The maximum number of tokens to encode.
+ /// The number of tokens that the input text will be encoded to.
+ /// The input text has to be normalized before calling this method.
+ private int CountTokens(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, out int charsConsumed, int maxTokens = int.MaxValue)
+ {
+ charsConsumed = 0;
+ if (text.IsEmpty)
+ {
+ return 0;
+ }
+
+ int tokenCount = addBeginOfSentence ? 1 : 0;
+
+ BpeSymbol[] symbols = ArrayPool.Shared.Rent(text.Length);
+
+ Dictionary<(int Index, int Len), (int LeftIndex, int LeftLen, int RightIndex, int RightLen)>? revMerge = Encode(text, symbols);
+
+ for (int index = 0; index != -1 && index < symbols.Length; index = symbols[index].next)
+ {
+ int id = symbols[index].id;
+ byte type = symbols[index].type;
+
+ if (id == UninitializedId)
+ {
+ if (_vocab.TryGetValue(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), out (int Id, float Score, byte Type) tokenInfo))
+ {
+ id = tokenInfo.Id;
+ type = tokenInfo.Type;
+ }
+ else
+ {
+ id = UnknownId;
+ type = 0;
+ }
+ }
+
+ if (type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused)
+ {
+ if (id == UnknownId && ByteFallback)
+ {
+ if (!EncodeAsBytes(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index, ref charsConsumed))
+ {
+ break;
+ }
+ }
+ else
+ {
+ if (tokenCount < maxTokens)
+ {
+ tokenCount++;
+ charsConsumed += symbols[index].pieceSpan.Length;
+ }
+ else
+ {
+ break;
+ }
+ }
+ continue;
+ }
+
+ if (!Segment(symbols[index].pieceSpan, text, ref charsConsumed))
+ {
+ break;
+ }
+ }
+
+ ArrayPool.Shared.Return(symbols);
+
+ if (addEndOfSentence)
+ {
+ if (tokenCount < maxTokens)
+ {
+ tokenCount++;
+ }
+ }
+
+ return tokenCount;
+
+ // Encode the Unknown token to bytes.
+ bool EncodeAsBytes(ReadOnlySpan text, int index, ref int charsConsumed)
+ {
+ for (int i = 0; i < text.Length; i++)
+ {
+ char c = text[i];
+ if (c <= 0x7F)
+ {
+ if (tokenCount < maxTokens)
+ {
+ tokenCount++;
+ charsConsumed++;
+ }
+ else
+ {
+ return false;
+ }
+ }
+ else
+ {
+ Span utf8Bytes = stackalloc byte[100];
+ byte[]? arrayPoolArray = null;
+
+ int len = Encoding.UTF8.GetMaxByteCount(text.Length - i);
+ if (len > utf8Bytes.Length)
+ {
+ arrayPoolArray = ArrayPool.Shared.Rent(len);
+ utf8Bytes = arrayPoolArray;
+ }
+
+ // Need to convert the text into UTF-8 bytes and then encode the bytes.
+ int encodedCount = Helpers.GetUtf8Bytes(text.Slice(i), utf8Bytes);
+ bool ret;
+
+ if (tokenCount + encodedCount <= maxTokens)
+ {
+ tokenCount += encodedCount;
+ charsConsumed += text.Length - i;
+ ret = true;
+ }
+ else
+ {
+ ret = false;
+ }
+
+ if (arrayPoolArray is not null)
+ {
+ ArrayPool.Shared.Return(arrayPoolArray);
+ }
+
+ return ret;
+ }
+ }
+
+ return true;
+ }
+
+ bool Segment((int Index, int Length) pieceSpan, ReadOnlySpan text, ref int charsConsumed)
+ {
+ if (!_vocab.TryGetValue(text.Slice(pieceSpan.Index, pieceSpan.Length), out (int Id, float Score, byte Type) id))
+ {
+ return EncodeAsBytes(text.Slice(pieceSpan.Index, pieceSpan.Length), pieceSpan.Index, ref charsConsumed);
+ }
+
+ if (id.Type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused ||
+ revMerge is null ||
+ !revMerge.TryGetValue((pieceSpan.Index, pieceSpan.Length), out (int LeftIndex, int LeftLen, int RightIndex, int RightLen) merge))
+ {
+ if (tokenCount < maxTokens)
+ {
+ tokenCount++;
+ charsConsumed += pieceSpan.Length;
+ return true;
+ }
+ else
+ {
+ return false;
+ }
+ }
+
+ return Segment((merge.LeftIndex, merge.LeftLen), text, ref charsConsumed) && Segment((merge.RightIndex, merge.RightLen), text, ref charsConsumed);
+ }
+ }
+
+ public override int GetIndexByTokenCountFromEnd(string? text, ReadOnlySpan textSpan, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, bool considerNormalization, out string? normalizedText, out int tokenCount)
+ {
+ if (maxTokenCount <= 0)
+ {
+ throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The max token count must be greater than 0.");
+ }
+
+ textSpan = text is null ? textSpan : text.AsSpan();
+
+ if (textSpan.IsEmpty)
+ {
+ normalizedText = null;
+ tokenCount = 0;
+ return 0;
+ }
+
+ ReadOnlySpan textToEncode;
+ if (considerNormalization && Normalizer is not null)
+ {
+ normalizedText = Normalizer.Normalize(textSpan);
+ textToEncode = normalizedText.AsSpan();
+ }
+ else
+ {
+ normalizedText = null;
+ textToEncode = textSpan;
+ }
+
+ int textIndex;
+ if (SpecialTokensRegex is not null)
+ {
+ tokenCount = CountTokensFromEndWithSpecialTokens(textToEncode, addBeginningOfSentence, addEndOfSentence, out textIndex, maxTokenCount);
+ }
+ else
+ {
+ tokenCount = CountTokensFromEnd(textToEncode, addBeginningOfSentence, addEndOfSentence, out textIndex, maxTokenCount);
+ }
+
+ return textIndex;
+ }
+
+ private int CountTokensFromEndWithSpecialTokens(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, out int textIndex, int maxTokens)
+ {
+ Debug.Assert(SpecialTokensRegex is not null);
+ Debug.Assert(maxTokens > 0);
+ Debug.Assert(text.Length > 0);
+
+ textIndex = text.Length;
+ int idsCount = 0;
+
+ if (addEndOfSentence)
+ {
+ idsCount++;
+ }
+
+ (int Offset, int Length)[] splits = PreTokenizer.SplitText(text, SpecialTokensRegex!).ToArray();
+
+ if (splits.Length == 0)
+ {
+ return CountTokensFromEnd(text, addBeginOfSentence, addEndOfSentence, out textIndex, maxTokens);
+ }
+
+ (int Offset, int Length) current = splits[splits.Length - 1];
+
+ int splitTextIndex;
+ ReadOnlySpan splitText;
+
+ if (current.Offset + current.Length < text.Length)
+ {
+ splitText = text.Slice(current.Offset + current.Length);
+ idsCount += CountTokensFromEnd(splitText, addBeginOfSentence: false, addEndOfSentence: false, out splitTextIndex, maxTokens - idsCount);
+ textIndex -= splitText.Length - splitTextIndex;
+ }
+
+ for (int i = splits.Length - 1; i >= 0 && idsCount < maxTokens; i--)
+ {
+ current = splits[i];
+
+ if (InternalSpecialTokens!.TryGetValue(text.Slice(current.Offset, current.Length), out int id))
+ {
+ idsCount++;
+ }
+ textIndex -= current.Length;
+
+ if (current.Offset > 0 && idsCount < maxTokens)
+ {
+ int start = i > 0 ? splits[i - 1].Offset + splits[i - 1].Length : 0;
+ splitText = text.Slice(start, current.Offset - start);
+ idsCount += CountTokensFromEnd(splitText, addBeginOfSentence: false, addEndOfSentence: false, out splitTextIndex, maxTokens - idsCount);
+ textIndex -= splitText.Length - splitTextIndex;
+ }
+ }
+
+ if (addBeginOfSentence && idsCount < maxTokens)
+ {
+ idsCount++;
+ }
+
+ return idsCount;
+ }
+ ///
+ /// Get the number of tokens that the input text will be encoded to.
+ ///
+ /// The text to encode.
+ /// Indicate emitting the beginning of sentence token during the encoding.
+ /// Indicate emitting the end of sentence token during the encoding.
+ /// Starting from this index to the end of the text will encompasses the maximum encoded tokens.
+ /// The maximum number of tokens to encode.
+ /// The number of tokens that the input text will be encoded to.
+ /// The input text has to be normalized before calling this method.
+ private int CountTokensFromEnd(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, out int textIndex, int maxTokens = int.MaxValue)
+ {
+ textIndex = text.Length;
+ if (text.IsEmpty)
+ {
+ return 0;
+ }
+
+ int tokenCount = addEndOfSentence ? 1 : 0;
+
+ BpeSymbol[] symbols = ArrayPool.Shared.Rent(text.Length);
+
+ Dictionary<(int Index, int Len), (int LeftIndex, int LeftLen, int RightIndex, int RightLen)>? revMerge = Encode(text, symbols);
+
+ // Move to the last symbol.
+ int lastSymbolIndex = 0;
+ while (symbols[lastSymbolIndex].next != -1 && lastSymbolIndex < symbols.Length)
+ {
+ lastSymbolIndex = symbols[lastSymbolIndex].next;
+ }
+
+ for (int index = lastSymbolIndex; index >= 0; index = symbols[index].prev)
+ {
+ int id = symbols[index].id;
+ byte type = symbols[index].type;
+
+ if (id == UninitializedId)
+ {
+ if (_vocab.TryGetValue(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), out (int Id, float Score, byte Type) tokenInfo))
+ {
+ id = tokenInfo.Id;
+ type = tokenInfo.Type;
+ }
+ else
+ {
+ id = UnknownId;
+ type = 0;
+ }
+ }
+
+ if (type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused)
+ {
+ if (id == UnknownId && ByteFallback)
+ {
+ if (!EncodeAsBytesFromEnd(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index, ref textIndex))
+ {
+ break;
+ }
+ }
+ else
+ {
+ if (tokenCount < maxTokens)
+ {
+ tokenCount++;
+ textIndex -= symbols[index].pieceSpan.Length;
+ }
+ else
+ {
+ break;
+ }
+ }
+ continue;
+ }
+
+ if (!SegmentFromEnd(symbols[index].pieceSpan, text, ref textIndex))
+ {
+ break;
+ }
+ }
+
+ ArrayPool.Shared.Return(symbols);
+
+ if (addBeginOfSentence)
+ {
+ if (tokenCount < maxTokens)
+ {
+ tokenCount++;
+ }
+ }
+
+ return tokenCount;
+
+ // Encode the Unknown token to bytes.
+ bool EncodeAsBytesFromEnd(ReadOnlySpan text, int index, ref int textIndex)
+ {
+ for (int i = text.Length - 1; i >= 0; i--)
+ {
+ char c = text[i];
+ if (c <= 0x7F)
+ {
+ if (tokenCount < maxTokens)
+ {
+ tokenCount++;
+ textIndex--;
+ }
+ else
+ {
+ return false;
+ }
+ }
+ else
+ {
+ Span utf8Bytes = stackalloc byte[100];
+ byte[]? arrayPoolArray = null;
+
+ int len = Encoding.UTF8.GetMaxByteCount(text.Length - i);
+ if (len > utf8Bytes.Length)
+ {
+ arrayPoolArray = ArrayPool.Shared.Rent(len);
+ utf8Bytes = arrayPoolArray;
+ }
+
+ // Need to convert the text into UTF-8 bytes and then encode the bytes.
+ int encodedCount = Helpers.GetUtf8Bytes(text.Slice(0, i + 1), utf8Bytes);
+ bool ret;
+
+ if (tokenCount + encodedCount <= maxTokens)
+ {
+ tokenCount += encodedCount;
+ textIndex -= i + 1;
+ ret = true;
+ }
+ else
+ {
+ ret = false;
+ }
+
+ if (arrayPoolArray is not null)
+ {
+ ArrayPool.Shared.Return(arrayPoolArray);
+ }
+
+ return ret;
+ }
+ }
+
+ return true;
+ }
+
+ bool SegmentFromEnd((int Index, int Length) pieceSpan, ReadOnlySpan text, ref int textIndex)
+ {
+ if (!_vocab.TryGetValue(text.Slice(pieceSpan.Index, pieceSpan.Length), out (int Id, float Score, byte Type) id))
+ {
+ return EncodeAsBytesFromEnd(text.Slice(pieceSpan.Index, pieceSpan.Length), pieceSpan.Index, ref textIndex);
+ }
+
+ if (id.Type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused ||
+ revMerge is null ||
+ !revMerge.TryGetValue((pieceSpan.Index, pieceSpan.Length), out (int LeftIndex, int LeftLen, int RightIndex, int RightLen) merge))
+ {
+ if (tokenCount < maxTokens)
+ {
+ tokenCount++;
+ textIndex -= pieceSpan.Length;
+ return true;
+ }
+ else
+ {
+ return false;
+ }
+ }
+
+ // Segment the right part first.
+ return SegmentFromEnd((merge.RightIndex, merge.RightLen), text, ref textIndex) && SegmentFromEnd((merge.LeftIndex, merge.LeftLen), text, ref textIndex);
+ }
+ }
+
+ private Dictionary<(int Index, int Len), (int LeftIndex, int LeftLen, int RightIndex, int RightLen)>? Encode(ReadOnlySpan text, BpeSymbol[] symbols)
+ {
+ Debug.Assert(text.Length > 0);
+ Debug.Assert(symbols.Length >= text.Length);
+
+ int symbolIndex = 0;
+ int spanIndex = 0;
+
+ while (spanIndex < text.Length)
+ {
+ int len = (Char.IsHighSurrogate(text[spanIndex]) && spanIndex < text.Length - 1 && Char.IsLowSurrogate(text[spanIndex + 1])) ? 2 : 1;
+
+ BpeSymbol s = new(
+ prev: symbolIndex == 0 ? -1 : symbolIndex - 1,
+ next: spanIndex + len >= text.Length ? -1 : symbolIndex + 1,
+ pieceSpan: (spanIndex, len),
+ id: UninitializedId,
+ type: 0);
+
+ symbols[symbolIndex++] = s;
+ spanIndex += len;
+ }
+
+ PriorityQueue agenda = new(symbolIndex);
+ Dictionary<(int Index, int Len), (int LeftIndex, int LeftLen, int RightIndex, int RightLen)>? revMerge = null;
+
+ for (int i = 1; i < symbolIndex; i++)
+ {
+ TryMerge(i - 1, i, text);
+ }
+
+ while (agenda.Count > 0)
+ {
+ SymbolPair top = agenda.Dequeue();
+
+ if (symbols[top.Left].pieceSpan.Length == 0 || symbols[top.Right].pieceSpan.Length == 0 ||
+ symbols[top.Left].pieceSpan.Length + symbols[top.Right].pieceSpan.Length != top.Length)
+ {
+ continue;
+ }
+
+ // Replaces symbols with `top` rule.
+ symbols[top.Left].pieceSpan = (symbols[top.Left].pieceSpan.Index, symbols[top.Left].pieceSpan.Length + symbols[top.Right].pieceSpan.Length);
+ symbols[top.Left].id = top.Id;
+
+ // Updates prev/next pointers.
+ symbols[top.Left].next = symbols[top.Right].next;
+
+ if (symbols[top.Right].next >= 0)
+ {
+ symbols[symbols[top.Right].next].prev = top.Left;
+ }
+ symbols[top.Right].pieceSpan = (0, 0);
+
+ // Adds new symbol pairs which are newly added after symbol replacement.
+ TryMerge(symbols[top.Left].prev, top.Left, text);
+ TryMerge(top.Left, symbols[top.Left].next, text);
+ }
+
+ return revMerge;
+
+ void TryMerge(int left, int right, ReadOnlySpan textSpan)
+ {
+ if (left == -1 || right == -1)
+ {
+ return;
+ }
+
+ int pieceLength = symbols[left].pieceSpan.Length + symbols[right].pieceSpan.Length;
+ if (!_vocab.TryGetValue(textSpan.Slice(symbols[left].pieceSpan.Index, pieceLength), out (int Id, float Score, byte Type) leftId))
+ {
+ return;
+ }
+
+ symbols[left].type = leftId.Type;
+
+ SymbolPair pair = new(left, right, leftId.Score, pieceLength, leftId.Id);
+ agenda.Enqueue(pair);
+
+ if (leftId.Type == (byte)ModelProto.Types.SentencePiece.Types.Type.Unused)
+ {
+ revMerge ??= new();
+ revMerge.Add((symbols[left].pieceSpan.Index, pieceLength), (symbols[left].pieceSpan.Index, symbols[left].pieceSpan.Length, symbols[right].pieceSpan.Index, symbols[right].pieceSpan.Length));
+ }
+ }
+ }
+
+ // Tries to avoid string allocations if possible.
+ private string GetTokenString(int id, int index, int length, ReadOnlySpan text)
+ => _vocabReverse.TryGetValue(id, out string? token) ? token : text.Slice(index, length).ToString();
+
+ private struct SymbolPair : IEquatable, IComparable
+ {
+ public int Left { get; set; }
+ public int Right { get; set; }
+ public int Length { get; set; }
+ public float Score { get; set; }
+ public int Id { get; set; }
+
+ public SymbolPair(int left, int right, float score, int length, int id)
+ {
+ Left = left;
+ Right = right;
+ Score = score;
+ Length = length;
+ Id = id;
+ }
+
+ public int CompareTo(SymbolPair other)
+ {
+ if (Score != other.Score)
+ {
+ return other.Score.CompareTo(Score);
+ }
+
+ return other.Left.CompareTo(Left);
+ }
+
+ public override int GetHashCode()
+ {
+ int hashCode = 23;
+ hashCode = (hashCode * 37) + Score.GetHashCode();
+ hashCode = (hashCode * 37) + Left.GetHashCode();
+ return hashCode;
+ }
+
+ public bool Equals(SymbolPair other) => Left == other.Left && Score == other.Score;
+ }
+
+ private record struct BpeSymbol(int prev, int next, (int Index, int Length) pieceSpan, int id, byte type);
+ }
+}
diff --git a/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs
index 873dd0c4f6..f41516e270 100644
--- a/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/SentencePieceTokenizer.cs
@@ -6,13 +6,7 @@
using System;
using System.Buffers;
using System.Collections.Generic;
-using System.Collections.ObjectModel;
-using System.Diagnostics;
using System.IO;
-using System.Linq;
-using System.Text;
-using System.Text.RegularExpressions;
-using System.Threading;
namespace Microsoft.ML.Tokenizers
{
@@ -24,134 +18,82 @@ namespace Microsoft.ML.Tokenizers
///
public class SentencePieceTokenizer : Tokenizer
{
- private const int UninitializedId = -2; // indicate if the symbol contains uninitialized id.
- private readonly Dictionary _vocab = new();
- private readonly Dictionary _vocabReverse = new();
- private IReadOnlyDictionary? _publicVocab;
- private readonly int _maxByteId;
- private readonly int _byteCodeToIdOffset; // offset of mapping byte code to the to the Ids.
- private readonly int _oneByteUtf8EncodingMaxId; // the maximum value of the one byte UTF-8 character.
- private readonly Normalizer? _normalizer;
- private readonly Regex? _specialTokensRegex;
- private readonly Dictionary? _specialTokens;
- private readonly Dictionary? _specialTokensReverse;
+ private readonly SentencePieceBaseModel _model;
- internal SentencePieceTokenizer(ModelProto modelProto, bool addBos, bool addEos, IReadOnlyDictionary? specialTokens = null) :
- this(modelProto is null ? throw new ArgumentNullException(nameof(modelProto)) : modelProto, specialTokens)
+ internal SentencePieceTokenizer(ModelProto modelProto, bool addBos, bool addEos, IReadOnlyDictionary? specialTokens = null)
{
- AddBeginningOfSentence = addBos;
- AddEndOfSentence = addEos;
- }
-
- private SentencePieceTokenizer(ModelProto modelProto, IReadOnlyDictionary? specialTokens)
- {
- for (int i = 0; i < modelProto.Pieces.Count; i++)
- {
- var piece = modelProto.Pieces[i];
- _vocab.Add(new StringSpanOrdinalKey(piece.Piece), (i, piece.Score, (byte)piece.Type));
- _vocabReverse.Add(i, piece.Piece);
-
- if (piece.Type == ModelProto.Types.SentencePiece.Types.Type.Byte)
- {
- _maxByteId = i;
- }
- }
-
- _byteCodeToIdOffset = _vocab.TryGetValue("<0x00>", out (int Id, float Score, byte Type) value) ? value.Id : _maxByteId;
- _oneByteUtf8EncodingMaxId = _byteCodeToIdOffset + 0x7F; // 0x7F is the maximum value of the one byte UTF-8 character.
-
- BeginningOfSentenceToken = modelProto.TrainerSpec.BosPiece ?? "";
- BeginningOfSentenceId = modelProto.TrainerSpec.BosId <= 0 ? 1 : modelProto.TrainerSpec.BosId;
- EndOfSentenceToken = modelProto.TrainerSpec.EosPiece ?? "";
- EndOfSentenceId = modelProto.TrainerSpec.EosId <= 0 ? 1 : modelProto.TrainerSpec.EosId;
- UnknownToken = modelProto.TrainerSpec.UnkPiece ?? "";
- UnknownId = modelProto.TrainerSpec.UnkId < 0 ? 0 : modelProto.TrainerSpec.UnkId;
-
- AddDummyPrefix = modelProto.NormalizerSpec.AddDummyPrefix;
- EscapeWhiteSpaces = modelProto.NormalizerSpec.EscapeWhitespaces;
- TreatWhitespaceAsSuffix = modelProto.TrainerSpec.TreatWhitespaceAsSuffix;
- ByteFallback = modelProto.TrainerSpec.ByteFallback;
-
- SpecialTokens = specialTokens;
- _normalizer = new SentencePieceNormalizer(modelProto.NormalizerSpec.RemoveExtraWhitespaces, AddDummyPrefix, EscapeWhiteSpaces, modelProto.TrainerSpec.TreatWhitespaceAsSuffix, specialTokens);
-
- if (specialTokens is not null && specialTokens.Count > 0)
+ _model = modelProto.TrainerSpec.ModelType switch
{
- _specialTokens = new Dictionary();
- _specialTokensReverse = new Dictionary();
-
- foreach (var item in specialTokens)
- {
- _specialTokens.Add(new StringSpanOrdinalKey(item.Key), item.Value);
- _specialTokensReverse.Add(item.Value, item.Key);
- }
-
- // We create this Regex object without a timeout, as we expect the match operation to complete in O(N) time complexity. Note that `specialTokens` are treated as constants after the tokenizer is created.
- _specialTokensRegex = new Regex(string.Join("|", specialTokens.Keys.Select(s => Regex.Escape(s))), RegexOptions.Compiled);
- }
+ TrainerSpec.Types.ModelType.Bpe => new SentencePieceBpeModel(modelProto, addBos, addEos, specialTokens),
+ TrainerSpec.Types.ModelType.Unigram => new SentencePieceUnigramModel(modelProto, addBos, addEos, specialTokens),
+ _ => throw new ArgumentException($"The model type '{modelProto.TrainerSpec.ModelType}' is not supported.", nameof(modelProto))
+ };
}
- public IReadOnlyDictionary? SpecialTokens { get; }
+ ///
+ /// The special tokens.
+ ///
+ public IReadOnlyDictionary? SpecialTokens => _model.SpecialTokens;
///
/// Specifies whether the model will do a byte fallback when it encounters unknown tokens during the encoding process.
///
- public bool ByteFallback { get; }
+ public bool ByteFallback => _model.ByteFallback;
///
/// Indicate emitting the prefix character U+2581 at the beginning of sentence token during the normalization and encoding.
///
- public bool AddDummyPrefix { get; }
+ public bool AddDummyPrefix => _model.AddDummyPrefix;
///
/// Indicate if the spaces should be replaced with character U+2581 during the normalization and encoding.
///
- public bool EscapeWhiteSpaces { get; }
+ public bool EscapeWhiteSpaces => _model.EscapeWhiteSpaces;
///
/// Indicate emitting the character U+2581 at the end of the last sentence token instead beginning of sentence token during the normalization and encoding.
///
- public bool TreatWhitespaceAsSuffix { get; private set; }
+ public bool TreatWhitespaceAsSuffix { get => _model.TreatWhitespaceAsSuffix; private set => _model.TreatWhitespaceAsSuffix = value; }
///
/// Indicate emitting the beginning of sentence token during the encoding.
///
- public bool AddBeginningOfSentence { get; }
+ public bool AddBeginningOfSentence => _model.AddBeginningOfSentence;
///
/// Indicate emitting the end of sentence token during the encoding.
///
- public bool AddEndOfSentence { get; }
+ public bool AddEndOfSentence => _model.AddEndOfSentence;
///
/// The beginning of sentence token.
///
- public string BeginningOfSentenceToken { get; }
+ public string BeginningOfSentenceToken => _model.BeginningOfSentenceToken;
///
/// The end of sentence token.
///
- public string EndOfSentenceToken { get; }
+ public string EndOfSentenceToken => _model.EndOfSentenceToken;
///
/// The unknown token.
///
- public string UnknownToken { get; }
+ public string UnknownToken => _model.UnknownToken;
///
/// The id of the beginning of sentence token.
///
- public int BeginningOfSentenceId { get; }
+ public int BeginningOfSentenceId => _model.BeginningOfSentenceId;
///
/// The id of the end of sentence token.
///
- public int EndOfSentenceId { get; }
+ public int EndOfSentenceId => _model.EndOfSentenceId;
///
/// The id of the unknown token.
///
- public int UnknownId { get; }
+ public int UnknownId => _model.UnknownId;
///
/// Gets the PreTokenizer used by the Tokenizer.
@@ -161,31 +103,12 @@ private SentencePieceTokenizer(ModelProto modelProto, IReadOnlyDictionary
/// Gets the Normalizer in use by the Tokenizer.
///
- public override Normalizer? Normalizer => _normalizer;
+ public override Normalizer? Normalizer => _model.Normalizer;
///
/// The vocabulary of the model.
///
- public IReadOnlyDictionary Vocabulary
- {
- get
- {
- IReadOnlyDictionary? publicVocab = Volatile.Read(ref _publicVocab);
- if (publicVocab is null)
- {
- var vocab = new Dictionary();
- foreach (var item in _vocab)
- {
- vocab.Add(item.Key.ToString(), item.Value.Id);
- }
-
- Interlocked.CompareExchange(ref _publicVocab, new ReadOnlyDictionary(vocab), null);
- publicVocab = _publicVocab;
- }
-
- return publicVocab;
- }
- }
+ public IReadOnlyDictionary Vocabulary => _model.Vocabulary;
///
/// Encodes input text to a list of s.
@@ -197,7 +120,7 @@ protected override EncodeResults EncodeToTokens(string? text, Read
{
return new EncodeResults
{
- Tokens = EncodeToTokens(text, textSpan, out string? normalizedText, AddBeginningOfSentence, AddEndOfSentence, settings.ConsiderPreTokenization, settings.ConsiderNormalization),
+ Tokens = _model.EncodeToTokens(text, textSpan, out string? normalizedText, AddBeginningOfSentence, AddEndOfSentence, settings.ConsiderNormalization),
NormalizedText = normalizedText,
CharsConsumed = normalizedText?.Length ?? text?.Length ?? textSpan.Length
};
@@ -214,7 +137,7 @@ protected override EncodeResults EncodeToTokens(string? text, Read
/// Indicate whether to consider normalization before tokenization.
/// The tokenization result includes a list of s with string value of the token, id, and offset.
public IReadOnlyList EncodeToTokens(string text, out string? normalizedText, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true)
- => EncodeToTokens(text, Span.Empty, out normalizedText, addBeginningOfSentence, addEndOfSentence, considerPreTokenization, considerNormalization);
+ => _model.EncodeToTokens(text, Span.Empty, out normalizedText, addBeginningOfSentence, addEndOfSentence, considerNormalization);
///
/// Encodes input text a list of s with string value of the token, id, and offset.
@@ -227,221 +150,8 @@ public IReadOnlyList EncodeToTokens(string text, out string? norma
/// Indicate whether to consider normalization before tokenization.
/// The tokenization result includes a list of s with string value of the token, id, and offset.
public IReadOnlyList EncodeToTokens(ReadOnlySpan text, out string? normalizedText, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true)
- => EncodeToTokens(null, text, out normalizedText, addBeginningOfSentence, addEndOfSentence, considerPreTokenization, considerNormalization);
-
- private IReadOnlyList EncodeToTokens(string? text, ReadOnlySpan textSpan, out string? normalizedText, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization, bool considerNormalization)
- {
- if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
- {
- normalizedText = null;
- return [];
- }
-
- ReadOnlySpan textToEncode = text is null ? textSpan : text.AsSpan();
- if (considerNormalization && _normalizer is not null)
- {
- normalizedText = text is not null ? _normalizer.Normalize(text) : _normalizer.Normalize(textSpan);
- textToEncode = normalizedText.AsSpan();
- }
- else
- {
- normalizedText = null;
- }
-
- if (textToEncode.Length == 0)
- {
- return [];
- }
-
- List? tokens = new();
-
- if (_specialTokensRegex is not null)
- {
- EncodeWithSpecialTokens(textToEncode, addBeginningOfSentence, addEndOfSentence, tokens);
- }
- else
- {
- EncodeInternal(textToEncode, addBeginningOfSentence, addEndOfSentence, tokens);
- }
-
- return tokens;
- }
-
- private void EncodeWithSpecialTokens(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, List tokens)
- {
- Debug.Assert(_specialTokensRegex is not null);
-
- if (addBeginOfSentence)
- {
- tokens.Add(new EncodedToken(BeginningOfSentenceId, BeginningOfSentenceToken, new Range(0, 0)));
- }
-
- int currentOffset = 0;
-
- foreach ((int Offset, int Length) in PreTokenizer.SplitText(text, _specialTokensRegex!))
- {
- if (Offset > currentOffset)
- {
- EncodeInternal(text.Slice(currentOffset, Offset - currentOffset), addBeginOfSentence: false, addEndOfSentence: false, tokens);
- }
-
- if (_specialTokens!.TryGetValue(text.Slice(Offset, Length), out int id))
- {
- tokens.Add(new EncodedToken(id, _specialTokensReverse![id], new Range(Offset, Offset + Length)));
- }
-
- currentOffset = Offset + Length;
- }
-
- if (currentOffset < text.Length)
- {
- EncodeInternal(text.Slice(currentOffset), addBeginOfSentence: false, addEndOfSentence: false, tokens);
- }
-
- if (addEndOfSentence)
- {
- tokens.Add(new EncodedToken(EndOfSentenceId, EndOfSentenceToken, new Range(text.Length, text.Length)));
- }
- }
-
- ///
- /// Encode a text to a list of tokens.
- ///
- /// The text to encode.
- /// Indicate emitting the beginning of sentence token during the encoding.
- /// Indicate emitting the end of sentence token during the encoding.
- /// A collection to store the encoded tokens.
- /// The input text has to be normalized before calling this method.
- private void EncodeInternal(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, List tokens)
- {
- BpeSymbol[] symbols = ArrayPool.Shared.Rent(text.Length);
-
- Dictionary<(int Index, int Len), (int LeftIndex, int LeftLen, int RightIndex, int RightLen)>? revMerge = Encode(text, symbols);
-
- if (addBeginOfSentence)
- {
- tokens.Add(new EncodedToken(BeginningOfSentenceId, BeginningOfSentenceToken, new Range(0, 0)));
- }
-
- for (int index = 0; (uint)index < (uint)symbols.Length; index = symbols[index].next)
- {
- int id = symbols[index].id;
- byte type = symbols[index].type;
-
- if (id == UninitializedId)
- {
- if (_vocab.TryGetValue(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), out (int Id, float Score, byte Type) tokenInfo))
- {
- id = tokenInfo.Id;
- type = tokenInfo.Type;
- }
- else
- {
- id = UnknownId;
- type = 0;
- }
- }
-
- if (type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused)
- {
- if (id == UnknownId && ByteFallback)
- {
- EncodeAsBytes(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index);
- }
- else
- {
- tokens.Add(new EncodedToken(
- id,
- GetTokenString(id, symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length, text),
- new Range(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Index + symbols[index].pieceSpan.Length)));
- }
- continue;
- }
-
- Segment(symbols[index].pieceSpan, text);
- }
-
- ArrayPool.Shared.Return(symbols);
-
- if (addEndOfSentence)
- {
- tokens.Add(new EncodedToken(EndOfSentenceId, EndOfSentenceToken, new Range(text.Length, text.Length)));
- }
-
- return;
-
- // Encode the Unknown token to bytes.
- void EncodeAsBytes(ReadOnlySpan text, int index)
- {
- for (int i = 0; i < text.Length; i++)
- {
- char c = text[i];
- if (c <= 0x7F)
- {
- int id = (int)c + _byteCodeToIdOffset; // byte code is mapped to the to the Ids starting from 4.
-
- if (_vocabReverse.TryGetValue(id, out string? token))
- {
- tokens.Add(new EncodedToken(id, token, new Range(index + i, index + i + 1)));
- }
- }
- else
- {
- Span utf8Bytes = stackalloc byte[256];
- byte[]? arrayPoolArray = null;
-
- int len = Encoding.UTF8.GetMaxByteCount(text.Length - i);
- if (len > utf8Bytes.Length)
- {
- arrayPoolArray = ArrayPool.Shared.Rent(len);
- utf8Bytes = arrayPoolArray;
- }
-
- // Need to convert the text into UTF-8 bytes and then encode the bytes.
- int bytesWritten = Helpers.GetUtf8Bytes(text.Slice(i), utf8Bytes);
- int length = text.Length - i;
- for (int j = 0; j < bytesWritten; j++)
- {
- int id = (int)utf8Bytes[j] + _byteCodeToIdOffset; // byte code is mapped to the to the Ids starting from 4.
-
- if (_vocabReverse.TryGetValue(id, out string? token))
- {
- tokens.Add(new EncodedToken(id, token, new Range(index + i, index + i + length)));
- }
-
- length = 0;
- }
-
- if (arrayPoolArray is not null)
- {
- ArrayPool.Shared.Return(arrayPoolArray);
- }
-
- break;
- }
- }
- }
-
- void Segment((int Index, int Length) pieceSpan, ReadOnlySpan text)
- {
- if (!_vocab.TryGetValue(text.Slice(pieceSpan.Index, pieceSpan.Length), out (int Id, float Score, byte Type) id))
- {
- EncodeAsBytes(text.Slice(pieceSpan.Index, pieceSpan.Length), pieceSpan.Index);
- return;
- }
-
- if (id.Type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused ||
- revMerge is null ||
- !revMerge.TryGetValue((pieceSpan.Index, pieceSpan.Length), out (int LeftIndex, int LeftLen, int RightIndex, int RightLen) merge))
- {
- tokens.Add(new EncodedToken(id.Id, text.Slice(pieceSpan.Index, pieceSpan.Length).ToString(), new Range(pieceSpan.Index, pieceSpan.Index + pieceSpan.Length)));
- return;
- }
+ => _model.EncodeToTokens(null, text, out normalizedText, addBeginningOfSentence, addEndOfSentence, considerNormalization);
- Segment((merge.LeftIndex, merge.LeftLen), text);
- Segment((merge.RightIndex, merge.RightLen), text);
- }
- }
///
/// Encodes input text to token Ids.
@@ -454,7 +164,7 @@ protected override EncodeResults EncodeToIds(string? text, ReadOnlySpan
{
- Tokens = EncodeToIds(text, textSpan, AddBeginningOfSentence, AddEndOfSentence, settings.ConsiderNormalization, out string? normalizedText, out int charsConsumed, settings.MaxTokenCount),
+ Tokens = _model.EncodeToIds(text, textSpan, AddBeginningOfSentence, AddEndOfSentence, settings.ConsiderNormalization, out string? normalizedText, out int charsConsumed, settings.MaxTokenCount),
NormalizedText = normalizedText,
CharsConsumed = charsConsumed
};
@@ -470,7 +180,7 @@ protected override EncodeResults EncodeToIds(string? text, ReadOnlySpanIndicate whether to consider normalization before tokenization.
/// The list of encoded Ids.
public IReadOnlyList EncodeToIds(string text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true)
- => EncodeToIds(text, Span.Empty, addBeginningOfSentence, addEndOfSentence, considerNormalization, out _, out _);
+ => _model.EncodeToIds(text, Span.Empty, addBeginningOfSentence, addEndOfSentence, considerNormalization, out _, out _);
///
/// Encodes input text to token Ids.
@@ -482,7 +192,7 @@ public IReadOnlyList EncodeToIds(string text, bool addBeginningOfSentence,
/// Indicate whether to consider normalization before tokenization.
/// The list of encoded Ids.
public IReadOnlyList EncodeToIds(ReadOnlySpan text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true)
- => EncodeToIds(null, text, addBeginningOfSentence, addEndOfSentence, considerNormalization, out _, out _);
+ => _model.EncodeToIds(null, text, addBeginningOfSentence, addEndOfSentence, considerNormalization, out _, out _);
///
/// Encodes input text to token Ids up to maximum number of tokens.
@@ -497,7 +207,7 @@ public IReadOnlyList EncodeToIds(ReadOnlySpan text, bool addBeginning
/// Indicate whether to consider normalization before tokenization.
/// The list of encoded Ids.
public IReadOnlyList EncodeToIds(string text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, out string? normalizedText, out int charsConsumed, bool considerPreTokenization = true, bool considerNormalization = true)
- => EncodeToIds(text, Span.Empty, addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedText, out charsConsumed, maxTokenCount);
+ => _model.EncodeToIds(text, Span.Empty, addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedText, out charsConsumed, maxTokenCount);
///
/// Encodes input text to token Ids up to maximum number of tokens.
@@ -512,320 +222,7 @@ public IReadOnlyList EncodeToIds(string text, bool addBeginningOfSentence,
/// Indicate whether to consider normalization before tokenization.
/// The list of encoded Ids.
public IReadOnlyList EncodeToIds(ReadOnlySpan text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, out string? normalizedText, out int charsConsumed, bool considerPreTokenization = true, bool considerNormalization = true)
- => EncodeToIds(null, text, addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedText, out charsConsumed, maxTokenCount);
-
-
- private IReadOnlyList EncodeToIds(string? text, ReadOnlySpan textSpan, bool addBeginningOfSentence, bool addEndOfSentence, bool considerNormalization, out string? normalizedText, out int charsConsumed, int maxTokenCount = int.MaxValue)
- {
- if (maxTokenCount <= 0)
- {
- throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero.");
- }
-
- if (string.IsNullOrEmpty(text) && textSpan.IsEmpty)
- {
- normalizedText = null;
- charsConsumed = 0;
- return [];
- }
-
- return EncodeToIds(text is null ? textSpan : text.AsSpan(), addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedText, out charsConsumed, maxTokenCount);
- }
-
- ///
- /// Encodes input text to token Ids up to maximum number of tokens.
- ///
- /// The text to encode.
- /// Indicate emitting the beginning of sentence token during the encoding.
- /// Indicate emitting the end of sentence token during the encoding.
- /// Indicate whether to consider normalization before tokenization.
- /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
- /// The length of the text that encompasses the maximum encoded tokens.
- /// The maximum number of tokens to encode.
- /// The list of encoded Ids.
- private IReadOnlyList EncodeToIds(ReadOnlySpan text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerNormalization,
- out string? normalizedText, out int charsConsumed, int maxTokenCount = int.MaxValue)
- {
- if (maxTokenCount <= 0)
- {
- throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero.");
- }
-
- if (text.IsEmpty)
- {
- normalizedText = null;
- charsConsumed = 0;
- return [];
- }
-
- ReadOnlySpan textToEncode;
-
- if (considerNormalization && _normalizer is not null)
- {
- normalizedText = _normalizer.Normalize(text);
- textToEncode = normalizedText.AsSpan();
- }
- else
- {
- normalizedText = null;
- textToEncode = text;
- }
-
- if (maxTokenCount <= 0)
- {
- throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than 0.");
- }
-
- List ids = new();
-
- if (_specialTokensRegex is not null)
- {
- EncodeToIdsWithAddedToken(textToEncode, addBeginningOfSentence, addEndOfSentence, ids, out charsConsumed, maxTokenCount);
- }
- else
- {
- EncodeToIds(textToEncode, addBeginningOfSentence, addEndOfSentence, ids, out charsConsumed, maxTokenCount);
- }
-
- return ids;
- }
-
- private int EncodeToIdsWithAddedToken(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, IList accumulatedIds, out int charsConsumed, int maxTokens = int.MaxValue)
- {
- Debug.Assert(_specialTokensRegex is not null);
- Debug.Assert(maxTokens > 0);
-
- charsConsumed = 0;
- int idsCount = 0;
-
- if (addBeginOfSentence)
- {
- accumulatedIds.Add(BeginningOfSentenceId);
- idsCount++;
- }
-
- int currentOffset = 0;
-
- int charsWritten;
-
- foreach ((int Offset, int Length) in PreTokenizer.SplitText(text, _specialTokensRegex!))
- {
- if (Offset > currentOffset)
- {
- idsCount += EncodeToIds(text.Slice(currentOffset, Offset - currentOffset), addBeginOfSentence: false, addEndOfSentence: false, accumulatedIds, out charsWritten, maxTokens - idsCount);
- charsConsumed += charsWritten;
- }
-
- if (idsCount < maxTokens && _specialTokens!.TryGetValue(text.Slice(Offset, Length), out int id))
- {
- accumulatedIds.Add(id);
- idsCount++;
- charsConsumed += Length;
- }
-
- currentOffset = Offset + Length;
- }
-
- if (currentOffset < text.Length && idsCount < maxTokens)
- {
- idsCount += EncodeToIds(text.Slice(currentOffset), addBeginOfSentence: false, addEndOfSentence: false, accumulatedIds, out charsWritten, maxTokens - idsCount);
- charsConsumed += charsWritten;
- }
-
- if (addEndOfSentence && idsCount < maxTokens)
- {
- accumulatedIds.Add(EndOfSentenceId);
- idsCount++;
- }
-
- return idsCount;
- }
-
- ///
- /// Encode a text to a list of Ids and add them to the accumulatedIds list.
- ///
- /// The text to encode.
- /// Indicate emitting the beginning of sentence token during the encoding.
- /// Indicate emitting the end of sentence token during the encoding.
- /// The list of accumulated encoded Ids.
- /// The length of the text that encompasses the maximum encoded tokens.
- /// The maximum number of tokens to encode.
- /// The number of tokens that the input text will be encoded to.
- /// The input text has to be normalized before calling this method.
- private int EncodeToIds(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, IList accumulatedIds, out int charsConsumed, int maxTokens = int.MaxValue)
- {
- charsConsumed = 0;
- if (text.IsEmpty)
- {
- return 0;
- }
-
- int idsCount = 0;
-
- if (addBeginOfSentence)
- {
- accumulatedIds.Add(BeginningOfSentenceId);
- idsCount++;
- }
-
- BpeSymbol[] symbols = ArrayPool.Shared.Rent(text.Length);
-
- Dictionary<(int Index, int Len), (int LeftIndex, int LeftLen, int RightIndex, int RightLen)>? revMerge = Encode(text, symbols);
-
- for (int index = 0; index != -1 && index < symbols.Length; index = symbols[index].next)
- {
- int id = symbols[index].id;
- byte type = symbols[index].type;
-
- if (id == UninitializedId)
- {
- if (_vocab.TryGetValue(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), out (int Id, float Score, byte Type) tokenInfo))
- {
- id = tokenInfo.Id;
- type = tokenInfo.Type;
- }
- else
- {
- id = UnknownId;
- type = 0;
- }
- }
-
- if (type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused)
- {
- if (id == UnknownId && ByteFallback)
- {
- if (!EncodeAsBytes(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index, ref charsConsumed))
- {
- ArrayPool.Shared.Return(symbols);
- return idsCount;
- }
- }
- else
- {
- if (idsCount < maxTokens)
- {
- accumulatedIds.Add(id);
- charsConsumed += symbols[index].pieceSpan.Length;
- idsCount++;
- }
- else
- {
- ArrayPool.Shared.Return(symbols);
- return idsCount;
- }
- }
- continue;
- }
-
- if (!Segment(symbols[index].pieceSpan, text, ref charsConsumed))
- {
- break;
- }
- }
-
- ArrayPool.Shared.Return(symbols);
-
- if (addEndOfSentence)
- {
- if (idsCount < maxTokens)
- {
- accumulatedIds.Add(EndOfSentenceId);
- idsCount++;
- }
- }
-
- return idsCount;
-
- // Encode the Unknown token to bytes.
- bool EncodeAsBytes(ReadOnlySpan text, int index, ref int charsConsumed)
- {
- for (int i = 0; i < text.Length; i++)
- {
- char c = text[i];
- if (c <= 0x7F)
- {
- if (idsCount < maxTokens)
- {
- charsConsumed++;
- accumulatedIds.Add((int)c + _byteCodeToIdOffset); // byte code is mapped to the to the Ids starting from 4.
- idsCount++;
- }
- else
- {
- return false;
- }
- }
- else
- {
- Span utf8Bytes = stackalloc byte[100];
- byte[]? arrayPoolArray = null;
-
- int len = Encoding.UTF8.GetMaxByteCount(text.Length - i);
- if (len > utf8Bytes.Length)
- {
- arrayPoolArray = ArrayPool.Shared.Rent(len);
- utf8Bytes = arrayPoolArray;
- }
-
- // Need to convert the text into UTF-8 bytes and then encode the bytes.
- int bytesWritten = Helpers.GetUtf8Bytes(text.Slice(i), utf8Bytes);
-
- bool ret;
- if (idsCount + bytesWritten <= maxTokens)
- {
- for (int j = 0; j < bytesWritten; j++)
- {
- accumulatedIds.Add((int)utf8Bytes[j] + _byteCodeToIdOffset); // byte code is mapped to the to the Ids starting from 4.
- }
-
- charsConsumed += text.Length - i;
- ret = true;
- }
- else
- {
- ret = false;
- }
-
- if (arrayPoolArray is not null)
- {
- ArrayPool.Shared.Return(arrayPoolArray);
- }
-
- return ret;
- }
- }
-
- return true;
- }
-
- bool Segment((int Index, int Length) pieceSpan, ReadOnlySpan text, ref int charsConsumed)
- {
- if (!_vocab.TryGetValue(text.Slice(pieceSpan.Index, pieceSpan.Length), out (int Id, float Score, byte Type) id))
- {
- return EncodeAsBytes(text.Slice(pieceSpan.Index, pieceSpan.Length), pieceSpan.Index, ref charsConsumed);
- }
-
- if (id.Type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused ||
- revMerge is null ||
- !revMerge.TryGetValue((pieceSpan.Index, pieceSpan.Length), out (int LeftIndex, int LeftLen, int RightIndex, int RightLen) merge))
- {
- if (idsCount < maxTokens)
- {
- accumulatedIds.Add(id.Id);
- charsConsumed += pieceSpan.Length;
- idsCount++;
- return true;
- }
- else
- {
- return false;
- }
- }
-
- return Segment((merge.LeftIndex, merge.LeftLen), text, ref charsConsumed) && Segment((merge.RightIndex, merge.RightLen), text, ref charsConsumed);
- }
- }
+ => _model.EncodeToIds(null, text, addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedText, out charsConsumed, maxTokenCount);
///
/// Get the number of tokens that the input text will be encoded to.
@@ -835,12 +232,7 @@ revMerge is null ||
/// The settings used to encode the text.
/// The number of token Ids that the input text will be encoded to.
protected override int CountTokens(string? text, ReadOnlySpan textSpan, EncodeSettings settings)
- {
- return CountTokens(text, textSpan, AddBeginningOfSentence, AddEndOfSentence, settings.ConsiderPreTokenization, settings.ConsiderNormalization, out _, out _, settings.MaxTokenCount);
- }
-
- private int CountTokens(string? text, ReadOnlySpan textSpan, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization, bool considerNormalization, out string? normalizedText, out int charsConsumed, int maxTokenCount = int.MaxValue)
- => CountTokens(text is null ? textSpan : text.AsSpan(), addBeginningOfSentence, addEndOfSentence, considerPreTokenization, considerNormalization, out normalizedText, out charsConsumed, maxTokenCount);
+ => _model.CountTokens(text, textSpan, AddBeginningOfSentence, AddEndOfSentence, settings.ConsiderNormalization, out _, out _, settings.MaxTokenCount);
///
/// Get the number of tokens that the input text will be encoded to.
@@ -852,7 +244,7 @@ private int CountTokens(string? text, ReadOnlySpan textSpan, bool addBegin
/// Indicate whether to consider normalization before tokenization.
/// The number of token Ids that the input text will be encoded to.
public int CountTokens(string text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true)
- => CountTokens(text, Span.Empty, addBeginningOfSentence, addEndOfSentence, considerPreTokenization, considerNormalization, out _, out _);
+ => _model.CountTokens(text, ReadOnlySpan.Empty, addBeginningOfSentence, addEndOfSentence, considerNormalization, out _, out _, int.MaxValue);
///
/// Get the number of tokens that the input text will be encoded to.
@@ -864,7 +256,7 @@ public int CountTokens(string text, bool addBeginningOfSentence, bool addEndOfSe
/// Indicate whether to consider normalization before tokenization.
/// The number of token Ids that the input text will be encoded to.
public int CountTokens(ReadOnlySpan text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization = true, bool considerNormalization = true)
- => CountTokens(null, text, addBeginningOfSentence, addEndOfSentence, considerPreTokenization, considerNormalization, out _, out _);
+ => _model.CountTokens(null, text, addBeginningOfSentence, addEndOfSentence, considerNormalization, out _, out _, int.MaxValue);
///
/// Get the number of tokens that the input text will be encoded to.
@@ -879,7 +271,7 @@ public int CountTokens(ReadOnlySpan text, bool addBeginningOfSentence, boo
/// The maximum number of tokens to encode.
/// The number of tokens that the input text will be encoded to.
public int CountTokens(string text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization, bool considerNormalization, out string? normalizedText, out int charsConsumed, int maxTokenCount = int.MaxValue)
- => CountTokens(text, Span.Empty, addBeginningOfSentence, addEndOfSentence, considerPreTokenization, considerNormalization, out normalizedText, out charsConsumed, maxTokenCount);
+ => _model.CountTokens(text, ReadOnlySpan.Empty, addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedText, out charsConsumed, maxTokenCount);
///
/// Get the number of tokens that the input text will be encoded to.
@@ -894,299 +286,54 @@ public int CountTokens(string text, bool addBeginningOfSentence, bool addEndOfSe
/// The maximum number of tokens to encode.
/// The number of tokens that the input text will be encoded to.
public int CountTokens(ReadOnlySpan text, bool addBeginningOfSentence, bool addEndOfSentence, bool considerPreTokenization, bool considerNormalization, out string? normalizedText, out int charsConsumed, int maxTokenCount = int.MaxValue)
- {
- if (maxTokenCount <= 0)
- {
- throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The maximum number of tokens must be greater than zero.");
- }
-
- if (text.IsEmpty)
- {
- normalizedText = null;
- charsConsumed = 0;
- return 0;
- }
+ => _model.CountTokens(null, text, addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedText, out charsConsumed, maxTokenCount);
- ReadOnlySpan textToEncode;
- if (considerNormalization && _normalizer is not null)
- {
- normalizedText = _normalizer.Normalize(text);
- textToEncode = normalizedText.AsSpan();
- }
- else
+ ///
+ /// Find the index of the maximum encoding capacity without surpassing the token limit.
+ ///
+ /// The text to encode.
+ /// The span of the text to encode which will be used if the is .
+ /// The settings used to encode the text.
+ /// Indicate whether to find the index from the end of the text.
+ /// If the tokenizer's normalization is enabled or has is , this will be set to in its normalized form; otherwise, this value will be set to .
+ /// The token count can be generated which should be smaller than the maximum token count.
+ ///
+ /// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
+ /// If is , it represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
+ /// if all tokens fit, the result will be length of the input text or the if the normalization is enabled.
+ /// If is , it represents the index of the first character to be included. In cases where no tokens fit, the result will be the text length; conversely,
+ /// if all tokens fit, the result will be zero.
+ ///
+ protected override int GetIndexByTokenCount(string? text, ReadOnlySpan textSpan, EncodeSettings settings, bool fromEnd, out string? normalizedText, out int tokenCount)
+ {
+ if (fromEnd)
{
- normalizedText = null;
- textToEncode = text;
+ return _model.GetIndexByTokenCountFromEnd(text, textSpan, AddBeginningOfSentence, AddEndOfSentence, settings.MaxTokenCount, settings.ConsiderNormalization, out normalizedText, out tokenCount);
}
- return _specialTokensRegex is not null ?
- CountTokensWithSpecialTokens(textToEncode, addBeginningOfSentence, addEndOfSentence, out charsConsumed, maxTokenCount) :
- CountTokens(textToEncode, addBeginningOfSentence, addEndOfSentence, out charsConsumed, maxTokenCount);
+ tokenCount = _model.CountTokens(text, textSpan, AddBeginningOfSentence, AddEndOfSentence, settings.ConsiderNormalization, out normalizedText, out int charsConsumed, settings.MaxTokenCount);
+ return charsConsumed;
}
- private int CountTokensWithSpecialTokens(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, out int charsConsumed, int maxTokens = int.MaxValue)
+ ///
+ /// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
+ ///
+ /// The text to encode.
+ /// Indicate emitting the beginning of sentence token during the encoding.
+ /// Indicate emitting the end of sentence token during the encoding.
+ /// The maximum token count to limit the encoding capacity.
+ /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
+ /// The token count can be generated which should be smaller than the maximum token count.
+ /// Indicate whether to consider pre-tokenization before tokenization.
+ /// Indicate whether to consider normalization before tokenization.
+ ///
+ /// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
+ /// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
+ /// if all tokens fit, the result will be length of the text or the if the normalization is enabled.
+ ///
+ public int GetIndexByTokenCount(string text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, out string? normalizedText, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
{
- Debug.Assert(_specialTokensRegex is not null);
- Debug.Assert(maxTokens > 0);
-
- charsConsumed = 0;
- int idsCount = 0;
-
- if (addBeginOfSentence)
- {
- idsCount++;
- }
-
- int currentOffset = 0;
-
- int charsWritten;
-
- foreach ((int Offset, int Length) in PreTokenizer.SplitText(text, _specialTokensRegex!))
- {
- if (Offset > currentOffset)
- {
- idsCount += CountTokens(text.Slice(currentOffset, Offset - currentOffset), addBeginOfSentence: false, addEndOfSentence: false, out charsWritten, maxTokens - idsCount);
- charsConsumed += charsWritten;
- }
-
- if (idsCount < maxTokens && _specialTokens!.TryGetValue(text.Slice(Offset, Length), out int id))
- {
- idsCount++;
- charsConsumed += Length;
- }
-
- currentOffset = Offset + Length;
- }
-
- if (currentOffset < text.Length && idsCount < maxTokens)
- {
- idsCount += CountTokens(text.Slice(currentOffset), addBeginOfSentence: false, addEndOfSentence: false, out charsWritten, maxTokens - idsCount);
- charsConsumed += charsWritten;
- }
-
- if (addEndOfSentence && idsCount < maxTokens)
- {
- idsCount++;
- }
-
- return idsCount;
- }
-
- ///
- /// Get the number of tokens that the input text will be encoded to.
- ///
- /// The text to encode.
- /// Indicate emitting the beginning of sentence token during the encoding.
- /// Indicate emitting the end of sentence token during the encoding.
- /// The length of the text that encompasses the maximum encoded tokens.
- /// The maximum number of tokens to encode.
- /// The number of tokens that the input text will be encoded to.
- /// The input text has to be normalized before calling this method.
- private int CountTokens(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, out int charsConsumed, int maxTokens = int.MaxValue)
- {
- charsConsumed = 0;
- if (text.IsEmpty)
- {
- return 0;
- }
-
- int tokenCount = addBeginOfSentence ? 1 : 0;
-
- BpeSymbol[] symbols = ArrayPool.Shared.Rent(text.Length);
-
- Dictionary<(int Index, int Len), (int LeftIndex, int LeftLen, int RightIndex, int RightLen)>? revMerge = Encode(text, symbols);
-
- for (int index = 0; index != -1 && index < symbols.Length; index = symbols[index].next)
- {
- int id = symbols[index].id;
- byte type = symbols[index].type;
-
- if (id == UninitializedId)
- {
- if (_vocab.TryGetValue(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), out (int Id, float Score, byte Type) tokenInfo))
- {
- id = tokenInfo.Id;
- type = tokenInfo.Type;
- }
- else
- {
- id = UnknownId;
- type = 0;
- }
- }
-
- if (type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused)
- {
- if (id == UnknownId && ByteFallback)
- {
- if (!EncodeAsBytes(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index, ref charsConsumed))
- {
- break;
- }
- }
- else
- {
- if (tokenCount < maxTokens)
- {
- tokenCount++;
- charsConsumed += symbols[index].pieceSpan.Length;
- }
- else
- {
- break;
- }
- }
- continue;
- }
-
- if (!Segment(symbols[index].pieceSpan, text, ref charsConsumed))
- {
- break;
- }
- }
-
- ArrayPool.Shared.Return(symbols);
-
- if (addEndOfSentence)
- {
- if (tokenCount < maxTokens)
- {
- tokenCount++;
- }
- }
-
- return tokenCount;
-
- // Encode the Unknown token to bytes.
- bool EncodeAsBytes(ReadOnlySpan text, int index, ref int charsConsumed)
- {
- for (int i = 0; i < text.Length; i++)
- {
- char c = text[i];
- if (c <= 0x7F)
- {
- if (tokenCount < maxTokens)
- {
- tokenCount++;
- charsConsumed++;
- }
- else
- {
- return false;
- }
- }
- else
- {
- Span utf8Bytes = stackalloc byte[100];
- byte[]? arrayPoolArray = null;
-
- int len = Encoding.UTF8.GetMaxByteCount(text.Length - i);
- if (len > utf8Bytes.Length)
- {
- arrayPoolArray = ArrayPool.Shared.Rent(len);
- utf8Bytes = arrayPoolArray;
- }
-
- // Need to convert the text into UTF-8 bytes and then encode the bytes.
- int encodedCount = Helpers.GetUtf8Bytes(text.Slice(i), utf8Bytes);
- bool ret;
-
- if (tokenCount + encodedCount <= maxTokens)
- {
- tokenCount += encodedCount;
- charsConsumed += text.Length - i;
- ret = true;
- }
- else
- {
- ret = false;
- }
-
- if (arrayPoolArray is not null)
- {
- ArrayPool.Shared.Return(arrayPoolArray);
- }
-
- return ret;
- }
- }
-
- return true;
- }
-
- bool Segment((int Index, int Length) pieceSpan, ReadOnlySpan text, ref int charsConsumed)
- {
- if (!_vocab.TryGetValue(text.Slice(pieceSpan.Index, pieceSpan.Length), out (int Id, float Score, byte Type) id))
- {
- return EncodeAsBytes(text.Slice(pieceSpan.Index, pieceSpan.Length), pieceSpan.Index, ref charsConsumed);
- }
-
- if (id.Type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused ||
- revMerge is null ||
- !revMerge.TryGetValue((pieceSpan.Index, pieceSpan.Length), out (int LeftIndex, int LeftLen, int RightIndex, int RightLen) merge))
- {
- if (tokenCount < maxTokens)
- {
- tokenCount++;
- charsConsumed += pieceSpan.Length;
- return true;
- }
- else
- {
- return false;
- }
- }
-
- return Segment((merge.LeftIndex, merge.LeftLen), text, ref charsConsumed) && Segment((merge.RightIndex, merge.RightLen), text, ref charsConsumed);
- }
- }
-
- ///
- /// Find the index of the maximum encoding capacity without surpassing the token limit.
- ///
- /// The text to encode.
- /// The span of the text to encode which will be used if the is .
- /// The settings used to encode the text.
- /// Indicate whether to find the index from the end of the text.
- /// If the tokenizer's normalization is enabled or has is , this will be set to in its normalized form; otherwise, this value will be set to .
- /// The token count can be generated which should be smaller than the maximum token count.
- ///
- /// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
- /// If is , it represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
- /// if all tokens fit, the result will be length of the input text or the if the normalization is enabled.
- /// If is , it represents the index of the first character to be included. In cases where no tokens fit, the result will be the text length; conversely,
- /// if all tokens fit, the result will be zero.
- ///
- protected override int GetIndexByTokenCount(string? text, ReadOnlySpan textSpan, EncodeSettings settings, bool fromEnd, out string? normalizedText, out int tokenCount)
- {
- if (fromEnd)
- {
- return GetIndexByTokenCountFromEnd(text, textSpan, settings.MaxTokenCount, settings.ConsiderNormalization, out normalizedText, out tokenCount);
- }
-
- tokenCount = CountTokens(text, textSpan, AddBeginningOfSentence, AddEndOfSentence, settings.ConsiderPreTokenization, settings.ConsiderNormalization, out normalizedText, out int charsConsumed, settings.MaxTokenCount);
- return charsConsumed;
- }
-
- ///
- /// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
- ///
- /// The text to encode.
- /// Indicate emitting the beginning of sentence token during the encoding.
- /// Indicate emitting the end of sentence token during the encoding.
- /// The maximum token count to limit the encoding capacity.
- /// If the tokenizer's normalization is enabled or is false, this will be set to in its normalized form; otherwise, this value will be set to null.
- /// The token count can be generated which should be smaller than the maximum token count.
- /// Indicate whether to consider pre-tokenization before tokenization.
- /// Indicate whether to consider normalization before tokenization.
- ///
- /// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
- /// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely,
- /// if all tokens fit, the result will be length of the text or the if the normalization is enabled.
- ///
- public int GetIndexByTokenCount(string text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, out string? normalizedText, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
- {
- tokenCount = CountTokens(text, Span.Empty, addBeginningOfSentence, addEndOfSentence, considerPreTokenization, considerNormalization, out normalizedText, out int charsConsumed, maxTokenCount);
+ tokenCount = _model.CountTokens(text, ReadOnlySpan.Empty, addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedText, out int charsConsumed, maxTokenCount);
return charsConsumed;
}
@@ -1208,13 +355,10 @@ public int GetIndexByTokenCount(string text, bool addBeginningOfSentence, bool a
///
public int GetIndexByTokenCount(ReadOnlySpan text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, out string? normalizedText, out int tokenCount, bool considerPreTokenization = true, bool considerNormalization = true)
{
- tokenCount = CountTokens(null, text, addBeginningOfSentence, addEndOfSentence, considerPreTokenization, considerNormalization, out normalizedText, out int charsConsumed, maxTokenCount);
+ tokenCount = _model.CountTokens(null, text, addBeginningOfSentence, addEndOfSentence, considerNormalization, out normalizedText, out int charsConsumed, maxTokenCount);
return charsConsumed;
}
- private int GetIndexByTokenCountFromEnd(string? text, ReadOnlySpan textSpan, int maxTokenCount, bool considerNormalization, out string? normalizedText, out int tokenCount)
- => GetIndexByTokenCountFromEnd(text is null ? textSpan : text.AsSpan(), AddBeginningOfSentence, AddEndOfSentence, maxTokenCount, considerNormalization, out normalizedText, out tokenCount);
-
///
/// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
///
@@ -1230,7 +374,7 @@ private int GetIndexByTokenCountFromEnd(string? text, ReadOnlySpan textSpa
/// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the ; conversely, if all tokens fit, the result will be 0.
///
public int GetIndexByTokenCountFromEnd(string text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, bool considerNormalization, out string? normalizedText, out int tokenCount)
- => GetIndexByTokenCountFromEnd(text is null ? ReadOnlySpan.Empty : text.AsSpan(), addBeginningOfSentence, addEndOfSentence, maxTokenCount, considerNormalization, out normalizedText, out tokenCount);
+ => _model.GetIndexByTokenCountFromEnd(text, ReadOnlySpan.Empty, addBeginningOfSentence, addEndOfSentence, maxTokenCount, considerNormalization, out normalizedText, out tokenCount);
///
/// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
@@ -1247,288 +391,14 @@ public int GetIndexByTokenCountFromEnd(string text, bool addBeginningOfSentence,
/// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the ; conversely, if all tokens fit, the result will be 0.
///
public int GetIndexByTokenCountFromEnd(ReadOnlySpan text, bool addBeginningOfSentence, bool addEndOfSentence, int maxTokenCount, bool considerNormalization, out string? normalizedText, out int tokenCount)
- {
- if (maxTokenCount <= 0)
- {
- throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The max token count must be greater than 0.");
- }
-
- if (text.IsEmpty)
- {
- normalizedText = null;
- tokenCount = 0;
- return 0;
- }
-
- ReadOnlySpan textToEncode;
- if (considerNormalization && _normalizer is not null)
- {
- normalizedText = _normalizer.Normalize(text);
- textToEncode = normalizedText.AsSpan();
- }
- else
- {
- normalizedText = null;
- textToEncode = text;
- }
-
- int textIndex;
- if (_specialTokensRegex is not null)
- {
- tokenCount = CountTokensFromEndWithSpecialTokens(textToEncode, addBeginningOfSentence, addEndOfSentence, out textIndex, maxTokenCount);
- }
- else
- {
- tokenCount = CountTokensFromEnd(textToEncode, addBeginningOfSentence, addEndOfSentence, out textIndex, maxTokenCount);
- }
-
- return textIndex;
- }
-
- private int CountTokensFromEndWithSpecialTokens(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, out int textIndex, int maxTokens)
- {
- Debug.Assert(_specialTokensRegex is not null);
- Debug.Assert(maxTokens > 0);
- Debug.Assert(text.Length > 0);
-
- textIndex = text.Length;
- int idsCount = 0;
-
- if (addEndOfSentence)
- {
- idsCount++;
- }
-
- (int Offset, int Length)[] splits = PreTokenizer.SplitText(text, _specialTokensRegex!).ToArray();
-
- if (splits.Length == 0)
- {
- return CountTokensFromEnd(text, addBeginOfSentence, addEndOfSentence, out textIndex, maxTokens);
- }
-
- (int Offset, int Length) current = splits[splits.Length - 1];
-
- int splitTextIndex;
- ReadOnlySpan splitText;
-
- if (current.Offset + current.Length < text.Length)
- {
- splitText = text.Slice(current.Offset + current.Length);
- idsCount += CountTokensFromEnd(splitText, addBeginOfSentence: false, addEndOfSentence: false, out splitTextIndex, maxTokens - idsCount);
- textIndex -= splitText.Length - splitTextIndex;
- }
-
- for (int i = splits.Length - 1; i >= 0 && idsCount < maxTokens; i--)
- {
- current = splits[i];
-
- if (_specialTokens!.TryGetValue(text.Slice(current.Offset, current.Length), out int id))
- {
- idsCount++;
- }
- textIndex -= current.Length;
-
- if (current.Offset > 0 && idsCount < maxTokens)
- {
- int start = i > 0 ? splits[i - 1].Offset + splits[i - 1].Length : 0;
- splitText = text.Slice(start, current.Offset - start);
- idsCount += CountTokensFromEnd(splitText, addBeginOfSentence: false, addEndOfSentence: false, out splitTextIndex, maxTokens - idsCount);
- textIndex -= splitText.Length - splitTextIndex;
- }
- }
-
- if (addBeginOfSentence && idsCount < maxTokens)
- {
- idsCount++;
- }
-
- return idsCount;
- }
-
- ///
- /// Get the number of tokens that the input text will be encoded to.
- ///
- /// The text to encode.
- /// Indicate emitting the beginning of sentence token during the encoding.
- /// Indicate emitting the end of sentence token during the encoding.
- /// Starting from this index to the end of the text will encompasses the maximum encoded tokens.
- /// The maximum number of tokens to encode.
- /// The number of tokens that the input text will be encoded to.
- /// The input text has to be normalized before calling this method.
- private int CountTokensFromEnd(ReadOnlySpan text, bool addBeginOfSentence, bool addEndOfSentence, out int textIndex, int maxTokens = int.MaxValue)
- {
- textIndex = text.Length;
- if (text.IsEmpty)
- {
- return 0;
- }
-
- int tokenCount = addEndOfSentence ? 1 : 0;
-
- BpeSymbol[] symbols = ArrayPool.Shared.Rent(text.Length);
-
- Dictionary<(int Index, int Len), (int LeftIndex, int LeftLen, int RightIndex, int RightLen)>? revMerge = Encode(text, symbols);
-
- // Move to the last symbol.
- int lastSymbolIndex = 0;
- while (symbols[lastSymbolIndex].next != -1 && lastSymbolIndex < symbols.Length)
- {
- lastSymbolIndex = symbols[lastSymbolIndex].next;
- }
-
- for (int index = lastSymbolIndex; index >= 0; index = symbols[index].prev)
- {
- int id = symbols[index].id;
- byte type = symbols[index].type;
-
- if (id == UninitializedId)
- {
- if (_vocab.TryGetValue(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), out (int Id, float Score, byte Type) tokenInfo))
- {
- id = tokenInfo.Id;
- type = tokenInfo.Type;
- }
- else
- {
- id = UnknownId;
- type = 0;
- }
- }
-
- if (type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused)
- {
- if (id == UnknownId && ByteFallback)
- {
- if (!EncodeAsBytesFromEnd(text.Slice(symbols[index].pieceSpan.Index, symbols[index].pieceSpan.Length), symbols[index].pieceSpan.Index, ref textIndex))
- {
- break;
- }
- }
- else
- {
- if (tokenCount < maxTokens)
- {
- tokenCount++;
- textIndex -= symbols[index].pieceSpan.Length;
- }
- else
- {
- break;
- }
- }
- continue;
- }
-
- if (!SegmentFromEnd(symbols[index].pieceSpan, text, ref textIndex))
- {
- break;
- }
- }
-
- ArrayPool.Shared.Return(symbols);
-
- if (addBeginOfSentence)
- {
- if (tokenCount < maxTokens)
- {
- tokenCount++;
- }
- }
-
- return tokenCount;
-
- // Encode the Unknown token to bytes.
- bool EncodeAsBytesFromEnd(ReadOnlySpan text, int index, ref int textIndex)
- {
- for (int i = text.Length - 1; i >= 0; i--)
- {
- char c = text[i];
- if (c <= 0x7F)
- {
- if (tokenCount < maxTokens)
- {
- tokenCount++;
- textIndex--;
- }
- else
- {
- return false;
- }
- }
- else
- {
- Span utf8Bytes = stackalloc byte[100];
- byte[]? arrayPoolArray = null;
-
- int len = Encoding.UTF8.GetMaxByteCount(text.Length - i);
- if (len > utf8Bytes.Length)
- {
- arrayPoolArray = ArrayPool.Shared.Rent(len);
- utf8Bytes = arrayPoolArray;
- }
-
- // Need to convert the text into UTF-8 bytes and then encode the bytes.
- int encodedCount = Helpers.GetUtf8Bytes(text.Slice(0, i + 1), utf8Bytes);
- bool ret;
-
- if (tokenCount + encodedCount <= maxTokens)
- {
- tokenCount += encodedCount;
- textIndex -= i + 1;
- ret = true;
- }
- else
- {
- ret = false;
- }
-
- if (arrayPoolArray is not null)
- {
- ArrayPool.Shared.Return(arrayPoolArray);
- }
-
- return ret;
- }
- }
-
- return true;
- }
-
- bool SegmentFromEnd((int Index, int Length) pieceSpan, ReadOnlySpan text, ref int textIndex)
- {
- if (!_vocab.TryGetValue(text.Slice(pieceSpan.Index, pieceSpan.Length), out (int Id, float Score, byte Type) id))
- {
- return EncodeAsBytesFromEnd(text.Slice(pieceSpan.Index, pieceSpan.Length), pieceSpan.Index, ref textIndex);
- }
-
- if (id.Type != (byte)ModelProto.Types.SentencePiece.Types.Type.Unused ||
- revMerge is null ||
- !revMerge.TryGetValue((pieceSpan.Index, pieceSpan.Length), out (int LeftIndex, int LeftLen, int RightIndex, int RightLen) merge))
- {
- if (tokenCount < maxTokens)
- {
- tokenCount++;
- textIndex -= pieceSpan.Length;
- return true;
- }
- else
- {
- return false;
- }
- }
-
- // Segment the right part first.
- return SegmentFromEnd((merge.RightIndex, merge.RightLen), text, ref textIndex) && SegmentFromEnd((merge.LeftIndex, merge.LeftLen), text, ref textIndex);
- }
- }
+ => _model.GetIndexByTokenCountFromEnd(null, text, addBeginningOfSentence, addEndOfSentence, maxTokenCount, considerNormalization, out normalizedText, out tokenCount);
///
/// Decode the given ids, back to a String.
///
/// The list of ids that we want to decode.
/// The decoded string.
- public override string Decode(IEnumerable ids)
- => Decode(ids, considerSpecialTokens: false);
+ public override string Decode(IEnumerable ids) => _model.Decode(ids, considerSpecialTokens: false);
///
/// Decode the given ids, back to a String.
@@ -1536,231 +406,7 @@ public override string Decode(IEnumerable ids)
/// The list of ids that we want to decode.
/// Indicate whether to consider special tokens during decoding.
/// The decoded string.
- public string Decode(IEnumerable ids, bool considerSpecialTokens)
- {
- if (ids is null)
- {
- throw new ArgumentNullException(nameof(ids));
- }
-
- using IEnumerator enumerator = ids.GetEnumerator();
- if (!enumerator.MoveNext())
- {
- return string.Empty;
- }
-
- ValueStringBuilder sb = new(stackalloc char[256]);
-
- int bytesCount = -1;
- byte[]? bytesPoolArray = null;
- bool prefixRemoved = false;
- int suffixIndex = -1;
- char prefixSuffixChar = EscapeWhiteSpaces ? SentencePieceNormalizer.DummyPrefix : ' ';
-
- if (enumerator.Current <= _maxByteId)
- {
- // First token is a byte token.
-
- while (enumerator.Current < _byteCodeToIdOffset)
- {
- // It is possible listing some special tokens before the byte tokens in the tokenizer's data.
- TryDecodeAsSpecialToken(this, enumerator.Current, considerSpecialTokens, ref sb);
-
- // Skip control tokens.
- if (!enumerator.MoveNext())
- {
- return sb.ToString();
- }
- }
-
- if (enumerator.Current <= _maxByteId)
- {
- EncodeByte(enumerator.Current, _oneByteUtf8EncodingMaxId, _byteCodeToIdOffset, ref bytesCount, ref bytesPoolArray, ref sb);
- }
- else if (_vocabReverse.TryGetValue(enumerator.Current, out string? token))
- {
- AppendTokenWithCheckingPrefix(AddDummyPrefix, TreatWhitespaceAsSuffix, token, prefixSuffixChar, ref sb, ref prefixRemoved, ref suffixIndex);
- }
- else
- {
- TryDecodeAsSpecialToken(this, enumerator.Current, considerSpecialTokens, ref sb);
- }
- }
- else if (_vocabReverse.TryGetValue(enumerator.Current, out string? token))
- {
- AppendTokenWithCheckingPrefix(AddDummyPrefix, TreatWhitespaceAsSuffix, token, prefixSuffixChar, ref sb, ref prefixRemoved, ref suffixIndex);
- }
- else
- {
- TryDecodeAsSpecialToken(this, enumerator.Current, considerSpecialTokens, ref sb);
- }
-
- char[]? charPoolArray = null;
-
- while (enumerator.MoveNext())
- {
- if (enumerator.Current < _byteCodeToIdOffset)
- {
- if (bytesCount >= 1)
- {
- FlushBytes(ref bytesCount, ref bytesPoolArray, ref charPoolArray, ref sb);
- }
-
- // It is possible listing some special tokens before the byte tokens in the tokenizer's data.
- TryDecodeAsSpecialToken(this, enumerator.Current, considerSpecialTokens, ref sb);
-
- continue;
- }
-
- if (enumerator.Current <= _maxByteId)
- {
- if (bytesCount >= 1)
- {
- Debug.Assert(bytesPoolArray is not null);
-
- if (bytesCount >= bytesPoolArray!.Length)
- {
- Helpers.ArrayPoolGrow(ref bytesPoolArray, bytesCount * 2);
- }
-
- bytesPoolArray![bytesCount++] = (byte)(enumerator.Current - _byteCodeToIdOffset);
- }
- else
- {
- EncodeByte(enumerator.Current, _oneByteUtf8EncodingMaxId, _byteCodeToIdOffset, ref bytesCount, ref bytesPoolArray, ref sb);
- }
- }
- else
- {
- if (bytesCount >= 1)
- {
- FlushBytes(ref bytesCount, ref bytesPoolArray, ref charPoolArray, ref sb);
- }
-
- if (_vocabReverse.TryGetValue(enumerator.Current, out string? token))
- {
- AppendTokenWithCheckingPrefix(AddDummyPrefix, TreatWhitespaceAsSuffix, token, prefixSuffixChar, ref sb, ref prefixRemoved, ref suffixIndex);
- }
- else
- {
- TryDecodeAsSpecialToken(this, enumerator.Current, considerSpecialTokens, ref sb);
- }
- }
- }
-
- if (bytesCount >= 1)
- {
- FlushBytes(ref bytesCount, ref bytesPoolArray, ref charPoolArray, ref sb);
- }
-
- if (AddDummyPrefix && TreatWhitespaceAsSuffix && suffixIndex >= 0 && sb.Length > 0)
- {
- Debug.Assert(sb[suffixIndex] == SentencePieceNormalizer.DummyPrefix);
- Debug.Assert(sb.Length > suffixIndex);
-
- sb.Remove(suffixIndex, 1);
- }
-
- if (bytesPoolArray is not null)
- {
- ArrayPool.Shared.Return(bytesPoolArray);
- }
-
- if (charPoolArray is not null)
- {
- ArrayPool.Shared.Return(charPoolArray);
- }
-
- return EscapeWhiteSpaces ? sb.ToString(SentencePieceNormalizer.DummyPrefix, ' ') : sb.ToString();
-
- static void FlushBytes(ref int bytesCount, ref byte[]? bytesPoolArray, ref char[]? charPoolArray, ref ValueStringBuilder sb)
- {
- Debug.Assert(bytesCount >= 1);
- Debug.Assert(bytesPoolArray is not null);
-
- int len = Encoding.UTF8.GetMaxCharCount(bytesCount);
-
- charPoolArray ??= ArrayPool.Shared.Rent(Math.Max(len, 50));
-
- if (len > charPoolArray.Length)
- {
- Helpers.ArrayPoolGrow(ref charPoolArray, len);
- }
-
- int charCount = Helpers.GetChars(bytesPoolArray.AsSpan(0, bytesCount), charPoolArray);
-
- sb.Append(charPoolArray.AsSpan(0, charCount));
- bytesCount = -1;
- }
-
- static void EncodeByte(int id, int oneByteUtf8EncodingMaxId, int byteCodeToIdOffset, ref int bytesCount, ref byte[]? bytesPoolArray, ref ValueStringBuilder sb)
- {
- if (id <= oneByteUtf8EncodingMaxId)
- {
- sb.Append((char)(id - byteCodeToIdOffset));
- }
- else
- {
- bytesCount = 1;
- bytesPoolArray ??= ArrayPool.Shared.Rent(50);
- bytesPoolArray[0] = (byte)(id - byteCodeToIdOffset);
- }
- }
-
- static void AppendTokenWithCheckingPrefix(bool addDummyPrefix, bool treatWhitespaceAsSuffix, string token, char prefixSuffixChar, ref ValueStringBuilder sb, ref bool prefixRemoved, ref int suffixIndex)
- {
- if (token.Length == 0)
- {
- return;
- }
-
- if (!addDummyPrefix)
- {
- sb.Append(token);
- return;
- }
-
- if (treatWhitespaceAsSuffix)
- {
- sb.Append(token);
- if (token[token.Length - 1] == prefixSuffixChar)
- {
- suffixIndex = sb.Length - 1;
- }
- }
- else
- {
- sb.Append(!prefixRemoved && token[0] == prefixSuffixChar ? token.AsSpan(1) : token.AsSpan());
- }
-
- prefixRemoved = true;
- }
-
- static void TryDecodeAsSpecialToken(SentencePieceTokenizer tokenizer, int id, bool considerSpecialTokens, ref ValueStringBuilder sb)
- {
- if (!considerSpecialTokens)
- {
- return;
- }
-
- if (id == tokenizer.BeginningOfSentenceId)
- {
- sb.Append(tokenizer.BeginningOfSentenceToken);
- }
- else if (id == tokenizer.EndOfSentenceId)
- {
- sb.Append(tokenizer.EndOfSentenceToken);
- }
- else if (id == tokenizer.UnknownId)
- {
- sb.Append(tokenizer.UnknownToken);
- }
- else if (tokenizer._specialTokensReverse?.TryGetValue(id, out string? specialToken) is true)
- {
- sb.Append(specialToken);
- }
- }
- }
+ public string Decode(IEnumerable ids, bool considerSpecialTokens) => _model.Decode(ids, considerSpecialTokens);
///
/// Decode the given ids back to text and store the result in the span.
@@ -1771,7 +417,7 @@ static void TryDecodeAsSpecialToken(SentencePieceTokenizer tokenizer, int id, bo
/// The number of characters written to the destination span.
/// The operation status indicates whether all IDs were successfully decoded or if the is too small to contain the entire decoded result.
public override OperationStatus Decode(IEnumerable ids, Span destination, out int idsConsumed, out int charsWritten)
- => Decode(ids, destination, considerSpecialTokens: false, out idsConsumed, out charsWritten);
+ => _model.Decode(ids, destination, considerSpecialTokens: false, out idsConsumed, out charsWritten);
///