From ca18ba750142f18d3bc163cdbd914a0180b620d0 Mon Sep 17 00:00:00 2001 From: Tarek Mahmoud Sayed Date: Mon, 4 Mar 2024 14:58:40 -0800 Subject: [PATCH 1/7] Adding needed Tokenizer's APIs --- src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs | 52 ++++++- src/Microsoft.ML.Tokenizers/Tokenizer.cs | 142 +++++++++++++++--- .../Microsoft.ML.Tokenizers.Tests/BpeTests.cs | 2 + .../EnglishRobertaTests.cs | 1 + .../TitokenTests.cs | 19 ++- .../TokenizerTests.cs | 67 +++++++++ 6 files changed, 252 insertions(+), 31 deletions(-) create mode 100644 test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs diff --git a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs index 60e9282a81..feb6e5c39e 100644 --- a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs +++ b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs @@ -9,6 +9,7 @@ using System.IO; using System.Linq; using System.Text; +using System.Text.RegularExpressions; using System.Threading; using System.Threading.Tasks; @@ -100,6 +101,43 @@ private Tiktoken(Stream vocabStream, IReadOnlyDictionary? specialTo } } + /// + /// Create a Tiktoken tokenizer based on model name and vocab file. + /// + /// Model name + /// The stream to the BPE vocab file. + /// Extra special tokens other than the built-in ones for the model + /// The size of the cache to use. + /// To normalize the text before tokenization + /// The tokenizer + public static Tokenizer CreateByModelName( + string modelName, + Stream vocabStream, + IReadOnlyDictionary? extraSpecialTokens = null, + int cacheSize = LruCache.DefaultCacheSize, + Normalizer? normalizer = null) + { + if (string.IsNullOrEmpty(modelName)) + { + throw new ArgumentNullException(nameof(modelName)); + } + + (Dictionary SpecialTokens, Regex Regex) tiktokenConfiguration = Tokenizer.GetTiktokenConfigurations(modelName); + + if (extraSpecialTokens is not null) + { + foreach (var extraSpecialToken in extraSpecialTokens) + { + tiktokenConfiguration.SpecialTokens.Add(extraSpecialToken.Key, extraSpecialToken.Value); + } + } + + return new Tokenizer( + new Tiktoken(vocabStream, tiktokenConfiguration.SpecialTokens, cacheSize), + new TikTokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens), + normalizer); + } + private static (Dictionary?, Dictionary?) CreateEncoderDecoder(IReadOnlyDictionary? specialTokens) { if (specialTokens is not null) @@ -233,7 +271,7 @@ await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) : /// The text to encode. /// Indicate if the token is a special token. /// The list of tokens generated from the text tokenization. - public override IReadOnlyList Encode(string text, bool isSpecialToken) + public override IReadOnlyList Encode(string text, bool isSpecialToken = false) { Token[] tokens; @@ -462,12 +500,14 @@ public override int CountTokens(ReadOnlySpan text, bool isSpecialToken) /// The decoded string. public override string? Decode(IEnumerable ids, TokenizerDecoder? decoder = null, bool considerSpecialTokens = true) { - // Tiktoken does not ensure a one-to-one mapping between IDs and tokens. Consequently, decoding individual IDs into tokens is not supported; - // instead, decoding all IDs must be done collectively. - // Here is example of case that map one character to multiple Ids: - // '⭐' U-2B50 is mapped to Ids [2928, 99834] in the Tiktoken model. - // In other words, the character '⭐' has UTF-8 code point 0xE2, 0xAD, 0x90, Tiktoken will map 0xE2 to [2928] and 0xAD, 0x90 to [99834]. + // Tiktoken doesn't guarantee a one-to-one correspondence between IDs and UTF-16 words. + // Consequently, decoding individual IDs into UTF-16 string is not supported; instead, decoding all IDs must be performed collectively. + // Here's an example case that maps one character to multiple IDs: + // '⭐' U-2B50 is mapped to IDs [2928, 99834] in the Tiktoken model. + // In other words, the character '⭐' with UTF-8 code point 0xE2, 0xAD, 0x90 will be mapped by Tiktoken as follows: 0xE2 to [2928] + // and 0xAD, 0x90 to [99834]. Decoding 2928 and 99834 individually won't reconstruct the original UTF-16 string '⭐' U-2B50; + // decoding all IDs together is required to get the expected result. if (ids is null) { return null; diff --git a/src/Microsoft.ML.Tokenizers/Tokenizer.cs b/src/Microsoft.ML.Tokenizers/Tokenizer.cs index c64ebf256e..e52dfef879 100644 --- a/src/Microsoft.ML.Tokenizers/Tokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Tokenizer.cs @@ -7,6 +7,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.IO; +using System.Linq; using System.Net.Http; using System.Text.RegularExpressions; using System.Threading; @@ -136,6 +137,76 @@ public int CountTokens(string text, bool considerSpecialTokens = true) return idsCount; } + /// + /// Find the maximum encoding capacity within the input text without surpassing the token limit. + /// + /// The text to encode. + /// The maximum token count to limit the encoding capacity. + /// Indicate if want to trim from the start of the text. + /// Indicate if want to consider the special tokens during the encoding. + /// + /// The entire normalized text, the starting offset within the returned text for token counting, the length of text constrained by the maximum token count, + /// and the token count can be generated using the provided subtext offset and length. + /// + /// The input text is null. + /// The maximum token count must be greater than 0. + /// + /// If the tokenizer has a normalizer, the returned text will be the normalized text. Otherwise the returned text will be the input text. + /// If is true, the returned offset will be 0. Otherwise the returned offset will be the starting index of the subtext. + /// If the provided is greater than the token count of the input text, the returned length will be the length of the input text. + /// If the provided is smaller enough to hold smallest number of grouped Ids, the returned length will be 0 and returned TokenCount will be 0. + /// + public (string Text, int Offset, int Length, int TokenCount) TrimWithinTokenLimit(string text, int maxTokenCount, bool fromStart = true, bool considerSpecialTokens = true) + { + if (text is null) + { + throw new ArgumentNullException(nameof(text)); + } + + if (maxTokenCount <= 0) + { + throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The max token count must be greater than 0."); + } + + string normalized = Normalizer is not null ? Normalizer.Normalize(text) : text; + int idsCount = 0; + + if (fromStart) + { + foreach (Split split in PreTokenizer.PreTokenize(normalized, considerSpecialTokens)) + { + int tokenCount = Model.CountTokens(split.TokenSpan, split.IsSpecialToken); + + if (tokenCount + idsCount > maxTokenCount) + { + return (normalized, 0, split.Offset.Index, idsCount); + } + + idsCount += tokenCount; + } + + return (normalized, 0, normalized.Length, idsCount); + } + + // from end + Split[] splits = PreTokenizer.PreTokenize(normalized, considerSpecialTokens).ToArray(); + + for (int i = splits.Length - 1; i >= 0; i--) + { + Split split = splits[i]; + int tokenCount = Model.CountTokens(split.TokenSpan, split.IsSpecialToken); + + if (tokenCount + idsCount > maxTokenCount) + { + return (normalized, split.Offset.Index + split.Offset.Length, normalized.Length - split.Offset.Index - split.Offset.Length, idsCount); + } + + idsCount += tokenCount; + } + + return (normalized, 0, normalized.Length, idsCount); + } + /// /// Decodes the Id to the mapped token. /// @@ -230,6 +301,56 @@ private static readonly (string Prefix, ModelEncoding Encoding)[] _modelPrefixTo { "gpt2", ModelEncoding.GPT2 } }; + private static ModelEncoding GetModelEncoding(string modelName) + { + if (!_modelToEncoding.TryGetValue(modelName, out ModelEncoding encoder)) + { + foreach ((string Prefix, ModelEncoding Encoding) in _modelPrefixToEncoding) + { + if (modelName.StartsWith(Prefix, StringComparison.OrdinalIgnoreCase)) + { + encoder = Encoding; + break; + } + } + } + + if (encoder == ModelEncoding.None) + { + throw new NotImplementedException($"Doesn't support this model [{modelName}]"); + } + + return encoder; + } + + internal static (Dictionary SpecialTokens, Regex Regex) GetTiktokenConfigurations(string modelName) + { + ModelEncoding modelEncoding = GetModelEncoding(modelName); + + switch (modelEncoding) + { + case ModelEncoding.Cl100kBase: + return (new Dictionary + { { EndOfText, 100257}, { FimPrefix, 100258}, { FimMiddle, 100259}, { FimSuffix, 100260}, { EndOfPrompt, 100276} }, Cl100kBaseRegex()); + + case ModelEncoding.P50kBase: + return (new Dictionary { { EndOfText, 50256 } }, P50kBaseRegex()); + + case ModelEncoding.P50kEdit: + return (new Dictionary + { { EndOfText, 50256 }, { FimPrefix, 50281 }, { FimMiddle, 50282 }, { FimSuffix, 50283 } }, P50kBaseRegex()); + + case ModelEncoding.R50kBase: + return (new Dictionary { { EndOfText, 50256 } }, P50kBaseRegex()); + + case ModelEncoding.GPT2: + return (new Dictionary { { EndOfText, 50256 }, }, P50kBaseRegex()); + + default: + Debug.Assert(false, $"Unexpected encoder [{modelEncoding}]"); + throw new NotImplementedException($"Doesn't support model '{modelName}'"); + } + } /// /// Create tokenizer based on model name @@ -247,26 +368,7 @@ public static Task CreateByModelNameAsync( { try { - ModelEncoding encoder; - - if (!_modelToEncoding.TryGetValue(modelName, out encoder)) - { - foreach ((string Prefix, ModelEncoding Encoding) in _modelPrefixToEncoding) - { - if (modelName.StartsWith(Prefix, StringComparison.OrdinalIgnoreCase)) - { - encoder = Encoding; - break; - } - } - } - - if (encoder == ModelEncoding.None) - { - throw new NotImplementedException($"Doesn't support this model [{modelName}]"); - } - - return CreateByEncoderNameAsync(encoder, extraSpecialTokens, normalizer, cancellationToken); + return CreateByEncoderNameAsync(GetModelEncoding(modelName), extraSpecialTokens, normalizer, cancellationToken); } catch (Exception ex) { diff --git a/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs b/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs index 2959184b5d..9ae8517494 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs @@ -225,6 +225,8 @@ public async void TestGpt2Vocab() Assert.Equal(12, encoding.Ids.Count); Assert.Equal(encoding.Ids, ids); Assert.Equal(12, tokenizer.CountTokens(text)); + + TokenizerTests.TestTokenLimits(tokenizer); } private static string WriteToMergeFile((string, string)[] mergeEntries) diff --git a/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs b/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs index ccf0e66ef9..4518f169bf 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs @@ -96,6 +96,7 @@ public async void TokenizationTest() Tokenizer tokenizer = new Tokenizer(new EnglishRoberta(vocabFile, mergeFile, translationFile), RobertaPreTokenizer.Instance); TestTokenizer(tokenizer); + TokenizerTests.TestTokenLimits(tokenizer); tokenizer = new Tokenizer(new EnglishRoberta(vocabFile, mergeFile, translationFile, filterUnsupportedChars: false), RobertaPreTokenizer.Instance); TestTokenizer(tokenizer); diff --git a/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs b/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs index af6581401a..9362f767ac 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs @@ -61,6 +61,12 @@ public async void TestTokenizerCreation() tokenizer = new Tokenizer(await Tiktoken.CreateAsync(stream, specialTokensEncoder), GPT4.PreTokenizer); } TestGPT4TokenizationEncoding(tokenizer); + + using (Stream stream = File.OpenRead(tokenizerDataFileName)) + { + tokenizer = Tiktoken.CreateByModelName("gpt-4", stream); + } + TestGPT4TokenizationEncoding(tokenizer); } finally { @@ -82,6 +88,8 @@ private void TestGPT4TokenizationEncoding(Tokenizer tokenizer) Assert.Equal(new List<(int, int)> { (0, 5), (5, 6) }, result.Offsets); Assert.Equal(encoded.Count, idsCount); Assert.Equal(encoded, result.Ids); + + TestGPT4Tokenizer(tokenizer); } [Fact] @@ -101,13 +109,12 @@ public void TestEncode1() Assert.Equal(encoded, result.Ids); } - [Fact] - public void TestEncode2() + private void TestGPT4Tokenizer(Tokenizer gpt4Tokenizer) { string text = ReadAndSanitizeFile("./Data/lib.rs.txt"); - IReadOnlyList encoded = GPT4.EncodeToIds(text, considerSpecialTokens: false); + IReadOnlyList encoded = gpt4Tokenizer.EncodeToIds(text, considerSpecialTokens: false); Assert.Equal(5584, encoded.Count); - int idsCount = GPT4.CountTokens(text, considerSpecialTokens: false); + int idsCount = gpt4Tokenizer.CountTokens(text, considerSpecialTokens: false); Assert.Equal(encoded.Count, idsCount); using (Stream stream = File.OpenRead("./Data/tokens.json")) @@ -116,8 +123,10 @@ public void TestEncode2() Assert.Equal(expected!, encoded.ToArray()); } - string? decoded = GPT4.Decode(encoded.ToArray()); + string? decoded = gpt4Tokenizer.Decode(encoded.ToArray()); Assert.Equal(text, decoded!); + + TokenizerTests.TestTokenLimits(gpt4Tokenizer); } [Fact] diff --git a/test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs b/test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs new file mode 100644 index 0000000000..88cd57bc45 --- /dev/null +++ b/test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs @@ -0,0 +1,67 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Tokenizers; +using System; +using System.Collections.Generic; +using System.Linq; +using Xunit; + +namespace Microsoft.ML.Tokenizers.Tests +{ + public class TokenizerTests + { + internal static void TestTokenLimits(Tokenizer tokenizer) + { + string input = @" + OpenAI's large language models (sometimes referred to as GPT's) process text using tokens, which are common sequences of characters found in a set of text. + The models learn to understand the statistical relationships between these tokens, and excel at producing the next token in a sequence of tokens. + You can use the tool below to understand how a piece of text might be tokenized by a language model, and the total count of tokens in that piece of text. + It's important to note that the exact tokenization process varies between models. Newer models like GPT-3.5 and GPT-4 use a different tokenizer than previous models, + and will produce different tokens for the same input text. + "; + + IReadOnlyList fullIdsList = tokenizer.EncodeToIds(input); + + for (int i = 1; i <= fullIdsList.Count; i++) + { + (string Text, int Offset, int Length, int TokenCount) result1 = tokenizer.TrimWithinTokenLimit(input, maxTokenCount: i, fromStart: true); + (string Text, int Offset, int Length, int TokenCount) result2 = tokenizer.TrimWithinTokenLimit(input, maxTokenCount: i, fromStart: false); + + IReadOnlyList? prefixIds = null; + IReadOnlyList? suffixIds = null; + + if (result1.TokenCount > 0) + { + Assert.Equal(0, result1.Offset); + string prefixString = result1.Text.Substring(result1.Offset, result1.Length); + prefixIds = tokenizer.EncodeToIds(prefixString); + Assert.Equal(result1.TokenCount, prefixIds.Count); + Assert.Equal(prefixIds, fullIdsList.Take(prefixIds.Count)); + } + + if (result2.TokenCount > 0) + { + Assert.Equal(result2.Text.Length, result2.Offset + result2.Length); + string suffixString = result2.Text.Substring(result2.Offset, result2.Length); + suffixIds = tokenizer.EncodeToIds(suffixString); + Assert.Equal(result2.TokenCount, suffixIds.Count); + Assert.Equal(suffixIds, fullIdsList.Skip(fullIdsList.Count - suffixIds.Count)); + } + + if (i == fullIdsList.Count) + { + Assert.Equal(result1.Text.Length, result1.Length); + Assert.Equal(result2.Text.Length, result2.Length); + Assert.Equal(fullIdsList, prefixIds); + Assert.Equal(fullIdsList, suffixIds); + } + } + + Assert.Throws(() => tokenizer.TrimWithinTokenLimit(input, maxTokenCount: 0, fromStart: true)); + Assert.Throws(() => tokenizer.TrimWithinTokenLimit(input, maxTokenCount: -1, fromStart: true)); + Assert.Throws(() => tokenizer.TrimWithinTokenLimit(null!, maxTokenCount: 0, fromStart: false)); + } + } +} \ No newline at end of file From 4832f9a3dc740f747c4708d22eab9322f0b4b9c7 Mon Sep 17 00:00:00 2001 From: Tarek Mahmoud Sayed Date: Mon, 4 Mar 2024 19:16:51 -0800 Subject: [PATCH 2/7] Address the feedback --- src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs | 40 +++++++++++ src/Microsoft.ML.Tokenizers/Tokenizer.cs | 72 ++++++++++--------- .../TitokenTests.cs | 6 ++ .../TokenizerTests.cs | 14 ++-- 4 files changed, 93 insertions(+), 39 deletions(-) diff --git a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs index feb6e5c39e..58637b97af 100644 --- a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs +++ b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs @@ -138,6 +138,46 @@ public static Tokenizer CreateByModelName( normalizer); } + /// + /// Create a Tiktoken tokenizer based on model name and vocab file. + /// + /// Model name + /// The stream to the BPE vocab file. + /// Extra special tokens other than the built-in ones for the model + /// The size of the cache to use. + /// To normalize the text before tokenization + /// used to request cancellation of the operation. + /// The tokenizer + public static async Task CreateByModelNameAsync( + string modelName, + Stream vocabStream, + IReadOnlyDictionary? extraSpecialTokens = null, + int cacheSize = LruCache.DefaultCacheSize, + Normalizer? normalizer = null, + CancellationToken cancellationToken = default) + { + if (string.IsNullOrEmpty(modelName)) + { + throw new ArgumentNullException(nameof(modelName)); + } + + (Dictionary SpecialTokens, Regex Regex) tiktokenConfiguration = Tokenizer.GetTiktokenConfigurations(modelName); + + if (extraSpecialTokens is not null) + { + foreach (var extraSpecialToken in extraSpecialTokens) + { + tiktokenConfiguration.SpecialTokens.Add(extraSpecialToken.Key, extraSpecialToken.Value); + } + } + + return new Tokenizer( + await CreateAsync(vocabStream, tiktokenConfiguration.SpecialTokens, cacheSize, cancellationToken).ConfigureAwait(false), + new TikTokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens), + normalizer); + } + + private static (Dictionary?, Dictionary?) CreateEncoderDecoder(IReadOnlyDictionary? specialTokens) { if (specialTokens is not null) diff --git a/src/Microsoft.ML.Tokenizers/Tokenizer.cs b/src/Microsoft.ML.Tokenizers/Tokenizer.cs index e52dfef879..00f688b704 100644 --- a/src/Microsoft.ML.Tokenizers/Tokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Tokenizer.cs @@ -138,11 +138,10 @@ public int CountTokens(string text, bool considerSpecialTokens = true) } /// - /// Find the maximum encoding capacity within the input text without surpassing the token limit. + /// Find the maximum encoding capacity from beginning within the input text without surpassing the token limit. /// /// The text to encode. /// The maximum token count to limit the encoding capacity. - /// Indicate if want to trim from the start of the text. /// Indicate if want to consider the special tokens during the encoding. /// /// The entire normalized text, the starting offset within the returned text for token counting, the length of text constrained by the maximum token count, @@ -152,11 +151,33 @@ public int CountTokens(string text, bool considerSpecialTokens = true) /// The maximum token count must be greater than 0. /// /// If the tokenizer has a normalizer, the returned text will be the normalized text. Otherwise the returned text will be the input text. - /// If is true, the returned offset will be 0. Otherwise the returned offset will be the starting index of the subtext. /// If the provided is greater than the token count of the input text, the returned length will be the length of the input text. /// If the provided is smaller enough to hold smallest number of grouped Ids, the returned length will be 0 and returned TokenCount will be 0. /// - public (string Text, int Offset, int Length, int TokenCount) TrimWithinTokenLimit(string text, int maxTokenCount, bool fromStart = true, bool considerSpecialTokens = true) + public (string Text, int Offset, int Length, int TokenCount) TrimSuffixWithinTokenLimit(string text, int maxTokenCount, bool considerSpecialTokens = true) => + TrimWithinTokenLimit(text, maxTokenCount, trimSuffix: true, considerSpecialTokens); + + /// + /// Find the maximum encoding capacity from the end within the input text without surpassing the token limit. + /// + /// The text to encode. + /// The maximum token count to limit the encoding capacity. + /// Indicate if want to consider the special tokens during the encoding. + /// + /// The entire normalized text, the starting offset within the returned text for token counting, the length of text constrained by the maximum token count, + /// and the token count can be generated using the provided subtext offset and length. + /// + /// The input text is null. + /// The maximum token count must be greater than 0. + /// + /// If the tokenizer has a normalizer, the returned text will be the normalized text. Otherwise the returned text will be the input text. + /// If the provided is greater than the token count of the input text, the returned length will be the length of the input text. + /// If the provided is smaller enough to hold smallest number of grouped Ids, the returned length will be 0 and returned TokenCount will be 0. + /// + public (string Text, int Offset, int Length, int TokenCount) TrimPrefixWithinTokenLimit(string text, int maxTokenCount, bool considerSpecialTokens = true) => + TrimWithinTokenLimit(text, maxTokenCount, trimSuffix: false, considerSpecialTokens); + + private (string Text, int Offset, int Length, int TokenCount) TrimWithinTokenLimit(string text, int maxTokenCount, bool trimSuffix = true, bool considerSpecialTokens = true) { if (text is null) { @@ -171,34 +192,15 @@ public int CountTokens(string text, bool considerSpecialTokens = true) string normalized = Normalizer is not null ? Normalizer.Normalize(text) : text; int idsCount = 0; - if (fromStart) - { - foreach (Split split in PreTokenizer.PreTokenize(normalized, considerSpecialTokens)) - { - int tokenCount = Model.CountTokens(split.TokenSpan, split.IsSpecialToken); - - if (tokenCount + idsCount > maxTokenCount) - { - return (normalized, 0, split.Offset.Index, idsCount); - } - - idsCount += tokenCount; - } - - return (normalized, 0, normalized.Length, idsCount); - } - - // from end - Split[] splits = PreTokenizer.PreTokenize(normalized, considerSpecialTokens).ToArray(); - - for (int i = splits.Length - 1; i >= 0; i--) + IEnumerable splits = PreTokenizer.PreTokenize(normalized, considerSpecialTokens); + foreach (Split split in (trimSuffix ? splits : splits.Reverse())) { - Split split = splits[i]; int tokenCount = Model.CountTokens(split.TokenSpan, split.IsSpecialToken); - - if (tokenCount + idsCount > maxTokenCount) + if (tokenCount > maxTokenCount - idsCount) { - return (normalized, split.Offset.Index + split.Offset.Length, normalized.Length - split.Offset.Index - split.Offset.Length, idsCount); + return trimSuffix ? + (normalized, 0, split.Offset.Index, idsCount) : + (normalized, split.Offset.Index + split.Offset.Length, normalized.Length - split.Offset.Index - split.Offset.Length, idsCount); } idsCount += tokenCount; @@ -317,7 +319,7 @@ private static ModelEncoding GetModelEncoding(string modelName) if (encoder == ModelEncoding.None) { - throw new NotImplementedException($"Doesn't support this model [{modelName}]"); + throw new NotSupportedException($"The model '{modelName}' is not supported."); } return encoder; @@ -348,7 +350,7 @@ internal static (Dictionary SpecialTokens, Regex Regex) GetTiktoken default: Debug.Assert(false, $"Unexpected encoder [{modelEncoding}]"); - throw new NotImplementedException($"Doesn't support model '{modelName}'"); + throw new NotSupportedException($"The model '{modelName}' is not supported."); } } @@ -368,7 +370,7 @@ public static Task CreateByModelNameAsync( { try { - return CreateByEncoderNameAsync(GetModelEncoding(modelName), extraSpecialTokens, normalizer, cancellationToken); + return CreateByEncoderNameAsync(modelName, GetModelEncoding(modelName), extraSpecialTokens, normalizer, cancellationToken); } catch (Exception ex) { @@ -403,13 +405,15 @@ public static Task CreateByModelNameAsync( /// /// Create tokenizer based on encoder name and extra special tokens /// + /// Model name /// Encoder label /// Extra special tokens other than the built-in ones for the encoder /// To normalize the text before tokenization /// used to request cancellation of the operation. /// The tokenizer - /// Throws if the encoder is not supported + /// Throws if the model name is not supported private static Task CreateByEncoderNameAsync( + string modelName, ModelEncoding modelEncoding, IReadOnlyDictionary? extraSpecialTokens, Normalizer? normalizer, @@ -441,7 +445,7 @@ private static Task CreateByEncoderNameAsync( default: Debug.Assert(false, $"Unexpected encoder [{modelEncoding}]"); - throw new NotImplementedException($"Doesn't support this encoder [{modelEncoding}]"); + throw new NotSupportedException($"The model '{modelName}' is not supported."); } } diff --git a/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs b/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs index 9362f767ac..e73766818f 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs @@ -67,6 +67,12 @@ public async void TestTokenizerCreation() tokenizer = Tiktoken.CreateByModelName("gpt-4", stream); } TestGPT4TokenizationEncoding(tokenizer); + + using (Stream stream = File.OpenRead(tokenizerDataFileName)) + { + tokenizer = await Tiktoken.CreateByModelNameAsync("gpt-3.5-turbo", stream); + } + TestGPT4TokenizationEncoding(tokenizer); } finally { diff --git a/test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs b/test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs index 88cd57bc45..8fb9f10508 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs @@ -26,8 +26,8 @@ internal static void TestTokenLimits(Tokenizer tokenizer) for (int i = 1; i <= fullIdsList.Count; i++) { - (string Text, int Offset, int Length, int TokenCount) result1 = tokenizer.TrimWithinTokenLimit(input, maxTokenCount: i, fromStart: true); - (string Text, int Offset, int Length, int TokenCount) result2 = tokenizer.TrimWithinTokenLimit(input, maxTokenCount: i, fromStart: false); + (string Text, int Offset, int Length, int TokenCount) result1 = tokenizer.TrimSuffixWithinTokenLimit(input, maxTokenCount: i); + (string Text, int Offset, int Length, int TokenCount) result2 = tokenizer.TrimPrefixWithinTokenLimit(input, maxTokenCount: i); IReadOnlyList? prefixIds = null; IReadOnlyList? suffixIds = null; @@ -59,9 +59,13 @@ internal static void TestTokenLimits(Tokenizer tokenizer) } } - Assert.Throws(() => tokenizer.TrimWithinTokenLimit(input, maxTokenCount: 0, fromStart: true)); - Assert.Throws(() => tokenizer.TrimWithinTokenLimit(input, maxTokenCount: -1, fromStart: true)); - Assert.Throws(() => tokenizer.TrimWithinTokenLimit(null!, maxTokenCount: 0, fromStart: false)); + Assert.Throws(() => tokenizer.TrimSuffixWithinTokenLimit(input, maxTokenCount: 0)); + Assert.Throws(() => tokenizer.TrimSuffixWithinTokenLimit(input, maxTokenCount: -1)); + Assert.Throws(() => tokenizer.TrimPrefixWithinTokenLimit(input, maxTokenCount: 0)); + Assert.Throws(() => tokenizer.TrimPrefixWithinTokenLimit(input, maxTokenCount: -1)); + + Assert.Throws(() => tokenizer.TrimSuffixWithinTokenLimit(null!, maxTokenCount: 0)); + Assert.Throws(() => tokenizer.TrimPrefixWithinTokenLimit(null!, maxTokenCount: 0)); } } } \ No newline at end of file From d9c1524f0df855e8ea276a052d4dd5cd42fa4bec Mon Sep 17 00:00:00 2001 From: Tarek Mahmoud Sayed Date: Tue, 5 Mar 2024 09:30:46 -0800 Subject: [PATCH 3/7] Small update to the newly exposed APIs --- src/Microsoft.ML.Tokenizers/Tokenizer.cs | 28 ++++++++++++------- .../TokenizerTests.cs | 12 ++++---- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/src/Microsoft.ML.Tokenizers/Tokenizer.cs b/src/Microsoft.ML.Tokenizers/Tokenizer.cs index 00f688b704..09a962a721 100644 --- a/src/Microsoft.ML.Tokenizers/Tokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Tokenizer.cs @@ -144,8 +144,9 @@ public int CountTokens(string text, bool considerSpecialTokens = true) /// The maximum token count to limit the encoding capacity. /// Indicate if want to consider the special tokens during the encoding. /// - /// The entire normalized text, the starting offset within the returned text for token counting, the length of text constrained by the maximum token count, - /// and the token count can be generated using the provided subtext offset and length. + /// - The entire normalized text. + /// - The length of text from the beginning of the normalized which is limited by the maximum token count + /// - The token count can be generated using the provided length which should be smaller than the maximum token count. /// /// The input text is null. /// The maximum token count must be greater than 0. @@ -154,8 +155,11 @@ public int CountTokens(string text, bool considerSpecialTokens = true) /// If the provided is greater than the token count of the input text, the returned length will be the length of the input text. /// If the provided is smaller enough to hold smallest number of grouped Ids, the returned length will be 0 and returned TokenCount will be 0. /// - public (string Text, int Offset, int Length, int TokenCount) TrimSuffixWithinTokenLimit(string text, int maxTokenCount, bool considerSpecialTokens = true) => - TrimWithinTokenLimit(text, maxTokenCount, trimSuffix: true, considerSpecialTokens); + public (string Text, int Length, int TokenCount) TrimSuffixWithinTokenLimit(string text, int maxTokenCount, bool considerSpecialTokens = true) + { + (string Text, int Offset, int Length, int TokenCount) result = TrimWithinTokenLimit(text, maxTokenCount, trimSuffix: true, considerSpecialTokens); + return (result.Text, result.Length, result.TokenCount); + } /// /// Find the maximum encoding capacity from the end within the input text without surpassing the token limit. @@ -164,18 +168,22 @@ public int CountTokens(string text, bool considerSpecialTokens = true) /// The maximum token count to limit the encoding capacity. /// Indicate if want to consider the special tokens during the encoding. /// - /// The entire normalized text, the starting offset within the returned text for token counting, the length of text constrained by the maximum token count, - /// and the token count can be generated using the provided subtext offset and length. + /// - The entire normalized text. + /// - The starting offset within the returned normalized text for token counting. + /// - The token count can be generated which should be smaller than the maximum token count. /// /// The input text is null. /// The maximum token count must be greater than 0. /// /// If the tokenizer has a normalizer, the returned text will be the normalized text. Otherwise the returned text will be the input text. - /// If the provided is greater than the token count of the input text, the returned length will be the length of the input text. - /// If the provided is smaller enough to hold smallest number of grouped Ids, the returned length will be 0 and returned TokenCount will be 0. + /// If the provided is greater than the token count of the input text, the returned Offset will be 0. + /// If the provided is smaller enough to hold smallest number of grouped Ids, the returned Offset will be equal to normalized text length and the returned TokenCount will be 0. /// - public (string Text, int Offset, int Length, int TokenCount) TrimPrefixWithinTokenLimit(string text, int maxTokenCount, bool considerSpecialTokens = true) => - TrimWithinTokenLimit(text, maxTokenCount, trimSuffix: false, considerSpecialTokens); + public (string Text, int Offset, int TokenCount) TrimPrefixWithinTokenLimit(string text, int maxTokenCount, bool considerSpecialTokens = true) + { + (string Text, int Offset, int Length, int TokenCount) result = TrimWithinTokenLimit(text, maxTokenCount, trimSuffix: false, considerSpecialTokens); + return (result.Text, result.Offset, result.TokenCount); + } private (string Text, int Offset, int Length, int TokenCount) TrimWithinTokenLimit(string text, int maxTokenCount, bool trimSuffix = true, bool considerSpecialTokens = true) { diff --git a/test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs b/test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs index 8fb9f10508..b8ffce87e0 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs @@ -26,16 +26,15 @@ internal static void TestTokenLimits(Tokenizer tokenizer) for (int i = 1; i <= fullIdsList.Count; i++) { - (string Text, int Offset, int Length, int TokenCount) result1 = tokenizer.TrimSuffixWithinTokenLimit(input, maxTokenCount: i); - (string Text, int Offset, int Length, int TokenCount) result2 = tokenizer.TrimPrefixWithinTokenLimit(input, maxTokenCount: i); + (string Text, int Length, int TokenCount) result1 = tokenizer.TrimSuffixWithinTokenLimit(input, maxTokenCount: i); + (string Text, int Offset, int TokenCount) result2 = tokenizer.TrimPrefixWithinTokenLimit(input, maxTokenCount: i); IReadOnlyList? prefixIds = null; IReadOnlyList? suffixIds = null; if (result1.TokenCount > 0) { - Assert.Equal(0, result1.Offset); - string prefixString = result1.Text.Substring(result1.Offset, result1.Length); + string prefixString = result1.Text.Substring(0, result1.Length); prefixIds = tokenizer.EncodeToIds(prefixString); Assert.Equal(result1.TokenCount, prefixIds.Count); Assert.Equal(prefixIds, fullIdsList.Take(prefixIds.Count)); @@ -43,8 +42,7 @@ internal static void TestTokenLimits(Tokenizer tokenizer) if (result2.TokenCount > 0) { - Assert.Equal(result2.Text.Length, result2.Offset + result2.Length); - string suffixString = result2.Text.Substring(result2.Offset, result2.Length); + string suffixString = result2.Text.Substring(result2.Offset); suffixIds = tokenizer.EncodeToIds(suffixString); Assert.Equal(result2.TokenCount, suffixIds.Count); Assert.Equal(suffixIds, fullIdsList.Skip(fullIdsList.Count - suffixIds.Count)); @@ -53,7 +51,7 @@ internal static void TestTokenLimits(Tokenizer tokenizer) if (i == fullIdsList.Count) { Assert.Equal(result1.Text.Length, result1.Length); - Assert.Equal(result2.Text.Length, result2.Length); + Assert.Equal(0, result2.Offset); Assert.Equal(fullIdsList, prefixIds); Assert.Equal(fullIdsList, suffixIds); } From 5041ebf135e615097683aa4ef9bc79cf105312a8 Mon Sep 17 00:00:00 2001 From: Tarek Mahmoud Sayed Date: Tue, 5 Mar 2024 11:48:39 -0800 Subject: [PATCH 4/7] fix comments --- src/Microsoft.ML.Tokenizers/Model/BPE.cs | 4 ++-- src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs | 4 ++-- src/Microsoft.ML.Tokenizers/Model/Model.cs | 4 ++-- src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs | 4 ++-- src/Microsoft.ML.Tokenizers/Tokenizer.cs | 2 +- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/Microsoft.ML.Tokenizers/Model/BPE.cs b/src/Microsoft.ML.Tokenizers/Model/BPE.cs index 20cfe7f38b..8f915b2a7d 100644 --- a/src/Microsoft.ML.Tokenizers/Model/BPE.cs +++ b/src/Microsoft.ML.Tokenizers/Model/BPE.cs @@ -177,7 +177,7 @@ private Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken, stri /// Encode a text string to a list of tokens. /// /// The text to encode. - /// Indicate if the token is a special token. + /// Indicate if the text is a special token. /// The list of tokens generated from the text tokenization. public override IReadOnlyList Encode(string text, bool isSpecialToken = false) { @@ -193,7 +193,7 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f /// Encode a split text string to a list of Ids and add them to the accumulatedIds list. /// /// The text to split. - /// Indicate if the token is a special token. + /// Indicate if the text is a special token. /// The list of accumulated encoded Ids. public override void EncodeToIds(ReadOnlySpan text, bool isSpecialToken, IList accumulatedIds) => EncodeToIdsWithCache(text, accumulatedIds); diff --git a/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs b/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs index 3155c778ec..6f0479155c 100644 --- a/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs +++ b/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs @@ -177,7 +177,7 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes /// Encode a text string to a list of tokens. /// /// The text to encode. - /// Indicate if the token is a special token. + /// Indicate if the text is a special token. /// The list of tokens generated from the text tokenization. public override IReadOnlyList Encode(string text, bool isSpecialToken = false) { @@ -225,7 +225,7 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f /// Encode a split text string to a list of Ids and add them to the accumulatedIds list. /// /// The text to split. - /// Indicate if the token is a special token. + /// Indicate if the text is a special token. /// The list of accumulated encoded Ids. public override void EncodeToIds(ReadOnlySpan text, bool isSpecialToken, IList accumulatedIds) => EncodeToIds(text, accumulatedIds); diff --git a/src/Microsoft.ML.Tokenizers/Model/Model.cs b/src/Microsoft.ML.Tokenizers/Model/Model.cs index 815bd04a0b..d48eca7a6c 100644 --- a/src/Microsoft.ML.Tokenizers/Model/Model.cs +++ b/src/Microsoft.ML.Tokenizers/Model/Model.cs @@ -17,7 +17,7 @@ public abstract class Model /// Encode a text to a list of tokens. /// /// The text to encode. - /// Indicate if the token is a special token. + /// Indicate if the text is a special token. /// The list of tokens generated from the text tokenization. public abstract IReadOnlyList Encode(string text, bool isSpecialToken = false); @@ -25,7 +25,7 @@ public abstract class Model /// Encode a text to a list of Ids and add them to the accumulatedIds list. /// /// The text to encode. - /// Indicate if the token is a special token. + /// Indicate if the text is a special token. /// The list of accumulated encoded Ids. /// /// This method does the default implementation that uses the Encode method to get the token's Ids. diff --git a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs index 58637b97af..ed0d550a8e 100644 --- a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs +++ b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs @@ -309,7 +309,7 @@ await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) : /// Encode a split text string to a list of tokens. /// /// The text to encode. - /// Indicate if the token is a special token. + /// Indicate if the text is a special token. /// The list of tokens generated from the text tokenization. public override IReadOnlyList Encode(string text, bool isSpecialToken = false) { @@ -372,7 +372,7 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f /// Encode text to a list of Ids. /// /// The text to encode. - /// Indicate if the token is a special token. + /// Indicate if the text is a special token. /// The list of accumulated Ids. public override void EncodeToIds(ReadOnlySpan text, bool isSpecialToken, IList accumulatedIds) { diff --git a/src/Microsoft.ML.Tokenizers/Tokenizer.cs b/src/Microsoft.ML.Tokenizers/Tokenizer.cs index 09a962a721..ea5c5f4dd1 100644 --- a/src/Microsoft.ML.Tokenizers/Tokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Tokenizer.cs @@ -145,7 +145,7 @@ public int CountTokens(string text, bool considerSpecialTokens = true) /// Indicate if want to consider the special tokens during the encoding. /// /// - The entire normalized text. - /// - The length of text from the beginning of the normalized which is limited by the maximum token count + /// - The length of text from the beginning of the normalized which is limited by the maximum token count. /// - The token count can be generated using the provided length which should be smaller than the maximum token count. /// /// The input text is null. From 765f79a6f79332a92048a99f36947efb06b93d0a Mon Sep 17 00:00:00 2001 From: Tarek Mahmoud Sayed Date: Tue, 5 Mar 2024 13:54:26 -0800 Subject: [PATCH 5/7] Update the APIs signatures --- src/Microsoft.ML.Tokenizers/Tokenizer.cs | 64 +++++++------------ .../TokenizerTests.cs | 33 +++++----- 2 files changed, 41 insertions(+), 56 deletions(-) diff --git a/src/Microsoft.ML.Tokenizers/Tokenizer.cs b/src/Microsoft.ML.Tokenizers/Tokenizer.cs index ea5c5f4dd1..853c01fe77 100644 --- a/src/Microsoft.ML.Tokenizers/Tokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Tokenizer.cs @@ -138,54 +138,40 @@ public int CountTokens(string text, bool considerSpecialTokens = true) } /// - /// Find the maximum encoding capacity from beginning within the input text without surpassing the token limit. + /// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit. /// /// The text to encode. /// The maximum token count to limit the encoding capacity. + /// If the tokenizer's normalization is enabled, the input text will be represented in its normalization form; otherwise, it will remain unchanged as the input text. + /// The token count can be generated which should be smaller than the maximum token count. /// Indicate if want to consider the special tokens during the encoding. - /// - /// - The entire normalized text. - /// - The length of text from the beginning of the normalized which is limited by the maximum token count. - /// - The token count can be generated using the provided length which should be smaller than the maximum token count. - /// + /// The index of the maximum encoding capacity within the processed text without surpassing the token limit. /// The input text is null. /// The maximum token count must be greater than 0. /// - /// If the tokenizer has a normalizer, the returned text will be the normalized text. Otherwise the returned text will be the input text. - /// If the provided is greater than the token count of the input text, the returned length will be the length of the input text. - /// If the provided is smaller enough to hold smallest number of grouped Ids, the returned length will be 0 and returned TokenCount will be 0. + /// If the whole text can be encoded within the token limit, the returned index will be the length of the processed text. /// - public (string Text, int Length, int TokenCount) TrimSuffixWithinTokenLimit(string text, int maxTokenCount, bool considerSpecialTokens = true) - { - (string Text, int Offset, int Length, int TokenCount) result = TrimWithinTokenLimit(text, maxTokenCount, trimSuffix: true, considerSpecialTokens); - return (result.Text, result.Length, result.TokenCount); - } + public int IndexOfTokenCount(string text, int maxTokenCount, out string processedText, out int tokenCount, bool considerSpecialTokens = true) + => IndexOF(text, maxTokenCount, fromStart: true, considerSpecialTokens, out processedText, out tokenCount); /// - /// Find the maximum encoding capacity from the end within the input text without surpassing the token limit. + /// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit. /// /// The text to encode. /// The maximum token count to limit the encoding capacity. + /// If the tokenizer's normalization is enabled, the input text will be represented in its normalization form; otherwise, it will remain unchanged as the input text. + /// The token count can be generated which should be smaller than the maximum token count. /// Indicate if want to consider the special tokens during the encoding. - /// - /// - The entire normalized text. - /// - The starting offset within the returned normalized text for token counting. - /// - The token count can be generated which should be smaller than the maximum token count. - /// + /// The start index of the maximum encoding capacity within the processed text without surpassing the token limit. /// The input text is null. /// The maximum token count must be greater than 0. /// - /// If the tokenizer has a normalizer, the returned text will be the normalized text. Otherwise the returned text will be the input text. - /// If the provided is greater than the token count of the input text, the returned Offset will be 0. - /// If the provided is smaller enough to hold smallest number of grouped Ids, the returned Offset will be equal to normalized text length and the returned TokenCount will be 0. + /// If the whole text can be encoded within the token limit, the returned index will be 0. /// - public (string Text, int Offset, int TokenCount) TrimPrefixWithinTokenLimit(string text, int maxTokenCount, bool considerSpecialTokens = true) - { - (string Text, int Offset, int Length, int TokenCount) result = TrimWithinTokenLimit(text, maxTokenCount, trimSuffix: false, considerSpecialTokens); - return (result.Text, result.Offset, result.TokenCount); - } + public int LastIndexOfTokenCount(string text, int maxTokenCount, out string processedText, out int tokenCount, bool considerSpecialTokens = true) + => IndexOF(text, maxTokenCount, fromStart: false, considerSpecialTokens, out processedText, out tokenCount); - private (string Text, int Offset, int Length, int TokenCount) TrimWithinTokenLimit(string text, int maxTokenCount, bool trimSuffix = true, bool considerSpecialTokens = true) + private int IndexOF(string text, int maxTokenCount, bool fromStart, bool considerSpecialTokens, out string processedText, out int tokenCount) { if (text is null) { @@ -197,24 +183,22 @@ public int CountTokens(string text, bool considerSpecialTokens = true) throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The max token count must be greater than 0."); } - string normalized = Normalizer is not null ? Normalizer.Normalize(text) : text; - int idsCount = 0; + processedText = Normalizer is not null ? Normalizer.Normalize(text) : text; + tokenCount = 0; - IEnumerable splits = PreTokenizer.PreTokenize(normalized, considerSpecialTokens); - foreach (Split split in (trimSuffix ? splits : splits.Reverse())) + IEnumerable splits = PreTokenizer.PreTokenize(processedText, considerSpecialTokens); + foreach (Split split in (fromStart ? splits : splits.Reverse())) { - int tokenCount = Model.CountTokens(split.TokenSpan, split.IsSpecialToken); - if (tokenCount > maxTokenCount - idsCount) + int count = Model.CountTokens(split.TokenSpan, split.IsSpecialToken); + if (tokenCount > maxTokenCount - count) { - return trimSuffix ? - (normalized, 0, split.Offset.Index, idsCount) : - (normalized, split.Offset.Index + split.Offset.Length, normalized.Length - split.Offset.Index - split.Offset.Length, idsCount); + return fromStart ? split.Offset.Index : split.Offset.Index + split.Offset.Length; } - idsCount += tokenCount; + tokenCount += count; } - return (normalized, 0, normalized.Length, idsCount); + return fromStart ? processedText.Length : 0; } /// diff --git a/test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs b/test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs index b8ffce87e0..48efe20f78 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs @@ -26,44 +26,45 @@ internal static void TestTokenLimits(Tokenizer tokenizer) for (int i = 1; i <= fullIdsList.Count; i++) { - (string Text, int Length, int TokenCount) result1 = tokenizer.TrimSuffixWithinTokenLimit(input, maxTokenCount: i); - (string Text, int Offset, int TokenCount) result2 = tokenizer.TrimPrefixWithinTokenLimit(input, maxTokenCount: i); + int index1 = tokenizer.IndexOfTokenCount(input, maxTokenCount: i, out string processedText1, out int tokenCount1); + int index2 = tokenizer.LastIndexOfTokenCount(input, maxTokenCount: i, out string processedText2, out int tokenCount2); + IReadOnlyList? prefixIds = null; IReadOnlyList? suffixIds = null; - if (result1.TokenCount > 0) + if (tokenCount1 > 0) { - string prefixString = result1.Text.Substring(0, result1.Length); + string prefixString = processedText1.Substring(0, index1); prefixIds = tokenizer.EncodeToIds(prefixString); - Assert.Equal(result1.TokenCount, prefixIds.Count); + Assert.Equal(tokenCount1, prefixIds.Count); Assert.Equal(prefixIds, fullIdsList.Take(prefixIds.Count)); } - if (result2.TokenCount > 0) + if (tokenCount2 > 0) { - string suffixString = result2.Text.Substring(result2.Offset); + string suffixString = processedText2.Substring(index2); suffixIds = tokenizer.EncodeToIds(suffixString); - Assert.Equal(result2.TokenCount, suffixIds.Count); + Assert.Equal(tokenCount2, suffixIds.Count); Assert.Equal(suffixIds, fullIdsList.Skip(fullIdsList.Count - suffixIds.Count)); } if (i == fullIdsList.Count) { - Assert.Equal(result1.Text.Length, result1.Length); - Assert.Equal(0, result2.Offset); + Assert.Equal(processedText1.Length, index1); + Assert.Equal(0, index2); Assert.Equal(fullIdsList, prefixIds); Assert.Equal(fullIdsList, suffixIds); } } - Assert.Throws(() => tokenizer.TrimSuffixWithinTokenLimit(input, maxTokenCount: 0)); - Assert.Throws(() => tokenizer.TrimSuffixWithinTokenLimit(input, maxTokenCount: -1)); - Assert.Throws(() => tokenizer.TrimPrefixWithinTokenLimit(input, maxTokenCount: 0)); - Assert.Throws(() => tokenizer.TrimPrefixWithinTokenLimit(input, maxTokenCount: -1)); + Assert.Throws(() => tokenizer.IndexOfTokenCount(input, maxTokenCount: 0, out _, out _)); + Assert.Throws(() => tokenizer.IndexOfTokenCount(input, maxTokenCount: -1, out _, out _)); + Assert.Throws(() => tokenizer.LastIndexOfTokenCount(input, maxTokenCount: 0, out _, out _)); + Assert.Throws(() => tokenizer.LastIndexOfTokenCount(input, maxTokenCount: -1, out _, out _)); - Assert.Throws(() => tokenizer.TrimSuffixWithinTokenLimit(null!, maxTokenCount: 0)); - Assert.Throws(() => tokenizer.TrimPrefixWithinTokenLimit(null!, maxTokenCount: 0)); + Assert.Throws(() => tokenizer.IndexOfTokenCount(null!, maxTokenCount: 10, out _, out _)); + Assert.Throws(() => tokenizer.LastIndexOfTokenCount(null!, maxTokenCount: 10, out _, out _)); } } } \ No newline at end of file From b9e95b2491d04ede8142f14a3904816c06e4d3e4 Mon Sep 17 00:00:00 2001 From: Tarek Mahmoud Sayed Date: Wed, 6 Mar 2024 12:54:09 -0800 Subject: [PATCH 6/7] More feedback addressing --- src/Microsoft.ML.Tokenizers/AddedToken.cs | 91 ------ src/Microsoft.ML.Tokenizers/Model/BPE.cs | 6 +- .../Model/EnglishRoberta.cs | 6 +- src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs | 273 ++++++++++++++++- .../PreTokenizer/Roberta.cs | 2 +- src/Microsoft.ML.Tokenizers/Tokenizer.cs | 286 +----------------- .../TitokenTests.cs | 12 +- 7 files changed, 294 insertions(+), 382 deletions(-) delete mode 100644 src/Microsoft.ML.Tokenizers/AddedToken.cs diff --git a/src/Microsoft.ML.Tokenizers/AddedToken.cs b/src/Microsoft.ML.Tokenizers/AddedToken.cs deleted file mode 100644 index 01684cf49f..0000000000 --- a/src/Microsoft.ML.Tokenizers/AddedToken.cs +++ /dev/null @@ -1,91 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Collections.Generic; -using System.Text; - -namespace Microsoft.ML.Tokenizers -{ - /// - /// Represent a token added by the user on top of the existing Model vocabulary. - /// AddedToken can be configured to specify the behavior they should have in various situations - /// like: - /// - Whether they should only match single words - /// - Whether to include any WhiteSpace on its left or right - /// - public struct AddedToken : IEquatable - { - /// - /// Gets or sets the content of the added token - /// - public string Content { get; set; } - - /// - /// Gets or sets whether this token must be a single word or can break words - /// - internal bool SingleWord { get; set; } - - /// - /// Gets or sets whether this token should strip WhiteSpaces on its left - /// - internal bool LeftStrip { get; set; } - - /// - /// Gets or sets whether this token should strip WhiteSpaces on its right - /// - internal bool RightStrip { get; set; } - - /// - /// Gets or sets whether this token should be normalized - /// - internal bool Normalized { get; set; } - - /// - /// Gets or sets whether this token is special - /// - internal bool Special { get; set; } - - /// - /// Create a new AddedToken object. - /// - public AddedToken() - { - Content = ""; - SingleWord = LeftStrip = RightStrip = Special = false; - Normalized = true; - } - - /// - /// Create a new AddedToken object from the given content, specifying if it is intended to be a - /// special token. Special tokens are not normalized by default. - /// - /// The content of the added token. - /// Indicate whether this token is special. - public AddedToken(string content, bool special = false) : this() - { - Content = content ?? ""; - Special = special; - Normalized = !special; - } - - /// - /// Determines whether two token instances are equal. - /// - /// The token to compare with the current token. - public bool Equals(AddedToken other) => Content == other.Content; - - // We only want to hash on the content. AddedToken cannot be added multiple times with different options - /// - /// Returns the hash code for the current token. - /// - public override int GetHashCode() => Content.GetHashCode(); - - - /// - /// Defines an implicit conversion of a string object to AddedToken. - /// - public static implicit operator AddedToken(string token) => new AddedToken(token); - } -} diff --git a/src/Microsoft.ML.Tokenizers/Model/BPE.cs b/src/Microsoft.ML.Tokenizers/Model/BPE.cs index 8f915b2a7d..63bc1e64ea 100644 --- a/src/Microsoft.ML.Tokenizers/Model/BPE.cs +++ b/src/Microsoft.ML.Tokenizers/Model/BPE.cs @@ -177,7 +177,7 @@ private Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken, stri /// Encode a text string to a list of tokens. /// /// The text to encode. - /// Indicate if the text is a special token. + /// Indicate if the text is a special token. This parameter is ignored in this model. /// The list of tokens generated from the text tokenization. public override IReadOnlyList Encode(string text, bool isSpecialToken = false) { @@ -193,7 +193,7 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f /// Encode a split text string to a list of Ids and add them to the accumulatedIds list. /// /// The text to split. - /// Indicate if the text is a special token. + /// Indicate if the text is a special token. This parameter is ignored in this model. /// The list of accumulated encoded Ids. public override void EncodeToIds(ReadOnlySpan text, bool isSpecialToken, IList accumulatedIds) => EncodeToIdsWithCache(text, accumulatedIds); @@ -202,7 +202,7 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f /// /// The text to encode. /// Indicate if the token is special token. - /// The number of tokens that the input text will be encoded to. + /// The number of tokens that the input text will be encoded to. This parameter is ignored in this model. public override int CountTokens(ReadOnlySpan text, bool isSpecialToken) => EncodeToIdsWithCache(text, null); /// diff --git a/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs b/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs index 6f0479155c..b4e2d7650f 100644 --- a/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs +++ b/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs @@ -177,7 +177,7 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes /// Encode a text string to a list of tokens. /// /// The text to encode. - /// Indicate if the text is a special token. + /// Indicate if the text is a special token. This parameter is ignored in this model. /// The list of tokens generated from the text tokenization. public override IReadOnlyList Encode(string text, bool isSpecialToken = false) { @@ -225,7 +225,7 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f /// Encode a split text string to a list of Ids and add them to the accumulatedIds list. /// /// The text to split. - /// Indicate if the text is a special token. + /// Indicate if the text is a special token. This parameter is ignored in this model. /// The list of accumulated encoded Ids. public override void EncodeToIds(ReadOnlySpan text, bool isSpecialToken, IList accumulatedIds) => EncodeToIds(text, accumulatedIds); @@ -233,7 +233,7 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f /// Get the number of tokens that the input text will be encoded to. /// /// The text to encode. - /// Indicate if the token is special token. + /// Indicate if the token is special token. This parameter is ignored in this model. /// The number of tokens that the input text will be encoded to. public override int CountTokens(ReadOnlySpan text, bool isSpecialToken) => EncodeToIds(text, null); diff --git a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs index ed0d550a8e..8e54369a0d 100644 --- a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs +++ b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs @@ -4,10 +4,12 @@ using System; using System.Buffers; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.Linq; +using System.Net.Http; using System.Text; using System.Text.RegularExpressions; using System.Threading; @@ -18,7 +20,7 @@ namespace Microsoft.ML.Tokenizers /// /// Represent the rapid Byte Pair Encoding model commonly referred to as Tiktoken. /// - public sealed class Tiktoken : Model + public sealed partial class Tiktoken : Model { private readonly Dictionary, int> _encoder; private readonly Dictionary> _decoder; @@ -122,7 +124,7 @@ public static Tokenizer CreateByModelName( throw new ArgumentNullException(nameof(modelName)); } - (Dictionary SpecialTokens, Regex Regex) tiktokenConfiguration = Tokenizer.GetTiktokenConfigurations(modelName); + (Dictionary SpecialTokens, Regex Regex) tiktokenConfiguration = GetTiktokenConfigurations(modelName); if (extraSpecialTokens is not null) { @@ -161,7 +163,7 @@ public static async Task CreateByModelNameAsync( throw new ArgumentNullException(nameof(modelName)); } - (Dictionary SpecialTokens, Regex Regex) tiktokenConfiguration = Tokenizer.GetTiktokenConfigurations(modelName); + (Dictionary SpecialTokens, Regex Regex) tiktokenConfiguration = GetTiktokenConfigurations(modelName); if (extraSpecialTokens is not null) { @@ -636,6 +638,271 @@ static void ArrayPoolGrow(ref Span utf8Bytes, ref byte[]? arrayPoolArray, /// public IReadOnlyDictionary> Decoder => _decoder; + private const string EndOfText = "<|endoftext|>"; + private const string FimPrefix = "<|fim_prefix|>"; + private const string FimMiddle = "<|fim_middle|>"; + private const string FimSuffix = "<|fim_suffix|>"; + private const string EndOfPrompt = "<|endofprompt|>"; + + private static readonly HttpClient _httpClient = new HttpClient(); + + private enum ModelEncoding + { + None, + Cl100kBase, + P50kBase, + P50kEdit, + R50kBase, + GPT2 + } + + private static readonly (string Prefix, ModelEncoding Encoding)[] _modelPrefixToEncoding = + [ + // chat + ("gpt-4-", ModelEncoding.Cl100kBase), // e.g., gpt-4-0314, etc., plus gpt-4-32k + ("gpt-3.5-turbo-", ModelEncoding.Cl100kBase) // e.g, gpt-3.5-turbo-0301, -0401, etc. + ]; + + private static readonly Dictionary _modelToEncoding = + new Dictionary(StringComparer.OrdinalIgnoreCase) + { + // chat + { "gpt-4", ModelEncoding.Cl100kBase }, + { "gpt-3.5-turbo", ModelEncoding.Cl100kBase }, + + // text + { "text-davinci-003", ModelEncoding.P50kBase }, + { "text-davinci-002", ModelEncoding.P50kBase }, + { "text-davinci-001", ModelEncoding.R50kBase }, + { "text-curie-001", ModelEncoding.R50kBase }, + { "text-babbage-001", ModelEncoding.R50kBase }, + { "text-ada-001", ModelEncoding.R50kBase }, + { "davinci", ModelEncoding.R50kBase }, + { "curie", ModelEncoding.R50kBase }, + { "babbage", ModelEncoding.R50kBase }, + { "ada", ModelEncoding.R50kBase }, + + // code + { "code-davinci-002", ModelEncoding.P50kBase }, + { "code-davinci-001", ModelEncoding.P50kBase }, + { "code-cushman-002", ModelEncoding.P50kBase }, + { "code-cushman-001", ModelEncoding.P50kBase }, + { "davinci-codex", ModelEncoding.P50kBase }, + { "cushman-codex", ModelEncoding.P50kBase }, + + // edit + { "text-davinci-edit-001", ModelEncoding.P50kEdit }, + { "code-davinci-edit-001", ModelEncoding.P50kEdit }, + + // embeddings + // https://platform.openai.com/docs/guides/embeddings/what-are-embeddings + { "text-embedding-ada-002", ModelEncoding.Cl100kBase }, + { "text-embedding-3-small", ModelEncoding.Cl100kBase }, + { "text-embedding-3-large", ModelEncoding.Cl100kBase }, + + // old embeddings + { "text-similarity-davinci-001", ModelEncoding.R50kBase }, + { "text-similarity-curie-001", ModelEncoding.R50kBase }, + { "text-similarity-babbage-001", ModelEncoding.R50kBase }, + { "text-similarity-ada-001", ModelEncoding.R50kBase }, + { "text-search-davinci-doc-001", ModelEncoding.R50kBase }, + { "text-search-curie-doc-001", ModelEncoding.R50kBase }, + { "text-search-babbage-doc-001", ModelEncoding.R50kBase }, + { "text-search-ada-doc-001", ModelEncoding.R50kBase }, + { "code-search-babbage-code-001", ModelEncoding.R50kBase }, + { "code-search-ada-code-001", ModelEncoding.R50kBase }, + + // open source + { "gpt2", ModelEncoding.GPT2 } + }; + + private static ModelEncoding GetModelEncoding(string modelName) + { + if (!_modelToEncoding.TryGetValue(modelName, out ModelEncoding encoder)) + { + foreach ((string Prefix, ModelEncoding Encoding) in _modelPrefixToEncoding) + { + if (modelName.StartsWith(Prefix, StringComparison.OrdinalIgnoreCase)) + { + encoder = Encoding; + break; + } + } + } + + if (encoder == ModelEncoding.None) + { + throw new NotSupportedException($"The model '{modelName}' is not supported."); + } + + return encoder; + } + + internal static (Dictionary SpecialTokens, Regex Regex) GetTiktokenConfigurations(string modelName) + { + ModelEncoding modelEncoding = GetModelEncoding(modelName); + + switch (modelEncoding) + { + case ModelEncoding.Cl100kBase: + return (new Dictionary + { { EndOfText, 100257}, { FimPrefix, 100258}, { FimMiddle, 100259}, { FimSuffix, 100260}, { EndOfPrompt, 100276} }, Cl100kBaseRegex()); + + case ModelEncoding.P50kBase: + return (new Dictionary { { EndOfText, 50256 } }, P50kBaseRegex()); + + case ModelEncoding.P50kEdit: + return (new Dictionary + { { EndOfText, 50256 }, { FimPrefix, 50281 }, { FimMiddle, 50282 }, { FimSuffix, 50283 } }, P50kBaseRegex()); + + case ModelEncoding.R50kBase: + return (new Dictionary { { EndOfText, 50256 } }, P50kBaseRegex()); + + case ModelEncoding.GPT2: + return (new Dictionary { { EndOfText, 50256 }, }, P50kBaseRegex()); + + default: + Debug.Assert(false, $"Unexpected encoder [{modelEncoding}]"); + throw new NotSupportedException($"The model '{modelName}' is not supported."); + } + } + + /// + /// Create tokenizer based on model name + /// + /// Model name + /// Extra special tokens other than the built-in ones for the model + /// To normalize the text before tokenization + /// used to request cancellation of the operation. + /// The tokenizer + public static Task CreateByModelNameAsync( + string modelName, + IReadOnlyDictionary? extraSpecialTokens = null, + Normalizer? normalizer = null, + CancellationToken cancellationToken = default) + { + try + { + return CreateByEncoderNameAsync(modelName, GetModelEncoding(modelName), extraSpecialTokens, normalizer, cancellationToken); + } + catch (Exception ex) + { + return Task.FromException(ex); + } + } + + // Regex patterns based on https://github.com/openai/tiktoken/blob/main/tiktoken_ext/openai_public.py + + private const string Cl100kBaseRegexPattern = /*lang=regex*/ @"'(?i:[sdmt]|re|ve|ll)|(?>[^\r\n\p{L}\p{N}]?)\p{L}+|\p{N}{1,3}| ?(?>[^\s\p{L}\p{N}]+)[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"; + private const string P50kBaseRegexPattern = /*lang=regex*/ @"'(?:[sdmt]|re|ve|ll)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"; + + private const string Cl100kBaseVocabUrl = @"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken"; + private const string P50RanksUrl = @"https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken"; + private const string R50RanksUrl = @"https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken"; + private const string GPT2Url = @"https://pythia.blob.core.windows.net/public/encoding/gpt2.tiktoken"; + +#if NET7_0_OR_GREATER + [GeneratedRegex(Cl100kBaseRegexPattern)] + private static partial Regex Cl100kBaseRegex(); + + [GeneratedRegex(P50kBaseRegexPattern)] + internal static partial Regex P50kBaseRegex(); +#else + private static Regex? _cl100kBaseRegex; + private static Regex Cl100kBaseRegex() => _cl100kBaseRegex ??= new Regex(Cl100kBaseRegexPattern, RegexOptions.Compiled); + + private static Regex? _p50kBaseRegex; + internal static Regex P50kBaseRegex() => _p50kBaseRegex ??= new Regex(P50kBaseRegexPattern, RegexOptions.Compiled); +#endif + + /// + /// Create tokenizer based on encoder name and extra special tokens + /// + /// Model name + /// Encoder label + /// Extra special tokens other than the built-in ones for the encoder + /// To normalize the text before tokenization + /// used to request cancellation of the operation. + /// The tokenizer + /// Throws if the model name is not supported + private static Task CreateByEncoderNameAsync( + string modelName, + ModelEncoding modelEncoding, + IReadOnlyDictionary? extraSpecialTokens, + Normalizer? normalizer, + CancellationToken cancellationToken) + { + switch (modelEncoding) + { + case ModelEncoding.Cl100kBase: + var specialTokens = new Dictionary + { { EndOfText, 100257}, { FimPrefix, 100258}, { FimMiddle, 100259}, { FimSuffix, 100260}, { EndOfPrompt, 100276} }; + return CreateTikTokenTokenizerAsync(Cl100kBaseRegex(), Cl100kBaseVocabUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken); + + case ModelEncoding.P50kBase: + specialTokens = new Dictionary { { EndOfText, 50256 } }; + return CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken); + + case ModelEncoding.P50kEdit: + specialTokens = new Dictionary + { { EndOfText, 50256 }, { FimPrefix, 50281 }, { FimMiddle, 50282 }, { FimSuffix, 50283 } }; + return CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken); + + case ModelEncoding.R50kBase: + specialTokens = new Dictionary { { EndOfText, 50256 } }; + return CreateTikTokenTokenizerAsync(P50kBaseRegex(), R50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken); + + case ModelEncoding.GPT2: + specialTokens = new Dictionary { { EndOfText, 50256 }, }; + return CreateTikTokenTokenizerAsync(P50kBaseRegex(), GPT2Url, specialTokens, extraSpecialTokens, normalizer, cancellationToken); + + default: + Debug.Assert(false, $"Unexpected encoder [{modelEncoding}]"); + throw new NotSupportedException($"The model '{modelName}' is not supported."); + } + } + + private static readonly ConcurrentDictionary, int> encoder, Dictionary vocab, Dictionary> decoder)> _tiktokenCache = new(StringComparer.OrdinalIgnoreCase); + + /// + /// Create tokenizer based on regex pattern, BPE rank file and special tokens + /// + /// Regex to break a long string + /// BPE rank file + /// Special tokens mapping. This may be mutated by the method. + /// Extra special tokens other than the built-in ones for the encoder + /// To normalize the text before tokenization + /// used to request cancellation of the operation. + /// The tokenizer + private static async Task CreateTikTokenTokenizerAsync( + Regex regex, + string mergeableRanksFileUrl, + Dictionary specialTokens, + IReadOnlyDictionary? extraSpecialTokens, + Normalizer? normalizer, + CancellationToken cancellationToken) + { + if (extraSpecialTokens is not null) + { + foreach (var extraSpecialToken in extraSpecialTokens) + { + specialTokens.Add(extraSpecialToken.Key, extraSpecialToken.Value); + } + } + + if (!_tiktokenCache.TryGetValue(mergeableRanksFileUrl, out (Dictionary, int> encoder, Dictionary vocab, Dictionary> decoder) cache)) + { + using (Stream stream = await Helpers.GetStreamAsync(_httpClient, mergeableRanksFileUrl, cancellationToken).ConfigureAwait(false)) + { + cache = await Tiktoken.LoadTikTokenBpeAsync(stream, useAsync: true, cancellationToken).ConfigureAwait(false); + } + + _tiktokenCache.TryAdd(mergeableRanksFileUrl, cache); + } + + return new Tokenizer(new Tiktoken(cache.encoder, cache.decoder, cache.vocab, specialTokens), new TikTokenPreTokenizer(regex, specialTokens), normalizer); + } + private static unsafe int GetUtf8Bytes(ReadOnlySpan source, Span destination) { #if NETCOREAPP diff --git a/src/Microsoft.ML.Tokenizers/PreTokenizer/Roberta.cs b/src/Microsoft.ML.Tokenizers/PreTokenizer/Roberta.cs index 81b4eb642f..9e8db6c9c8 100644 --- a/src/Microsoft.ML.Tokenizers/PreTokenizer/Roberta.cs +++ b/src/Microsoft.ML.Tokenizers/PreTokenizer/Roberta.cs @@ -30,7 +30,7 @@ public override IEnumerable PreTokenize(string text, bool considerSpecial return Array.Empty(); } - return SplitText(text, Tokenizer.P50kBaseRegex()); + return SplitText(text, Tiktoken.P50kBaseRegex()); } } } diff --git a/src/Microsoft.ML.Tokenizers/Tokenizer.cs b/src/Microsoft.ML.Tokenizers/Tokenizer.cs index 853c01fe77..254ffa0801 100644 --- a/src/Microsoft.ML.Tokenizers/Tokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Tokenizer.cs @@ -3,12 +3,10 @@ // See the LICENSE file in the project root for more information. using System; -using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.Linq; -using System.Net.Http; using System.Text.RegularExpressions; using System.Threading; using System.Threading.Tasks; @@ -145,14 +143,14 @@ public int CountTokens(string text, bool considerSpecialTokens = true) /// If the tokenizer's normalization is enabled, the input text will be represented in its normalization form; otherwise, it will remain unchanged as the input text. /// The token count can be generated which should be smaller than the maximum token count. /// Indicate if want to consider the special tokens during the encoding. - /// The index of the maximum encoding capacity within the processed text without surpassing the token limit. + /// + /// The index of the maximum encoding capacity within the processed text without surpassing the token limit. + /// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely, if all tokens fit, the result will be length of the . + /// /// The input text is null. /// The maximum token count must be greater than 0. - /// - /// If the whole text can be encoded within the token limit, the returned index will be the length of the processed text. - /// public int IndexOfTokenCount(string text, int maxTokenCount, out string processedText, out int tokenCount, bool considerSpecialTokens = true) - => IndexOF(text, maxTokenCount, fromStart: true, considerSpecialTokens, out processedText, out tokenCount); + => IndexOf(text, maxTokenCount, fromStart: true, considerSpecialTokens, out processedText, out tokenCount); /// /// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit. @@ -162,16 +160,19 @@ public int IndexOfTokenCount(string text, int maxTokenCount, out string processe /// If the tokenizer's normalization is enabled, the input text will be represented in its normalization form; otherwise, it will remain unchanged as the input text. /// The token count can be generated which should be smaller than the maximum token count. /// Indicate if want to consider the special tokens during the encoding. - /// The start index of the maximum encoding capacity within the processed text without surpassing the token limit. + /// + /// The start index of the maximum encoding capacity within the processed text without surpassing the token limit. + /// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the ; conversely, if all tokens fit, the result will be 0. + /// /// The input text is null. /// The maximum token count must be greater than 0. /// /// If the whole text can be encoded within the token limit, the returned index will be 0. /// public int LastIndexOfTokenCount(string text, int maxTokenCount, out string processedText, out int tokenCount, bool considerSpecialTokens = true) - => IndexOF(text, maxTokenCount, fromStart: false, considerSpecialTokens, out processedText, out tokenCount); + => IndexOf(text, maxTokenCount, fromStart: false, considerSpecialTokens, out processedText, out tokenCount); - private int IndexOF(string text, int maxTokenCount, bool fromStart, bool considerSpecialTokens, out string processedText, out int tokenCount) + private int IndexOf(string text, int maxTokenCount, bool fromStart, bool considerSpecialTokens, out string processedText, out int tokenCount) { if (text is null) { @@ -216,270 +217,5 @@ private int IndexOF(string text, int maxTokenCount, bool fromStart, bool conside /// Whether the special tokens should be kept in the decoded string. /// The decoded string. public string? Decode(IEnumerable ids, bool considerSpecialTokens = true) => Model.Decode(ids, Decoder, considerSpecialTokens); - - private const string EndOfText = "<|endoftext|>"; - private const string FimPrefix = "<|fim_prefix|>"; - private const string FimMiddle = "<|fim_middle|>"; - private const string FimSuffix = "<|fim_suffix|>"; - private const string EndOfPrompt = "<|endofprompt|>"; - - private static readonly HttpClient _httpClient = new HttpClient(); - - private enum ModelEncoding - { - None, - Cl100kBase, - P50kBase, - P50kEdit, - R50kBase, - GPT2 - } - - private static readonly (string Prefix, ModelEncoding Encoding)[] _modelPrefixToEncoding = - [ - // chat - ("gpt-4-", ModelEncoding.Cl100kBase), // e.g., gpt-4-0314, etc., plus gpt-4-32k - ("gpt-3.5-turbo-", ModelEncoding.Cl100kBase) // e.g, gpt-3.5-turbo-0301, -0401, etc. - ]; - - private static readonly Dictionary _modelToEncoding = - new Dictionary(StringComparer.OrdinalIgnoreCase) - { - // chat - { "gpt-4", ModelEncoding.Cl100kBase }, - { "gpt-3.5-turbo", ModelEncoding.Cl100kBase }, - - // text - { "text-davinci-003", ModelEncoding.P50kBase }, - { "text-davinci-002", ModelEncoding.P50kBase }, - { "text-davinci-001", ModelEncoding.R50kBase }, - { "text-curie-001", ModelEncoding.R50kBase }, - { "text-babbage-001", ModelEncoding.R50kBase }, - { "text-ada-001", ModelEncoding.R50kBase }, - { "davinci", ModelEncoding.R50kBase }, - { "curie", ModelEncoding.R50kBase }, - { "babbage", ModelEncoding.R50kBase }, - { "ada", ModelEncoding.R50kBase }, - - // code - { "code-davinci-002", ModelEncoding.P50kBase }, - { "code-davinci-001", ModelEncoding.P50kBase }, - { "code-cushman-002", ModelEncoding.P50kBase }, - { "code-cushman-001", ModelEncoding.P50kBase }, - { "davinci-codex", ModelEncoding.P50kBase }, - { "cushman-codex", ModelEncoding.P50kBase }, - - // edit - { "text-davinci-edit-001", ModelEncoding.P50kEdit }, - { "code-davinci-edit-001", ModelEncoding.P50kEdit }, - - // embeddings - // https://platform.openai.com/docs/guides/embeddings/what-are-embeddings - { "text-embedding-ada-002", ModelEncoding.Cl100kBase }, - { "text-embedding-3-small", ModelEncoding.Cl100kBase }, - { "text-embedding-3-large", ModelEncoding.Cl100kBase }, - - // old embeddings - { "text-similarity-davinci-001", ModelEncoding.R50kBase }, - { "text-similarity-curie-001", ModelEncoding.R50kBase }, - { "text-similarity-babbage-001", ModelEncoding.R50kBase }, - { "text-similarity-ada-001", ModelEncoding.R50kBase }, - { "text-search-davinci-doc-001", ModelEncoding.R50kBase }, - { "text-search-curie-doc-001", ModelEncoding.R50kBase }, - { "text-search-babbage-doc-001", ModelEncoding.R50kBase }, - { "text-search-ada-doc-001", ModelEncoding.R50kBase }, - { "code-search-babbage-code-001", ModelEncoding.R50kBase }, - { "code-search-ada-code-001", ModelEncoding.R50kBase }, - - // open source - { "gpt2", ModelEncoding.GPT2 } - }; - - private static ModelEncoding GetModelEncoding(string modelName) - { - if (!_modelToEncoding.TryGetValue(modelName, out ModelEncoding encoder)) - { - foreach ((string Prefix, ModelEncoding Encoding) in _modelPrefixToEncoding) - { - if (modelName.StartsWith(Prefix, StringComparison.OrdinalIgnoreCase)) - { - encoder = Encoding; - break; - } - } - } - - if (encoder == ModelEncoding.None) - { - throw new NotSupportedException($"The model '{modelName}' is not supported."); - } - - return encoder; - } - - internal static (Dictionary SpecialTokens, Regex Regex) GetTiktokenConfigurations(string modelName) - { - ModelEncoding modelEncoding = GetModelEncoding(modelName); - - switch (modelEncoding) - { - case ModelEncoding.Cl100kBase: - return (new Dictionary - { { EndOfText, 100257}, { FimPrefix, 100258}, { FimMiddle, 100259}, { FimSuffix, 100260}, { EndOfPrompt, 100276} }, Cl100kBaseRegex()); - - case ModelEncoding.P50kBase: - return (new Dictionary { { EndOfText, 50256 } }, P50kBaseRegex()); - - case ModelEncoding.P50kEdit: - return (new Dictionary - { { EndOfText, 50256 }, { FimPrefix, 50281 }, { FimMiddle, 50282 }, { FimSuffix, 50283 } }, P50kBaseRegex()); - - case ModelEncoding.R50kBase: - return (new Dictionary { { EndOfText, 50256 } }, P50kBaseRegex()); - - case ModelEncoding.GPT2: - return (new Dictionary { { EndOfText, 50256 }, }, P50kBaseRegex()); - - default: - Debug.Assert(false, $"Unexpected encoder [{modelEncoding}]"); - throw new NotSupportedException($"The model '{modelName}' is not supported."); - } - } - - /// - /// Create tokenizer based on model name - /// - /// Model name - /// Extra special tokens other than the built-in ones for the model - /// To normalize the text before tokenization - /// used to request cancellation of the operation. - /// The tokenizer - public static Task CreateByModelNameAsync( - string modelName, - IReadOnlyDictionary? extraSpecialTokens = null, - Normalizer? normalizer = null, - CancellationToken cancellationToken = default) - { - try - { - return CreateByEncoderNameAsync(modelName, GetModelEncoding(modelName), extraSpecialTokens, normalizer, cancellationToken); - } - catch (Exception ex) - { - return Task.FromException(ex); - } - } - - // Regex patterns based on https://github.com/openai/tiktoken/blob/main/tiktoken_ext/openai_public.py - - private const string Cl100kBaseRegexPattern = /*lang=regex*/ @"'(?i:[sdmt]|re|ve|ll)|(?>[^\r\n\p{L}\p{N}]?)\p{L}+|\p{N}{1,3}| ?(?>[^\s\p{L}\p{N}]+)[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"; - private const string P50kBaseRegexPattern = /*lang=regex*/ @"'(?:[sdmt]|re|ve|ll)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"; - - private const string Cl100kBaseVocabUrl = @"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken"; - private const string P50RanksUrl = @"https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken"; - private const string R50RanksUrl = @"https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken"; - private const string GPT2Url = @"https://pythia.blob.core.windows.net/public/encoding/gpt2.tiktoken"; - -#if NET7_0_OR_GREATER - [GeneratedRegex(Cl100kBaseRegexPattern)] - private static partial Regex Cl100kBaseRegex(); - - [GeneratedRegex(P50kBaseRegexPattern)] - internal static partial Regex P50kBaseRegex(); -#else - private static Regex? _cl100kBaseRegex; - private static Regex Cl100kBaseRegex() => _cl100kBaseRegex ??= new Regex(Cl100kBaseRegexPattern, RegexOptions.Compiled); - - private static Regex? _p50kBaseRegex; - internal static Regex P50kBaseRegex() => _p50kBaseRegex ??= new Regex(P50kBaseRegexPattern, RegexOptions.Compiled); -#endif - - /// - /// Create tokenizer based on encoder name and extra special tokens - /// - /// Model name - /// Encoder label - /// Extra special tokens other than the built-in ones for the encoder - /// To normalize the text before tokenization - /// used to request cancellation of the operation. - /// The tokenizer - /// Throws if the model name is not supported - private static Task CreateByEncoderNameAsync( - string modelName, - ModelEncoding modelEncoding, - IReadOnlyDictionary? extraSpecialTokens, - Normalizer? normalizer, - CancellationToken cancellationToken) - { - switch (modelEncoding) - { - case ModelEncoding.Cl100kBase: - var specialTokens = new Dictionary - { { EndOfText, 100257}, { FimPrefix, 100258}, { FimMiddle, 100259}, { FimSuffix, 100260}, { EndOfPrompt, 100276} }; - return CreateTikTokenTokenizerAsync(Cl100kBaseRegex(), Cl100kBaseVocabUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken); - - case ModelEncoding.P50kBase: - specialTokens = new Dictionary { { EndOfText, 50256 } }; - return CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken); - - case ModelEncoding.P50kEdit: - specialTokens = new Dictionary - { { EndOfText, 50256 }, { FimPrefix, 50281 }, { FimMiddle, 50282 }, { FimSuffix, 50283 } }; - return CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken); - - case ModelEncoding.R50kBase: - specialTokens = new Dictionary { { EndOfText, 50256 } }; - return CreateTikTokenTokenizerAsync(P50kBaseRegex(), R50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken); - - case ModelEncoding.GPT2: - specialTokens = new Dictionary { { EndOfText, 50256 }, }; - return CreateTikTokenTokenizerAsync(P50kBaseRegex(), GPT2Url, specialTokens, extraSpecialTokens, normalizer, cancellationToken); - - default: - Debug.Assert(false, $"Unexpected encoder [{modelEncoding}]"); - throw new NotSupportedException($"The model '{modelName}' is not supported."); - } - } - - private static readonly ConcurrentDictionary, int> encoder, Dictionary vocab, Dictionary> decoder)> _tiktokenCache = new(StringComparer.OrdinalIgnoreCase); - - /// - /// Create tokenizer based on regex pattern, BPE rank file and special tokens - /// - /// Regex to break a long string - /// BPE rank file - /// Special tokens mapping. This may be mutated by the method. - /// Extra special tokens other than the built-in ones for the encoder - /// To normalize the text before tokenization - /// used to request cancellation of the operation. - /// The tokenizer - private static async Task CreateTikTokenTokenizerAsync( - Regex regex, - string mergeableRanksFileUrl, - Dictionary specialTokens, - IReadOnlyDictionary? extraSpecialTokens, - Normalizer? normalizer, - CancellationToken cancellationToken) - { - if (extraSpecialTokens is not null) - { - foreach (var extraSpecialToken in extraSpecialTokens) - { - specialTokens.Add(extraSpecialToken.Key, extraSpecialToken.Value); - } - } - - if (!_tiktokenCache.TryGetValue(mergeableRanksFileUrl, out (Dictionary, int> encoder, Dictionary vocab, Dictionary> decoder) cache)) - { - using (Stream stream = await Helpers.GetStreamAsync(_httpClient, mergeableRanksFileUrl, cancellationToken).ConfigureAwait(false)) - { - cache = await Tiktoken.LoadTikTokenBpeAsync(stream, useAsync: true, cancellationToken).ConfigureAwait(false); - } - - _tiktokenCache.TryAdd(mergeableRanksFileUrl, cache); - } - - return new Tokenizer(new Tiktoken(cache.encoder, cache.decoder, cache.vocab, specialTokens), new TikTokenPreTokenizer(regex, specialTokens), normalizer); - } } } diff --git a/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs b/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs index e73766818f..fc5a0772a2 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs @@ -25,11 +25,11 @@ public class TiktokenTests { IMEnd, 100265}, }; - public static Tokenizer GPT4 { get; } = Tokenizer.CreateByModelNameAsync("gpt-4", _specialTokens).GetAwaiter().GetResult(); - public static Tokenizer GPT2 { get; } = Tokenizer.CreateByModelNameAsync("gpt2").GetAwaiter().GetResult(); - public static Tokenizer P50kBase { get; } = Tokenizer.CreateByModelNameAsync("text-davinci-003").GetAwaiter().GetResult(); - public static Tokenizer R50kBase { get; } = Tokenizer.CreateByModelNameAsync("ada").GetAwaiter().GetResult(); - public static Tokenizer P50kEdit { get; } = Tokenizer.CreateByModelNameAsync("text-davinci-edit-001").GetAwaiter().GetResult(); + public static Tokenizer GPT4 { get; } = Tiktoken.CreateByModelNameAsync("gpt-4", _specialTokens).GetAwaiter().GetResult(); + public static Tokenizer GPT2 { get; } = Tiktoken.CreateByModelNameAsync("gpt2").GetAwaiter().GetResult(); + public static Tokenizer P50kBase { get; } = Tiktoken.CreateByModelNameAsync("text-davinci-003").GetAwaiter().GetResult(); + public static Tokenizer R50kBase { get; } = Tiktoken.CreateByModelNameAsync("ada").GetAwaiter().GetResult(); + public static Tokenizer P50kEdit { get; } = Tiktoken.CreateByModelNameAsync("text-davinci-edit-001").GetAwaiter().GetResult(); [Fact] public async void TestTokenizerCreation() @@ -298,7 +298,7 @@ public void TestEncodeR50kBase() [InlineData("gpt2")] public async void TestAllSupportedModelNames(string modelName) { - Tokenizer tokenizer = await Tokenizer.CreateByModelNameAsync(modelName); + Tokenizer tokenizer = await Tiktoken.CreateByModelNameAsync(modelName); Assert.NotNull(tokenizer.Model); Assert.NotNull(tokenizer.PreTokenizer); } From 1d895069a5fd9698840e3d1000a65f48b9aa1305 Mon Sep 17 00:00:00 2001 From: Tarek Mahmoud Sayed Date: Wed, 6 Mar 2024 16:45:52 -0800 Subject: [PATCH 7/7] Fix the comments --- src/Microsoft.ML.Tokenizers/Model/BPE.cs | 12 ++++++------ src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs | 12 ++++++------ src/Microsoft.ML.Tokenizers/Model/Model.cs | 12 ++++++------ src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs | 12 ++++++------ 4 files changed, 24 insertions(+), 24 deletions(-) diff --git a/src/Microsoft.ML.Tokenizers/Model/BPE.cs b/src/Microsoft.ML.Tokenizers/Model/BPE.cs index 63bc1e64ea..6bc231f41e 100644 --- a/src/Microsoft.ML.Tokenizers/Model/BPE.cs +++ b/src/Microsoft.ML.Tokenizers/Model/BPE.cs @@ -176,8 +176,8 @@ private Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken, stri /// /// Encode a text string to a list of tokens. /// - /// The text to encode. - /// Indicate if the text is a special token. This parameter is ignored in this model. + /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token. + /// Specifies whether the entire is considered a special token. This parameter is ignored in this model. /// The list of tokens generated from the text tokenization. public override IReadOnlyList Encode(string text, bool isSpecialToken = false) { @@ -192,16 +192,16 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f /// /// Encode a split text string to a list of Ids and add them to the accumulatedIds list. /// - /// The text to split. - /// Indicate if the text is a special token. This parameter is ignored in this model. + /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token. + /// Specifies whether the entire is considered a special token. This parameter is ignored in this model. /// The list of accumulated encoded Ids. public override void EncodeToIds(ReadOnlySpan text, bool isSpecialToken, IList accumulatedIds) => EncodeToIdsWithCache(text, accumulatedIds); /// /// Get the number of tokens that the input text will be encoded to. /// - /// The text to encode. - /// Indicate if the token is special token. + /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token. + /// Specifies whether the entire is considered a special token. This parameter is ignored in this model. /// The number of tokens that the input text will be encoded to. This parameter is ignored in this model. public override int CountTokens(ReadOnlySpan text, bool isSpecialToken) => EncodeToIdsWithCache(text, null); diff --git a/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs b/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs index b4e2d7650f..8c6e5de95d 100644 --- a/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs +++ b/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs @@ -176,8 +176,8 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes /// /// Encode a text string to a list of tokens. /// - /// The text to encode. - /// Indicate if the text is a special token. This parameter is ignored in this model. + /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token. + /// Specifies whether the entire is considered a special token. This parameter is ignored in this model. /// The list of tokens generated from the text tokenization. public override IReadOnlyList Encode(string text, bool isSpecialToken = false) { @@ -224,16 +224,16 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f /// /// Encode a split text string to a list of Ids and add them to the accumulatedIds list. /// - /// The text to split. - /// Indicate if the text is a special token. This parameter is ignored in this model. + /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token. + /// Specifies whether the entire is considered a special token. This parameter is ignored in this model. /// The list of accumulated encoded Ids. public override void EncodeToIds(ReadOnlySpan text, bool isSpecialToken, IList accumulatedIds) => EncodeToIds(text, accumulatedIds); /// /// Get the number of tokens that the input text will be encoded to. /// - /// The text to encode. - /// Indicate if the token is special token. This parameter is ignored in this model. + /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token. + /// Specifies whether the entire is considered a special token. This parameter is ignored in this model. /// The number of tokens that the input text will be encoded to. public override int CountTokens(ReadOnlySpan text, bool isSpecialToken) => EncodeToIds(text, null); diff --git a/src/Microsoft.ML.Tokenizers/Model/Model.cs b/src/Microsoft.ML.Tokenizers/Model/Model.cs index d48eca7a6c..df98700029 100644 --- a/src/Microsoft.ML.Tokenizers/Model/Model.cs +++ b/src/Microsoft.ML.Tokenizers/Model/Model.cs @@ -16,16 +16,16 @@ public abstract class Model /// /// Encode a text to a list of tokens. /// - /// The text to encode. - /// Indicate if the text is a special token. + /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token. + /// Specifies whether the entire is considered a special token. /// The list of tokens generated from the text tokenization. public abstract IReadOnlyList Encode(string text, bool isSpecialToken = false); /// /// Encode a text to a list of Ids and add them to the accumulatedIds list. /// - /// The text to encode. - /// Indicate if the text is a special token. + /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token. + /// Specifies whether the entire is considered a special token. /// The list of accumulated encoded Ids. /// /// This method does the default implementation that uses the Encode method to get the token's Ids. @@ -49,8 +49,8 @@ public virtual void EncodeToIds(ReadOnlySpan text, bool isSpecialToken, IL /// /// Get the number of tokens that the input text will be encoded to. /// - /// The text to encode. - /// Indicate if the token is special token. + /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token. + /// Specifies whether the entire is considered a special token. /// The number of tokens that the input text will be encoded to. /// /// This method does the default implementation that uses the EncodeToIds method to get the number of token's Ids. diff --git a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs index 8e54369a0d..ccca3c63c7 100644 --- a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs +++ b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs @@ -310,8 +310,8 @@ await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) : /// /// Encode a split text string to a list of tokens. /// - /// The text to encode. - /// Indicate if the text is a special token. + /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token. + /// Specifies whether the entire is considered a special token. /// The list of tokens generated from the text tokenization. public override IReadOnlyList Encode(string text, bool isSpecialToken = false) { @@ -373,8 +373,8 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f /// /// Encode text to a list of Ids. /// - /// The text to encode. - /// Indicate if the text is a special token. + /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token. + /// Specifies whether the entire is considered a special token. /// The list of accumulated Ids. public override void EncodeToIds(ReadOnlySpan text, bool isSpecialToken, IList accumulatedIds) { @@ -420,8 +420,8 @@ public override void EncodeToIds(ReadOnlySpan text, bool isSpecialToken, I /// /// Get the number of tokens that the input text will be encoded to. /// - /// The text to tokenize. - /// Indicate if the token is special token. + /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token. + /// Specifies whether the entire is considered a special token. /// The number of tokens that the input text will be encoded to. public override int CountTokens(ReadOnlySpan text, bool isSpecialToken) {