Skip to content
5 changes: 2 additions & 3 deletions LLama.Unittest/LLamaContextTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@ public LLamaContextTests()
{
var @params = new ModelParams(Constants.GenerativeModelPath2)
{
ContextSize = 128,
ContextSize = 512,
BatchSize = 8,
UBatchSize = 8,
SeqMax = 1,
VocabOnly = false,
GpuLayerCount = Constants.CIGpuLayerCount,
};
Expand All @@ -33,7 +32,7 @@ public void Dispose()
[Fact]
public void CheckProperties()
{
Assert.Equal(128u, _context.ContextSize);
Assert.Equal(512u, _context.ContextSize);
Assert.Equal(960, _context.EmbeddingSize);
Assert.Equal(49152, _context.Vocab.Count);
}
Expand Down
4 changes: 2 additions & 2 deletions LLama.Unittest/LLamaContextWithCustomLoggerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public LLamaContextWithCustomLoggerTests()
{
var @params = new ModelParams(Constants.GenerativeModelPath2)
{
ContextSize = 128,
ContextSize = 512,
GpuLayerCount = Constants.CIGpuLayerCount,
};

Expand All @@ -55,7 +55,7 @@ public void Dispose()
[Fact]
public void CheckProperties()
{
Assert.Equal(128u, _context.ContextSize);
Assert.Equal(512u, _context.ContextSize);
Assert.Equal(960, _context.EmbeddingSize);
Assert.Equal(49152, _context.Vocab.Count);
}
Expand Down
1 change: 0 additions & 1 deletion LLama.Unittest/LLamaRerankerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ public LLamaRerankerTests(ITestOutputHelper testOutputHelper)
ContextSize = 0,
PoolingType = LLamaPoolingType.Rank,
GpuLayerCount = Constants.CIGpuLayerCount,

};
using var weights = LLamaWeights.LoadFromFile(@params);
_reranker = new LLamaReranker(weights, @params);
Expand Down
2 changes: 1 addition & 1 deletion LLama.Unittest/SamplingTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ public void BatchedSampling()
}
}

