diff --git a/src/Microsoft.ML.Tokenizers/AddedToken.cs b/src/Microsoft.ML.Tokenizers/AddedToken.cs
deleted file mode 100644
index 01684cf49f..0000000000
--- a/src/Microsoft.ML.Tokenizers/AddedToken.cs
+++ /dev/null
@@ -1,91 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-using System;
-using System.Collections.Generic;
-using System.Text;
-
-namespace Microsoft.ML.Tokenizers
-{
- ///
- /// Represent a token added by the user on top of the existing Model vocabulary.
- /// AddedToken can be configured to specify the behavior they should have in various situations
- /// like:
- /// - Whether they should only match single words
- /// - Whether to include any WhiteSpace on its left or right
- ///
- public struct AddedToken : IEquatable
- {
- ///
- /// Gets or sets the content of the added token
- ///
- public string Content { get; set; }
-
- ///
- /// Gets or sets whether this token must be a single word or can break words
- ///
- internal bool SingleWord { get; set; }
-
- ///
- /// Gets or sets whether this token should strip WhiteSpaces on its left
- ///
- internal bool LeftStrip { get; set; }
-
- ///
- /// Gets or sets whether this token should strip WhiteSpaces on its right
- ///
- internal bool RightStrip { get; set; }
-
- ///
- /// Gets or sets whether this token should be normalized
- ///
- internal bool Normalized { get; set; }
-
- ///
- /// Gets or sets whether this token is special
- ///
- internal bool Special { get; set; }
-
- ///
- /// Create a new AddedToken object.
- ///
- public AddedToken()
- {
- Content = "";
- SingleWord = LeftStrip = RightStrip = Special = false;
- Normalized = true;
- }
-
- ///
- /// Create a new AddedToken object from the given content, specifying if it is intended to be a
- /// special token. Special tokens are not normalized by default.
- ///
- /// The content of the added token.
- /// Indicate whether this token is special.
- public AddedToken(string content, bool special = false) : this()
- {
- Content = content ?? "";
- Special = special;
- Normalized = !special;
- }
-
- ///
- /// Determines whether two token instances are equal.
- ///
- /// The token to compare with the current token.
- public bool Equals(AddedToken other) => Content == other.Content;
-
- // We only want to hash on the content. AddedToken cannot be added multiple times with different options
- ///
- /// Returns the hash code for the current token.
- ///
- public override int GetHashCode() => Content.GetHashCode();
-
-
- ///
- /// Defines an implicit conversion of a string object to AddedToken.
- ///
- public static implicit operator AddedToken(string token) => new AddedToken(token);
- }
-}
diff --git a/src/Microsoft.ML.Tokenizers/Model/BPE.cs b/src/Microsoft.ML.Tokenizers/Model/BPE.cs
index 20cfe7f38b..6bc231f41e 100644
--- a/src/Microsoft.ML.Tokenizers/Model/BPE.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/BPE.cs
@@ -176,8 +176,8 @@ private Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken, stri
///
/// Encode a text string to a list of tokens.
///
- /// The text to encode.
- /// Indicate if the token is a special token.
+ /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token.
+ /// Specifies whether the entire is considered a special token. This parameter is ignored in this model.
/// The list of tokens generated from the text tokenization.
public override IReadOnlyList Encode(string text, bool isSpecialToken = false)
{
@@ -192,17 +192,17 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f
///
/// Encode a split text string to a list of Ids and add them to the accumulatedIds list.
///
- /// The text to split.
- /// Indicate if the token is a special token.
+ /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token.
+ /// Specifies whether the entire is considered a special token. This parameter is ignored in this model.
/// The list of accumulated encoded Ids.
public override void EncodeToIds(ReadOnlySpan text, bool isSpecialToken, IList accumulatedIds) => EncodeToIdsWithCache(text, accumulatedIds);
///
/// Get the number of tokens that the input text will be encoded to.
///
- /// The text to encode.
- /// Indicate if the token is special token.
- /// The number of tokens that the input text will be encoded to.
+ /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token.
+ /// Specifies whether the entire is considered a special token. This parameter is ignored in this model.
+ /// The number of tokens that the input text will be encoded to. This parameter is ignored in this model.
public override int CountTokens(ReadOnlySpan text, bool isSpecialToken) => EncodeToIdsWithCache(text, null);
///
diff --git a/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs b/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs
index 3155c778ec..8c6e5de95d 100644
--- a/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs
@@ -176,8 +176,8 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes
///
/// Encode a text string to a list of tokens.
///
- /// The text to encode.
- /// Indicate if the token is a special token.
+ /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token.
+ /// Specifies whether the entire is considered a special token. This parameter is ignored in this model.
/// The list of tokens generated from the text tokenization.
public override IReadOnlyList Encode(string text, bool isSpecialToken = false)
{
@@ -224,16 +224,16 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken = f
///
/// Encode a split text string to a list of Ids and add them to the accumulatedIds list.
///
- /// The text to split.
- /// Indicate if the token is a special token.
+ /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token.
+ /// Specifies whether the entire is considered a special token. This parameter is ignored in this model.
/// The list of accumulated encoded Ids.
public override void EncodeToIds(ReadOnlySpan text, bool isSpecialToken, IList accumulatedIds) => EncodeToIds(text, accumulatedIds);
///
/// Get the number of tokens that the input text will be encoded to.
///
- /// The text to encode.
- /// Indicate if the token is special token.
+ /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token.
+ /// Specifies whether the entire is considered a special token. This parameter is ignored in this model.
/// The number of tokens that the input text will be encoded to.
public override int CountTokens(ReadOnlySpan text, bool isSpecialToken) => EncodeToIds(text, null);
diff --git a/src/Microsoft.ML.Tokenizers/Model/Model.cs b/src/Microsoft.ML.Tokenizers/Model/Model.cs
index 815bd04a0b..df98700029 100644
--- a/src/Microsoft.ML.Tokenizers/Model/Model.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/Model.cs
@@ -16,16 +16,16 @@ public abstract class Model
///
/// Encode a text to a list of tokens.
///
- /// The text to encode.
- /// Indicate if the token is a special token.
+ /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token.
+ /// Specifies whether the entire is considered a special token.
/// The list of tokens generated from the text tokenization.
public abstract IReadOnlyList Encode(string text, bool isSpecialToken = false);
///
/// Encode a text to a list of Ids and add them to the accumulatedIds list.
///
- /// The text to encode.
- /// Indicate if the token is a special token.
+ /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token.
+ /// Specifies whether the entire is considered a special token.
/// The list of accumulated encoded Ids.
///
/// This method does the default implementation that uses the Encode method to get the token's Ids.
@@ -49,8 +49,8 @@ public virtual void EncodeToIds(ReadOnlySpan text, bool isSpecialToken, IL
///
/// Get the number of tokens that the input text will be encoded to.
///
- /// The text to encode.
- /// Indicate if the token is special token.
+ /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token.
+ /// Specifies whether the entire is considered a special token.
/// The number of tokens that the input text will be encoded to.
///
/// This method does the default implementation that uses the EncodeToIds method to get the number of token's Ids.
diff --git a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs
index 60e9282a81..ccca3c63c7 100644
--- a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs
@@ -4,11 +4,14 @@
using System;
using System.Buffers;
+using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
+using System.Net.Http;
using System.Text;
+using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Tasks;
@@ -17,7 +20,7 @@ namespace Microsoft.ML.Tokenizers
///
/// Represent the rapid Byte Pair Encoding model commonly referred to as Tiktoken.
///
- public sealed class Tiktoken : Model
+ public sealed partial class Tiktoken : Model
{
private readonly Dictionary, int> _encoder;
private readonly Dictionary> _decoder;
@@ -100,6 +103,83 @@ private Tiktoken(Stream vocabStream, IReadOnlyDictionary? specialTo
}
}
+ ///
+ /// Create a Tiktoken tokenizer based on model name and vocab file.
+ ///
+ /// Model name
+ /// The stream to the BPE vocab file.
+ /// Extra special tokens other than the built-in ones for the model
+ /// The size of the cache to use.
+ /// To normalize the text before tokenization
+ /// The tokenizer
+ public static Tokenizer CreateByModelName(
+ string modelName,
+ Stream vocabStream,
+ IReadOnlyDictionary? extraSpecialTokens = null,
+ int cacheSize = LruCache.DefaultCacheSize,
+ Normalizer? normalizer = null)
+ {
+ if (string.IsNullOrEmpty(modelName))
+ {
+ throw new ArgumentNullException(nameof(modelName));
+ }
+
+ (Dictionary SpecialTokens, Regex Regex) tiktokenConfiguration = GetTiktokenConfigurations(modelName);
+
+ if (extraSpecialTokens is not null)
+ {
+ foreach (var extraSpecialToken in extraSpecialTokens)
+ {
+ tiktokenConfiguration.SpecialTokens.Add(extraSpecialToken.Key, extraSpecialToken.Value);
+ }
+ }
+
+ return new Tokenizer(
+ new Tiktoken(vocabStream, tiktokenConfiguration.SpecialTokens, cacheSize),
+ new TikTokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens),
+ normalizer);
+ }
+
+ ///
+ /// Create a Tiktoken tokenizer based on model name and vocab file.
+ ///
+ /// Model name
+ /// The stream to the BPE vocab file.
+ /// Extra special tokens other than the built-in ones for the model
+ /// The size of the cache to use.
+ /// To normalize the text before tokenization
+ /// used to request cancellation of the operation.
+ /// The tokenizer
+ public static async Task CreateByModelNameAsync(
+ string modelName,
+ Stream vocabStream,
+ IReadOnlyDictionary? extraSpecialTokens = null,
+ int cacheSize = LruCache.DefaultCacheSize,
+ Normalizer? normalizer = null,
+ CancellationToken cancellationToken = default)
+ {
+ if (string.IsNullOrEmpty(modelName))
+ {
+ throw new ArgumentNullException(nameof(modelName));
+ }
+
+ (Dictionary SpecialTokens, Regex Regex) tiktokenConfiguration = GetTiktokenConfigurations(modelName);
+
+ if (extraSpecialTokens is not null)
+ {
+ foreach (var extraSpecialToken in extraSpecialTokens)
+ {
+ tiktokenConfiguration.SpecialTokens.Add(extraSpecialToken.Key, extraSpecialToken.Value);
+ }
+ }
+
+ return new Tokenizer(
+ await CreateAsync(vocabStream, tiktokenConfiguration.SpecialTokens, cacheSize, cancellationToken).ConfigureAwait(false),
+ new TikTokenPreTokenizer(tiktokenConfiguration.Regex, tiktokenConfiguration.SpecialTokens),
+ normalizer);
+ }
+
+
private static (Dictionary?, Dictionary?) CreateEncoderDecoder(IReadOnlyDictionary? specialTokens)
{
if (specialTokens is not null)
@@ -230,10 +310,10 @@ await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) :
///
/// Encode a split text string to a list of tokens.
///
- /// The text to encode.
- /// Indicate if the token is a special token.
+ /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token.
+ /// Specifies whether the entire is considered a special token.
/// The list of tokens generated from the text tokenization.
- public override IReadOnlyList Encode(string text, bool isSpecialToken)
+ public override IReadOnlyList Encode(string text, bool isSpecialToken = false)
{
Token[] tokens;
@@ -293,8 +373,8 @@ public override IReadOnlyList Encode(string text, bool isSpecialToken)
///
/// Encode text to a list of Ids.
///
- /// The text to encode.
- /// Indicate if the token is a special token.
+ /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token.
+ /// Specifies whether the entire is considered a special token.
/// The list of accumulated Ids.
public override void EncodeToIds(ReadOnlySpan text, bool isSpecialToken, IList accumulatedIds)
{
@@ -340,8 +420,8 @@ public override void EncodeToIds(ReadOnlySpan text, bool isSpecialToken, I
///
/// Get the number of tokens that the input text will be encoded to.
///
- /// The text to tokenize.
- /// Indicate if the token is special token.
+ /// The text to encode. If the value of the parameter is true, the entire text will be treated as a special token.
+ /// Specifies whether the entire is considered a special token.
/// The number of tokens that the input text will be encoded to.
public override int CountTokens(ReadOnlySpan text, bool isSpecialToken)
{
@@ -462,12 +542,14 @@ public override int CountTokens(ReadOnlySpan text, bool isSpecialToken)
/// The decoded string.
public override string? Decode(IEnumerable ids, TokenizerDecoder? decoder = null, bool considerSpecialTokens = true)
{
- // Tiktoken does not ensure a one-to-one mapping between IDs and tokens. Consequently, decoding individual IDs into tokens is not supported;
- // instead, decoding all IDs must be done collectively.
- // Here is example of case that map one character to multiple Ids:
- // '⭐' U-2B50 is mapped to Ids [2928, 99834] in the Tiktoken model.
- // In other words, the character '⭐' has UTF-8 code point 0xE2, 0xAD, 0x90, Tiktoken will map 0xE2 to [2928] and 0xAD, 0x90 to [99834].
+ // Tiktoken doesn't guarantee a one-to-one correspondence between IDs and UTF-16 words.
+ // Consequently, decoding individual IDs into UTF-16 string is not supported; instead, decoding all IDs must be performed collectively.
+ // Here's an example case that maps one character to multiple IDs:
+ // '⭐' U-2B50 is mapped to IDs [2928, 99834] in the Tiktoken model.
+ // In other words, the character '⭐' with UTF-8 code point 0xE2, 0xAD, 0x90 will be mapped by Tiktoken as follows: 0xE2 to [2928]
+ // and 0xAD, 0x90 to [99834]. Decoding 2928 and 99834 individually won't reconstruct the original UTF-16 string '⭐' U-2B50;
+ // decoding all IDs together is required to get the expected result.
if (ids is null)
{
return null;
@@ -556,6 +638,271 @@ static void ArrayPoolGrow(ref Span utf8Bytes, ref byte[]? arrayPoolArray,
///
public IReadOnlyDictionary> Decoder => _decoder;
+ private const string EndOfText = "<|endoftext|>";
+ private const string FimPrefix = "<|fim_prefix|>";
+ private const string FimMiddle = "<|fim_middle|>";
+ private const string FimSuffix = "<|fim_suffix|>";
+ private const string EndOfPrompt = "<|endofprompt|>";
+
+ private static readonly HttpClient _httpClient = new HttpClient();
+
+ private enum ModelEncoding
+ {
+ None,
+ Cl100kBase,
+ P50kBase,
+ P50kEdit,
+ R50kBase,
+ GPT2
+ }
+
+ private static readonly (string Prefix, ModelEncoding Encoding)[] _modelPrefixToEncoding =
+ [
+ // chat
+ ("gpt-4-", ModelEncoding.Cl100kBase), // e.g., gpt-4-0314, etc., plus gpt-4-32k
+ ("gpt-3.5-turbo-", ModelEncoding.Cl100kBase) // e.g, gpt-3.5-turbo-0301, -0401, etc.
+ ];
+
+ private static readonly Dictionary _modelToEncoding =
+ new Dictionary(StringComparer.OrdinalIgnoreCase)
+ {
+ // chat
+ { "gpt-4", ModelEncoding.Cl100kBase },
+ { "gpt-3.5-turbo", ModelEncoding.Cl100kBase },
+
+ // text
+ { "text-davinci-003", ModelEncoding.P50kBase },
+ { "text-davinci-002", ModelEncoding.P50kBase },
+ { "text-davinci-001", ModelEncoding.R50kBase },
+ { "text-curie-001", ModelEncoding.R50kBase },
+ { "text-babbage-001", ModelEncoding.R50kBase },
+ { "text-ada-001", ModelEncoding.R50kBase },
+ { "davinci", ModelEncoding.R50kBase },
+ { "curie", ModelEncoding.R50kBase },
+ { "babbage", ModelEncoding.R50kBase },
+ { "ada", ModelEncoding.R50kBase },
+
+ // code
+ { "code-davinci-002", ModelEncoding.P50kBase },
+ { "code-davinci-001", ModelEncoding.P50kBase },
+ { "code-cushman-002", ModelEncoding.P50kBase },
+ { "code-cushman-001", ModelEncoding.P50kBase },
+ { "davinci-codex", ModelEncoding.P50kBase },
+ { "cushman-codex", ModelEncoding.P50kBase },
+
+ // edit
+ { "text-davinci-edit-001", ModelEncoding.P50kEdit },
+ { "code-davinci-edit-001", ModelEncoding.P50kEdit },
+
+ // embeddings
+ // https://platform.openai.com/docs/guides/embeddings/what-are-embeddings
+ { "text-embedding-ada-002", ModelEncoding.Cl100kBase },
+ { "text-embedding-3-small", ModelEncoding.Cl100kBase },
+ { "text-embedding-3-large", ModelEncoding.Cl100kBase },
+
+ // old embeddings
+ { "text-similarity-davinci-001", ModelEncoding.R50kBase },
+ { "text-similarity-curie-001", ModelEncoding.R50kBase },
+ { "text-similarity-babbage-001", ModelEncoding.R50kBase },
+ { "text-similarity-ada-001", ModelEncoding.R50kBase },
+ { "text-search-davinci-doc-001", ModelEncoding.R50kBase },
+ { "text-search-curie-doc-001", ModelEncoding.R50kBase },
+ { "text-search-babbage-doc-001", ModelEncoding.R50kBase },
+ { "text-search-ada-doc-001", ModelEncoding.R50kBase },
+ { "code-search-babbage-code-001", ModelEncoding.R50kBase },
+ { "code-search-ada-code-001", ModelEncoding.R50kBase },
+
+ // open source
+ { "gpt2", ModelEncoding.GPT2 }
+ };
+
+ private static ModelEncoding GetModelEncoding(string modelName)
+ {
+ if (!_modelToEncoding.TryGetValue(modelName, out ModelEncoding encoder))
+ {
+ foreach ((string Prefix, ModelEncoding Encoding) in _modelPrefixToEncoding)
+ {
+ if (modelName.StartsWith(Prefix, StringComparison.OrdinalIgnoreCase))
+ {
+ encoder = Encoding;
+ break;
+ }
+ }
+ }
+
+ if (encoder == ModelEncoding.None)
+ {
+ throw new NotSupportedException($"The model '{modelName}' is not supported.");
+ }
+
+ return encoder;
+ }
+
+ internal static (Dictionary SpecialTokens, Regex Regex) GetTiktokenConfigurations(string modelName)
+ {
+ ModelEncoding modelEncoding = GetModelEncoding(modelName);
+
+ switch (modelEncoding)
+ {
+ case ModelEncoding.Cl100kBase:
+ return (new Dictionary
+ { { EndOfText, 100257}, { FimPrefix, 100258}, { FimMiddle, 100259}, { FimSuffix, 100260}, { EndOfPrompt, 100276} }, Cl100kBaseRegex());
+
+ case ModelEncoding.P50kBase:
+ return (new Dictionary { { EndOfText, 50256 } }, P50kBaseRegex());
+
+ case ModelEncoding.P50kEdit:
+ return (new Dictionary
+ { { EndOfText, 50256 }, { FimPrefix, 50281 }, { FimMiddle, 50282 }, { FimSuffix, 50283 } }, P50kBaseRegex());
+
+ case ModelEncoding.R50kBase:
+ return (new Dictionary { { EndOfText, 50256 } }, P50kBaseRegex());
+
+ case ModelEncoding.GPT2:
+ return (new Dictionary { { EndOfText, 50256 }, }, P50kBaseRegex());
+
+ default:
+ Debug.Assert(false, $"Unexpected encoder [{modelEncoding}]");
+ throw new NotSupportedException($"The model '{modelName}' is not supported.");
+ }
+ }
+
+ ///
+ /// Create tokenizer based on model name
+ ///
+ /// Model name
+ /// Extra special tokens other than the built-in ones for the model
+ /// To normalize the text before tokenization
+ /// used to request cancellation of the operation.
+ /// The tokenizer
+ public static Task CreateByModelNameAsync(
+ string modelName,
+ IReadOnlyDictionary? extraSpecialTokens = null,
+ Normalizer? normalizer = null,
+ CancellationToken cancellationToken = default)
+ {
+ try
+ {
+ return CreateByEncoderNameAsync(modelName, GetModelEncoding(modelName), extraSpecialTokens, normalizer, cancellationToken);
+ }
+ catch (Exception ex)
+ {
+ return Task.FromException(ex);
+ }
+ }
+
+ // Regex patterns based on https://github.com/openai/tiktoken/blob/main/tiktoken_ext/openai_public.py
+
+ private const string Cl100kBaseRegexPattern = /*lang=regex*/ @"'(?i:[sdmt]|re|ve|ll)|(?>[^\r\n\p{L}\p{N}]?)\p{L}+|\p{N}{1,3}| ?(?>[^\s\p{L}\p{N}]+)[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+";
+ private const string P50kBaseRegexPattern = /*lang=regex*/ @"'(?:[sdmt]|re|ve|ll)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+";
+
+ private const string Cl100kBaseVocabUrl = @"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken";
+ private const string P50RanksUrl = @"https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken";
+ private const string R50RanksUrl = @"https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken";
+ private const string GPT2Url = @"https://pythia.blob.core.windows.net/public/encoding/gpt2.tiktoken";
+
+#if NET7_0_OR_GREATER
+ [GeneratedRegex(Cl100kBaseRegexPattern)]
+ private static partial Regex Cl100kBaseRegex();
+
+ [GeneratedRegex(P50kBaseRegexPattern)]
+ internal static partial Regex P50kBaseRegex();
+#else
+ private static Regex? _cl100kBaseRegex;
+ private static Regex Cl100kBaseRegex() => _cl100kBaseRegex ??= new Regex(Cl100kBaseRegexPattern, RegexOptions.Compiled);
+
+ private static Regex? _p50kBaseRegex;
+ internal static Regex P50kBaseRegex() => _p50kBaseRegex ??= new Regex(P50kBaseRegexPattern, RegexOptions.Compiled);
+#endif
+
+ ///
+ /// Create tokenizer based on encoder name and extra special tokens
+ ///
+ /// Model name
+ /// Encoder label
+ /// Extra special tokens other than the built-in ones for the encoder
+ /// To normalize the text before tokenization
+ /// used to request cancellation of the operation.
+ /// The tokenizer
+ /// Throws if the model name is not supported
+ private static Task CreateByEncoderNameAsync(
+ string modelName,
+ ModelEncoding modelEncoding,
+ IReadOnlyDictionary? extraSpecialTokens,
+ Normalizer? normalizer,
+ CancellationToken cancellationToken)
+ {
+ switch (modelEncoding)
+ {
+ case ModelEncoding.Cl100kBase:
+ var specialTokens = new Dictionary
+ { { EndOfText, 100257}, { FimPrefix, 100258}, { FimMiddle, 100259}, { FimSuffix, 100260}, { EndOfPrompt, 100276} };
+ return CreateTikTokenTokenizerAsync(Cl100kBaseRegex(), Cl100kBaseVocabUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
+
+ case ModelEncoding.P50kBase:
+ specialTokens = new Dictionary { { EndOfText, 50256 } };
+ return CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
+
+ case ModelEncoding.P50kEdit:
+ specialTokens = new Dictionary
+ { { EndOfText, 50256 }, { FimPrefix, 50281 }, { FimMiddle, 50282 }, { FimSuffix, 50283 } };
+ return CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
+
+ case ModelEncoding.R50kBase:
+ specialTokens = new Dictionary { { EndOfText, 50256 } };
+ return CreateTikTokenTokenizerAsync(P50kBaseRegex(), R50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
+
+ case ModelEncoding.GPT2:
+ specialTokens = new Dictionary { { EndOfText, 50256 }, };
+ return CreateTikTokenTokenizerAsync(P50kBaseRegex(), GPT2Url, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
+
+ default:
+ Debug.Assert(false, $"Unexpected encoder [{modelEncoding}]");
+ throw new NotSupportedException($"The model '{modelName}' is not supported.");
+ }
+ }
+
+ private static readonly ConcurrentDictionary, int> encoder, Dictionary vocab, Dictionary> decoder)> _tiktokenCache = new(StringComparer.OrdinalIgnoreCase);
+
+ ///
+ /// Create tokenizer based on regex pattern, BPE rank file and special tokens
+ ///
+ /// Regex to break a long string
+ /// BPE rank file
+ /// Special tokens mapping. This may be mutated by the method.
+ /// Extra special tokens other than the built-in ones for the encoder
+ /// To normalize the text before tokenization
+ /// used to request cancellation of the operation.
+ /// The tokenizer
+ private static async Task CreateTikTokenTokenizerAsync(
+ Regex regex,
+ string mergeableRanksFileUrl,
+ Dictionary specialTokens,
+ IReadOnlyDictionary? extraSpecialTokens,
+ Normalizer? normalizer,
+ CancellationToken cancellationToken)
+ {
+ if (extraSpecialTokens is not null)
+ {
+ foreach (var extraSpecialToken in extraSpecialTokens)
+ {
+ specialTokens.Add(extraSpecialToken.Key, extraSpecialToken.Value);
+ }
+ }
+
+ if (!_tiktokenCache.TryGetValue(mergeableRanksFileUrl, out (Dictionary, int> encoder, Dictionary vocab, Dictionary> decoder) cache))
+ {
+ using (Stream stream = await Helpers.GetStreamAsync(_httpClient, mergeableRanksFileUrl, cancellationToken).ConfigureAwait(false))
+ {
+ cache = await Tiktoken.LoadTikTokenBpeAsync(stream, useAsync: true, cancellationToken).ConfigureAwait(false);
+ }
+
+ _tiktokenCache.TryAdd(mergeableRanksFileUrl, cache);
+ }
+
+ return new Tokenizer(new Tiktoken(cache.encoder, cache.decoder, cache.vocab, specialTokens), new TikTokenPreTokenizer(regex, specialTokens), normalizer);
+ }
+
private static unsafe int GetUtf8Bytes(ReadOnlySpan source, Span destination)
{
#if NETCOREAPP
diff --git a/src/Microsoft.ML.Tokenizers/PreTokenizer/Roberta.cs b/src/Microsoft.ML.Tokenizers/PreTokenizer/Roberta.cs
index 81b4eb642f..9e8db6c9c8 100644
--- a/src/Microsoft.ML.Tokenizers/PreTokenizer/Roberta.cs
+++ b/src/Microsoft.ML.Tokenizers/PreTokenizer/Roberta.cs
@@ -30,7 +30,7 @@ public override IEnumerable PreTokenize(string text, bool considerSpecial
return Array.Empty();
}
- return SplitText(text, Tokenizer.P50kBaseRegex());
+ return SplitText(text, Tiktoken.P50kBaseRegex());
}
}
}
diff --git a/src/Microsoft.ML.Tokenizers/Tokenizer.cs b/src/Microsoft.ML.Tokenizers/Tokenizer.cs
index c64ebf256e..254ffa0801 100644
--- a/src/Microsoft.ML.Tokenizers/Tokenizer.cs
+++ b/src/Microsoft.ML.Tokenizers/Tokenizer.cs
@@ -3,11 +3,10 @@
// See the LICENSE file in the project root for more information.
using System;
-using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
-using System.Net.Http;
+using System.Linq;
using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Tasks;
@@ -137,251 +136,86 @@ public int CountTokens(string text, bool considerSpecialTokens = true)
}
///
- /// Decodes the Id to the mapped token.
+ /// Find the index of the maximum encoding capacity from the start within the text without surpassing the token limit.
///
- /// The id to map to the token.
- /// Indicate if want to consider the special tokens during the decoding.
- /// The decoded string or null if there is no token mapped to the input id.
- public string? Decode(int id, bool considerSpecialTokens = true) => Model.MapIdToToken(id, considerSpecialTokens);
+ /// The text to encode.
+ /// The maximum token count to limit the encoding capacity.
+ /// If the tokenizer's normalization is enabled, the input text will be represented in its normalization form; otherwise, it will remain unchanged as the input text.
+ /// The token count can be generated which should be smaller than the maximum token count.
+ /// Indicate if want to consider the special tokens during the encoding.
+ ///
+ /// The index of the maximum encoding capacity within the processed text without surpassing the token limit.
+ /// It represents the index immediately following the last character to be included. In cases where no tokens fit, the result will be 0; conversely, if all tokens fit, the result will be length of the .
+ ///
+ /// The input text is null.
+ /// The maximum token count must be greater than 0.
+ public int IndexOfTokenCount(string text, int maxTokenCount, out string processedText, out int tokenCount, bool considerSpecialTokens = true)
+ => IndexOf(text, maxTokenCount, fromStart: true, considerSpecialTokens, out processedText, out tokenCount);
///
- /// Decode the given ids, back to a String.
+ /// Find the index of the maximum encoding capacity from the end within the text without surpassing the token limit.
///
- /// The list of ids that we want to decode.
- /// Whether the special tokens should be kept in the decoded string.
- /// The decoded string.
- public string? Decode(IEnumerable ids, bool considerSpecialTokens = true) => Model.Decode(ids, Decoder, considerSpecialTokens);
-
- private const string EndOfText = "<|endoftext|>";
- private const string FimPrefix = "<|fim_prefix|>";
- private const string FimMiddle = "<|fim_middle|>";
- private const string FimSuffix = "<|fim_suffix|>";
- private const string EndOfPrompt = "<|endofprompt|>";
-
- private static readonly HttpClient _httpClient = new HttpClient();
-
- private enum ModelEncoding
+ /// The text to encode.
+ /// The maximum token count to limit the encoding capacity.
+ /// If the tokenizer's normalization is enabled, the input text will be represented in its normalization form; otherwise, it will remain unchanged as the input text.
+ /// The token count can be generated which should be smaller than the maximum token count.
+ /// Indicate if want to consider the special tokens during the encoding.
+ ///
+ /// The start index of the maximum encoding capacity within the processed text without surpassing the token limit.
+ /// It represents the index at the first character to be included. In cases where no tokens fit, the result will be length of the ; conversely, if all tokens fit, the result will be 0.
+ ///
+ /// The input text is null.
+ /// The maximum token count must be greater than 0.
+ ///
+ /// If the whole text can be encoded within the token limit, the returned index will be 0.
+ ///
+ public int LastIndexOfTokenCount(string text, int maxTokenCount, out string processedText, out int tokenCount, bool considerSpecialTokens = true)
+ => IndexOf(text, maxTokenCount, fromStart: false, considerSpecialTokens, out processedText, out tokenCount);
+
+ private int IndexOf(string text, int maxTokenCount, bool fromStart, bool considerSpecialTokens, out string processedText, out int tokenCount)
{
- None,
- Cl100kBase,
- P50kBase,
- P50kEdit,
- R50kBase,
- GPT2
- }
-
- private static readonly (string Prefix, ModelEncoding Encoding)[] _modelPrefixToEncoding =
- [
- // chat
- ("gpt-4-", ModelEncoding.Cl100kBase), // e.g., gpt-4-0314, etc., plus gpt-4-32k
- ("gpt-3.5-turbo-", ModelEncoding.Cl100kBase) // e.g, gpt-3.5-turbo-0301, -0401, etc.
- ];
-
- private static readonly Dictionary _modelToEncoding =
- new Dictionary(StringComparer.OrdinalIgnoreCase)
- {
- // chat
- { "gpt-4", ModelEncoding.Cl100kBase },
- { "gpt-3.5-turbo", ModelEncoding.Cl100kBase },
-
- // text
- { "text-davinci-003", ModelEncoding.P50kBase },
- { "text-davinci-002", ModelEncoding.P50kBase },
- { "text-davinci-001", ModelEncoding.R50kBase },
- { "text-curie-001", ModelEncoding.R50kBase },
- { "text-babbage-001", ModelEncoding.R50kBase },
- { "text-ada-001", ModelEncoding.R50kBase },
- { "davinci", ModelEncoding.R50kBase },
- { "curie", ModelEncoding.R50kBase },
- { "babbage", ModelEncoding.R50kBase },
- { "ada", ModelEncoding.R50kBase },
-
- // code
- { "code-davinci-002", ModelEncoding.P50kBase },
- { "code-davinci-001", ModelEncoding.P50kBase },
- { "code-cushman-002", ModelEncoding.P50kBase },
- { "code-cushman-001", ModelEncoding.P50kBase },
- { "davinci-codex", ModelEncoding.P50kBase },
- { "cushman-codex", ModelEncoding.P50kBase },
-
- // edit
- { "text-davinci-edit-001", ModelEncoding.P50kEdit },
- { "code-davinci-edit-001", ModelEncoding.P50kEdit },
-
- // embeddings
- // https://platform.openai.com/docs/guides/embeddings/what-are-embeddings
- { "text-embedding-ada-002", ModelEncoding.Cl100kBase },
- { "text-embedding-3-small", ModelEncoding.Cl100kBase },
- { "text-embedding-3-large", ModelEncoding.Cl100kBase },
-
- // old embeddings
- { "text-similarity-davinci-001", ModelEncoding.R50kBase },
- { "text-similarity-curie-001", ModelEncoding.R50kBase },
- { "text-similarity-babbage-001", ModelEncoding.R50kBase },
- { "text-similarity-ada-001", ModelEncoding.R50kBase },
- { "text-search-davinci-doc-001", ModelEncoding.R50kBase },
- { "text-search-curie-doc-001", ModelEncoding.R50kBase },
- { "text-search-babbage-doc-001", ModelEncoding.R50kBase },
- { "text-search-ada-doc-001", ModelEncoding.R50kBase },
- { "code-search-babbage-code-001", ModelEncoding.R50kBase },
- { "code-search-ada-code-001", ModelEncoding.R50kBase },
-
- // open source
- { "gpt2", ModelEncoding.GPT2 }
- };
-
+ if (text is null)
+ {
+ throw new ArgumentNullException(nameof(text));
+ }
- ///
- /// Create tokenizer based on model name
- ///
- /// Model name
- /// Extra special tokens other than the built-in ones for the model
- /// To normalize the text before tokenization
- /// used to request cancellation of the operation.
- /// The tokenizer
- public static Task CreateByModelNameAsync(
- string modelName,
- IReadOnlyDictionary? extraSpecialTokens = null,
- Normalizer? normalizer = null,
- CancellationToken cancellationToken = default)
- {
- try
+ if (maxTokenCount <= 0)
{
- ModelEncoding encoder;
+ throw new ArgumentOutOfRangeException(nameof(maxTokenCount), "The max token count must be greater than 0.");
+ }
- if (!_modelToEncoding.TryGetValue(modelName, out encoder))
- {
- foreach ((string Prefix, ModelEncoding Encoding) in _modelPrefixToEncoding)
- {
- if (modelName.StartsWith(Prefix, StringComparison.OrdinalIgnoreCase))
- {
- encoder = Encoding;
- break;
- }
- }
- }
+ processedText = Normalizer is not null ? Normalizer.Normalize(text) : text;
+ tokenCount = 0;
- if (encoder == ModelEncoding.None)
+ IEnumerable splits = PreTokenizer.PreTokenize(processedText, considerSpecialTokens);
+ foreach (Split split in (fromStart ? splits : splits.Reverse()))
+ {
+ int count = Model.CountTokens(split.TokenSpan, split.IsSpecialToken);
+ if (tokenCount > maxTokenCount - count)
{
- throw new NotImplementedException($"Doesn't support this model [{modelName}]");
+ return fromStart ? split.Offset.Index : split.Offset.Index + split.Offset.Length;
}
- return CreateByEncoderNameAsync(encoder, extraSpecialTokens, normalizer, cancellationToken);
- }
- catch (Exception ex)
- {
- return Task.FromException(ex);
+ tokenCount += count;
}
- }
-
- // Regex patterns based on https://github.com/openai/tiktoken/blob/main/tiktoken_ext/openai_public.py
-
- private const string Cl100kBaseRegexPattern = /*lang=regex*/ @"'(?i:[sdmt]|re|ve|ll)|(?>[^\r\n\p{L}\p{N}]?)\p{L}+|\p{N}{1,3}| ?(?>[^\s\p{L}\p{N}]+)[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+";
- private const string P50kBaseRegexPattern = /*lang=regex*/ @"'(?:[sdmt]|re|ve|ll)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+";
-
- private const string Cl100kBaseVocabUrl = @"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken";
- private const string P50RanksUrl = @"https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken";
- private const string R50RanksUrl = @"https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken";
- private const string GPT2Url = @"https://pythia.blob.core.windows.net/public/encoding/gpt2.tiktoken";
-
-#if NET7_0_OR_GREATER
- [GeneratedRegex(Cl100kBaseRegexPattern)]
- private static partial Regex Cl100kBaseRegex();
- [GeneratedRegex(P50kBaseRegexPattern)]
- internal static partial Regex P50kBaseRegex();
-#else
- private static Regex? _cl100kBaseRegex;
- private static Regex Cl100kBaseRegex() => _cl100kBaseRegex ??= new Regex(Cl100kBaseRegexPattern, RegexOptions.Compiled);
-
- private static Regex? _p50kBaseRegex;
- internal static Regex P50kBaseRegex() => _p50kBaseRegex ??= new Regex(P50kBaseRegexPattern, RegexOptions.Compiled);
-#endif
+ return fromStart ? processedText.Length : 0;
+ }
///
- /// Create tokenizer based on encoder name and extra special tokens
+ /// Decodes the Id to the mapped token.
///
- /// Encoder label
- /// Extra special tokens other than the built-in ones for the encoder
- /// To normalize the text before tokenization
- /// used to request cancellation of the operation.
- /// The tokenizer
- /// Throws if the encoder is not supported
- private static Task CreateByEncoderNameAsync(
- ModelEncoding modelEncoding,
- IReadOnlyDictionary? extraSpecialTokens,
- Normalizer? normalizer,
- CancellationToken cancellationToken)
- {
- switch (modelEncoding)
- {
- case ModelEncoding.Cl100kBase:
- var specialTokens = new Dictionary
- { { EndOfText, 100257}, { FimPrefix, 100258}, { FimMiddle, 100259}, { FimSuffix, 100260}, { EndOfPrompt, 100276} };
- return CreateTikTokenTokenizerAsync(Cl100kBaseRegex(), Cl100kBaseVocabUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
-
- case ModelEncoding.P50kBase:
- specialTokens = new Dictionary { { EndOfText, 50256 } };
- return CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
-
- case ModelEncoding.P50kEdit:
- specialTokens = new Dictionary
- { { EndOfText, 50256 }, { FimPrefix, 50281 }, { FimMiddle, 50282 }, { FimSuffix, 50283 } };
- return CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
-
- case ModelEncoding.R50kBase:
- specialTokens = new Dictionary { { EndOfText, 50256 } };
- return CreateTikTokenTokenizerAsync(P50kBaseRegex(), R50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
-
- case ModelEncoding.GPT2:
- specialTokens = new Dictionary { { EndOfText, 50256 }, };
- return CreateTikTokenTokenizerAsync(P50kBaseRegex(), GPT2Url, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
-
- default:
- Debug.Assert(false, $"Unexpected encoder [{modelEncoding}]");
- throw new NotImplementedException($"Doesn't support this encoder [{modelEncoding}]");
- }
- }
-
- private static readonly ConcurrentDictionary, int> encoder, Dictionary vocab, Dictionary> decoder)> _tiktokenCache = new(StringComparer.OrdinalIgnoreCase);
+ /// The id to map to the token.
+ /// Indicate if want to consider the special tokens during the decoding.
+ /// The decoded string or null if there is no token mapped to the input id.
+ public string? Decode(int id, bool considerSpecialTokens = true) => Model.MapIdToToken(id, considerSpecialTokens);
///
- /// Create tokenizer based on regex pattern, BPE rank file and special tokens
+ /// Decode the given ids, back to a String.
///
- /// Regex to break a long string
- /// BPE rank file
- /// Special tokens mapping. This may be mutated by the method.
- /// Extra special tokens other than the built-in ones for the encoder
- /// To normalize the text before tokenization
- /// used to request cancellation of the operation.
- /// The tokenizer
- private static async Task CreateTikTokenTokenizerAsync(
- Regex regex,
- string mergeableRanksFileUrl,
- Dictionary specialTokens,
- IReadOnlyDictionary? extraSpecialTokens,
- Normalizer? normalizer,
- CancellationToken cancellationToken)
- {
- if (extraSpecialTokens is not null)
- {
- foreach (var extraSpecialToken in extraSpecialTokens)
- {
- specialTokens.Add(extraSpecialToken.Key, extraSpecialToken.Value);
- }
- }
-
- if (!_tiktokenCache.TryGetValue(mergeableRanksFileUrl, out (Dictionary, int> encoder, Dictionary vocab, Dictionary> decoder) cache))
- {
- using (Stream stream = await Helpers.GetStreamAsync(_httpClient, mergeableRanksFileUrl, cancellationToken).ConfigureAwait(false))
- {
- cache = await Tiktoken.LoadTikTokenBpeAsync(stream, useAsync: true, cancellationToken).ConfigureAwait(false);
- }
-
- _tiktokenCache.TryAdd(mergeableRanksFileUrl, cache);
- }
-
- return new Tokenizer(new Tiktoken(cache.encoder, cache.decoder, cache.vocab, specialTokens), new TikTokenPreTokenizer(regex, specialTokens), normalizer);
- }
+ /// The list of ids that we want to decode.
+ /// Whether the special tokens should be kept in the decoded string.
+ /// The decoded string.
+ public string? Decode(IEnumerable ids, bool considerSpecialTokens = true) => Model.Decode(ids, Decoder, considerSpecialTokens);
}
}
diff --git a/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs b/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs
index 2959184b5d..9ae8517494 100644
--- a/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs
+++ b/test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs
@@ -225,6 +225,8 @@ public async void TestGpt2Vocab()
Assert.Equal(12, encoding.Ids.Count);
Assert.Equal(encoding.Ids, ids);
Assert.Equal(12, tokenizer.CountTokens(text));
+
+ TokenizerTests.TestTokenLimits(tokenizer);
}
private static string WriteToMergeFile((string, string)[] mergeEntries)
diff --git a/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs b/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs
index ccf0e66ef9..4518f169bf 100644
--- a/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs
+++ b/test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs
@@ -96,6 +96,7 @@ public async void TokenizationTest()
Tokenizer tokenizer = new Tokenizer(new EnglishRoberta(vocabFile, mergeFile, translationFile), RobertaPreTokenizer.Instance);
TestTokenizer(tokenizer);
+ TokenizerTests.TestTokenLimits(tokenizer);
tokenizer = new Tokenizer(new EnglishRoberta(vocabFile, mergeFile, translationFile, filterUnsupportedChars: false), RobertaPreTokenizer.Instance);
TestTokenizer(tokenizer);
diff --git a/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs b/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs
index af6581401a..fc5a0772a2 100644
--- a/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs
+++ b/test/Microsoft.ML.Tokenizers.Tests/TitokenTests.cs
@@ -25,11 +25,11 @@ public class TiktokenTests
{ IMEnd, 100265},
};
- public static Tokenizer GPT4 { get; } = Tokenizer.CreateByModelNameAsync("gpt-4", _specialTokens).GetAwaiter().GetResult();
- public static Tokenizer GPT2 { get; } = Tokenizer.CreateByModelNameAsync("gpt2").GetAwaiter().GetResult();
- public static Tokenizer P50kBase { get; } = Tokenizer.CreateByModelNameAsync("text-davinci-003").GetAwaiter().GetResult();
- public static Tokenizer R50kBase { get; } = Tokenizer.CreateByModelNameAsync("ada").GetAwaiter().GetResult();
- public static Tokenizer P50kEdit { get; } = Tokenizer.CreateByModelNameAsync("text-davinci-edit-001").GetAwaiter().GetResult();
+ public static Tokenizer GPT4 { get; } = Tiktoken.CreateByModelNameAsync("gpt-4", _specialTokens).GetAwaiter().GetResult();
+ public static Tokenizer GPT2 { get; } = Tiktoken.CreateByModelNameAsync("gpt2").GetAwaiter().GetResult();
+ public static Tokenizer P50kBase { get; } = Tiktoken.CreateByModelNameAsync("text-davinci-003").GetAwaiter().GetResult();
+ public static Tokenizer R50kBase { get; } = Tiktoken.CreateByModelNameAsync("ada").GetAwaiter().GetResult();
+ public static Tokenizer P50kEdit { get; } = Tiktoken.CreateByModelNameAsync("text-davinci-edit-001").GetAwaiter().GetResult();
[Fact]
public async void TestTokenizerCreation()
@@ -61,6 +61,18 @@ public async void TestTokenizerCreation()
tokenizer = new Tokenizer(await Tiktoken.CreateAsync(stream, specialTokensEncoder), GPT4.PreTokenizer);
}
TestGPT4TokenizationEncoding(tokenizer);
+
+ using (Stream stream = File.OpenRead(tokenizerDataFileName))
+ {
+ tokenizer = Tiktoken.CreateByModelName("gpt-4", stream);
+ }
+ TestGPT4TokenizationEncoding(tokenizer);
+
+ using (Stream stream = File.OpenRead(tokenizerDataFileName))
+ {
+ tokenizer = await Tiktoken.CreateByModelNameAsync("gpt-3.5-turbo", stream);
+ }
+ TestGPT4TokenizationEncoding(tokenizer);
}
finally
{
@@ -82,6 +94,8 @@ private void TestGPT4TokenizationEncoding(Tokenizer tokenizer)
Assert.Equal(new List<(int, int)> { (0, 5), (5, 6) }, result.Offsets);
Assert.Equal(encoded.Count, idsCount);
Assert.Equal(encoded, result.Ids);
+
+ TestGPT4Tokenizer(tokenizer);
}
[Fact]
@@ -101,13 +115,12 @@ public void TestEncode1()
Assert.Equal(encoded, result.Ids);
}
- [Fact]
- public void TestEncode2()
+ private void TestGPT4Tokenizer(Tokenizer gpt4Tokenizer)
{
string text = ReadAndSanitizeFile("./Data/lib.rs.txt");
- IReadOnlyList encoded = GPT4.EncodeToIds(text, considerSpecialTokens: false);
+ IReadOnlyList encoded = gpt4Tokenizer.EncodeToIds(text, considerSpecialTokens: false);
Assert.Equal(5584, encoded.Count);
- int idsCount = GPT4.CountTokens(text, considerSpecialTokens: false);
+ int idsCount = gpt4Tokenizer.CountTokens(text, considerSpecialTokens: false);
Assert.Equal(encoded.Count, idsCount);
using (Stream stream = File.OpenRead("./Data/tokens.json"))
@@ -116,8 +129,10 @@ public void TestEncode2()
Assert.Equal(expected!, encoded.ToArray());
}
- string? decoded = GPT4.Decode(encoded.ToArray());
+ string? decoded = gpt4Tokenizer.Decode(encoded.ToArray());
Assert.Equal(text, decoded!);
+
+ TokenizerTests.TestTokenLimits(gpt4Tokenizer);
}
[Fact]
@@ -283,7 +298,7 @@ public void TestEncodeR50kBase()
[InlineData("gpt2")]
public async void TestAllSupportedModelNames(string modelName)
{
- Tokenizer tokenizer = await Tokenizer.CreateByModelNameAsync(modelName);
+ Tokenizer tokenizer = await Tiktoken.CreateByModelNameAsync(modelName);
Assert.NotNull(tokenizer.Model);
Assert.NotNull(tokenizer.PreTokenizer);
}
diff --git a/test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs b/test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs
new file mode 100644
index 0000000000..48efe20f78
--- /dev/null
+++ b/test/Microsoft.ML.Tokenizers.Tests/TokenizerTests.cs
@@ -0,0 +1,70 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Tokenizers;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using Xunit;
+
+namespace Microsoft.ML.Tokenizers.Tests
+{
+ public class TokenizerTests
+ {
+ internal static void TestTokenLimits(Tokenizer tokenizer)
+ {
+ string input = @"
+ OpenAI's large language models (sometimes referred to as GPT's) process text using tokens, which are common sequences of characters found in a set of text.
+ The models learn to understand the statistical relationships between these tokens, and excel at producing the next token in a sequence of tokens.
+ You can use the tool below to understand how a piece of text might be tokenized by a language model, and the total count of tokens in that piece of text.
+ It's important to note that the exact tokenization process varies between models. Newer models like GPT-3.5 and GPT-4 use a different tokenizer than previous models,
+ and will produce different tokens for the same input text.
+ ";
+
+ IReadOnlyList fullIdsList = tokenizer.EncodeToIds(input);
+
+ for (int i = 1; i <= fullIdsList.Count; i++)
+ {
+ int index1 = tokenizer.IndexOfTokenCount(input, maxTokenCount: i, out string processedText1, out int tokenCount1);
+ int index2 = tokenizer.LastIndexOfTokenCount(input, maxTokenCount: i, out string processedText2, out int tokenCount2);
+
+
+ IReadOnlyList? prefixIds = null;
+ IReadOnlyList? suffixIds = null;
+
+ if (tokenCount1 > 0)
+ {
+ string prefixString = processedText1.Substring(0, index1);
+ prefixIds = tokenizer.EncodeToIds(prefixString);
+ Assert.Equal(tokenCount1, prefixIds.Count);
+ Assert.Equal(prefixIds, fullIdsList.Take(prefixIds.Count));
+ }
+
+ if (tokenCount2 > 0)
+ {
+ string suffixString = processedText2.Substring(index2);
+ suffixIds = tokenizer.EncodeToIds(suffixString);
+ Assert.Equal(tokenCount2, suffixIds.Count);
+ Assert.Equal(suffixIds, fullIdsList.Skip(fullIdsList.Count - suffixIds.Count));
+ }
+
+ if (i == fullIdsList.Count)
+ {
+ Assert.Equal(processedText1.Length, index1);
+ Assert.Equal(0, index2);
+ Assert.Equal(fullIdsList, prefixIds);
+ Assert.Equal(fullIdsList, suffixIds);
+ }
+ }
+
+ Assert.Throws(() => tokenizer.IndexOfTokenCount(input, maxTokenCount: 0, out _, out _));
+ Assert.Throws(() => tokenizer.IndexOfTokenCount(input, maxTokenCount: -1, out _, out _));
+ Assert.Throws(() => tokenizer.LastIndexOfTokenCount(input, maxTokenCount: 0, out _, out _));
+ Assert.Throws(() => tokenizer.LastIndexOfTokenCount(input, maxTokenCount: -1, out _, out _));
+
+ Assert.Throws(() => tokenizer.IndexOfTokenCount(null!, maxTokenCount: 10, out _, out _));
+ Assert.Throws(() => tokenizer.LastIndexOfTokenCount(null!, maxTokenCount: 10, out _, out _));
+ }
+ }
+}
\ No newline at end of file