diff --git a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs
index 9935dd6428..74d95df23a 100644
--- a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs
+++ b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs
@@ -9,6 +9,7 @@
using System.IO;
using System.Linq;
using System.Text;
+using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.ML.Tokenizers
@@ -104,9 +105,11 @@ private Tiktoken(int cacheSize)
///
/// Stream to the BPE rank file
/// Whether to perform I/O synchronously or asynchronously.
+ /// used to request cancellation of the operation.
/// Map of byte[] to integer token id
///
- internal static async ValueTask<(Dictionary, int>, Dictionary, IReadOnlyDictionary)> LoadTikTokenBpeAsync(Stream tikTokenBpeFileStream, bool useAsync)
+ internal static async ValueTask<(Dictionary, int>, Dictionary, IReadOnlyDictionary)> LoadTikTokenBpeAsync(
+ Stream tikTokenBpeFileStream, bool useAsync, CancellationToken cancellationToken = default)
{
var encoder = new Dictionary, int>(ReadOnlyMemoryByteComparer.Instance);
var vocab = new Dictionary();
@@ -119,7 +122,7 @@ private Tiktoken(int cacheSize)
while (true)
{
string? line = useAsync ?
- await reader.ReadLineAsync().ConfigureAwait(false) :
+ await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) :
reader.ReadLine();
if (string.IsNullOrWhiteSpace(line))
{
@@ -136,10 +139,10 @@ await reader.ReadLineAsync().ConfigureAwait(false) :
throw new FormatException($"Invalid format in the BPE encoder file stream");
}
- byte[] tokenBytes = Helpers.FromBase64String(line, 0, spaceIndex);
-
if (Helpers.TryParseInt32(line, spaceIndex + 1, out int rank))
{
+ byte[] tokenBytes = Helpers.FromBase64String(line, 0, spaceIndex);
+
encoder[tokenBytes] = rank;
decoder[rank] = tokenBytes;
@@ -214,7 +217,7 @@ public override IReadOnlyList Tokenize(string sequence, bool isSpecialTok
// cache miss
if (_vocab.TryGetValue(sequence, out int mappedId))
{
- return new List { new(mappedId, sequence, (0, sequence.Length)) };
+ return new Token[1] { new(mappedId, sequence, (0, sequence.Length)) };
}
byte[] arrayPoolArray = ArrayPool.Shared.Rent(Encoding.UTF8.GetMaxByteCount(sequence.Length));
diff --git a/src/Microsoft.ML.Tokenizers/Tokenizer.cs b/src/Microsoft.ML.Tokenizers/Tokenizer.cs
index d002f55833..d29766d65c 100644
--- a/src/Microsoft.ML.Tokenizers/Tokenizer.cs
+++ b/src/Microsoft.ML.Tokenizers/Tokenizer.cs
@@ -9,6 +9,7 @@
using System.IO;
using System.Net.Http;
using System.Text.RegularExpressions;
+using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.ML.Tokenizers
@@ -346,32 +347,41 @@ private static readonly (string Prefix, ModelEncoding Encoding)[] _modelPrefixTo
/// Model name
/// Extra special tokens other than the built-in ones for the model
/// To normalize the text before tokenization
+ /// used to request cancellation of the operation.
/// The tokenizer
- public static async Task CreateByModelNameAsync(
+ public static Task CreateByModelNameAsync(
string modelName,
IReadOnlyDictionary? extraSpecialTokens = null,
- Normalizer? normalizer = null)
+ Normalizer? normalizer = null,
+ CancellationToken cancellationToken = default)
{
- ModelEncoding encoder;
-
- if (!_modelToEncoding.TryGetValue(modelName, out encoder))
+ try
{
- foreach ((string Prefix, ModelEncoding Encoding) in _modelPrefixToEncoding)
+ ModelEncoding encoder;
+
+ if (!_modelToEncoding.TryGetValue(modelName, out encoder))
{
- if (modelName.StartsWith(Prefix, StringComparison.OrdinalIgnoreCase))
+ foreach ((string Prefix, ModelEncoding Encoding) in _modelPrefixToEncoding)
{
- encoder = Encoding;
- break;
+ if (modelName.StartsWith(Prefix, StringComparison.OrdinalIgnoreCase))
+ {
+ encoder = Encoding;
+ break;
+ }
}
}
- }
- if (encoder == ModelEncoding.None)
+ if (encoder == ModelEncoding.None)
+ {
+ throw new NotImplementedException($"Doesn't support this model [{modelName}]");
+ }
+
+ return CreateByEncoderNameAsync(encoder, extraSpecialTokens, normalizer, cancellationToken);
+ }
+ catch (Exception ex)
{
- throw new NotImplementedException($"Doesn't support this model [{modelName}]");
+ return Task.FromException(ex);
}
-
- return await CreateByEncoderNameAsync(encoder, extraSpecialTokens, normalizer).ConfigureAwait(false);
}
private const string Cl100kBaseRegexPattern = @"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+";
@@ -402,36 +412,38 @@ public static async Task CreateByModelNameAsync(
/// Encoder label
/// Extra special tokens other than the built-in ones for the encoder
/// To normalize the text before tokenization
+ /// used to request cancellation of the operation.
/// The tokenizer
/// Throws if the encoder is not supported
- private static async Task CreateByEncoderNameAsync(
+ private static Task CreateByEncoderNameAsync(
ModelEncoding modelEncoding,
IReadOnlyDictionary? extraSpecialTokens,
- Normalizer? normalizer)
+ Normalizer? normalizer,
+ CancellationToken cancellationToken)
{
switch (modelEncoding)
{
case ModelEncoding.Cl100kBase:
var specialTokens = new Dictionary
{ { EndOfText, 100257}, { FimPrefix, 100258}, { FimMiddle, 100259}, { FimSuffix, 100260}, { EndOfPrompt, 100276} };
- return await CreateTikTokenTokenizerAsync(Cl100kBaseRegex(), Cl100kBaseVocabUrl, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false);
+ return CreateTikTokenTokenizerAsync(Cl100kBaseRegex(), Cl100kBaseVocabUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
case ModelEncoding.P50kBase:
specialTokens = new Dictionary { { EndOfText, 50256 } };
- return await CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false);
+ return CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
case ModelEncoding.P50kEdit:
specialTokens = new Dictionary
{ { EndOfText, 50256 }, { FimPrefix, 50281 }, { FimMiddle, 50282 }, { FimSuffix, 50283 } };
- return await CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false);
+ return CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
case ModelEncoding.R50kBase:
specialTokens = new Dictionary { { EndOfText, 50256 } };
- return await CreateTikTokenTokenizerAsync(P50kBaseRegex(), R50RanksUrl, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false);
+ return CreateTikTokenTokenizerAsync(P50kBaseRegex(), R50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
case ModelEncoding.GPT2:
specialTokens = new Dictionary { { EndOfText, 50256 }, };
- return await CreateTikTokenTokenizerAsync(P50kBaseRegex(), GPT2Url, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false);
+ return CreateTikTokenTokenizerAsync(P50kBaseRegex(), GPT2Url, specialTokens, extraSpecialTokens, normalizer, cancellationToken);
default:
Debug.Assert(false, $"Unexpected encoder [{modelEncoding}]");
@@ -449,13 +461,15 @@ private static async Task CreateByEncoderNameAsync(
/// Special tokens mapping. This may be mutated by the method.
/// Extra special tokens other than the built-in ones for the encoder
/// To normalize the text before tokenization
+ /// used to request cancellation of the operation.
/// The tokenizer
private static async Task CreateTikTokenTokenizerAsync(
Regex regex,
string mergeableRanksFileUrl,
Dictionary specialTokens,
IReadOnlyDictionary? extraSpecialTokens,
- Normalizer? normalizer)
+ Normalizer? normalizer,
+ CancellationToken cancellationToken)
{
if (extraSpecialTokens is not null)
{
@@ -467,9 +481,9 @@ private static async Task CreateTikTokenTokenizerAsync(
if (!_tiktokenCache.TryGetValue(mergeableRanksFileUrl, out (Dictionary, int> encoder, Dictionary vocab, IReadOnlyDictionary decoder) cache))
{
- using (Stream stream = await _httpClient.GetStreamAsync(mergeableRanksFileUrl).ConfigureAwait(false))
+ using (Stream stream = await Helpers.GetStreamAsync(_httpClient, mergeableRanksFileUrl, cancellationToken).ConfigureAwait(false))
{
- cache = await Tiktoken.LoadTikTokenBpeAsync(stream, useAsync: true).ConfigureAwait(false);
+ cache = await Tiktoken.LoadTikTokenBpeAsync(stream, useAsync: true, cancellationToken).ConfigureAwait(false);
}
_tiktokenCache.TryAdd(mergeableRanksFileUrl, cache);
diff --git a/src/Microsoft.ML.Tokenizers/Utils/Helpers.netcoreapp.cs b/src/Microsoft.ML.Tokenizers/Utils/Helpers.netcoreapp.cs
index 99d764a9cf..b64531431f 100644
--- a/src/Microsoft.ML.Tokenizers/Utils/Helpers.netcoreapp.cs
+++ b/src/Microsoft.ML.Tokenizers/Utils/Helpers.netcoreapp.cs
@@ -1,26 +1,41 @@
-// Licensed to the .NET Foundation under one or more agreements.
+// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
+using System.Buffers.Text;
+using System.Diagnostics;
using System.Globalization;
+using System.IO;
+using System.Threading.Tasks;
+using System.Threading;
+using System.Net.Http;
namespace Microsoft.ML.Tokenizers
{
internal static class Helpers
{
+ public static ValueTask ReadLineAsync(StreamReader reader, CancellationToken cancellationToken) =>
+ reader.ReadLineAsync(cancellationToken);
+
+ public static Task GetStreamAsync(HttpClient client, string url, CancellationToken cancellationToken) =>
+ client.GetStreamAsync(url, cancellationToken);
+
public static byte[] FromBase64String(string base64String, int offset, int length)
{
- Span bytes = stackalloc byte[300];
- if (!Convert.TryFromBase64Chars(base64String.AsSpan().Slice(offset, length), bytes, out int bytesWritten))
+ if (!Base64.IsValid(base64String.AsSpan(offset, length), out int decodedLength))
{
- throw new System.FormatException($"Invalid base64 string '{base64String.Substring(offset, length)}'");
+ throw new FormatException($"Invalid base64 string '{base64String.Substring(offset, length)}'");
}
- return bytes.Slice(0, bytesWritten).ToArray();
+
+ byte[] bytes = new byte[decodedLength];
+ bool success = Convert.TryFromBase64Chars(base64String.AsSpan(offset, length), bytes, out int bytesWritten);
+ Debug.Assert(success);
+ Debug.Assert(bytes.Length == bytesWritten);
+ return bytes;
}
internal static bool TryParseInt32(string s, int offset, out int result)
=> int.TryParse(s.AsSpan().Slice(offset), NumberStyles.None, CultureInfo.InvariantCulture, out result);
}
}
-
diff --git a/src/Microsoft.ML.Tokenizers/Utils/Helpers.netstandard.cs b/src/Microsoft.ML.Tokenizers/Utils/Helpers.netstandard.cs
index 4f354cda5a..2979c99b6e 100644
--- a/src/Microsoft.ML.Tokenizers/Utils/Helpers.netstandard.cs
+++ b/src/Microsoft.ML.Tokenizers/Utils/Helpers.netstandard.cs
@@ -1,13 +1,30 @@
-// Licensed to the .NET Foundation under one or more agreements.
+// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
+using System.IO;
+using System.Net.Http;
+using System.Threading;
+using System.Threading.Tasks;
namespace Microsoft.ML.Tokenizers
{
internal static class Helpers
{
+ public static ValueTask ReadLineAsync(StreamReader reader, CancellationToken cancellationToken)
+ {
+ cancellationToken.ThrowIfCancellationRequested();
+ return new ValueTask(reader.ReadLineAsync());
+ }
+
+ public static async Task GetStreamAsync(HttpClient client, string url, CancellationToken cancellationToken)
+ {
+ HttpResponseMessage response = await client.GetAsync(url, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false);
+ response.EnsureSuccessStatusCode();
+ return await response.Content.ReadAsStreamAsync().ConfigureAwait(false);
+ }
+
public static byte[] FromBase64String(string base64String, int offset, int length) => Convert.FromBase64String(base64String.Substring(offset, length));
// Not support signed number