// Add " repeat" and test whether next tokens will be "this phrase forever.".
// Add " repeat" and test whether next tokens will be "this phrase forever."
for (int i = 0; i < 4; i++)
{
for (int b = 0; b < batch_count; b++)
Expand Down
2 changes: 1 addition & 1 deletion LLama.Web/Common/ModelOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ public class ModelOptions
public bool NoKqvOffload { get; set; }

/// <inheritdoc />
public bool FlashAttention { get; set; }
public bool? FlashAttention { get; set; }

/// <inheritdoc />
public Encoding Encoding { get; set; } = Encoding.UTF8;
Expand Down
4 changes: 2 additions & 2 deletions LLama/Abstractions/IContextParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ public interface IContextParams
/// <summary>
/// Whether to use flash attention
/// </summary>
bool FlashAttention { get; }

bool? FlashAttention { get; }
/// <summary>
/// defragment the KV cache if holes/size &gt; defrag_threshold, Set to &lt;= 0 to disable (default)
/// </summary>
Expand Down
7 changes: 4 additions & 3 deletions LLama/Common/ModelParams.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System;
using LLama.Abstractions;
using System.Text;
using System.Text.Json.Serialization;
Expand Down Expand Up @@ -95,12 +96,12 @@ public record ModelParams

/// <inheritdoc />
public bool NoKqvOffload { get; set; }

/// <inheritdoc />

public bool FlashAttention { get; set; }
public bool? FlashAttention { get; set; }

/// <inheritdoc />
[Obsolete]
public float? DefragThreshold { get; set; }

/// <inheritdoc />
Expand Down
11 changes: 9 additions & 2 deletions LLama/Extensions/IContextParamsExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public static void ToLlamaContextParams(this IContextParams @params, out LLamaCo
result.yarn_beta_slow = @params.YarnBetaSlow ?? 1f;
result.yarn_orig_ctx = @params.YarnOriginalContext ?? 0;
result.rope_scaling_type = @params.YarnScalingType ?? RopeScalingType.Unspecified;

result.defrag_threshold = @params.DefragThreshold ?? -1;

result.cb_eval = IntPtr.Zero;
Expand All @@ -49,9 +49,16 @@ public static void ToLlamaContextParams(this IContextParams @params, out LLamaCo
result.type_k = @params.TypeK ?? GGMLType.GGML_TYPE_F16;
result.type_v = @params.TypeV ?? GGMLType.GGML_TYPE_F16;
result.offload_kqv = [email protected];
result.flash_attention = @params.FlashAttention;
Copy link
Contributor

@Lyrcaxis Lyrcaxis Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of completely removing the option to use flash attention, can you pass to llama_flash_attn_type?
I would suggest keeping the previous FlashAttention bool as it was -- but turn it to nullable, so null == Auto.

result.llama_flash_attn_type = @params.FlashAttention switch
{
    true => LLamaFlashAttentionType.LLAMA_FLASH_ATTENTION_TYPE_ENABLED,
    false => LLamaFlashAttentionType.LLAMA_FLASH_ATTENTION_TYPE_DISABLED,
    null => LLamaFlashAttentionType.LLAMA_FLASH_ATTENTION_TYPE_AUTO
}
result.kv_unified = true; // if we wanna hardcode it here instead of in `Default()`.

result.llama_pooling_type = @params.PoolingType;
result.attention_type = @params.AttentionType;
result.llama_flash_attn_type = @params.FlashAttention switch
{
true => LLamaFlashAttentionType.LLAMA_FLASH_ATTENTION_TYPE_ENABLED,
false => LLamaFlashAttentionType.LLAMA_FLASH_ATTENTION_TYPE_DISABLED,
null => LLamaFlashAttentionType.LLAMA_FLASH_ATTENTION_TYPE_AUTO
};
result.kv_unified = true;
result.n_seq_max = (uint)Math.Min(Math.Max(10,result.n_ctx/8),256);

result.n_threads = Threads(@params.Threads);
result.n_threads_batch = Threads(@params.BatchThreads);
Expand Down
2 changes: 1 addition & 1 deletion LLama/LLamaSharp.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
</ItemGroup>

<PropertyGroup>
<BinaryReleaseId>11dd5a44eb180e</BinaryReleaseId>
<BinaryReleaseId>86587da</BinaryReleaseId>
</PropertyGroup>

<PropertyGroup>
Expand Down
5 changes: 5 additions & 0 deletions LLama/Native/LLamaContextParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ public struct LLamaContextParams
/// Attention type to use for embeddings
/// </summary>
public LLamaAttentionType attention_type;

/// <summary>
/// when to enable Flash Attention
/// </summary>
public LLamaFlashAttentionType llama_flash_attn_type;

/// <summary>
/// RoPE base frequency, 0 = from model
Expand Down
19 changes: 19 additions & 0 deletions LLama/Native/LLamaFlashAttentionType.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
namespace LLama.Native;
/// <summary>
/// flash_attn_type
/// </summary>
public enum LLamaFlashAttentionType
{
/// <summary>
/// attention type auto
/// </summary>
LLAMA_FLASH_ATTENTION_TYPE_AUTO = -1,
/// <summary>
/// attention disabled
/// </summary>
LLAMA_FLASH_ATTENTION_TYPE_DISABLED = 0,
/// <summary>
/// attention enabled
/// </summary>
LLAMA_FLASH_ATTENTION_TYPE_ENABLED = 1,
}
7 changes: 6 additions & 1 deletion LLama/Native/LLamaFtype.cs
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,12 @@ public enum LLamaFtype
/// except 1d tensors
/// </summary>
LLAMA_FTYPE_MOSTLY_TQ2_0 = 37,


/// <summary>
/// except 1d tensors
/// </summary>
LLAMA_FTYPE_MOSTLY_MXFP4_MOE = 38,

/// <summary>
/// File type was not specified
/// </summary>
Expand Down
11 changes: 10 additions & 1 deletion LLama/Native/LLamaModelParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,16 @@ public bool check_tensors
set => _check_tensors = Convert.ToSByte(value);
}
private sbyte _check_tensors;


/// <summary>
/// use extra buffer types (used for weight repacking)
/// </summary>
public bool use_extra_bufts
{
readonly get => Convert.ToBoolean(_use_extra_bufts);
set => _use_extra_bufts = Convert.ToSByte(value);
}
private sbyte _use_extra_bufts;
/// <summary>
/// Create a LLamaModelParams with default values
/// </summary>
Expand Down
67 changes: 56 additions & 11 deletions LLama/Native/NativeApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ public static unsafe int llama_chat_apply_template(byte* tmpl, LLamaChatMessage*
{
return internal_llama_chat_apply_template(tmpl, chat, n_msg, add_ass, buf, length);

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_chat_apply_template")]
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl,EntryPoint = "llama_chat_apply_template")]
static extern int internal_llama_chat_apply_template(byte* tmpl, LLamaChatMessage* chat, nuint n_msg, [MarshalAs(UnmanagedType.U1)] bool add_ass, byte* buf, int length);
}

Expand Down Expand Up @@ -215,7 +215,8 @@ public static unsafe int llama_chat_apply_template(byte* tmpl, LLamaChatMessage*
/// <param name="lstrip">User can skip up to 'lstrip' leading spaces before copying (useful when encoding/decoding multiple tokens with 'add_space_prefix')</param>
/// <param name="special">If true, special tokens are rendered in the output</param>
/// <returns>The length written, or if the buffer is too small a negative that indicates the length required</returns>
public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LLamaToken llamaToken, Span<byte> buffer, int lstrip, bool special)
public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LLamaToken llamaToken,
Span<byte> buffer, int lstrip, bool special)
{
// Handle invalid tokens
if ((int)llamaToken < 0)
Expand All @@ -225,12 +226,14 @@ public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LL
{
fixed (byte* bufferPtr = buffer)
{
return llama_token_to_piece_native(vocab.VocabNative, llamaToken, bufferPtr, buffer.Length, lstrip, special);
return llama_token_to_piece_native(vocab.VocabNative, llamaToken, bufferPtr, buffer.Length, lstrip,
special);
}
}

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_token_to_piece")]
static extern unsafe int llama_token_to_piece_native(LLamaVocabNative* model, LLamaToken llamaToken, byte* buffer, int length, int lstrip, [MarshalAs(UnmanagedType.U1)] bool special);
static extern unsafe int llama_token_to_piece_native(LLamaVocabNative* model, LLamaToken llamaToken,
byte* buffer, int length, int lstrip, [MarshalAs(UnmanagedType.U1)] bool special);
}

