diff --git a/LLama.Unittest/LLamaEmbedderTests.cs b/LLama.Unittest/LLamaEmbedderTests.cs index 7d7654126..5c01984b0 100644 --- a/LLama.Unittest/LLamaEmbedderTests.cs +++ b/LLama.Unittest/LLamaEmbedderTests.cs @@ -41,43 +41,40 @@ private async Task CompareEmbeddings(string modelPath) var spoon = (await embedder.GetEmbeddings("The spoon is not real")).Single().EuclideanNormalization(); Assert.DoesNotContain(float.NaN, spoon); - - if (false) - { - //TODO: the below does not work with the new memory efficient context handling - we probably need to define Microsoft.Extensions.AI.IEmbeddingGenerator GetService interface that creates the context on the fly - - var generator = (IEmbeddingGenerator>)embedder; - Assert.NotNull(generator.GetService()); - Assert.Equal(nameof(LLamaEmbedder), generator.GetService()?.ProviderName); - Assert.NotNull(generator.GetService()?.DefaultModelId); - Assert.NotEmpty(generator.GetService()?.DefaultModelId!); - Assert.Same(embedder, generator.GetService()); - Assert.Same(generator, generator.GetService>>()); - Assert.Null(generator.GetService()); - - var embeddings = await generator.GenerateAsync( - [ - "The cat is cute", - "The kitten is cute", - "The spoon is not real" - ]); - Assert.All(cat.Zip(embeddings[0].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001)); - Assert.All(kitten.Zip(embeddings[1].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001)); - Assert.All(spoon.Zip(embeddings[2].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001)); - - _testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]"); - _testOutputHelper.WriteLine($"Kitten = [{string.Join(",", kitten.AsMemory().Slice(0, 7).ToArray())}...]"); - _testOutputHelper.WriteLine($"Spoon = [{string.Join(",", spoon.AsMemory().Slice(0, 7).ToArray())}...]"); - - var close = 1 - Dot(cat, kitten); - var far = 1 - Dot(cat, spoon); - - _testOutputHelper.WriteLine(""); - _testOutputHelper.WriteLine($"Cat.Kitten (Close): {close:F4}"); - _testOutputHelper.WriteLine($"Cat.Spoon (Far): {far:F4}"); - - Assert.True(close < far); - } + + using var context = new LLamaContext(weights, @params); + var managedEmbedder = new LLamaEmbedder(context); + IEmbeddingGenerator> generator = managedEmbedder; + Assert.NotNull(generator.GetService()); + Assert.Equal(nameof(LLamaEmbedder), generator.GetService()?.ProviderName); + Assert.NotNull(generator.GetService()?.DefaultModelId); + Assert.NotEmpty(generator.GetService()?.DefaultModelId!); + Assert.Same(managedEmbedder, generator.GetService()); + Assert.Same(generator, generator.GetService>>()); + Assert.Null(generator.GetService()); + + var embeddings = await generator.GenerateAsync( + [ + "The cat is cute", + "The kitten is cute", + "The spoon is not real" + ]); + Assert.All(cat.Zip(embeddings[0].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001)); + Assert.All(kitten.Zip(embeddings[1].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001)); + Assert.All(spoon.Zip(embeddings[2].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001)); + + _testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]"); + _testOutputHelper.WriteLine($"Kitten = [{string.Join(",", kitten.AsMemory().Slice(0, 7).ToArray())}...]"); + _testOutputHelper.WriteLine($"Spoon = [{string.Join(",", spoon.AsMemory().Slice(0, 7).ToArray())}...]"); + + var close = 1 - Dot(cat, kitten); + var far = 1 - Dot(cat, spoon); + + _testOutputHelper.WriteLine(""); + _testOutputHelper.WriteLine($"Cat.Kitten (Close): {close:F4}"); + _testOutputHelper.WriteLine($"Cat.Spoon (Far): {far:F4}"); + + Assert.True(close < far); } [Fact] diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 42d76c514..7d6fa6e1a 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -109,6 +109,84 @@ public LLamaToken[] Tokenize(string text, bool addBos = true, bool special = fal return NativeHandle.Tokenize(text, addBos, special, Encoding); } + #region Sequence ID management + private LLamaSeqIdManager? _seqIdManager; + + /// + /// Get the sequence ID manager for this context. + /// + public LLamaSeqIdManager SequenceManager + { + get + { + var manager = _seqIdManager; + if (manager != null) return manager; + var newManager = new LLamaSeqIdManager(Params.SeqMax); + var original = Interlocked.CompareExchange(ref _seqIdManager, newManager, comparand: null); + manager = original ?? newManager; + return manager; + } + } + + /// + /// Returns the next available sequence ID for use in model operations. + /// Callers will asynchronously wait if none are available. + /// On disposal, the sequence ID is returned to the owning for reuse. + /// + /// + /// Failure to dispose the returned will likely result in undefined behavior. + /// + /// + /// The returned sequence represents an exclusive reservation on the sequence ID within the context. + /// For the duration of the , no other caller will receive the same sequence ID from this context. + /// + /// flag indicating whether to remove memory associated with the sequence ID when it is released back to the manager. + /// optional timeout for acquiring a sequence ID. If null, waits indefinitely. + /// cancellation token to cancel the wait operation. + /// The next available sequence ID. + public async Task AcquireSequenceIdAsync(bool removeMemoryOnRelease = false, TimeSpan? timeout = null, CancellationToken cancellationToken = default) + { + var seqId = await SequenceManager.NextAsync(timeout, cancellationToken).ConfigureAwait(false); + return new ManagedLLamaSeqId(owner: this, seqId, removeMemoryOnRelease); + } + + /// + /// Represents a managed that is returned to the owning when disposed. + /// + public readonly struct ManagedLLamaSeqId : IDisposable + { + private readonly LLamaContext? _owner; + private readonly bool _removeMemoryOnRelease; + + /// + /// The sequence ID. + /// + public LLamaSeqId SeqId { get; } + + /// + /// Implicit conversion to . + /// + /// managed sequence ID. + /// the underlying sequence ID. + public static implicit operator LLamaSeqId(ManagedLLamaSeqId managedSeqId) => managedSeqId.SeqId; + + internal ManagedLLamaSeqId(LLamaContext owner, LLamaSeqId seqId, bool removeMemoryOnRelease) + { + _owner = owner; + SeqId = seqId; + _removeMemoryOnRelease = removeMemoryOnRelease; + } + + /// + public void Dispose() + { + if (_owner == null || _owner.NativeHandle.IsClosed) return; + if (_removeMemoryOnRelease) _owner.NativeHandle.MemorySequenceRemove(SeqId, 0, -1); + _owner.SequenceManager.Return(SeqId); + } + } + #endregion + /// /// Detokenize the tokens to text. /// @@ -441,6 +519,7 @@ public Task DecodeAsync(LLamaBatchEmbeddings batch, CancellationTo public void Dispose() { NativeHandle.Dispose(); + _seqIdManager?.Dispose(); } /// diff --git a/LLama/LLamaEmbedder.EmbeddingGenerator.cs b/LLama/LLamaEmbedder.EmbeddingGenerator.cs index bce9f8d8b..3960dd227 100644 --- a/LLama/LLamaEmbedder.EmbeddingGenerator.cs +++ b/LLama/LLamaEmbedder.EmbeddingGenerator.cs @@ -3,7 +3,6 @@ using System.Diagnostics; using System.Threading; using System.Threading.Tasks; -using LLama.Native; using Microsoft.Extensions.AI; namespace LLama; @@ -16,25 +15,27 @@ public partial class LLamaEmbedder /// object? IEmbeddingGenerator.GetService(Type serviceType, object? serviceKey) { - if (serviceKey is null) + if (serviceKey is not null) { - if (serviceType == typeof(EmbeddingGeneratorMetadata)) - { - return _metadata ??= new( - nameof(LLamaEmbedder), - defaultModelId: Context.NativeHandle.ModelHandle.ReadMetadata().TryGetValue("general.name", out var name) ? name : null, - defaultModelDimensions: EmbeddingSize); - } + return null; + } + + if (_hasExternalContext && serviceType == typeof(EmbeddingGeneratorMetadata)) + { + return _metadata ??= new( + nameof(LLamaEmbedder), + defaultModelId: Context.NativeHandle.ModelHandle.ReadMetadata().TryGetValue("general.name", out var name) ? name : null, + defaultModelDimensions: EmbeddingSize); + } - if (serviceType?.IsInstanceOfType(Context) is true) - { - return Context; - } + if (_hasExternalContext && serviceType?.IsInstanceOfType(Context) is true) + { + return Context; + } - if (serviceType?.IsInstanceOfType(this) is true) - { - return this; - } + if (serviceType?.IsInstanceOfType(this) is true) + { + return this; } return null; @@ -43,11 +44,6 @@ public partial class LLamaEmbedder /// async Task>> IEmbeddingGenerator>.GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options, CancellationToken cancellationToken) { - if (Context.NativeHandle.PoolingType == LLamaPoolingType.None) - { - throw new NotSupportedException($"Embedding generation is not supported with {nameof(LLamaPoolingType)}.{nameof(LLamaPoolingType.None)}."); - } - GeneratedEmbeddings> results = new() { Usage = new() { InputTokenCount = 0 }, @@ -56,7 +52,7 @@ async Task>> IEmbeddingGenerator(embeddings[0]) { CreatedAt = DateTime.UtcNow }); diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs index eee9a01e9..fb3c0ca1e 100644 --- a/LLama/LLamaEmbedder.cs +++ b/LLama/LLamaEmbedder.cs @@ -1,14 +1,11 @@ using System; using System.Collections.Generic; -using System.Linq; using System.Threading; using System.Threading.Tasks; using LLama.Abstractions; using LLama.Exceptions; using LLama.Native; -using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; -using static System.Net.Mime.MediaTypeNames; namespace LLama; @@ -26,18 +23,26 @@ public sealed partial class LLamaEmbedder /// /// LLama Context /// + /// + /// If the context was not provided externally, the returned context will be in a disposed state. + /// public LLamaContext Context { get; private set; } - private LLamaWeights _weights; - private IContextParams _params; - private ILogger? _logger; + private readonly LLamaWeights? _weights; + private readonly IContextParams _params; + private readonly ILogger? _logger; + private readonly bool _hasExternalContext; /// - /// Create a new embedder, using the given LLamaWeights + /// Create a new embedder, using the given . + /// This will create and dispose a new for each embedding request. + /// If you want to manage the context lifetime yourself, consider using the other constructor that takes a . /// - /// - /// - /// + /// weights to use for generating embeddings. The weights must be for a model that supports embeddings (i.e. it must have an encoder or a decoder, but not both). + /// context parameters to use when creating the context + /// optional logger + /// raised if the provided context has batch size different from ubatch size + /// raised if the provided context is for an encoder-decoder model public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logger = null) { if (@params.UBatchSize != @params.BatchSize) @@ -51,12 +56,39 @@ public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logg _weights = weights; _params = @params; _logger = logger; + _hasExternalContext = false; + } + + /// + /// Creates a new embedder using the given . + /// The caller is responsible for managing the lifetime of the context, and must ensure that the context remains valid + /// for the entire lifetime of this . The context will not be disposed when this embedder is disposed. + /// + /// context to use for generating embeddings. The context must be configured with a model that supports embeddings (i.e. it must have an encoder or a decoder, but not both). + /// optional logger + /// raised if the provided context has batch size different from ubatch size + /// raised if the provided context is for an encoder-decoder model + public LLamaEmbedder(LLamaContext context, ILogger? logger = null) + { + if (context.Params.UBatchSize != context.Params.BatchSize) + throw new ArgumentException("For non-causal models, batch size must be equal to ubatch size", nameof(context)); + + if (context.NativeHandle.ModelHandle is { HasEncoder: true, HasDecoder: true }) + throw new NotSupportedException("Computing embeddings in encoder-decoder models is not supported"); + + Context = context; + EmbeddingSize = Context.EmbeddingSize; + NativeApi.llama_set_embeddings(Context.NativeHandle, true); + _params = context.Params; + _logger = logger; + _hasExternalContext = true; } /// public void Dispose() { - Context.Dispose(); + if (!_hasExternalContext && !Context.NativeHandle.IsClosed) + Context.Dispose(); } /// @@ -72,19 +104,25 @@ public void Dispose() public async Task> GetEmbeddings(string input, CancellationToken cancellationToken = default) => (await GetEmbeddingsWithTokenCount(input, cancellationToken).ConfigureAwait(false)).Embeddings; + private async Task<(IReadOnlyList Embeddings, int Tokens)> GetEmbeddingsWithTokenCount(string input, CancellationToken cancellationToken = default) { - // Ensure the context from last time is disposed (it always should be) - if (!Context.NativeHandle.IsClosed) - Context.Dispose(); + if (!_hasExternalContext) + { + if (!Context.NativeHandle.IsClosed) + Context.Dispose(); - Context = _weights.CreateContext(_params, _logger); - NativeApi.llama_set_embeddings(Context.NativeHandle, true); + Context = _weights!.CreateContext(_params, _logger); + NativeApi.llama_set_embeddings(Context.NativeHandle, true); + } - // Add all of the tokens to the batch + using var seqId = await Context.AcquireSequenceIdAsync(removeMemoryOnRelease: true, cancellationToken: cancellationToken); + // Add all the tokens to the batch var tokens = Context.Tokenize(input, special: true); if (tokens.Length > Context.ContextSize) - throw new ArgumentException($"Embedding prompt is longer than the context window ({tokens.Length} > {Context.ContextSize})", nameof(input)); + throw new ArgumentException( + $"Embedding prompt is longer than the context window ({tokens.Length} > {Context.ContextSize})", + nameof(input)); // Check if we should cancel the work, just before doing anything expensive (encode/decode) cancellationToken.ThrowIfCancellationRequested(); @@ -100,27 +138,27 @@ public async Task> GetEmbeddings(string input, Cancellati n_eval = batchSize; batch.Clear(); - batch.AddRange(tokens.AsSpan(i, n_eval), n_past, LLamaSeqId.Zero, true); + batch.AddRange(tokens.AsSpan(i, n_eval), n_past, seqId, true); n_past += n_eval; // Run model switch (Context.NativeHandle.ModelHandle.HasEncoder, Context.NativeHandle.ModelHandle.HasDecoder) { case (true, false): - { - var result = await Context.EncodeAsync(batch, cancellationToken); - if (result != EncodeResult.Ok) - throw new RuntimeError($"Failed to encode: {result}"); - break; - } + { + var result = await Context.EncodeAsync(batch, cancellationToken); + if (result != EncodeResult.Ok) + throw new RuntimeError($"Failed to encode: {result}"); + break; + } case (false, true): - { - var result = await Context.DecodeAsync(batch, cancellationToken); - if (result != DecodeResult.Ok) - throw new RuntimeError($"Failed to decode: {result}"); - break; - } + { + var result = await Context.DecodeAsync(batch, cancellationToken); + if (result != DecodeResult.Ok) + throw new RuntimeError($"Failed to decode: {result}"); + break; + } default: throw new NotSupportedException("Unsupported model type"); @@ -140,7 +178,7 @@ public async Task> GetEmbeddings(string input, Cancellati } else { - results.Add(Context.NativeHandle.GetEmbeddingsSeq(LLamaSeqId.Zero).ToArray()); + results.Add(Context.NativeHandle.GetEmbeddingsSeq(seqId).ToArray()); } // Normalize the embeddings vector @@ -150,7 +188,8 @@ public async Task> GetEmbeddings(string input, Cancellati embedding.EuclideanNormalization(); } - Context.Dispose(); + if (!_hasExternalContext) + Context.Dispose(); return (results, tokens.Length); } diff --git a/LLama/Native/LLamaSeqIdManager.cs b/LLama/Native/LLamaSeqIdManager.cs new file mode 100644 index 000000000..4cb7c27b5 --- /dev/null +++ b/LLama/Native/LLamaSeqIdManager.cs @@ -0,0 +1,82 @@ +using System; +using System.Collections.Concurrent; +using System.Threading; +using System.Threading.Tasks; + +namespace LLama.Native; + +/// +/// Provides management for LLama models. +/// Based on the provided max sequence count, it allocates and recycles for use in model operations. +/// +/// +/// The class is thread-safe and allows multiple concurrent requests for sequence IDs. +/// +public sealed class LLamaSeqIdManager : IDisposable +{ + private readonly SemaphoreSlim _semaphore; + private readonly ConcurrentBag _availableIds; + + /// + /// Constructs a new with the specified maximum sequence count. + /// + /// maximum number of sequence IDs to manage. + public LLamaSeqIdManager(uint maxSeqCount) + { + var max = Math.Max((int)maxSeqCount, 1); // Ensure at least one sequence ID is available + _semaphore = new SemaphoreSlim(initialCount: max, maxCount: max); + _availableIds = []; + for (var i = 0; i < max; i++) + { + _availableIds.Add(i); + } + } + + /// + /// Returns the next available sequence ID. + /// Callers will asynchronously wait if none are available. + /// + /// optional timeout for acquiring a sequence ID. If null, waits indefinitely. + /// cancellation token to cancel the wait operation. + /// The next available sequence ID. + public async Task NextAsync(TimeSpan? timeout = null, CancellationToken cancellationToken = default) + { + await _semaphore.WaitAsync(timeout ?? TimeSpan.FromMilliseconds(-1), cancellationToken).ConfigureAwait(false); + + try + { + cancellationToken.ThrowIfCancellationRequested(); + + if (_availableIds.TryTake(out var seqId)) + { + return (LLamaSeqId)seqId; + } + + throw new InvalidOperationException("No sequence ID available despite semaphore release"); + } + catch + { + _semaphore.Release(); + throw; + } + } + + /// + /// Returns a sequence ID to the manager, making it available for reuse. + /// + /// + /// It's the caller's responsibility to ensure the sequence ID is in a valid state for reuse. + /// + /// sequence ID to return. + public void Return(LLamaSeqId seqId) + { + _availableIds.Add(seqId.Value); + _semaphore.Release(); + } + + /// + public void Dispose() + { + _semaphore.Dispose(); + } +} \ No newline at end of file