diff --git a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs index 9935dd6428..74d95df23a 100644 --- a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs +++ b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs @@ -9,6 +9,7 @@ using System.IO; using System.Linq; using System.Text; +using System.Threading; using System.Threading.Tasks; namespace Microsoft.ML.Tokenizers @@ -104,9 +105,11 @@ private Tiktoken(int cacheSize) /// /// Stream to the BPE rank file /// Whether to perform I/O synchronously or asynchronously. + /// used to request cancellation of the operation. /// Map of byte[] to integer token id /// - internal static async ValueTask<(Dictionary, int>, Dictionary, IReadOnlyDictionary)> LoadTikTokenBpeAsync(Stream tikTokenBpeFileStream, bool useAsync) + internal static async ValueTask<(Dictionary, int>, Dictionary, IReadOnlyDictionary)> LoadTikTokenBpeAsync( + Stream tikTokenBpeFileStream, bool useAsync, CancellationToken cancellationToken = default) { var encoder = new Dictionary, int>(ReadOnlyMemoryByteComparer.Instance); var vocab = new Dictionary(); @@ -119,7 +122,7 @@ private Tiktoken(int cacheSize) while (true) { string? line = useAsync ? - await reader.ReadLineAsync().ConfigureAwait(false) : + await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) : reader.ReadLine(); if (string.IsNullOrWhiteSpace(line)) { @@ -136,10 +139,10 @@ await reader.ReadLineAsync().ConfigureAwait(false) : throw new FormatException($"Invalid format in the BPE encoder file stream"); } - byte[] tokenBytes = Helpers.FromBase64String(line, 0, spaceIndex); - if (Helpers.TryParseInt32(line, spaceIndex + 1, out int rank)) { + byte[] tokenBytes = Helpers.FromBase64String(line, 0, spaceIndex); + encoder[tokenBytes] = rank; decoder[rank] = tokenBytes; @@ -214,7 +217,7 @@ public override IReadOnlyList Tokenize(string sequence, bool isSpecialTok // cache miss if (_vocab.TryGetValue(sequence, out int mappedId)) { - return new List { new(mappedId, sequence, (0, sequence.Length)) }; + return new Token[1] { new(mappedId, sequence, (0, sequence.Length)) }; } byte[] arrayPoolArray = ArrayPool.Shared.Rent(Encoding.UTF8.GetMaxByteCount(sequence.Length)); diff --git a/src/Microsoft.ML.Tokenizers/Tokenizer.cs b/src/Microsoft.ML.Tokenizers/Tokenizer.cs index d002f55833..d29766d65c 100644 --- a/src/Microsoft.ML.Tokenizers/Tokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Tokenizer.cs @@ -9,6 +9,7 @@ using System.IO; using System.Net.Http; using System.Text.RegularExpressions; +using System.Threading; using System.Threading.Tasks; namespace Microsoft.ML.Tokenizers @@ -346,32 +347,41 @@ private static readonly (string Prefix, ModelEncoding Encoding)[] _modelPrefixTo /// Model name /// Extra special tokens other than the built-in ones for the model /// To normalize the text before tokenization + /// used to request cancellation of the operation. /// The tokenizer - public static async Task CreateByModelNameAsync( + public static Task CreateByModelNameAsync( string modelName, IReadOnlyDictionary? extraSpecialTokens = null, - Normalizer? normalizer = null) + Normalizer? normalizer = null, + CancellationToken cancellationToken = default) { - ModelEncoding encoder; - - if (!_modelToEncoding.TryGetValue(modelName, out encoder)) + try { - foreach ((string Prefix, ModelEncoding Encoding) in _modelPrefixToEncoding) + ModelEncoding encoder; + + if (!_modelToEncoding.TryGetValue(modelName, out encoder)) { - if (modelName.StartsWith(Prefix, StringComparison.OrdinalIgnoreCase)) + foreach ((string Prefix, ModelEncoding Encoding) in _modelPrefixToEncoding) { - encoder = Encoding; - break; + if (modelName.StartsWith(Prefix, StringComparison.OrdinalIgnoreCase)) + { + encoder = Encoding; + break; + } } } - } - if (encoder == ModelEncoding.None) + if (encoder == ModelEncoding.None) + { + throw new NotImplementedException($"Doesn't support this model [{modelName}]"); + } + + return CreateByEncoderNameAsync(encoder, extraSpecialTokens, normalizer, cancellationToken); + } + catch (Exception ex) { - throw new NotImplementedException($"Doesn't support this model [{modelName}]"); + return Task.FromException(ex); } - - return await CreateByEncoderNameAsync(encoder, extraSpecialTokens, normalizer).ConfigureAwait(false); } private const string Cl100kBaseRegexPattern = @"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"; @@ -402,36 +412,38 @@ public static async Task CreateByModelNameAsync( /// Encoder label /// Extra special tokens other than the built-in ones for the encoder /// To normalize the text before tokenization + /// used to request cancellation of the operation. /// The tokenizer /// Throws if the encoder is not supported - private static async Task CreateByEncoderNameAsync( + private static Task CreateByEncoderNameAsync( ModelEncoding modelEncoding, IReadOnlyDictionary? extraSpecialTokens, - Normalizer? normalizer) + Normalizer? normalizer, + CancellationToken cancellationToken) { switch (modelEncoding) { case ModelEncoding.Cl100kBase: var specialTokens = new Dictionary { { EndOfText, 100257}, { FimPrefix, 100258}, { FimMiddle, 100259}, { FimSuffix, 100260}, { EndOfPrompt, 100276} }; - return await CreateTikTokenTokenizerAsync(Cl100kBaseRegex(), Cl100kBaseVocabUrl, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false); + return CreateTikTokenTokenizerAsync(Cl100kBaseRegex(), Cl100kBaseVocabUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken); case ModelEncoding.P50kBase: specialTokens = new Dictionary { { EndOfText, 50256 } }; - return await CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false); + return CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken); case ModelEncoding.P50kEdit: specialTokens = new Dictionary { { EndOfText, 50256 }, { FimPrefix, 50281 }, { FimMiddle, 50282 }, { FimSuffix, 50283 } }; - return await CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false); + return CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken); case ModelEncoding.R50kBase: specialTokens = new Dictionary { { EndOfText, 50256 } }; - return await CreateTikTokenTokenizerAsync(P50kBaseRegex(), R50RanksUrl, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false); + return CreateTikTokenTokenizerAsync(P50kBaseRegex(), R50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken); case ModelEncoding.GPT2: specialTokens = new Dictionary { { EndOfText, 50256 }, }; - return await CreateTikTokenTokenizerAsync(P50kBaseRegex(), GPT2Url, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false); + return CreateTikTokenTokenizerAsync(P50kBaseRegex(), GPT2Url, specialTokens, extraSpecialTokens, normalizer, cancellationToken); default: Debug.Assert(false, $"Unexpected encoder [{modelEncoding}]"); @@ -449,13 +461,15 @@ private static async Task CreateByEncoderNameAsync( /// Special tokens mapping. This may be mutated by the method. /// Extra special tokens other than the built-in ones for the encoder /// To normalize the text before tokenization + /// used to request cancellation of the operation. /// The tokenizer private static async Task CreateTikTokenTokenizerAsync( Regex regex, string mergeableRanksFileUrl, Dictionary specialTokens, IReadOnlyDictionary? extraSpecialTokens, - Normalizer? normalizer) + Normalizer? normalizer, + CancellationToken cancellationToken) { if (extraSpecialTokens is not null) { @@ -467,9 +481,9 @@ private static async Task CreateTikTokenTokenizerAsync( if (!_tiktokenCache.TryGetValue(mergeableRanksFileUrl, out (Dictionary, int> encoder, Dictionary vocab, IReadOnlyDictionary decoder) cache)) { - using (Stream stream = await _httpClient.GetStreamAsync(mergeableRanksFileUrl).ConfigureAwait(false)) + using (Stream stream = await Helpers.GetStreamAsync(_httpClient, mergeableRanksFileUrl, cancellationToken).ConfigureAwait(false)) { - cache = await Tiktoken.LoadTikTokenBpeAsync(stream, useAsync: true).ConfigureAwait(false); + cache = await Tiktoken.LoadTikTokenBpeAsync(stream, useAsync: true, cancellationToken).ConfigureAwait(false); } _tiktokenCache.TryAdd(mergeableRanksFileUrl, cache); diff --git a/src/Microsoft.ML.Tokenizers/Utils/Helpers.netcoreapp.cs b/src/Microsoft.ML.Tokenizers/Utils/Helpers.netcoreapp.cs index 99d764a9cf..b64531431f 100644 --- a/src/Microsoft.ML.Tokenizers/Utils/Helpers.netcoreapp.cs +++ b/src/Microsoft.ML.Tokenizers/Utils/Helpers.netcoreapp.cs @@ -1,26 +1,41 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. using System; +using System.Buffers.Text; +using System.Diagnostics; using System.Globalization; +using System.IO; +using System.Threading.Tasks; +using System.Threading; +using System.Net.Http; namespace Microsoft.ML.Tokenizers { internal static class Helpers { + public static ValueTask ReadLineAsync(StreamReader reader, CancellationToken cancellationToken) => + reader.ReadLineAsync(cancellationToken); + + public static Task GetStreamAsync(HttpClient client, string url, CancellationToken cancellationToken) => + client.GetStreamAsync(url, cancellationToken); + public static byte[] FromBase64String(string base64String, int offset, int length) { - Span bytes = stackalloc byte[300]; - if (!Convert.TryFromBase64Chars(base64String.AsSpan().Slice(offset, length), bytes, out int bytesWritten)) + if (!Base64.IsValid(base64String.AsSpan(offset, length), out int decodedLength)) { - throw new System.FormatException($"Invalid base64 string '{base64String.Substring(offset, length)}'"); + throw new FormatException($"Invalid base64 string '{base64String.Substring(offset, length)}'"); } - return bytes.Slice(0, bytesWritten).ToArray(); + + byte[] bytes = new byte[decodedLength]; + bool success = Convert.TryFromBase64Chars(base64String.AsSpan(offset, length), bytes, out int bytesWritten); + Debug.Assert(success); + Debug.Assert(bytes.Length == bytesWritten); + return bytes; } internal static bool TryParseInt32(string s, int offset, out int result) => int.TryParse(s.AsSpan().Slice(offset), NumberStyles.None, CultureInfo.InvariantCulture, out result); } } - diff --git a/src/Microsoft.ML.Tokenizers/Utils/Helpers.netstandard.cs b/src/Microsoft.ML.Tokenizers/Utils/Helpers.netstandard.cs index 4f354cda5a..2979c99b6e 100644 --- a/src/Microsoft.ML.Tokenizers/Utils/Helpers.netstandard.cs +++ b/src/Microsoft.ML.Tokenizers/Utils/Helpers.netstandard.cs @@ -1,13 +1,30 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. using System; +using System.IO; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; namespace Microsoft.ML.Tokenizers { internal static class Helpers { + public static ValueTask ReadLineAsync(StreamReader reader, CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + return new ValueTask(reader.ReadLineAsync()); + } + + public static async Task GetStreamAsync(HttpClient client, string url, CancellationToken cancellationToken) + { + HttpResponseMessage response = await client.GetAsync(url, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false); + response.EnsureSuccessStatusCode(); + return await response.Content.ReadAsStreamAsync().ConfigureAwait(false); + } + public static byte[] FromBase64String(string base64String, int offset, int length) => Convert.FromBase64String(base64String.Substring(offset, length)); // Not support signed number