diff --git a/LLama.Benchmark/Collections/FixedSizeQueueBenchmark.cs b/LLama.Benchmark/Collections/FixedSizeQueueBenchmark.cs new file mode 100644 index 000000000..653ffb8cb --- /dev/null +++ b/LLama.Benchmark/Collections/FixedSizeQueueBenchmark.cs @@ -0,0 +1,46 @@ +using System.Linq; +using BenchmarkDotNet.Attributes; +using BenchmarkDotNet.Engines; +using BenchmarkDotNet.Jobs; +using LLama.Common; + +namespace LLama.Benchmark.Collections; + +[SimpleJob(RunStrategy.Throughput, RuntimeMoniker.Net80)] +[MemoryDiagnoser] +[BenchmarkCategory("Collections", "FixedSizeQueue")] +public class FixedSizeQueueBenchmark +{ + [Params(32, 512, 4096)] + public int Capacity { get; set; } + + private int[] _values = Array.Empty(); + + [GlobalSetup] + public void Setup() + { + _values = Enumerable.Range(0, Capacity * 4).ToArray(); + } + + [Benchmark] + public int EnqueueWrap() + { + var queue = new FixedSizeQueue(Capacity); + foreach (var value in _values) + queue.Enqueue(value); + return queue.Count; + } + + [Benchmark] + public int IterateTailSum() + { + var queue = new FixedSizeQueue(Capacity); + foreach (var value in _values) + queue.Enqueue(value); + + var sum = 0; + foreach (var value in queue) + sum += value; + return sum; + } +} diff --git a/LLama/AntipromptProcessor.cs b/LLama/AntipromptProcessor.cs index c18c0915d..e4ec0f188 100644 --- a/LLama/AntipromptProcessor.cs +++ b/LLama/AntipromptProcessor.cs @@ -11,7 +11,7 @@ public sealed class AntipromptProcessor private int _longestAntiprompt; private readonly List _antiprompts = new(); - private string? _string; + private string _buffer = string.Empty; /// @@ -46,6 +46,8 @@ public void SetAntiprompts(IEnumerable antiprompts) _longestAntiprompt = 0; foreach (var antiprompt in _antiprompts) _longestAntiprompt = Math.Max(_longestAntiprompt, antiprompt.Length); + + _buffer = string.Empty; } /// @@ -55,21 +57,21 @@ public void SetAntiprompts(IEnumerable antiprompts) /// true if the text buffer ends with any antiprompt public bool Add(string text) { - _string += text; + _buffer += text; // When the string gets very long (4x antiprompt length) trim it down (to 2x antiprompt length). // This trimming leaves a lot of extra characters because two sequences can be considered "equal" in unicode // even with different numbers of characters. Hopefully there are enough characters here to handle all those weird circumstances! var maxLength = Math.Max(32, _longestAntiprompt * 4); var trimLength = Math.Max(16, _longestAntiprompt * 2); - if (_string.Length > maxLength) - _string = _string.Substring(_string.Length - trimLength); + if (_buffer.Length > maxLength) + _buffer = _buffer.Substring(_buffer.Length - trimLength); foreach (var antiprompt in _antiprompts) - if (_string.EndsWith(antiprompt, StringComparison.CurrentCulture)) + if (_buffer.EndsWith(antiprompt, StringComparison.CurrentCulture)) return true; return false; } } -} \ No newline at end of file +} diff --git a/LLama/Batched/BatchedExecutor.cs b/LLama/Batched/BatchedExecutor.cs index cdb1835e4..462e9e555 100644 --- a/LLama/Batched/BatchedExecutor.cs +++ b/LLama/Batched/BatchedExecutor.cs @@ -1,7 +1,6 @@ using System; using System.Collections.Generic; using System.Diagnostics; -using System.Linq; using System.Threading; using System.Threading.Tasks; using LLama.Abstractions; @@ -16,7 +15,12 @@ public sealed class BatchedExecutor : IDisposable { private int _nextSequenceId; - private readonly List _batchQueue = [ ]; + private readonly List _batchQueue = []; + private int _batchQueueHead; + private int _batchedTokenCount; + private bool _batchedTokenCountDirty = true; + // Skip compacting the queue until this many processed batches accumulate at the front. + private const int CleanupThreshold = 16; /// /// Set to 1 using interlocked exchange while inference is running @@ -42,12 +46,27 @@ public sealed class BatchedExecutor /// /// Get the number of tokens in the batch, waiting for to be called /// - public int BatchedTokenCount => _batchQueue.Sum(a => a.ItemCount); + public int BatchedTokenCount + { + get + { + if (_batchedTokenCountDirty) + { + var total = 0; + for (var i = _batchQueueHead; i < _batchQueue.Count; i++) + total += _batchQueue[i].ItemCount; + _batchedTokenCount = total; + _batchedTokenCountDirty = false; + } + + return _batchedTokenCount; + } + } /// /// Number of batches in the queue, waiting for to be called /// - public int BatchQueueCount => _batchQueue.Count; + public int BatchQueueCount => _batchQueue.Count - _batchQueueHead; /// /// Check if this executor has been disposed. @@ -147,12 +166,13 @@ public async Task Infer(CancellationToken cancellation = default) // again after the issue has been fixed (e.g. some KV cache space has been freed) to retry this operation. if (status != DecodeResult.Ok) { - _batchQueue.Insert(0, next); + RequeueFront(next); return status; } // Everything was ok, advance the epoch Epoch++; + CleanupQueue(); return status; } @@ -166,13 +186,45 @@ public async Task Infer(CancellationToken cancellation = default) IBatch? GetNextBatch() { - if (_batchQueue.Count == 0) + if (_batchQueueHead >= _batchQueue.Count) + { + _batchQueue.Clear(); + _batchQueueHead = 0; return null; - - var nextBatch = _batchQueue[0]; - _batchQueue.RemoveAt(0); + } + + var nextBatch = _batchQueue[_batchQueueHead]; + _batchQueueHead++; + _batchedTokenCountDirty = true; return nextBatch; } + + void RequeueFront(IBatch batch) + { + Debug.Assert(_batchQueueHead > 0, "Cannot requeue batch when queue head is at zero."); + _batchQueue[--_batchQueueHead] = batch; + _batchedTokenCountDirty = true; + } + + // Remove batches that have already been consumed so the head index does not grow without bound. + void CleanupQueue() + { + if (_batchQueueHead == 0) + return; + + if (_batchQueueHead >= _batchQueue.Count) + { + _batchQueue.Clear(); + _batchQueueHead = 0; + return; + } + + if (_batchQueueHead > CleanupThreshold && _batchQueueHead > _batchQueue.Count / 2) + { + _batchQueue.RemoveRange(0, _batchQueueHead); + _batchQueueHead = 0; + } + } } /// @@ -202,7 +254,7 @@ internal LLamaSeqId GetNextSequenceId() throw new ArgumentOutOfRangeException(nameof(minCapacity), $"Request batch capacity must be less than or equal to BatchSize ({Context.BatchSize})"); // Find a batch with space for at least minCapacity tokens - for (var i = 0; i < _batchQueue.Count; i++) + for (var i = _batchQueueHead; i < _batchQueue.Count; i++) { var item = _batchQueue[i]; if (item is not TokenBatch { Batch: var batch }) @@ -213,13 +265,17 @@ internal LLamaSeqId GetNextSequenceId() continue; if (batch.TokenCount < Context.BatchSize) - return (batch, Epoch + (uint)(i + 1) * 2); + { + _batchedTokenCountDirty = true; + return (batch, Epoch + (uint)(i - _batchQueueHead + 1) * 2); + } } // Add a new batch to the end of the queue var end = new LLamaBatch(); _batchQueue.Add(new TokenBatch(end)); - return (end, Epoch + (uint)_batchQueue.Count * 2); + _batchedTokenCountDirty = true; + return (end, Epoch + (uint)(_batchQueue.Count - _batchQueueHead) * 2); } /// @@ -234,7 +290,7 @@ internal LLamaSeqId GetNextSequenceId() throw new ArgumentOutOfRangeException(nameof(minCapacity), $"Request batch capacity must be less than or equal to BatchSize ({Context.BatchSize})"); // Find a batch with space for at least minCapacity embeddings - for (var i = 0; i < _batchQueue.Count; i++) + for (var i = _batchQueueHead; i < _batchQueue.Count; i++) { var item = _batchQueue[i]; if (item is not EmbeddingBatch { Batch: var batch }) @@ -245,13 +301,17 @@ internal LLamaSeqId GetNextSequenceId() continue; if (batch.EmbeddingsCount < Context.BatchSize) - return (batch, Epoch + (uint)(i + 1) * 2); + { + _batchedTokenCountDirty = true; + return (batch, Epoch + (uint)(i - _batchQueueHead + 1) * 2); + } } // Add a new batch to the end of the queue var end = new LLamaBatchEmbeddings(Context.EmbeddingSize); _batchQueue.Add(new EmbeddingBatch(end)); - return (end, Epoch + (uint)_batchQueue.Count * 2); + _batchedTokenCountDirty = true; + return (end, Epoch + (uint)(_batchQueue.Count - _batchQueueHead) * 2); } #region batches @@ -286,4 +346,4 @@ public Task DecodeAsync(LLamaContext ctx, CancellationToken token) } } #endregion -} \ No newline at end of file +} diff --git a/LLama/Common/FixedSizeQueue.cs b/LLama/Common/FixedSizeQueue.cs index 62056498c..d1f2fb11d 100644 --- a/LLama/Common/FixedSizeQueue.cs +++ b/LLama/Common/FixedSizeQueue.cs @@ -6,21 +6,38 @@ namespace LLama.Common { /// - /// A queue with fixed storage size. - /// Currently it's only a naive implementation and needs to be further optimized in the future. + /// A queue with fixed storage size backed by a circular buffer. /// public class FixedSizeQueue : IReadOnlyList { - private readonly List _storage; + private readonly T[] _buffer; + private int _start; + private int _count; + private T[]? _window; + + // Minimum capacity for the temporary buffer used to expose a contiguous view. + private const int MinimumWindowSize = 4; + // Resize multiplier for the temporary buffer to reduce copy churn as it grows. + private const int WindowGrowthFactor = 2; /// - public T this[int index] => _storage[index]; + public T this[int index] + { + get + { + if ((uint)index >= (uint)_count) + throw new ArgumentOutOfRangeException(nameof(index)); + + var actualIndex = (_start + index) % Capacity; + return _buffer[actualIndex]; + } + } /// /// Number of items in this queue /// - public int Count => _storage.Count; + public int Count => _count; /// /// Maximum number of items allowed in this queue @@ -28,53 +45,78 @@ public class FixedSizeQueue public int Capacity { get; } /// - /// Create a new queue + /// Create a new queue. /// - /// the maximum number of items to store in this queue + /// The maximum number of items to store in this queue. public FixedSizeQueue(int size) { + if (size <= 0) + throw new ArgumentOutOfRangeException(nameof(size), size, "Capacity must be greater than zero."); + Capacity = size; - _storage = new(); + _buffer = new T[size]; + _start = 0; + _count = 0; } /// - /// Fill the quene with the data. Please ensure that data.Count <= size + /// Fill the queue with existing data. Please ensure that data.Count <= size /// /// /// public FixedSizeQueue(int size, IEnumerable data) + : this(size) { #if NET6_0_OR_GREATER - // Try to check the size without enumerating the entire IEnumerable. This may not be able to get the count, - // in which case we'll have to check later if (data.TryGetNonEnumeratedCount(out var dataCount) && dataCount > size) - throw new ArgumentException($"The max size set for the quene is {size}, but got {dataCount} initial values."); + throw new ArgumentException($"The max size set for the queue is {size}, but got {dataCount} initial values."); #endif - // Size of "data" is unknown, copy it all into a list - Capacity = size; - _storage = new List(data); + if (data is ICollection collection) + { + if (collection.Count > size) + throw new ArgumentException($"The max size set for the queue is {size}, but got {collection.Count} initial values."); + + foreach (var item in collection) + Enqueue(item); + return; + } - // Now check if that list is a valid size. - if (_storage.Count > Capacity) - throw new ArgumentException($"The max size set for the quene is {size}, but got {_storage.Count} initial values."); + var index = 0; + foreach (var item in data) + { + if (index >= size) + throw new ArgumentException($"The max size set for the queue is {size}, but got {index + 1} initial values."); + + Enqueue(item); + index++; + } } /// - /// Enquene an element. + /// Enqueue an element. When the queue is full the oldest element is overwritten. /// - /// public void Enqueue(T item) { - _storage.Add(item); - if (_storage.Count > Capacity) - _storage.RemoveAt(0); + if (_count < Capacity) + { + var tail = (_start + _count) % Capacity; + _buffer[tail] = item; + _count++; + } + else + { + _buffer[_start] = item; + _start++; + if (_start == Capacity) + _start = 0; + } } /// public IEnumerator GetEnumerator() { - return _storage.GetEnumerator(); + return Enumerate().GetEnumerator(); } /// @@ -83,17 +125,12 @@ IEnumerator IEnumerable.GetEnumerator() return GetEnumerator(); } - internal ReadOnlySpan AsSpan(int count) + private IEnumerable Enumerate() { - // Ensure the request isn't for more tokens than actually exist - count = Math.Min(count, Count); - - // Take `count` items from the end -#if NET8_0_OR_GREATER - return CollectionsMarshal.AsSpan(_storage)[^count..]; -#else - return _storage.ToArray().AsSpan(_storage.Count - count, count); -#endif + for (var i = 0; i < _count; i++) + { + yield return this[i]; + } } } } diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index 36989006e..e3efb35a5 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -64,6 +64,11 @@ public abstract class StatefulExecutorBase : ILLamaExecutor /// public LLamaContext Context { get; } + /// + /// Tracks anti-prompts across streamed output. + /// + protected AntipromptProcessor AntipromptProcessor { get; } + // LLava Section // /// @@ -98,6 +103,7 @@ protected StatefulExecutorBase(LLamaContext context, ILogger? logger = null) _n_session_consumed = 0; _last_n_tokens = new FixedSizeQueue((int)Context.ContextSize); _decoder = new StreamingTokenDecoder(context); + AntipromptProcessor = new AntipromptProcessor(); } /// @@ -214,7 +220,8 @@ protected virtual void TryReuseMatchingPrefix() { if (_embeds[i] != _session_tokens[_n_session_consumed]) { - _session_tokens = _session_tokens.Take(_n_session_consumed).ToList(); + if (_session_tokens.Count > _n_session_consumed) + _session_tokens.RemoveRange(_n_session_consumed, _session_tokens.Count - _n_session_consumed); break; } @@ -310,6 +317,8 @@ public virtual async IAsyncEnumerable InferAsync(string? text, IInferenc NeedToSaveSession = !string.IsNullOrEmpty(_pathSession) && _n_matching_session_tokens < _embed_inps.Count }; + AntipromptProcessor.SetAntiprompts(inferenceParams.AntiPrompts ?? Array.Empty()); + await PreprocessInputs(text, args); while (await GetLoopCondition(args)) @@ -318,12 +327,15 @@ public virtual async IAsyncEnumerable InferAsync(string? text, IInferenc { break; } + args.LastOutput = string.Empty; await InferInternal(inferenceParams, args); if (args.ReturnValue) { _decoder.AddRange(_embeds); - yield return _decoder.Read(); + var decoded = _decoder.Read(); + args.LastOutput = decoded; + yield return decoded; } var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args); @@ -394,6 +406,11 @@ protected class InferStateArgs /// /// public bool NeedToSaveSession { get; set; } + + /// + /// Most recent decoded output from the model. + /// + public string LastOutput { get; set; } = string.Empty; } #pragma warning disable CS1591, CS8618 // Missing XML and irrelevant nullable warnings diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 331591fba..f4aec5b6e 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -158,7 +158,7 @@ protected override Task PreprocessInputs(string? text, InferStateArgs args) { if (_embed_inps.Count <= _consumedTokensCount) { - if (_last_n_tokens.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding)) + if (!string.IsNullOrEmpty(args.LastOutput) && AntipromptProcessor.Add(args.LastOutput)) { args.WaitForInput = true; return (true, Array.Empty()); diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 7c9558ee3..1baebfa7e 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -207,7 +207,7 @@ private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = tru { if (_embed_inps.Count <= _consumedTokensCount) { - if (_last_n_tokens.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding)) + if (!string.IsNullOrEmpty(args.LastOutput) && AntipromptProcessor.Add(args.LastOutput)) args.WaitForInput = true; if (_pastTokensCount > 0 && args.WaitForInput) diff --git a/LLama/StreamingTokenDecoder.cs b/LLama/StreamingTokenDecoder.cs index ead0ee88e..cd821fb3b 100644 --- a/LLama/StreamingTokenDecoder.cs +++ b/LLama/StreamingTokenDecoder.cs @@ -1,6 +1,8 @@ using System.Diagnostics; using System; using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; using System.Text; using LLama.Native; @@ -181,7 +183,12 @@ public string Read() if (_characters.Count == 0) return ""; - var str = string.Join("", _characters); +#if NET5_0_OR_GREATER + var span = CollectionsMarshal.AsSpan(_characters); + var str = new string(span); +#else + var str = new string(_characters.ToArray()); +#endif _characters.Clear(); return str; }