Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions LLama.Benchmark/Collections/FixedSizeQueueBenchmark.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
using System.Linq;
using BenchmarkDotNet.Attributes;
using BenchmarkDotNet.Engines;
using BenchmarkDotNet.Jobs;
using LLama.Common;

namespace LLama.Benchmark.Collections;

[SimpleJob(RunStrategy.Throughput, RuntimeMoniker.Net80)]
[MemoryDiagnoser]
[BenchmarkCategory("Collections", "FixedSizeQueue")]
public class FixedSizeQueueBenchmark
{
[Params(32, 512, 4096)]
public int Capacity { get; set; }

private int[] _values = Array.Empty<int>();

[GlobalSetup]
public void Setup()
{
_values = Enumerable.Range(0, Capacity * 4).ToArray();
}

[Benchmark]
public int EnqueueWrap()
{
var queue = new FixedSizeQueue<int>(Capacity);
foreach (var value in _values)
queue.Enqueue(value);
return queue.Count;
}

[Benchmark]
public int IterateTailSum()
{
var queue = new FixedSizeQueue<int>(Capacity);
foreach (var value in _values)
queue.Enqueue(value);

var sum = 0;
foreach (var value in queue)
sum += value;
return sum;
}
}
14 changes: 8 additions & 6 deletions LLama/AntipromptProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ public sealed class AntipromptProcessor
private int _longestAntiprompt;
private readonly List<string> _antiprompts = new();

private string? _string;
private string _buffer = string.Empty;


/// <summary>
Expand Down Expand Up @@ -46,6 +46,8 @@ public void SetAntiprompts(IEnumerable<string> antiprompts)
_longestAntiprompt = 0;
foreach (var antiprompt in _antiprompts)
_longestAntiprompt = Math.Max(_longestAntiprompt, antiprompt.Length);

_buffer = string.Empty;
}

/// <summary>
Expand All @@ -55,21 +57,21 @@ public void SetAntiprompts(IEnumerable<string> antiprompts)
/// <returns>true if the text buffer ends with any antiprompt</returns>
public bool Add(string text)
{
_string += text;
_buffer += text;

// When the string gets very long (4x antiprompt length) trim it down (to 2x antiprompt length).
// This trimming leaves a lot of extra characters because two sequences can be considered "equal" in unicode
// even with different numbers of characters. Hopefully there are enough characters here to handle all those weird circumstances!
var maxLength = Math.Max(32, _longestAntiprompt * 4);
var trimLength = Math.Max(16, _longestAntiprompt * 2);
if (_string.Length > maxLength)
_string = _string.Substring(_string.Length - trimLength);
if (_buffer.Length > maxLength)
_buffer = _buffer.Substring(_buffer.Length - trimLength);

foreach (var antiprompt in _antiprompts)
if (_string.EndsWith(antiprompt, StringComparison.CurrentCulture))
if (_buffer.EndsWith(antiprompt, StringComparison.CurrentCulture))
return true;

return false;
}
}
}
}
92 changes: 76 additions & 16 deletions LLama/Batched/BatchedExecutor.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using LLama.Abstractions;
Expand All @@ -16,7 +15,12 @@ public sealed class BatchedExecutor
: IDisposable
{
private int _nextSequenceId;
private readonly List<IBatch> _batchQueue = [ ];
private readonly List<IBatch> _batchQueue = [];
private int _batchQueueHead;
private int _batchedTokenCount;
private bool _batchedTokenCountDirty = true;
// Skip compacting the queue until this many processed batches accumulate at the front.
private const int CleanupThreshold = 16;

/// <summary>
/// Set to 1 using interlocked exchange while inference is running
Expand All @@ -42,12 +46,27 @@ public sealed class BatchedExecutor
/// <summary>
/// Get the number of tokens in the batch, waiting for <see cref="Infer"/> to be called
/// </summary>
public int BatchedTokenCount => _batchQueue.Sum(a => a.ItemCount);
public int BatchedTokenCount
{
get
{
if (_batchedTokenCountDirty)
{
var total = 0;
for (var i = _batchQueueHead; i < _batchQueue.Count; i++)
total += _batchQueue[i].ItemCount;
_batchedTokenCount = total;
_batchedTokenCountDirty = false;
}

return _batchedTokenCount;
}
}

/// <summary>
/// Number of batches in the queue, waiting for <see cref="Infer"/> to be called
/// </summary>
public int BatchQueueCount => _batchQueue.Count;
public int BatchQueueCount => _batchQueue.Count - _batchQueueHead;

/// <summary>
/// Check if this executor has been disposed.
Expand Down Expand Up @@ -147,12 +166,13 @@ public async Task<DecodeResult> Infer(CancellationToken cancellation = default)
// again after the issue has been fixed (e.g. some KV cache space has been freed) to retry this operation.
if (status != DecodeResult.Ok)
{
_batchQueue.Insert(0, next);
RequeueFront(next);
return status;
}

// Everything was ok, advance the epoch
Epoch++;
CleanupQueue();

return status;
}
Expand All @@ -166,13 +186,45 @@ public async Task<DecodeResult> Infer(CancellationToken cancellation = default)

