From cefb091d6cac4d50af1e232e382a917442bc343a Mon Sep 17 00:00:00 2001 From: bmazzarol-bunnings Date: Mon, 29 Sep 2025 22:34:26 +0800 Subject: [PATCH 1/5] fix: Allow externally managed contexts with LLamaEmbedder Fixes #1259 and potentially #1247 with changes to how the caller manages the LLamaEmbedder. --- LLama.Unittest/LLamaEmbedderTests.cs | 71 +++++++++++------------ LLama/LLamaEmbedder.EmbeddingGenerator.cs | 42 ++++++-------- LLama/LLamaEmbedder.cs | 70 ++++++++++++++++------ 3 files changed, 106 insertions(+), 77 deletions(-) diff --git a/LLama.Unittest/LLamaEmbedderTests.cs b/LLama.Unittest/LLamaEmbedderTests.cs index 7d7654126..5c01984b0 100644 --- a/LLama.Unittest/LLamaEmbedderTests.cs +++ b/LLama.Unittest/LLamaEmbedderTests.cs @@ -41,43 +41,40 @@ private async Task CompareEmbeddings(string modelPath) var spoon = (await embedder.GetEmbeddings("The spoon is not real")).Single().EuclideanNormalization(); Assert.DoesNotContain(float.NaN, spoon); - - if (false) - { - //TODO: the below does not work with the new memory efficient context handling - we probably need to define Microsoft.Extensions.AI.IEmbeddingGenerator GetService interface that creates the context on the fly - - var generator = (IEmbeddingGenerator>)embedder; - Assert.NotNull(generator.GetService()); - Assert.Equal(nameof(LLamaEmbedder), generator.GetService()?.ProviderName); - Assert.NotNull(generator.GetService()?.DefaultModelId); - Assert.NotEmpty(generator.GetService()?.DefaultModelId!); - Assert.Same(embedder, generator.GetService()); - Assert.Same(generator, generator.GetService>>()); - Assert.Null(generator.GetService()); - - var embeddings = await generator.GenerateAsync( - [ - "The cat is cute", - "The kitten is cute", - "The spoon is not real" - ]); - Assert.All(cat.Zip(embeddings[0].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001)); - Assert.All(kitten.Zip(embeddings[1].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001)); - Assert.All(spoon.Zip(embeddings[2].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001)); - - _testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]"); - _testOutputHelper.WriteLine($"Kitten = [{string.Join(",", kitten.AsMemory().Slice(0, 7).ToArray())}...]"); - _testOutputHelper.WriteLine($"Spoon = [{string.Join(",", spoon.AsMemory().Slice(0, 7).ToArray())}...]"); - - var close = 1 - Dot(cat, kitten); - var far = 1 - Dot(cat, spoon); - - _testOutputHelper.WriteLine(""); - _testOutputHelper.WriteLine($"Cat.Kitten (Close): {close:F4}"); - _testOutputHelper.WriteLine($"Cat.Spoon (Far): {far:F4}"); - - Assert.True(close < far); - } + + using var context = new LLamaContext(weights, @params); + var managedEmbedder = new LLamaEmbedder(context); + IEmbeddingGenerator> generator = managedEmbedder; + Assert.NotNull(generator.GetService()); + Assert.Equal(nameof(LLamaEmbedder), generator.GetService()?.ProviderName); + Assert.NotNull(generator.GetService()?.DefaultModelId); + Assert.NotEmpty(generator.GetService()?.DefaultModelId!); + Assert.Same(managedEmbedder, generator.GetService()); + Assert.Same(generator, generator.GetService>>()); + Assert.Null(generator.GetService()); + + var embeddings = await generator.GenerateAsync( + [ + "The cat is cute", + "The kitten is cute", + "The spoon is not real" + ]); + Assert.All(cat.Zip(embeddings[0].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001)); + Assert.All(kitten.Zip(embeddings[1].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001)); + Assert.All(spoon.Zip(embeddings[2].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001)); + + _testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]"); + _testOutputHelper.WriteLine($"Kitten = [{string.Join(",", kitten.AsMemory().Slice(0, 7).ToArray())}...]"); + _testOutputHelper.WriteLine($"Spoon = [{string.Join(",", spoon.AsMemory().Slice(0, 7).ToArray())}...]"); + + var close = 1 - Dot(cat, kitten); + var far = 1 - Dot(cat, spoon); + + _testOutputHelper.WriteLine(""); + _testOutputHelper.WriteLine($"Cat.Kitten (Close): {close:F4}"); + _testOutputHelper.WriteLine($"Cat.Spoon (Far): {far:F4}"); + + Assert.True(close < far); } [Fact] diff --git a/LLama/LLamaEmbedder.EmbeddingGenerator.cs b/LLama/LLamaEmbedder.EmbeddingGenerator.cs index bce9f8d8b..3960dd227 100644 --- a/LLama/LLamaEmbedder.EmbeddingGenerator.cs +++ b/LLama/LLamaEmbedder.EmbeddingGenerator.cs @@ -3,7 +3,6 @@ using System.Diagnostics; using System.Threading; using System.Threading.Tasks; -using LLama.Native; using Microsoft.Extensions.AI; namespace LLama; @@ -16,25 +15,27 @@ public partial class LLamaEmbedder /// object? IEmbeddingGenerator.GetService(Type serviceType, object? serviceKey) { - if (serviceKey is null) + if (serviceKey is not null) { - if (serviceType == typeof(EmbeddingGeneratorMetadata)) - { - return _metadata ??= new( - nameof(LLamaEmbedder), - defaultModelId: Context.NativeHandle.ModelHandle.ReadMetadata().TryGetValue("general.name", out var name) ? name : null, - defaultModelDimensions: EmbeddingSize); - } + return null; + } + + if (_hasExternalContext && serviceType == typeof(EmbeddingGeneratorMetadata)) + { + return _metadata ??= new( + nameof(LLamaEmbedder), + defaultModelId: Context.NativeHandle.ModelHandle.ReadMetadata().TryGetValue("general.name", out var name) ? name : null, + defaultModelDimensions: EmbeddingSize); + } - if (serviceType?.IsInstanceOfType(Context) is true) - { - return Context; - } + if (_hasExternalContext && serviceType?.IsInstanceOfType(Context) is true) + { + return Context; + } - if (serviceType?.IsInstanceOfType(this) is true) - { - return this; - } + if (serviceType?.IsInstanceOfType(this) is true) + { + return this; } return null; @@ -43,11 +44,6 @@ public partial class LLamaEmbedder /// async Task>> IEmbeddingGenerator>.GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options, CancellationToken cancellationToken) { - if (Context.NativeHandle.PoolingType == LLamaPoolingType.None) - { - throw new NotSupportedException($"Embedding generation is not supported with {nameof(LLamaPoolingType)}.{nameof(LLamaPoolingType.None)}."); - } - GeneratedEmbeddings> results = new() { Usage = new() { InputTokenCount = 0 }, @@ -56,7 +52,7 @@ async Task>> IEmbeddingGenerator(embeddings[0]) { CreatedAt = DateTime.UtcNow }); diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs index eee9a01e9..e831a1724 100644 --- a/LLama/LLamaEmbedder.cs +++ b/LLama/LLamaEmbedder.cs @@ -1,14 +1,11 @@ using System; using System.Collections.Generic; -using System.Linq; using System.Threading; using System.Threading.Tasks; using LLama.Abstractions; using LLama.Exceptions; using LLama.Native; -using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; -using static System.Net.Mime.MediaTypeNames; namespace LLama; @@ -26,18 +23,26 @@ public sealed partial class LLamaEmbedder /// /// LLama Context /// + /// + /// If the context was not provided externally, the returned context will be in a disposed state. + /// public LLamaContext Context { get; private set; } - private LLamaWeights _weights; - private IContextParams _params; - private ILogger? _logger; + private readonly LLamaWeights? _weights; + private readonly IContextParams _params; + private readonly ILogger? _logger; + private readonly bool _hasExternalContext; /// - /// Create a new embedder, using the given LLamaWeights + /// Create a new embedder, using the given . + /// This will create and dispose a new for each embedding request. + /// If you want to manage the context lifetime yourself, consider using the other constructor that takes a . /// - /// - /// - /// + /// weights to use for generating embeddings. The weights must be for a model that supports embeddings (i.e. it must have an encoder or a decoder, but not both). + /// context parameters to use when creating the context + /// optional logger + /// raised if the provided context has batch size different from ubatch size + /// raised if the provided context is for an encoder-decoder model public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logger = null) { if (@params.UBatchSize != @params.BatchSize) @@ -51,12 +56,39 @@ public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logg _weights = weights; _params = @params; _logger = logger; + _hasExternalContext = false; + } + + /// + /// Creates a new embedder using the given . + /// The caller is responsible for managing the lifetime of the context, and must ensure that the context remains valid + /// for the entire lifetime of this . The context will not be disposed when this embedder is disposed. + /// + /// context to use for generating embeddings. The context must be configured with a model that supports embeddings (i.e. it must have an encoder or a decoder, but not both). + /// optional logger + /// raised if the provided context has batch size different from ubatch size + /// raised if the provided context is for an encoder-decoder model + public LLamaEmbedder(LLamaContext context, ILogger? logger = null) + { + if(context.Params.UBatchSize != context.Params.BatchSize) + throw new ArgumentException("For non-causal models, batch size must be equal to ubatch size", nameof(context)); + + if (context.NativeHandle.ModelHandle is { HasEncoder: true, HasDecoder: true }) + throw new NotSupportedException("Computing embeddings in encoder-decoder models is not supported"); + + Context = context; + EmbeddingSize = Context.EmbeddingSize; + NativeApi.llama_set_embeddings(Context.NativeHandle, true); + _params = context.Params; + _logger = logger; + _hasExternalContext = true; } /// public void Dispose() { - Context.Dispose(); + if(!_hasExternalContext && !Context.NativeHandle.IsClosed) + Context.Dispose(); } /// @@ -72,14 +104,17 @@ public void Dispose() public async Task> GetEmbeddings(string input, CancellationToken cancellationToken = default) => (await GetEmbeddingsWithTokenCount(input, cancellationToken).ConfigureAwait(false)).Embeddings; + private async Task<(IReadOnlyList Embeddings, int Tokens)> GetEmbeddingsWithTokenCount(string input, CancellationToken cancellationToken = default) { - // Ensure the context from last time is disposed (it always should be) - if (!Context.NativeHandle.IsClosed) - Context.Dispose(); + if (!_hasExternalContext) + { + if (!Context.NativeHandle.IsClosed) + Context.Dispose(); - Context = _weights.CreateContext(_params, _logger); - NativeApi.llama_set_embeddings(Context.NativeHandle, true); + Context = _weights!.CreateContext(_params, _logger); + NativeApi.llama_set_embeddings(Context.NativeHandle, true); + } // Add all of the tokens to the batch var tokens = Context.Tokenize(input, special: true); @@ -150,7 +185,8 @@ public async Task> GetEmbeddings(string input, Cancellati embedding.EuclideanNormalization(); } - Context.Dispose(); + if (!_hasExternalContext) + Context.Dispose(); return (results, tokens.Length); } From 38b8d8be31445d28800ded547e3c1a83dc28e7de Mon Sep 17 00:00:00 2001 From: bmazzarol-bunnings Date: Tue, 4 Nov 2025 22:46:15 +0800 Subject: [PATCH 2/5] feat: Implement LLamaSeqIdManager for sequence ID management in LLama models --- LLama/LLamaEmbedder.cs | 120 +++++++++++++++++------------- LLama/Native/LLamaSeqIdManager.cs | 68 +++++++++++++++++ 2 files changed, 137 insertions(+), 51 deletions(-) create mode 100644 LLama/Native/LLamaSeqIdManager.cs diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs index e831a1724..9dbd9bb48 100644 --- a/LLama/LLamaEmbedder.cs +++ b/LLama/LLamaEmbedder.cs @@ -32,6 +32,7 @@ public sealed partial class LLamaEmbedder private readonly IContextParams _params; private readonly ILogger? _logger; private readonly bool _hasExternalContext; + private readonly LLamaSeqIdManager? _lamaSeqIdManager; /// /// Create a new embedder, using the given . @@ -57,6 +58,7 @@ public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logg _params = @params; _logger = logger; _hasExternalContext = false; + _lamaSeqIdManager = null; } /// @@ -82,6 +84,7 @@ public LLamaEmbedder(LLamaContext context, ILogger? logger = null) _params = context.Params; _logger = logger; _hasExternalContext = true; + _lamaSeqIdManager = new LLamaSeqIdManager(context.Params.SeqMax); } /// @@ -89,6 +92,7 @@ public void Dispose() { if(!_hasExternalContext && !Context.NativeHandle.IsClosed) Context.Dispose(); + _lamaSeqIdManager?.Dispose(); } /// @@ -116,32 +120,37 @@ public async Task> GetEmbeddings(string input, Cancellati NativeApi.llama_set_embeddings(Context.NativeHandle, true); } - // Add all of the tokens to the batch - var tokens = Context.Tokenize(input, special: true); - if (tokens.Length > Context.ContextSize) - throw new ArgumentException($"Embedding prompt is longer than the context window ({tokens.Length} > {Context.ContextSize})", nameof(input)); - - // Check if we should cancel the work, just before doing anything expensive (encode/decode) - cancellationToken.ThrowIfCancellationRequested(); - - // Evaluate prompt in batch-size chunks - var n_past = 0; - var batch = new LLamaBatch(); - var batchSize = (int)Context.Params.BatchSize; - for (var i = 0; i < tokens.Length; i += batchSize) + var seqId = _lamaSeqIdManager is not null ? await _lamaSeqIdManager.Next() : LLamaSeqId.Zero; + try { - var n_eval = tokens.Length - i; - if (n_eval > batchSize) - n_eval = batchSize; - - batch.Clear(); - batch.AddRange(tokens.AsSpan(i, n_eval), n_past, LLamaSeqId.Zero, true); - n_past += n_eval; - - // Run model - switch (Context.NativeHandle.ModelHandle.HasEncoder, Context.NativeHandle.ModelHandle.HasDecoder) + // Add all the tokens to the batch + var tokens = Context.Tokenize(input, special: true); + if (tokens.Length > Context.ContextSize) + throw new ArgumentException( + $"Embedding prompt is longer than the context window ({tokens.Length} > {Context.ContextSize})", + nameof(input)); + + // Check if we should cancel the work, just before doing anything expensive (encode/decode) + cancellationToken.ThrowIfCancellationRequested(); + + // Evaluate prompt in batch-size chunks + var n_past = 0; + var batch = new LLamaBatch(); + var batchSize = (int)Context.Params.BatchSize; + for (var i = 0; i < tokens.Length; i += batchSize) { - case (true, false): + var n_eval = tokens.Length - i; + if (n_eval > batchSize) + n_eval = batchSize; + + batch.Clear(); + batch.AddRange(tokens.AsSpan(i, n_eval), n_past, seqId, true); + n_past += n_eval; + + // Run model + switch (Context.NativeHandle.ModelHandle.HasEncoder, Context.NativeHandle.ModelHandle.HasDecoder) + { + case (true, false): { var result = await Context.EncodeAsync(batch, cancellationToken); if (result != EncodeResult.Ok) @@ -149,7 +158,7 @@ public async Task> GetEmbeddings(string input, Cancellati break; } - case (false, true): + case (false, true): { var result = await Context.DecodeAsync(batch, cancellationToken); if (result != DecodeResult.Ok) @@ -157,37 +166,46 @@ public async Task> GetEmbeddings(string input, Cancellati break; } - default: - throw new NotSupportedException("Unsupported model type"); + default: + throw new NotSupportedException("Unsupported model type"); + } } - } - // Extract results - var poolingType = Context.NativeHandle.PoolingType; - var resultsCount = poolingType == LLamaPoolingType.None ? tokens.Length : 1; - var results = new List(resultsCount); + // Extract results + var poolingType = Context.NativeHandle.PoolingType; + var resultsCount = poolingType == LLamaPoolingType.None ? tokens.Length : 1; + var results = new List(resultsCount); - if (poolingType == LLamaPoolingType.None) - { - var positions = batch.GetLogitPositions(); - foreach (var (_, pos) in positions) - results.Add(Context.NativeHandle.GetEmbeddingsIth(pos).ToArray()); - } - else - { - results.Add(Context.NativeHandle.GetEmbeddingsSeq(LLamaSeqId.Zero).ToArray()); - } + if (poolingType == LLamaPoolingType.None) + { + var positions = batch.GetLogitPositions(); + foreach (var (_, pos) in positions) + results.Add(Context.NativeHandle.GetEmbeddingsIth(pos).ToArray()); + } + else + { + results.Add(Context.NativeHandle.GetEmbeddingsSeq(seqId).ToArray()); + } - // Normalize the embeddings vector - // https://github.com/ggerganov/llama.cpp/blob/2891c8aa9af17f4ff636ff3868bc34ff72b56e25/examples/embedding/embedding.cpp#L92 - foreach (var embedding in results) - { - embedding.EuclideanNormalization(); - } + // Normalize the embeddings vector + // https://github.com/ggerganov/llama.cpp/blob/2891c8aa9af17f4ff636ff3868bc34ff72b56e25/examples/embedding/embedding.cpp#L92 + foreach (var embedding in results) + { + embedding.EuclideanNormalization(); + } - if (!_hasExternalContext) - Context.Dispose(); + if (!_hasExternalContext) + Context.Dispose(); - return (results, tokens.Length); + return (results, tokens.Length); + } + finally + { + if (_lamaSeqIdManager != null) + { + Context.NativeHandle.MemorySequenceRemove(seqId,0,-1); + _lamaSeqIdManager.Return(seqId); + } + } } } \ No newline at end of file diff --git a/LLama/Native/LLamaSeqIdManager.cs b/LLama/Native/LLamaSeqIdManager.cs new file mode 100644 index 000000000..f16315fda --- /dev/null +++ b/LLama/Native/LLamaSeqIdManager.cs @@ -0,0 +1,68 @@ +using System; +using System.Collections.Concurrent; +using System.Threading; +using System.Threading.Tasks; + +namespace LLama.Native; + +/// +/// Provides management for LLama models. +/// Based on the provided max sequence count, it allocates and recycles for use in model operations. +/// +/// +/// The class is thread-safe and allows multiple concurrent requests for sequence IDs. +/// +public sealed class LLamaSeqIdManager : IDisposable +{ + private readonly SemaphoreSlim _semaphore; + private readonly ConcurrentBag _availableIds; + + /// + /// Constructs a new with the specified maximum sequence count. + /// + /// maximum number of sequence IDs to manage. + public LLamaSeqIdManager(uint maxSeqCount) + { + _semaphore = new SemaphoreSlim((int)maxSeqCount, (int)maxSeqCount); + _availableIds = []; + for (var i = 0; i < maxSeqCount; i++) + { + _availableIds.Add(i); + } + } + + /// + /// Returns the next available sequence ID. + /// Callers will asynchronously wait if none are available. + /// + /// >The next available sequence ID. + public async Task Next() + { + await _semaphore.WaitAsync(); + if (_availableIds.TryTake(out var seqId)) + { + return (LLamaSeqId)seqId; + } + + throw new InvalidOperationException("No sequence ID available despite semaphore release"); + } + + /// + /// Returns a sequence ID to the manager, making it available for reuse. + /// + /// + /// It's the caller's responsibility to ensure the sequence ID is in a valid state for reuse. + /// + /// sequence ID to return. + public void Return(LLamaSeqId seqId) + { + _availableIds.Add(seqId.Value); + _semaphore.Release(); + } + + /// + public void Dispose() + { + _semaphore.Dispose(); + } +} \ No newline at end of file From 2fb0355e5cfc9b84ee222f196cdbfdcc2e7c7950 Mon Sep 17 00:00:00 2001 From: bmazzarol-bunnings Date: Wed, 5 Nov 2025 22:44:47 +0800 Subject: [PATCH 3/5] feat: Enhance LLamaSeqIdManager and LLamaContext for improved sequence ID management --- LLama/LLamaContext.cs | 79 ++++++++++++++++ LLama/LLamaEmbedder.cs | 147 ++++++++++++++---------------- LLama/Native/LLamaSeqIdManager.cs | 13 ++- 3 files changed, 155 insertions(+), 84 deletions(-) diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 42d76c514..7d6fa6e1a 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -109,6 +109,84 @@ public LLamaToken[] Tokenize(string text, bool addBos = true, bool special = fal return NativeHandle.Tokenize(text, addBos, special, Encoding); } + #region Sequence ID management + private LLamaSeqIdManager? _seqIdManager; + + /// + /// Get the sequence ID manager for this context. + /// + public LLamaSeqIdManager SequenceManager + { + get + { + var manager = _seqIdManager; + if (manager != null) return manager; + var newManager = new LLamaSeqIdManager(Params.SeqMax); + var original = Interlocked.CompareExchange(ref _seqIdManager, newManager, comparand: null); + manager = original ?? newManager; + return manager; + } + } + + /// + /// Returns the next available sequence ID for use in model operations. + /// Callers will asynchronously wait if none are available. + /// On disposal, the sequence ID is returned to the owning for reuse. + /// + /// + /// Failure to dispose the returned will likely result in undefined behavior. + /// + /// + /// The returned sequence represents an exclusive reservation on the sequence ID within the context. + /// For the duration of the , no other caller will receive the same sequence ID from this context. + /// + /// flag indicating whether to remove memory associated with the sequence ID when it is released back to the manager. + /// optional timeout for acquiring a sequence ID. If null, waits indefinitely. + /// cancellation token to cancel the wait operation. + /// The next available sequence ID. + public async Task AcquireSequenceIdAsync(bool removeMemoryOnRelease = false, TimeSpan? timeout = null, CancellationToken cancellationToken = default) + { + var seqId = await SequenceManager.NextAsync(timeout, cancellationToken).ConfigureAwait(false); + return new ManagedLLamaSeqId(owner: this, seqId, removeMemoryOnRelease); + } + + /// + /// Represents a managed that is returned to the owning when disposed. + /// + public readonly struct ManagedLLamaSeqId : IDisposable + { + private readonly LLamaContext? _owner; + private readonly bool _removeMemoryOnRelease; + + /// + /// The sequence ID. + /// + public LLamaSeqId SeqId { get; } + + /// + /// Implicit conversion to . + /// + /// managed sequence ID. + /// the underlying sequence ID. + public static implicit operator LLamaSeqId(ManagedLLamaSeqId managedSeqId) => managedSeqId.SeqId; + + internal ManagedLLamaSeqId(LLamaContext owner, LLamaSeqId seqId, bool removeMemoryOnRelease) + { + _owner = owner; + SeqId = seqId; + _removeMemoryOnRelease = removeMemoryOnRelease; + } + + /// + public void Dispose() + { + if (_owner == null || _owner.NativeHandle.IsClosed) return; + if (_removeMemoryOnRelease) _owner.NativeHandle.MemorySequenceRemove(SeqId, 0, -1); + _owner.SequenceManager.Return(SeqId); + } + } + #endregion + /// /// Detokenize the tokens to text. /// @@ -441,6 +519,7 @@ public Task DecodeAsync(LLamaBatchEmbeddings batch, CancellationTo public void Dispose() { NativeHandle.Dispose(); + _seqIdManager?.Dispose(); } /// diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs index 9dbd9bb48..510bbca4d 100644 --- a/LLama/LLamaEmbedder.cs +++ b/LLama/LLamaEmbedder.cs @@ -60,7 +60,7 @@ public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logg _hasExternalContext = false; _lamaSeqIdManager = null; } - + /// /// Creates a new embedder using the given . /// The caller is responsible for managing the lifetime of the context, and must ensure that the context remains valid @@ -72,12 +72,12 @@ public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logg /// raised if the provided context is for an encoder-decoder model public LLamaEmbedder(LLamaContext context, ILogger? logger = null) { - if(context.Params.UBatchSize != context.Params.BatchSize) + if (context.Params.UBatchSize != context.Params.BatchSize) throw new ArgumentException("For non-causal models, batch size must be equal to ubatch size", nameof(context)); - + if (context.NativeHandle.ModelHandle is { HasEncoder: true, HasDecoder: true }) throw new NotSupportedException("Computing embeddings in encoder-decoder models is not supported"); - + Context = context; EmbeddingSize = Context.EmbeddingSize; NativeApi.llama_set_embeddings(Context.NativeHandle, true); @@ -90,7 +90,7 @@ public LLamaEmbedder(LLamaContext context, ILogger? logger = null) /// public void Dispose() { - if(!_hasExternalContext && !Context.NativeHandle.IsClosed) + if (!_hasExternalContext && !Context.NativeHandle.IsClosed) Context.Dispose(); _lamaSeqIdManager?.Dispose(); } @@ -120,92 +120,81 @@ public async Task> GetEmbeddings(string input, Cancellati NativeApi.llama_set_embeddings(Context.NativeHandle, true); } - var seqId = _lamaSeqIdManager is not null ? await _lamaSeqIdManager.Next() : LLamaSeqId.Zero; - try + using var seqId = await Context.AcquireSequenceIdAsync(removeMemoryOnRelease: true, cancellationToken: cancellationToken); + // Add all the tokens to the batch + var tokens = Context.Tokenize(input, special: true); + if (tokens.Length > Context.ContextSize) + throw new ArgumentException( + $"Embedding prompt is longer than the context window ({tokens.Length} > {Context.ContextSize})", + nameof(input)); + + // Check if we should cancel the work, just before doing anything expensive (encode/decode) + cancellationToken.ThrowIfCancellationRequested(); + + // Evaluate prompt in batch-size chunks + var n_past = 0; + var batch = new LLamaBatch(); + var batchSize = (int)Context.Params.BatchSize; + for (var i = 0; i < tokens.Length; i += batchSize) { - // Add all the tokens to the batch - var tokens = Context.Tokenize(input, special: true); - if (tokens.Length > Context.ContextSize) - throw new ArgumentException( - $"Embedding prompt is longer than the context window ({tokens.Length} > {Context.ContextSize})", - nameof(input)); - - // Check if we should cancel the work, just before doing anything expensive (encode/decode) - cancellationToken.ThrowIfCancellationRequested(); - - // Evaluate prompt in batch-size chunks - var n_past = 0; - var batch = new LLamaBatch(); - var batchSize = (int)Context.Params.BatchSize; - for (var i = 0; i < tokens.Length; i += batchSize) - { - var n_eval = tokens.Length - i; - if (n_eval > batchSize) - n_eval = batchSize; + var n_eval = tokens.Length - i; + if (n_eval > batchSize) + n_eval = batchSize; - batch.Clear(); - batch.AddRange(tokens.AsSpan(i, n_eval), n_past, seqId, true); - n_past += n_eval; + batch.Clear(); + batch.AddRange(tokens.AsSpan(i, n_eval), n_past, seqId, true); + n_past += n_eval; - // Run model - switch (Context.NativeHandle.ModelHandle.HasEncoder, Context.NativeHandle.ModelHandle.HasDecoder) + // Run model + switch (Context.NativeHandle.ModelHandle.HasEncoder, Context.NativeHandle.ModelHandle.HasDecoder) + { + case (true, false): { - case (true, false): - { - var result = await Context.EncodeAsync(batch, cancellationToken); - if (result != EncodeResult.Ok) - throw new RuntimeError($"Failed to encode: {result}"); - break; - } - - case (false, true): - { - var result = await Context.DecodeAsync(batch, cancellationToken); - if (result != DecodeResult.Ok) - throw new RuntimeError($"Failed to decode: {result}"); - break; - } - - default: - throw new NotSupportedException("Unsupported model type"); + var result = await Context.EncodeAsync(batch, cancellationToken); + if (result != EncodeResult.Ok) + throw new RuntimeError($"Failed to encode: {result}"); + break; } - } - - // Extract results - var poolingType = Context.NativeHandle.PoolingType; - var resultsCount = poolingType == LLamaPoolingType.None ? tokens.Length : 1; - var results = new List(resultsCount); - if (poolingType == LLamaPoolingType.None) - { - var positions = batch.GetLogitPositions(); - foreach (var (_, pos) in positions) - results.Add(Context.NativeHandle.GetEmbeddingsIth(pos).ToArray()); - } - else - { - results.Add(Context.NativeHandle.GetEmbeddingsSeq(seqId).ToArray()); - } + case (false, true): + { + var result = await Context.DecodeAsync(batch, cancellationToken); + if (result != DecodeResult.Ok) + throw new RuntimeError($"Failed to decode: {result}"); + break; + } - // Normalize the embeddings vector - // https://github.com/ggerganov/llama.cpp/blob/2891c8aa9af17f4ff636ff3868bc34ff72b56e25/examples/embedding/embedding.cpp#L92 - foreach (var embedding in results) - { - embedding.EuclideanNormalization(); + default: + throw new NotSupportedException("Unsupported model type"); } + } - if (!_hasExternalContext) - Context.Dispose(); + // Extract results + var poolingType = Context.NativeHandle.PoolingType; + var resultsCount = poolingType == LLamaPoolingType.None ? tokens.Length : 1; + var results = new List(resultsCount); - return (results, tokens.Length); + if (poolingType == LLamaPoolingType.None) + { + var positions = batch.GetLogitPositions(); + foreach (var (_, pos) in positions) + results.Add(Context.NativeHandle.GetEmbeddingsIth(pos).ToArray()); } - finally + else { - if (_lamaSeqIdManager != null) - { - Context.NativeHandle.MemorySequenceRemove(seqId,0,-1); - _lamaSeqIdManager.Return(seqId); - } + results.Add(Context.NativeHandle.GetEmbeddingsSeq(seqId).ToArray()); } + + // Normalize the embeddings vector + // https://github.com/ggerganov/llama.cpp/blob/2891c8aa9af17f4ff636ff3868bc34ff72b56e25/examples/embedding/embedding.cpp#L92 + foreach (var embedding in results) + { + embedding.EuclideanNormalization(); + } + + if (!_hasExternalContext) + Context.Dispose(); + + return (results, tokens.Length); } } \ No newline at end of file diff --git a/LLama/Native/LLamaSeqIdManager.cs b/LLama/Native/LLamaSeqIdManager.cs index f16315fda..4fb1d071d 100644 --- a/LLama/Native/LLamaSeqIdManager.cs +++ b/LLama/Native/LLamaSeqIdManager.cs @@ -23,9 +23,10 @@ public sealed class LLamaSeqIdManager : IDisposable /// maximum number of sequence IDs to manage. public LLamaSeqIdManager(uint maxSeqCount) { - _semaphore = new SemaphoreSlim((int)maxSeqCount, (int)maxSeqCount); + var max = Math.Max((int)maxSeqCount, 1); // Ensure at least one sequence ID is available + _semaphore = new SemaphoreSlim(initialCount: max, maxCount: max); _availableIds = []; - for (var i = 0; i < maxSeqCount; i++) + for (var i = 0; i < max; i++) { _availableIds.Add(i); } @@ -35,10 +36,12 @@ public LLamaSeqIdManager(uint maxSeqCount) /// Returns the next available sequence ID. /// Callers will asynchronously wait if none are available. /// - /// >The next available sequence ID. - public async Task Next() + /// optional timeout for acquiring a sequence ID. If null, waits indefinitely. + /// cancellation token to cancel the wait operation. + /// The next available sequence ID. + public async Task NextAsync(TimeSpan? timeout = null, CancellationToken cancellationToken = default) { - await _semaphore.WaitAsync(); + await _semaphore.WaitAsync(timeout ?? TimeSpan.FromMilliseconds(-1), cancellationToken).ConfigureAwait(false); if (_availableIds.TryTake(out var seqId)) { return (LLamaSeqId)seqId; From b1585a730917fb18c510c2314337d5f711142e04 Mon Sep 17 00:00:00 2001 From: bmazzarol-bunnings Date: Wed, 5 Nov 2025 23:14:50 +0800 Subject: [PATCH 4/5] fix: Remove unused LLamaSeqIdManager from LLamaEmbedder --- LLama/LLamaEmbedder.cs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs index 510bbca4d..fb3c0ca1e 100644 --- a/LLama/LLamaEmbedder.cs +++ b/LLama/LLamaEmbedder.cs @@ -32,7 +32,6 @@ public sealed partial class LLamaEmbedder private readonly IContextParams _params; private readonly ILogger? _logger; private readonly bool _hasExternalContext; - private readonly LLamaSeqIdManager? _lamaSeqIdManager; /// /// Create a new embedder, using the given . @@ -58,7 +57,6 @@ public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logg _params = @params; _logger = logger; _hasExternalContext = false; - _lamaSeqIdManager = null; } /// @@ -84,7 +82,6 @@ public LLamaEmbedder(LLamaContext context, ILogger? logger = null) _params = context.Params; _logger = logger; _hasExternalContext = true; - _lamaSeqIdManager = new LLamaSeqIdManager(context.Params.SeqMax); } /// @@ -92,7 +89,6 @@ public void Dispose() { if (!_hasExternalContext && !Context.NativeHandle.IsClosed) Context.Dispose(); - _lamaSeqIdManager?.Dispose(); } /// From ef52ebb6ae858f9b080574e6cfce4ddc67f42cc6 Mon Sep 17 00:00:00 2001 From: bmazzarol-bunnings Date: Wed, 5 Nov 2025 23:24:48 +0800 Subject: [PATCH 5/5] fix: Improve error handling in NextAsync method of LLamaSeqIdManager --- LLama/Native/LLamaSeqIdManager.cs | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/LLama/Native/LLamaSeqIdManager.cs b/LLama/Native/LLamaSeqIdManager.cs index 4fb1d071d..4cb7c27b5 100644 --- a/LLama/Native/LLamaSeqIdManager.cs +++ b/LLama/Native/LLamaSeqIdManager.cs @@ -42,12 +42,23 @@ public LLamaSeqIdManager(uint maxSeqCount) public async Task NextAsync(TimeSpan? timeout = null, CancellationToken cancellationToken = default) { await _semaphore.WaitAsync(timeout ?? TimeSpan.FromMilliseconds(-1), cancellationToken).ConfigureAwait(false); - if (_availableIds.TryTake(out var seqId)) + + try { - return (LLamaSeqId)seqId; - } + cancellationToken.ThrowIfCancellationRequested(); + + if (_availableIds.TryTake(out var seqId)) + { + return (LLamaSeqId)seqId; + } - throw new InvalidOperationException("No sequence ID available despite semaphore release"); + throw new InvalidOperationException("No sequence ID available despite semaphore release"); + } + catch + { + _semaphore.Release(); + throw; + } } ///