/// <summary>
Expand All @@ -247,7 +250,9 @@ public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LL
/// Returns a negative number on failure - the number of tokens that would have been returned. Returns INT32_MIN on overflow (e.g., tokenization result size exceeds int32_t limit)
/// </returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
internal static extern unsafe int llama_tokenize(LLamaVocabNative* model, byte* text, int text_len, LLamaToken* tokens, int n_max_tokens, [MarshalAs(UnmanagedType.U1)] bool add_special, [MarshalAs(UnmanagedType.U1)] bool parse_special);
internal static extern unsafe int llama_tokenize(LLamaVocabNative* model, byte* text, int text_len,
LLamaToken* tokens, int n_max_tokens, [MarshalAs(UnmanagedType.U1)] bool add_special,
[MarshalAs(UnmanagedType.U1)] bool parse_special);

/// <summary>
/// Convert the provided tokens into text (inverse of llama_tokenize()).
Expand All @@ -261,7 +266,8 @@ public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LL
/// <param name="unparseSpecial">unparse_special If true, special tokens are rendered in the output.</param>
/// <returns>Returns the number of chars/bytes on success, no more than textLengthMax. Returns a negative number on failure - the number of chars/bytes that would have been returned.</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
internal static extern unsafe int llama_detokenize(LLamaVocabNative* model, LLamaToken* tokens, int nTokens, byte* textOut, int textLengthMax, bool removeSpecial, bool unparseSpecial);
internal static extern unsafe int llama_detokenize(LLamaVocabNative* model, LLamaToken* tokens, int nTokens,
byte* textOut, int textLengthMax, bool removeSpecial, bool unparseSpecial);

/// <summary>
/// Register a callback to receive llama log messages
Expand All @@ -272,7 +278,7 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
{
NativeLogConfig.llama_log_set(logCallback);
}

/// <summary>
/// Allocates a batch of tokens on the heap
/// Each token can be assigned up to n_seq_max sequence ids
Expand Down Expand Up @@ -311,7 +317,8 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
/// <param name="il_end"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern unsafe int llama_apply_adapter_cvec(SafeLLamaContextHandle ctx, float* data, nuint len, int n_embd, int il_start, int il_end);
public static extern unsafe int llama_apply_adapter_cvec(SafeLLamaContextHandle ctx, float* data, nuint len,
int n_embd, int il_start, int il_end);

/// <summary>
/// Build a split GGUF final path for this chunk.
Expand All @@ -324,7 +331,8 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
/// <param name="split_count"></param>
/// <returns>Returns the split_path length.</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_split_path(string split_path, nuint maxlen, string path_prefix, int split_no, int split_count);
public static extern int llama_split_path(string split_path, nuint maxlen, string path_prefix, int split_no,
int split_count);

