diff --git a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs index 9935dd6428..e74c8f6293 100644 --- a/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs +++ b/src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs @@ -20,7 +20,7 @@ public sealed class Tiktoken : Model { private readonly Dictionary, int> _encoder = null!; private readonly IReadOnlyDictionary _decoder = null!; - private readonly LruCache _cache; + private readonly LruCache? _cache; private readonly IReadOnlyDictionary? _specialTokensEncoder; private readonly Dictionary? _specialTokensDecoder; private readonly Dictionary _vocab = null!; @@ -96,7 +96,14 @@ private Tiktoken(Stream tikTokenBpeFileStream, IReadOnlyDictionary? private Tiktoken(int cacheSize) { - _cache = new LruCache(cacheSize); + if (cacheSize < 0) + { + throw new ArgumentOutOfRangeException(nameof(cacheSize)); + } + else if (cacheSize > 0) + { + _cache = new LruCache(cacheSize); + } } /// @@ -198,7 +205,7 @@ public override IReadOnlyList Tokenize(string sequence, bool isSpecialTok throw new InvalidOperationException($"The special token {sequence} doesn't exist in the tokenizer"); } - if (_cache.Lookup(sequence, out int[] ids)) + if (_cache?.Lookup(sequence, out int[] ids) is true) { tokens = new Token[ids.Length]; tokens[0] = new Token(ids[0], sequence, (0, sequence.Length)); @@ -222,7 +229,7 @@ public override IReadOnlyList Tokenize(string sequence, bool isSpecialTok int[] encodedIds = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder); Debug.Assert(encodedIds.Length > 0); - _cache.Add(sequence, encodedIds); + _cache?.Add(sequence, encodedIds); tokens = new Token[encodedIds.Length]; tokens[0] = new Token(encodedIds[0], sequence, (0, sequence.Length)); @@ -259,7 +266,7 @@ public override void TokenizeToIds(string sequence, bool isSpecialToken, IList.Shared.Return(arrayPoolArray); return encodedIds.Length; @@ -346,7 +353,7 @@ public override int CountTokens(string sequence, bool isSpecialToken) return specialTokenId; } - if (_cache.Lookup(token, out int[] ids)) + if (_cache?.Lookup(token, out int[] ids) is true) { if (ids.Length == 1) { @@ -367,7 +374,7 @@ public override int CountTokens(string sequence, bool isSpecialToken) int encodedLength = GetUtf8Bytes(token.AsSpan(), arrayPoolArray); int[] idsToCache = BytePairEncoder.BytePairEncode(arrayPoolArray.AsMemory(0, encodedLength), _encoder); - _cache.Add(token, idsToCache); + _cache?.Add(token, idsToCache); if (idsToCache.Length == 1) {