IBatch? GetNextBatch()
{
if (_batchQueue.Count == 0)
if (_batchQueueHead >= _batchQueue.Count)
{
_batchQueue.Clear();
_batchQueueHead = 0;
return null;

var nextBatch = _batchQueue[0];
_batchQueue.RemoveAt(0);
}

var nextBatch = _batchQueue[_batchQueueHead];
_batchQueueHead++;
_batchedTokenCountDirty = true;
return nextBatch;
}

void RequeueFront(IBatch batch)
{
Debug.Assert(_batchQueueHead > 0, "Cannot requeue batch when queue head is at zero.");
_batchQueue[--_batchQueueHead] = batch;
_batchedTokenCountDirty = true;
}

// Remove batches that have already been consumed so the head index does not grow without bound.
void CleanupQueue()
{
if (_batchQueueHead == 0)
return;

if (_batchQueueHead >= _batchQueue.Count)
{
_batchQueue.Clear();
_batchQueueHead = 0;
return;
}

if (_batchQueueHead > CleanupThreshold && _batchQueueHead > _batchQueue.Count / 2)
{
_batchQueue.RemoveRange(0, _batchQueueHead);
_batchQueueHead = 0;
}
}
}

/// <inheritdoc />
Expand Down Expand Up @@ -202,7 +254,7 @@ internal LLamaSeqId GetNextSequenceId()
throw new ArgumentOutOfRangeException(nameof(minCapacity), $"Request batch capacity must be less than or equal to BatchSize ({Context.BatchSize})");

// Find a batch with space for at least minCapacity tokens
for (var i = 0; i < _batchQueue.Count; i++)
for (var i = _batchQueueHead; i < _batchQueue.Count; i++)
{
var item = _batchQueue[i];
if (item is not TokenBatch { Batch: var batch })
Expand All @@ -213,13 +265,17 @@ internal LLamaSeqId GetNextSequenceId()
continue;

if (batch.TokenCount < Context.BatchSize)
return (batch, Epoch + (uint)(i + 1) * 2);
{
_batchedTokenCountDirty = true;
return (batch, Epoch + (uint)(i - _batchQueueHead + 1) * 2);
}
}

// Add a new batch to the end of the queue
var end = new LLamaBatch();
_batchQueue.Add(new TokenBatch(end));
return (end, Epoch + (uint)_batchQueue.Count * 2);
_batchedTokenCountDirty = true;
return (end, Epoch + (uint)(_batchQueue.Count - _batchQueueHead) * 2);
}

/// <summary>
Expand All @@ -234,7 +290,7 @@ internal LLamaSeqId GetNextSequenceId()
throw new ArgumentOutOfRangeException(nameof(minCapacity), $"Request batch capacity must be less than or equal to BatchSize ({Context.BatchSize})");

// Find a batch with space for at least minCapacity embeddings
for (var i = 0; i < _batchQueue.Count; i++)
for (var i = _batchQueueHead; i < _batchQueue.Count; i++)
{
var item = _batchQueue[i];
if (item is not EmbeddingBatch { Batch: var batch })
Expand All @@ -245,13 +301,17 @@ internal LLamaSeqId GetNextSequenceId()
continue;

if (batch.EmbeddingsCount < Context.BatchSize)
return (batch, Epoch + (uint)(i + 1) * 2);
{
_batchedTokenCountDirty = true;
return (batch, Epoch + (uint)(i - _batchQueueHead + 1) * 2);
}
}

// Add a new batch to the end of the queue
var end = new LLamaBatchEmbeddings(Context.EmbeddingSize);
_batchQueue.Add(new EmbeddingBatch(end));
return (end, Epoch + (uint)_batchQueue.Count * 2);
_batchedTokenCountDirty = true;
return (end, Epoch + (uint)(_batchQueue.Count - _batchQueueHead) * 2);
}

#region batches
Expand Down Expand Up @@ -286,4 +346,4 @@ public Task<DecodeResult> DecodeAsync(LLamaContext ctx, CancellationToken token)
}
}
#endregion
}
}
Loading
Loading