/// <summary>
/// Extract the path prefix from the split_path if and only if the split_no and split_count match.
Expand All @@ -337,7 +345,8 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
/// <param name="split_count"></param>
/// <returns>Returns the split_prefix length.</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_split_prefix(string split_prefix, nuint maxlen, string split_path, int split_no, int split_count);
public static extern int llama_split_prefix(string split_prefix, nuint maxlen, string split_path, int split_no,
int split_count);

//[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
//todo: public static void llama_attach_threadpool(SafeLLamaContextHandle ctx, ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch);
Expand Down Expand Up @@ -380,5 +389,41 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
/// <returns>Name of the buffer type</returns>
[DllImport(ggmlBaseLibraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr ggml_backend_buft_name(IntPtr buft);

/// <summary>
///
/// </summary>
/// <param name="ctx"></param>
/// <param name="seq_id"></param>
/// <param name="flags"></param>
/// <returns></returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern UIntPtr llama_state_seq_get_size_ext(IntPtr ctx, int seq_id, uint flags);

/// <summary>
///
/// </summary>
/// <param name="ctx"></param>
/// <param name="dst"></param>
/// <param name="size"></param>
/// <param name="seq_id"></param>
/// <param name="flags"></param>
/// <returns></returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern UIntPtr llama_state_seq_get_data_ext(IntPtr ctx, [Out] byte[] dst, UIntPtr size,
int seq_id, uint flags);

/// <summary>
///
/// </summary>
/// <param name="ctx"></param>
/// <param name="src"></param>
/// <param name="size"></param>
/// <param name="dest_seq_id"></param>
/// <param name="flags"></param>
/// <returns></returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern UIntPtr llama_state_seq_set_data_ext(IntPtr ctx, byte[] src, UIntPtr size, int dest_seq_id,
uint flags);
}
}
}
41 changes: 41 additions & 0 deletions LLama/Native/SafeLLamaContextHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,47 @@ static SafeLLamaContextHandle()
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern int llama_set_adapter_lora(SafeLLamaContextHandle context, IntPtr adapter, float scale);

/// <summary>
/// Get metadata value as a string by key name
/// </summary>
/// <param name="adapter"></param>
/// <param name="key"></param>
/// <param name="buf"></param>
/// <param name="buf_size"></param>
/// <returns>The length of the value string (on success) -1 otherwise </returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern int llama_adapter_meta_val_str(IntPtr adapter, string key, StringBuilder buf, UIntPtr buf_size);

/// <summary>
/// Get the number of metadata key value pairs
/// </summary>
/// <param name="adapter"></param>
/// <returns>The count of meta key value pairs</returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern int llama_adapter_meta_count(IntPtr adapter);

/// <summary>
/// Get metadata key name by index
/// </summary>
/// <param name="adapter"></param>
/// <param name="i"></param>
/// <param name="buf"></param>
/// <param name="buf_size"></param>
/// <returns>The length of string i.e meta key (on success) -1 otherwise</returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern int llama_adapter_meta_key_by_index(IntPtr adapter, int i, StringBuilder buf, UIntPtr buf_size);

/// <summary>
/// Get metadata key value by index
/// </summary>
/// <param name="adapter"></param>
/// <param name="i"></param>
/// <param name="buf"></param>
/// <param name="buf_size"></param>
/// <returns>The length of value string (on success) -1 otherwise</returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern int llama_adapter_meta_val_by_index(IntPtr adapter, int i, StringBuilder buf, UIntPtr buf_size);

[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern int llama_rm_adapter_lora(SafeLLamaContextHandle context, IntPtr adapter);

Expand Down
2 changes: 1 addition & 1 deletion LLama/Native/SafeLLamaSamplerHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ static extern unsafe IntPtr llama_sampler_init_logit_bias(

// This is a tricky method to work with!
// It can't return a handle, because that would create a second handle to these resources.
// Instead It returns the raw pointer, and that can be looked up in the _samplers dictionary.
// Instead , It returns the raw pointer, and that can be looked up in the _samplers dictionary.
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern IntPtr llama_sampler_chain_get(SafeLLamaSamplerChainHandle chain, int i);
// ReSharper restore InconsistentNaming
Expand Down
Loading