diff --git a/src/Microsoft.ML.Tokenizers/AddedToken.cs b/src/Microsoft.ML.Tokenizers/AddedToken.cs deleted file mode 100644 index 01684cf49f..0000000000 --- a/src/Microsoft.ML.Tokenizers/AddedToken.cs +++ /dev/null @@ -1,91 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Collections.Generic; -using System.Text; - -namespace Microsoft.ML.Tokenizers -{ - /// - /// Represent a token added by the user on top of the existing Model vocabulary. - /// AddedToken can be configured to specify the behavior they should have in various situations - /// like: - /// - Whether they should only match single words - /// - Whether to include any WhiteSpace on its left or right - /// - public struct AddedToken : IEquatable - { - /// - /// Gets or sets the content of the added token - /// - public string Content { get; set; } - - /// - /// Gets or sets whether this token must be a single word or can break words - /// - internal bool SingleWord { get; set; } - - /// - /// Gets or sets whether this token should strip WhiteSpaces on its left - /// - internal bool LeftStrip { get; set; } - - /// - /// Gets or sets whether this token should strip WhiteSpaces on its right - /// - internal bool RightStrip { get; set; } - - /// - /// Gets or sets whether this token should be normalized - /// - internal bool Normalized { get; set; } - - /// - /// Gets or sets whether this token is special - /// - internal bool Special { get; set; } - - /// - /// Create a new AddedToken object. - /// - public AddedToken() - { - Content = ""; - SingleWord = LeftStrip = RightStrip = Special = false; - Normalized = true; - } - - /// - /// Create a new AddedToken object from the given content, specifying if it is intended to be a - /// special token. Special tokens are not normalized by default. - /// - /// The content of the added token. - /// Indicate whether this token is special. - public AddedToken(string content, bool special = false) : this() - { - Content = content ?? ""; - Special = special; - Normalized = !special; - } - - /// - /// Determines whether two token instances are equal. - /// - /// The token to compare with the current token. - public bool Equals(AddedToken other) => Content == other.Content; - - // We only want to hash on the content. AddedToken cannot be added multiple times with different options - /// - /// Returns the hash code for the current token. - /// - public override int GetHashCode() => Content.GetHashCode(); - - - /// - /// Defines an implicit conversion of a string object to AddedToken. - /// - public static implicit operator AddedToken(string token) => new AddedToken(token); - } -} diff --git a/src/Microsoft.ML.Tokenizers/Model/BPE.cs b/src/Microsoft.ML.Tokenizers/Model/BPE.cs index 20cfe7f38b..6bc231f41e 100644 --- a/src/Microsoft.ML.Tokenizers/Model/BPE.cs +++ b/src/Microsoft.ML.Tokenizers/Model/BPE.cs @@ -176,8 +176,8 @@ private Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken, stri /// /// Encode a text string to a list of tokens. /// - /// The text to encode. - /// Indicate if the token is a special token. + /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token. + /// Specifies whether the entire is considered a special token. This parameter is ignored in this model. /// The list of tokens generated from the text tokenization. public override IReadOnlyList Encode(string text, bool isSpecialToken = false) { @@ -192,17 +192,17 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f /// /// Encode a split text string to a list of Ids and add them to the accumulatedIds list. /// - /// The text to split. - /// Indicate if the token is a special token. + /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token. + /// Specifies whether the entire is considered a special token. This parameter is ignored in this model. /// The list of accumulated encoded Ids. public override void EncodeToIds(ReadOnlySpan text, bool isSpecialToken, IList accumulatedIds) => EncodeToIdsWithCache(text, accumulatedIds); /// /// Get the number of tokens that the input text will be encoded to. /// - /// The text to encode. - /// Indicate if the token is special token. - /// The number of tokens that the input text will be encoded to. + /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token. + /// Specifies whether the entire is considered a special token. This parameter is ignored in this model. + /// The number of tokens that the input text will be encoded to. This parameter is ignored in this model. public override int CountTokens(ReadOnlySpan text, bool isSpecialToken) => EncodeToIdsWithCache(text, null); /// diff --git a/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs b/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs index 3155c778ec..8c6e5de95d 100644 --- a/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs +++ b/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs @@ -176,8 +176,8 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes /// /// Encode a text string to a list of tokens. /// - /// The text to encode. - /// Indicate if the token is a special token. + /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token. + /// Specifies whether the entire is considered a special token. This parameter is ignored in this model. /// The list of tokens generated from the text tokenization. public override IReadOnlyList Encode(string text, bool isSpecialToken = false) { @@ -224,16 +224,16 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f /// /// Encode a split text string to a list of Ids and add them to the accumulatedIds list. /// - /// The text to split. - /// Indicate if the token is a special token. + /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token. + /// Specifies whether the entire is considered a special token. This parameter is ignored in this model. /// The list of accumulated encoded Ids. public override void EncodeToIds(ReadOnlySpan text, bool isSpecialToken, IList accumulatedIds) => EncodeToIds(text, accumulatedIds); /// /// Get the number of tokens that the input text will be encoded to. /// - /// The text to encode. - /// Indicate if the token is special token. + /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token. + /// Specifies whether the entire is considered a special token. This parameter is ignored in this model. /// The number of tokens that the input text will be encoded to. public override int CountTokens(ReadOnlySpan text, bool isSpecialToken) => EncodeToIds(text, null); diff --git a/src/Microsoft.ML.Tokenizers/Model/Model.cs b/src/Microsoft.ML.Tokenizers/Model/Model.cs index 815bd04a0b..df98700029 100644 --- a/src/Microsoft.ML.Tokenizers/Model/Model.cs +++ b/src/Microsoft.ML.Tokenizers/Model/Model.cs @@ -16,16 +16,16 @@ public abstract class Model /// /// Encode a text to a list of tokens. /// - /// The text to encode. - /// Indicate if the token is a special token. + /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token. + /// Specifies whether the entire is considered a special token. /// The list of tokens generated from the text tokenization. public abstract IReadOnlyList Encode(string text, bool isSpecialToken = false); /// /// Encode a text to a list of Ids and add them to the accumulatedIds list. /// - /// The text to encode. - /// Indicate if the token is a special token. + /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token. + /// Specifies whether the entire is considered a special token. /// The list of accumulated encoded Ids. /// /// This method does the default implementation that uses the Encode method to get the token's Ids. @@ -49,8 +49,8 @@ public virtual void EncodeToIds(ReadOnlySpan text, bool isSpecialToken, IL /// /// Get the number of tokens that the input text will be encoded to. /// - /// The text to encode. - /// Indicate if the token is special token. + /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token. + /// Specifies whether the entire is considered a special token. /// The number of tokens that the input text will be encoded to. /// /// This method does the default implementation that uses the EncodeToIds method to get the number of token's Ids. diff --git a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs index 60e9282a81..ccca3c63c7 100644 --- a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs +++ b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs @@ -4,11 +4,14 @@ using System; using System.Buffers; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.Linq; +using System.Net.Http; using System.Text; +using System.Text.RegularExpressions; using System.Threading; using System.Threading.Tasks; @@ -17,7 +20,7 @@ namespace Microsoft.ML.Tokenizers /// /// Represent the rapid Byte Pair Encoding model commonly referred to as Tiktoken. /// - public sealed class Tiktoken : Model + public sealed partial class Tiktoken : Model { private readonly Dictionary, int> _encoder; private readonly Dictionary> _decoder; @@ -100,6 +103,83 @@ private Tiktoken(Stream vocabStream, IReadOnlyDictionary? specialTo } } + /// + /// Create a Tiktoken tokenizer based on model name and vocab file. + /// + /// Model name + /// The stream to the BPE vocab file. + /// Extra special tokens other than the built-in ones for the model + /// The size of the cache to use. + /// To normalize the text before tokenization + /// The tokenizer + public static Tokenizer CreateByModelName( + string modelName, + Stream vocabStream, + IReadOnlyDictionary? extraSpecialTokens = null, + int cacheSize = LruCache.DefaultCacheSize, + Normalizer? normalizer = null) + { + if (string.IsNullOrEmpty(modelName)) + { + throw new ArgumentNullException(nameof(modelName)); + } + + (Dictionary SpecialTokens, Regex Regex) tiktokenConfiguration = GetTiktokenConfigurations(modelName); + + if (extraSpecialTokens is not null) + { + foreach (var extraSpecialToken in extraSpecialTokens) + { + tiktokenConfiguration.SpecialTokens.Add(extraSpecialToken.Key, extraSpecialToken.Value); + } + } + + return new Tokenizer( + new Tiktoken(vocabStream, tiktokenConfiguration.SpecialTokens, cacheSize), + new TikTokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens), + normalizer); + } + + /// + /// Create a Tiktoken tokenizer based on model name and vocab file. + /// + /// Model name + /// The stream to the BPE vocab file. + /// Extra special tokens other than the built-in ones for the model + /// The size of the cache to use. + /// To normalize the text before tokenization + /// used to request cancellation of the operation. + /// The tokenizer + public static async Task CreateByModelNameAsync( + string modelName, + Stream vocabStream, + IReadOnlyDictionary? extraSpecialTokens = null, + int cacheSize = LruCache.DefaultCacheSize, + Normalizer? normalizer = null, + CancellationToken cancellationToken = default) + { + if (string.IsNullOrEmpty(modelName)) + { + throw new ArgumentNullException(nameof(modelName)); + } + + (Dictionary SpecialTokens, Regex Regex) tiktokenConfiguration = GetTiktokenConfigurations(modelName); + + if (extraSpecialTokens is not null) + { + foreach (var extraSpecialToken in extraSpecialTokens) + { + tiktokenConfiguration.SpecialTokens.Add(extraSpecialToken.Key, extraSpecialToken.Value); + } + } + + return new Tokenizer( + await CreateAsync(vocabStream, tiktokenConfiguration.SpecialTokens, cacheSize, cancellationToken).ConfigureAwait(false), + new TikTokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens), + normalizer); + } + + private static (Dictionary?, Dictionary?) CreateEncoderDecoder(IReadOnlyDictionary? specialTokens) { if (specialTokens is not null) @@ -230,10 +310,10 @@ await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) : /// /// Encode a split text string to a list of tokens. /// - /// The text to encode. - /// Indicate if the token is a special token. + /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token. + /// Specifies whether the entire is considered a special token. /// The list of tokens generated from the text tokenization. - public override IReadOnlyList Encode(string text, bool isSpecialToken) + public override IReadOnlyList Encode(string text, bool isSpecialToken = false) { Token[] tokens; @@ -293,8 +373,8 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken) /// /// Encode text to a list of Ids. /// - /// The text to encode. - /// Indicate if the token is a special token. + /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token. + /// Specifies whether the entire is considered a special token. /// The list of accumulated Ids. public override void EncodeToIds(ReadOnlySpan text, bool isSpecialToken, IList accumulatedIds) { @@ -340,8 +420,8 @@ public override void EncodeToIds(ReadOnlySpan text, bool isSpecialToken, I /// /// Get the number of tokens that the input text will be encoded to. /// - /// The text to tokenize. - /// Indicate if the token is special token. + /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token. + /// Specifies whether the entire is considered a special token. /// The number of tokens that the input text will be encoded to. public override int CountTokens(ReadOnlySpan text, bool isSpecialToken) { @@ -462,12 +542,14 @@ public override int CountTokens(ReadOnlySpan text, bool isSpecialToken) /// The decoded string. public override string? Decode(IEnumerable ids, TokenizerDecoder? decoder = null, bool considerSpecialTokens = true) { - // Tiktoken does not ensure a one-to-one mapping between IDs and tokens. Consequently, decoding individual IDs into tokens is not supported; - // instead, decoding all IDs must be done collectively. - // Here is example of case that map one character to multiple Ids: - // '⭐' U-2B50 is mapped to Ids [2928, 99834] in the Tiktoken model. - // In other words, the character '⭐' has UTF-8 code point 0xE2, 0xAD, 0x90, Tiktoken will map 0xE2 to [2928] and 0xAD, 0x90 to [99834]. + // Tiktoken doesn't guarantee a one-to-one correspondence between IDs and UTF-16 words. + // Consequently, decoding individual IDs into UTF-16 string is not supported; instead, decoding all IDs must be performed collectively. + // Here's an example case that maps one character to multiple IDs: + // '⭐' U-2B50 is mapped to IDs [2928, 99834] in the Tiktoken model. + // In other words, the character '⭐' with UTF-8 code point 0xE2, 0xAD, 0x90 will be mapped by Tiktoken as follows: 0xE2 to [2928] + // and 0xAD, 0x90 to [99834]. Decoding 2928 and 99834 individually won't reconstruct the original UTF-16 string '⭐' U-2B50; + // decoding all IDs together is required to get the expected result. if (ids is null) { return null; @@ -556,6 +638,271 @@ static void ArrayPoolGrow(ref Span utf8Bytes, ref byte[]? arrayPoolArray, /// public IReadOnlyDictionary> Decoder => _decoder; + private const string EndOfText = "<|endoftext|>"; + private const string FimPrefix = "<|fim_prefix|>"; + private const string FimMiddle = "<|fim_middle|>"; + private const string FimSuffix = "<|fim_suffix|>"; + private const string EndOfPrompt = "<|endofprompt|>"; + + private static readonly HttpClient _httpClient = new HttpClient(); + + private enum ModelEncoding + { + None, + Cl100kBase, + P50kBase, + P50kEdit, + R50kBase, + GPT2 + } + + private static readonly (string Prefix, ModelEncoding Encoding)[] _modelPrefixToEncoding = + [ + // chat + ("gpt-4-", ModelEncoding.Cl100kBase), // e.g., gpt-4-0314, etc., plus gpt-4-32k + ("gpt-3.5-turbo-", ModelEncoding.Cl100kBase) // e.g, gpt-3.5-turbo-0301, -0401, etc. + ]; + + private static readonly Dictionary _modelToEncoding = + new Dictionary(StringComparer.OrdinalIgnoreCase) + { + // chat + { "gpt-4", ModelEncoding.Cl100kBase }, + { "gpt-3.5-turbo", ModelEncoding.Cl100kBase }, + + // text + { "text-davinci-003", ModelEncoding.P50kBase }, + { "text-davinci-002", ModelEncoding.P50kBase }, + { "text-davinci-001", ModelEncoding.R50kBase }, + { "text-curie-001", ModelEncoding.R50kBase }, + { "text-babbage-001", ModelEncoding.R50kBase }, + { "text-ada-001", ModelEncoding.R50kBase }, + { "davinci", ModelEncoding.R50kBase }, + { "curie", ModelEncoding.R50kBase }, + { "babbage", ModelEncoding.R50kBase }, + { "ada", ModelEncoding.R50kBase }, + + // code + { "code-davinci-002", ModelEncoding.P50kBase }, + { "code-davinci-001", ModelEncoding.P50kBase }, + { "code-cushman-002", ModelEncoding.P50kBase }, + { "code-cushman-001", ModelEncoding.P50kBase }, + { "davinci-codex", ModelEncoding.P50kBase }, + { "cushman-codex", ModelEncoding.P50kBase }, + + // edit + { "text-davinci-edit-001", ModelEncoding.P50kEdit }, + { "code-davinci-edit-001", ModelEncoding.P50kEdit }, + + // embeddings + // https://platform.openai.com/docs/guides/embeddings/what-are-embeddings + { "text-embedding-ada-002", ModelEncoding.Cl100kBase }, + { "text-embedding-3-small", ModelEncoding.Cl100kBase }, + { "text-embedding-3-large", ModelEncoding.Cl100kBase }, + + // old embeddings + { "text-similarity-davinci-001", ModelEncoding.R50kBase }, + { "text-similarity-curie-001", ModelEncoding.R50kBase }, + { "text-similarity-babbage-001", ModelEncoding.R50kBase }, + { "text-similarity-ada-001", ModelEncoding.R50kBase }, + { "text-search-davinci-doc-001", ModelEncoding.R50kBase }, + { "text-search-curie-doc-001", ModelEncoding.R50kBase }, + { "text-search-babbage-doc-001", ModelEncoding.R50kBase }, + { "text-search-ada-doc-001", ModelEncoding.R50kBase }, + { "code-search-babbage-code-001", ModelEncoding.R50kBase }, + { "code-search-ada-code-001", ModelEncoding.R50kBase }, + + // open source + { "gpt2", ModelEncoding.GPT2 } + }; + + private static ModelEncoding GetModelEncoding(string modelName) + { + if (!_modelToEncoding.TryGetValue(modelName, out ModelEncoding encoder)) + { + foreach ((string Prefix, ModelEncoding Encoding) in _modelPrefixToEncoding) + { + if (modelName.StartsWith(Prefix, StringComparison.OrdinalIgnoreCase)) + { + encoder = Encoding; + break; + } + } + } + + if (encoder == ModelEncoding.None) + { + throw new NotSupportedException($"The model '{modelName}' is not supported."); + } + + return encoder; + } + + internal static (Dictionary SpecialTokens, Regex Regex) GetTiktokenConfigurations(string modelName) + { + ModelEncoding modelEncoding = GetModelEncoding(modelName); + + switch (modelEncoding) + { + case ModelEncoding.Cl100kBase: + return (new Dictionary + { { EndOfText, 100257}, { FimPrefix, 100258}, { FimMiddle, 100259}, { FimSuffix, 100260}, { EndOfPrompt, 100276} }, Cl100kBaseRegex()); + + case ModelEncoding.P50kBase: + return (new Dictionary { { EndOfText, 50256 } }, P50kBaseRegex()); + + case ModelEncoding.P50kEdit: + return (new Dictionary + { { EndOfText, 50256 }, { FimPrefix, 50281 }, { FimMiddle, 50282 }, { FimSuffix, 50283 } }, P50kBaseRegex()); + + case ModelEncoding.R50kBase: + return (new Dictionary { { EndOfText, 50256 } }, P50kBaseRegex()); + + case ModelEncoding.GPT2: + return (new Dictionary { { EndOfText, 50256 }, }, P50kBaseRegex()); + + default: + Debug.Assert(false, $"Unexpected encoder [{modelEncoding}]"); + throw new NotSupportedException($"The model '{modelName}' is not supported."); + } + } + + /// + /// Create tokenizer based on model name + /// + /// Model name + /// Extra special tokens other than the built-in ones for the model + /// To normalize the text before tokenization + /// used to request cancellation of the operation. + /// The tokenizer + public static Task CreateByModelNameAsync( + string modelName, + IReadOnlyDictionary? extraSpecialTokens = null, + Normalizer? normalizer = null, + CancellationToken cancellationToken = default) + { + try + { + return CreateByEncoderNameAsync(modelName, GetModelEncoding(modelName), extraSpecialTokens, normalizer, cancellationToken); + } + catch (Exception ex) + { + return Task.FromException(ex); + } + } + + // Regex patterns based on https://github.com/openai/tiktoken/blob/main/tiktoken_ext/openai_public.py + + private const string Cl100kBaseRegexPattern = /*lang=regex*/ @"'(?i:[sdmt]|re|ve|ll)|(?>[^\r\n\p{L}\p{N}]?)\p{L}+|\p{N}{1,3}| ?(?>[^\s\p{L}\p{N}]+)[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"; + private const string P50kBaseRegexPattern = /*lang=regex*/ @"'(?:[sdmt]|re|ve|ll)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"; + + private const string Cl100kBaseVocabUrl = @"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken"; + private const string P50RanksUrl = @"https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken"; + private const string R50RanksUrl = @"https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken"; + private const string GPT2Url = @"https://pythia.blob.core.windows.net/public/encoding/gpt2.tiktoken"; + +#if NET7_0_OR_GREATER + [GeneratedRegex(Cl100kBaseRegexPattern)] + private static partial Regex Cl100kBaseRegex(); + + [GeneratedRegex(P50kBaseRegexPattern)] + internal static partial Regex P50kBaseRegex(); +#else + private static Regex? _cl100kBaseRegex; + private static Regex Cl100kBaseRegex() => _cl100kBaseRegex ??= new Regex(Cl100kBaseRegexPattern, RegexOptions.Compiled); + + private static Regex? _p50kBaseRegex; + internal static Regex P50kBaseRegex() => _p50kBaseRegex ??= new Regex(P50kBaseRegexPattern, RegexOptions.Compiled); +#endif + + /// + /// Create tokenizer based on encoder name and extra special tokens + /// + /// Model name + /// Encoder label + /// Extra special tokens other than the built-in ones for the encoder + /// To normalize the text before tokenization + /// used to request cancellation of the operation. + /// The tokenizer + /// Throws if the model name is not supported + private static Task CreateByEncoderNameAsync( + string modelName, + ModelEncoding modelEncoding, + IReadOnlyDictionary? extraSpecialTokens, + Normalizer? normalizer, + CancellationToken cancellationToken) + { + switch (modelEncoding) + { + case ModelEncoding.Cl100kBase: + var specialTokens = new Dictionary + { { EndOfText, 100257}, { FimPrefix, 100258}, { FimMiddle, 100259}, { FimSuffix, 100260}, { EndOfPrompt, 100276} }; + return CreateTikTokenTokenizerAsync(Cl100kBaseRegex(), Cl100kBaseVocabUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken); + + case ModelEncoding.P50kBase: + specialTokens = new Dictionary { { EndOfText, 50256 } }; + return CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken); + + case ModelEncoding.P50kEdit: + specialTokens = new Dictionary + { { EndOfText, 50256 }, { FimPrefix, 50281 }, { FimMiddle, 50282 }, { FimSuffix, 50283 } }; + return CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken); + + case ModelEncoding.R50kBase: + specialTokens = new Dictionary { { EndOfText, 50256 } }; + return CreateTikTokenTokenizerAsync(P50kBaseRegex(), R50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken); + + case ModelEncoding.GPT2: + specialTokens = new Dictionary { { EndOfText, 50256 }, }; + return CreateTikTokenTokenizerAsync(P50kBaseRegex(), GPT2Url, specialTokens, extraSpecialTokens, normalizer, cancellationToken); + + default: + Debug.Assert(false, $"Unexpected encoder [{modelEncoding}]"); + throw new NotSupportedException($"The model '{modelName}' is not supported."); + } + } + + private static readonly ConcurrentDictionary, int> encoder, Dictionary vocab, Dictionary> decoder)> _tiktokenCache = new(StringComparer.OrdinalIgnoreCase); + + /// + /// Create tokenizer based on regex pattern, BPE rank file and special tokens + /// + /// Regex to break a long string + /// BPE rank file + /// Special tokens mapping. This may be mutated by the method. + /// Extra special tokens other than the built-in ones for the encoder + /// To normalize the text before tokenization + /// used to request cancellation of the operation. + /// The tokenizer + private static async Task CreateTikTokenTokenizerAsync( + Regex regex, + string mergeableRanksFileUrl, + Dictionary specialTokens, + IReadOnlyDictionary? extraSpecialTokens, + Normalizer? normalizer, + CancellationToken cancellationToken) + { + if (extraSpecialTokens is not null) + { + foreach (var extraSpecialToken in extraSpecialTokens) + { + specialTokens.Add(extraSpecialToken.Key, extraSpecialToken.Value); + } + } + + if (!_tiktokenCache.TryGetValue(mergeableRanksFileUrl, out (Dictionary, int> encoder, Dictionary vocab, Dictionary> decoder) cache)) + { + using (Stream stream = await Helpers.GetStreamAsync(_httpClient, mergeableRanksFileUrl, cancellationToken).ConfigureAwait(false)) + { + cache = await Tiktoken.LoadTikTokenBpeAsync(stream, useAsync: true, cancellationToken).ConfigureAwait(false); + } + + _tiktokenCache.TryAdd(mergeableRanksFileUrl, cache); + } + + return new Tokenizer(new Tiktoken(cache.encoder, cache.decoder, cache.vocab, specialTokens), new TikTokenPreTokenizer(regex, specialTokens), normalizer); + } + private static unsafe int GetUtf8Bytes(ReadOnlySpan source, Span destination) { #if NETCOREAPP diff --git a/src/Microsoft.ML.Tokenizers/PreTokenizer/Roberta.cs b/src/Microsoft.ML.Tokenizers/PreTokenizer/Roberta.cs index 81b4eb642f..9e8db6c9c8 100644 --- a/src/Microsoft.ML.Tokenizers/PreTokenizer/Roberta.cs +++ b/src/Microsoft.ML.Tokenizers/PreTokenizer/Roberta.cs @@ -30,7 +30,7 @@ public override IEnumerable PreTokenize(string text, bool considerSpecial return Array.Empty(); } - return SplitText(text, Tokenizer.P50kBaseRegex()); + return SplitText(text, Tiktoken.P50kBaseRegex()); } } } diff --git a/src/Microsoft.ML.Tokenizers/Tokenizer.cs b/src/Microsoft.ML.Tokenizers/Tokenizer.cs index c64ebf256e..254ffa0801 100644 --- a/src/Microsoft.ML.Tokenizers/Tokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Tokenizer.cs @@ -3,11 +3,10 @@ // See the LICENSE file in the project root for more information. using System; -using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics; using System.IO; -using System.Net.Http; +using System.Linq; using System.Text.RegularExpressions; using System.Threading; using System.Threading.Tasks; @@ -137,251 +136,86 @@ public int CountTokens(string text, bool considerSpecialTokens = true) } /// - /// Decodes the Id to the mapped token. + /// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit. /// - /// The id to map to the token. - /// Indicate if want to consider the special tokens during the decoding. - /// The decoded string or null if there is no token mapped to the input id. - public string? Decode(int id, bool considerSpecialTokens = true) => Model.MapIdToToken(id, considerSpecialTokens); + /// The text to encode. + /// The maximum token count to limit the encoding capacity. + /// If the tokenizer's normalization is enabled, the input text will be represented in its normalization form; otherwise, it will remain unchanged as the input text. + /// The token count can be generated which should be smaller than the maximum token count. + /// Indicate if want to consider the special tokens during the encoding. + /// + /// The index of the maximum encoding capacity within the processed text without surpassing the token limit. + /// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely, if all tokens fit, the result will be length of the . + /// + /// The input text is null. + /// The maximum token count must be greater than 0. + public int IndexOfTokenCount(string text, int maxTokenCount, out string processedText, out int tokenCount, bool considerSpecialTokens = true) + => IndexOf(text, maxTokenCount, fromStart: true, considerSpecialTokens, out processedText, out tokenCount); /// - /// Decode the given ids, back to a String. + /// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit. /// - /// The list of ids that we want to decode. - /// Whether the special tokens should be kept in the decoded string. - /// The decoded string. - public string? Decode(IEnumerable ids, bool considerSpecialTokens = true) => Model.Decode(ids, Decoder, considerSpecialTokens); - - private const string EndOfText = "<|endoftext|>"; - private const string FimPrefix = "<|fim_prefix|>"; - private const string FimMiddle = "<|fim_middle|>"; - private const string FimSuffix = "<|fim_suffix|>"; - private const string EndOfPrompt = "<|endofprompt|>"; - - private static readonly HttpClient _httpClient = new HttpClient(); - - private enum ModelEncoding + /// The text to encode. + /// The maximum token count to limit the encoding capacity. + /// If the tokenizer's normalization is enabled, the input text will be represented in its normalization form; otherwise, it will remain unchanged as the input text. + /// The token count can be generated which should be smaller than the maximum token count. + /// Indicate if want to consider the special tokens during the encoding. + /// + /// The start index of the maximum encoding capacity within the processed text without surpassing the token limit. + /// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the ; conversely, if all tokens fit, the result will be 0. + /// + /// The input text is null. + /// The maximum token count must be greater than 0. + /// + /// If the whole text can be encoded within the token limit, the returned index will be 0. + /// + public int LastIndexOfTokenCount(string text, int maxTokenCount, out string processedText, out int tokenCount, bool considerSpecialTokens = true) + => IndexOf(text, maxTokenCount, fromStart: false, considerSpecialTokens, out processedText, out tokenCount); + + private int IndexOf(string text, int maxTokenCount, bool fromStart, bool considerSpecialTokens, out string processedText, out int tokenCount) { - None, - Cl100kBase, - P50kBase, - P50kEdit, - R50kBase, - GPT2 - } - - private static readonly (string Prefix, ModelEncoding Encoding)[] _modelPrefixToEncoding = - [ - // chat - ("gpt-4-", ModelEncoding.Cl100kBase), // e.g., gpt-4-0314, etc., plus gpt-4-32k - ("gpt-3.5-turbo-", ModelEncoding.Cl100kBase) // e.g, gpt-3.5-turbo-0301, -0401, etc. - ]; - - private static readonly Dictionary _modelToEncoding = - new Dictionary(StringComparer.OrdinalIgnoreCase) - { - // chat - { "gpt-4", ModelEncoding.Cl100kBase }, - { "gpt-3.5-turbo", ModelEncoding.Cl100kBase }, - - // text - { "text-davinci-003", ModelEncoding.P50kBase }, - { "text-davinci-002", ModelEncoding.P50kBase }, - { "text-davinci-001", ModelEncoding.R50kBase }, - { "text-curie-001", ModelEncoding.R50kBase }, - { "text-babbage-001", ModelEncoding.R50kBase }, - { "text-ada-001", ModelEncoding.R50kBase }, - { "davinci", ModelEncoding.R50kBase }, - { "curie", ModelEncoding.R50kBase }, - { "babbage", ModelEncoding.R50kBase }, - { "ada", ModelEncoding.R50kBase }, - - // code - { "code-davinci-002", ModelEncoding.P50kBase }, - { "code-davinci-001", ModelEncoding.P50kBase }, - { "code-cushman-002", ModelEncoding.P50kBase }, - { "code-cushman-001", ModelEncoding.P50kBase }, - { "davinci-codex", ModelEncoding.P50kBase }, - { "cushman-codex", ModelEncoding.P50kBase }, - - // edit - { "text-davinci-edit-001", ModelEncoding.P50kEdit }, - { "code-davinci-edit-001", ModelEncoding.P50kEdit }, - - // embeddings - // https://platform.openai.com/docs/guides/embeddings/what-are-embeddings - { "text-embedding-ada-002", ModelEncoding.Cl100kBase }, - { "text-embedding-3-small", ModelEncoding.Cl100kBase }, - { "text-embedding-3-large", ModelEncoding.Cl100kBase }, - - // old embeddings - { "text-similarity-davinci-001", ModelEncoding.R50kBase }, - { "text-similarity-curie-001", ModelEncoding.R50kBase }, - { "text-similarity-babbage-001", ModelEncoding.R50kBase }, - { "text-similarity-ada-001", ModelEncoding.R50kBase }, - { "text-search-davinci-doc-001", ModelEncoding.R50kBase }, - { "text-search-curie-doc-001", ModelEncoding.R50kBase }, - { "text-search-babbage-doc-001", ModelEncoding.R50kBase }, - { "text-search-ada-doc-001", ModelEncoding.R50kBase }, - { "code-search-babbage-code-001", ModelEncoding.R50kBase }, - { "code-search-ada-code-001", ModelEncoding.R50kBase }, - - // open source - { "gpt2", ModelEncoding.GPT2 } - }; - + if (text is null) + { + throw new ArgumentNullException(nameof(text)); + } - /// - /// Create tokenizer based on model name - /// - /// Model name - /// Extra special tokens other than the built-in ones for the model - /// To normalize the text before tokenization - /// used to request cancellation of the operation. - /// The tokenizer - public static Task CreateByModelNameAsync( - string modelName, - IReadOnlyDictionary? extraSpecialTokens = null, - Normalizer? normalizer = null, - CancellationToken cancellationToken = default) - { - try + if (maxTokenCount <= 0) { - ModelEncoding encoder; + throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The max token count must be greater than 0."); + } - if (!_modelToEncoding.TryGetValue(modelName, out encoder)) - { - foreach ((string Prefix, ModelEncoding Encoding) in _modelPrefixToEncoding) - { - if (modelName.StartsWith(Prefix, StringComparison.OrdinalIgnoreCase)) - { - encoder = Encoding; - break; - } - } - } + processedText = Normalizer is not null ? Normalizer.Normalize(text) : text; + tokenCount = 0; - if (encoder == ModelEncoding.None) + IEnumerable splits = PreTokenizer.PreTokenize(processedText, considerSpecialTokens); + foreach (Split split in (fromStart ? splits : splits.Reverse())) + { + int count = Model.CountTokens(split.TokenSpan, split.IsSpecialToken); + if (tokenCount > maxTokenCount - count) { - throw new NotImplementedException($"Doesn't support this model [{modelName}]"); + return fromStart ? split.Offset.Index : split.Offset.Index + split.Offset.Length; } - return CreateByEncoderNameAsync(encoder, extraSpecialTokens, normalizer, cancellationToken); - } - catch (Exception ex) - { - return Task.FromException(ex); + tokenCount += count; } - } - - // Regex patterns based on https://github.com/openai/tiktoken/blob/main/tiktoken_ext/openai_public.py - - private const string Cl100kBaseRegexPattern = /*lang=regex*/ @"'(?i:[sdmt]|re|ve|ll)|(?>[^\r\n\p{L}\p{N}]?)\p{L}+|\p{N}{1,3}| ?(?>[^\s\p{L}\p{N}]+)[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"; - private const string P50kBaseRegexPattern = /*lang=regex*/ @"'(?:[sdmt]|re|ve|ll)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"; - - private const string Cl100kBaseVocabUrl = @"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken"; - private const string P50RanksUrl = @"https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken"; - private const string R50RanksUrl = @"https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken"; - private const string GPT2Url = @"https://pythia.blob.core.windows.net/public/encoding/gpt2.tiktoken"; - -#if NET7_0_OR_GREATER - [GeneratedRegex(Cl100kBaseRegexPattern)] - private static partial Regex Cl100kBaseRegex(); - [GeneratedRegex(P50kBaseRegexPattern)] - internal static partial Regex P50kBaseRegex(); -#else - private static Regex? _cl100kBaseRegex; - private static Regex Cl100kBaseRegex() => _cl100kBaseRegex ??= new Regex(Cl100kBaseRegexPattern, RegexOptions.Compiled); - - private static Regex? _p50kBaseRegex; - internal static Regex P50kBaseRegex() => _p50kBaseRegex ??= new Regex(P50kBaseRegexPattern, RegexOptions.Compiled); -#endif + return fromStart ? processedText.Length : 0; + } /// - /// Create tokenizer based on encoder name and extra special tokens + /// Decodes the Id to the mapped token. /// - /// Encoder label - /// Extra special tokens other than the built-in ones for the encoder - /// To normalize the text before tokenization - /// used to request cancellation of the operation. - /// The tokenizer - /// Throws if the encoder is not supported - private static Task CreateByEncoderNameAsync( - ModelEncoding modelEncoding, - IReadOnlyDictionary? extraSpecialTokens, - Normalizer? normalizer, - CancellationToken cancellationToken) - { - switch (modelEncoding) - { - case ModelEncoding.Cl100kBase: - var specialTokens = new Dictionary - { { EndOfText, 100257}, { FimPrefix, 100258}, { FimMiddle, 100259}, { FimSuffix, 100260}, { EndOfPrompt, 100276} }; - return CreateTikTokenTokenizerAsync(Cl100kBaseRegex(), Cl100kBaseVocabUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken); - - case ModelEncoding.P50kBase: - specialTokens = new Dictionary { { EndOfText, 50256 } }; - return CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken); - - case ModelEncoding.P50kEdit: - specialTokens = new Dictionary - { { EndOfText, 50256 }, { FimPrefix, 50281 }, { FimMiddle, 50282 }, { FimSuffix, 50283 } }; - return CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken); - - case ModelEncoding.R50kBase: - specialTokens = new Dictionary { { EndOfText, 50256 } }; - return CreateTikTokenTokenizerAsync(P50kBaseRegex(), R50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken); - - case ModelEncoding.GPT2: - specialTokens = new Dictionary { { EndOfText, 50256 }, }; - return CreateTikTokenTokenizerAsync(P50kBaseRegex(), GPT2Url, specialTokens, extraSpecialTokens, normalizer, cancellationToken); - - default: - Debug.Assert(false, $"Unexpected encoder [{modelEncoding}]"); - throw new NotImplementedException($"Doesn't support this encoder [{modelEncoding}]"); - } - } - - private static readonly ConcurrentDictionary, int> encoder, Dictionary vocab, Dictionary> decoder)> _tiktokenCache = new(StringComparer.OrdinalIgnoreCase); + /// The id to map to the token. + /// Indicate if want to consider the special tokens during the decoding. + /// The decoded string or null if there is no token mapped to the input id. + public string? Decode(int id, bool considerSpecialTokens = true) => Model.MapIdToToken(id, considerSpecialTokens); /// - /// Create tokenizer based on regex pattern, BPE rank file and special tokens + /// Decode the given ids, back to a String. /// - /// Regex to break a long string - /// BPE rank file - /// Special tokens mapping. This may be mutated by the method. - /// Extra special tokens other than the built-in ones for the encoder - /// To normalize the text before tokenization - /// used to request cancellation of the operation. - /// The tokenizer - private static async Task CreateTikTokenTokenizerAsync( - Regex regex, - string mergeableRanksFileUrl, - Dictionary specialTokens, - IReadOnlyDictionary? extraSpecialTokens, - Normalizer? normalizer, - CancellationToken cancellationToken) - { - if (extraSpecialTokens is not null) - { - foreach (var extraSpecialToken in extraSpecialTokens) - { - specialTokens.Add(extraSpecialToken.Key, extraSpecialToken.Value); - } - } - - if (!_tiktokenCache.TryGetValue(mergeableRanksFileUrl, out (Dictionary, int> encoder, Dictionary vocab, Dictionary> decoder) cache)) - { - using (Stream stream = await Helpers.GetStreamAsync(_httpClient, mergeableRanksFileUrl, cancellationToken).ConfigureAwait(false)) - { - cache = await Tiktoken.LoadTikTokenBpeAsync(stream, useAsync: true, cancellationToken).ConfigureAwait(false); - } - - _tiktokenCache.TryAdd(mergeableRanksFileUrl, cache); - } - - return new Tokenizer(new Tiktoken(cache.encoder, cache.decoder, cache.vocab, specialTokens), new TikTokenPreTokenizer(regex, specialTokens), normalizer); - } + /// The list of ids that we want to decode. + /// Whether the special tokens should be kept in the decoded string. + /// The decoded string. + public string? Decode(IEnumerable ids, bool considerSpecialTokens = true) => Model.Decode(ids, Decoder, considerSpecialTokens); } } diff --git a/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs b/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs index 2959184b5d..9ae8517494 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs @@ -225,6 +225,8 @@ public async void TestGpt2Vocab() Assert.Equal(12, encoding.Ids.Count); Assert.Equal(encoding.Ids, ids); Assert.Equal(12, tokenizer.CountTokens(text)); + + TokenizerTests.TestTokenLimits(tokenizer); } private static string WriteToMergeFile((string, string)[] mergeEntries) diff --git a/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs b/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs index ccf0e66ef9..4518f169bf 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs @@ -96,6 +96,7 @@ public async void TokenizationTest() Tokenizer tokenizer = new Tokenizer(new EnglishRoberta(vocabFile, mergeFile, translationFile), RobertaPreTokenizer.Instance); TestTokenizer(tokenizer); + TokenizerTests.TestTokenLimits(tokenizer); tokenizer = new Tokenizer(new EnglishRoberta(vocabFile, mergeFile, translationFile, filterUnsupportedChars: false), RobertaPreTokenizer.Instance); TestTokenizer(tokenizer); diff --git a/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs b/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs index af6581401a..fc5a0772a2 100644 --- a/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs +++ b/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs @@ -25,11 +25,11 @@ public class TiktokenTests { IMEnd, 100265}, }; - public static Tokenizer GPT4 { get; } = Tokenizer.CreateByModelNameAsync("gpt-4", _specialTokens).GetAwaiter().GetResult(); - public static Tokenizer GPT2 { get; } = Tokenizer.CreateByModelNameAsync("gpt2").GetAwaiter().GetResult(); - public static Tokenizer P50kBase { get; } = Tokenizer.CreateByModelNameAsync("text-davinci-003").GetAwaiter().GetResult(); - public static Tokenizer R50kBase { get; } = Tokenizer.CreateByModelNameAsync("ada").GetAwaiter().GetResult(); - public static Tokenizer P50kEdit { get; } = Tokenizer.CreateByModelNameAsync("text-davinci-edit-001").GetAwaiter().GetResult(); + public static Tokenizer GPT4 { get; } = Tiktoken.CreateByModelNameAsync("gpt-4", _specialTokens).GetAwaiter().GetResult(); + public static Tokenizer GPT2 { get; } = Tiktoken.CreateByModelNameAsync("gpt2").GetAwaiter().GetResult(); + public static Tokenizer P50kBase { get; } = Tiktoken.CreateByModelNameAsync("text-davinci-003").GetAwaiter().GetResult(); + public static Tokenizer R50kBase { get; } = Tiktoken.CreateByModelNameAsync("ada").GetAwaiter().GetResult(); + public static Tokenizer P50kEdit { get; } = Tiktoken.CreateByModelNameAsync("text-davinci-edit-001").GetAwaiter().GetResult(); [Fact] public async void TestTokenizerCreation() @@ -61,6 +61,18 @@ public async void TestTokenizerCreation() tokenizer = new Tokenizer(await Tiktoken.CreateAsync(stream, specialTokensEncoder), GPT4.PreTokenizer); } TestGPT4TokenizationEncoding(tokenizer); + + using (Stream stream = File.OpenRead(tokenizerDataFileName)) + { + tokenizer = Tiktoken.CreateByModelName("gpt-4", stream); + } + TestGPT4TokenizationEncoding(tokenizer); + + using (Stream stream = File.OpenRead(tokenizerDataFileName)) + { + tokenizer = await Tiktoken.CreateByModelNameAsync("gpt-3.5-turbo", stream); + } + TestGPT4TokenizationEncoding(tokenizer); } finally { @@ -82,6 +94,8 @@ private void TestGPT4TokenizationEncoding(Tokenizer tokenizer) Assert.Equal(new List<(int, int)> { (0, 5), (5, 6) }, result.Offsets); Assert.Equal(encoded.Count, idsCount); Assert.Equal(encoded, result.Ids); + + TestGPT4Tokenizer(tokenizer); } [Fact] @@ -101,13 +115,12 @@ public void TestEncode1() Assert.Equal(encoded, result.Ids); } - [Fact] - public void TestEncode2() + private void TestGPT4Tokenizer(Tokenizer gpt4Tokenizer) { string text = ReadAndSanitizeFile("./Data/lib.rs.txt"); - IReadOnlyList encoded = GPT4.EncodeToIds(text, considerSpecialTokens: false); + IReadOnlyList encoded = gpt4Tokenizer.EncodeToIds(text, considerSpecialTokens: false); Assert.Equal(5584, encoded.Count); - int idsCount = GPT4.CountTokens(text, considerSpecialTokens: false); + int idsCount = gpt4Tokenizer.CountTokens(text, considerSpecialTokens: false); Assert.Equal(encoded.Count, idsCount); using (Stream stream = File.OpenRead("./Data/tokens.json")) @@ -116,8 +129,10 @@ public void TestEncode2() Assert.Equal(expected!, encoded.ToArray()); } - string? decoded = GPT4.Decode(encoded.ToArray()); + string? decoded = gpt4Tokenizer.Decode(encoded.ToArray()); Assert.Equal(text, decoded!); + + TokenizerTests.TestTokenLimits(gpt4Tokenizer); } [Fact] @@ -283,7 +298,7 @@ public void TestEncodeR50kBase() [InlineData("gpt2")] public async void TestAllSupportedModelNames(string modelName) { - Tokenizer tokenizer = await Tokenizer.CreateByModelNameAsync(modelName); + Tokenizer tokenizer = await Tiktoken.CreateByModelNameAsync(modelName); Assert.NotNull(tokenizer.Model); Assert.NotNull(tokenizer.PreTokenizer); } diff --git a/test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs b/test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs new file mode 100644 index 0000000000..48efe20f78 --- /dev/null +++ b/test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs @@ -0,0 +1,70 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Tokenizers; +using System; +using System.Collections.Generic; +using System.Linq; +using Xunit; + +namespace Microsoft.ML.Tokenizers.Tests +{ + public class TokenizerTests + { + internal static void TestTokenLimits(Tokenizer tokenizer) + { + string input = @" + OpenAI's large language models (sometimes referred to as GPT's) process text using tokens, which are common sequences of characters found in a set of text. + The models learn to understand the statistical relationships between these tokens, and excel at producing the next token in a sequence of tokens. + You can use the tool below to understand how a piece of text might be tokenized by a language model, and the total count of tokens in that piece of text. + It's important to note that the exact tokenization process varies between models. Newer models like GPT-3.5 and GPT-4 use a different tokenizer than previous models, + and will produce different tokens for the same input text. + "; + + IReadOnlyList fullIdsList = tokenizer.EncodeToIds(input); + + for (int i = 1; i <= fullIdsList.Count; i++) + { + int index1 = tokenizer.IndexOfTokenCount(input, maxTokenCount: i, out string processedText1, out int tokenCount1); + int index2 = tokenizer.LastIndexOfTokenCount(input, maxTokenCount: i, out string processedText2, out int tokenCount2); + + + IReadOnlyList? prefixIds = null; + IReadOnlyList? suffixIds = null; + + if (tokenCount1 > 0) + { + string prefixString = processedText1.Substring(0, index1); + prefixIds = tokenizer.EncodeToIds(prefixString); + Assert.Equal(tokenCount1, prefixIds.Count); + Assert.Equal(prefixIds, fullIdsList.Take(prefixIds.Count)); + } + + if (tokenCount2 > 0) + { + string suffixString = processedText2.Substring(index2); + suffixIds = tokenizer.EncodeToIds(suffixString); + Assert.Equal(tokenCount2, suffixIds.Count); + Assert.Equal(suffixIds, fullIdsList.Skip(fullIdsList.Count - suffixIds.Count)); + } + + if (i == fullIdsList.Count) + { + Assert.Equal(processedText1.Length, index1); + Assert.Equal(0, index2); + Assert.Equal(fullIdsList, prefixIds); + Assert.Equal(fullIdsList, suffixIds); + } + } + + Assert.Throws(() => tokenizer.IndexOfTokenCount(input, maxTokenCount: 0, out _, out _)); + Assert.Throws(() => tokenizer.IndexOfTokenCount(input, maxTokenCount: -1, out _, out _)); + Assert.Throws(() => tokenizer.LastIndexOfTokenCount(input, maxTokenCount: 0, out _, out _)); + Assert.Throws(() => tokenizer.LastIndexOfTokenCount(input, maxTokenCount: -1, out _, out _)); + + Assert.Throws(() => tokenizer.IndexOfTokenCount(null!, maxTokenCount: 10, out _, out _)); + Assert.Throws(() => tokenizer.LastIndexOfTokenCount(null!, maxTokenCount: 10, out _, out _)); + } + } +} \ No newline at end of file