diff --git a/LLama.Unittest/LLamaContextTests.cs b/LLama.Unittest/LLamaContextTests.cs
index e28b55ce0..7d85589d6 100644
--- a/LLama.Unittest/LLamaContextTests.cs
+++ b/LLama.Unittest/LLamaContextTests.cs
@@ -13,10 +13,9 @@ public LLamaContextTests()
{
var @params = new ModelParams(Constants.GenerativeModelPath2)
{
- ContextSize = 128,
+ ContextSize = 512,
BatchSize = 8,
UBatchSize = 8,
- SeqMax = 1,
VocabOnly = false,
GpuLayerCount = Constants.CIGpuLayerCount,
};
@@ -33,7 +32,7 @@ public void Dispose()
[Fact]
public void CheckProperties()
{
- Assert.Equal(128u, _context.ContextSize);
+ Assert.Equal(512u, _context.ContextSize);
Assert.Equal(960, _context.EmbeddingSize);
Assert.Equal(49152, _context.Vocab.Count);
}
diff --git a/LLama.Unittest/LLamaContextWithCustomLoggerTests.cs b/LLama.Unittest/LLamaContextWithCustomLoggerTests.cs
index 1d16b0481..871b6b8cd 100644
--- a/LLama.Unittest/LLamaContextWithCustomLoggerTests.cs
+++ b/LLama.Unittest/LLamaContextWithCustomLoggerTests.cs
@@ -30,7 +30,7 @@ public LLamaContextWithCustomLoggerTests()
{
var @params = new ModelParams(Constants.GenerativeModelPath2)
{
- ContextSize = 128,
+ ContextSize = 512,
GpuLayerCount = Constants.CIGpuLayerCount,
};
@@ -55,7 +55,7 @@ public void Dispose()
[Fact]
public void CheckProperties()
{
- Assert.Equal(128u, _context.ContextSize);
+ Assert.Equal(512u, _context.ContextSize);
Assert.Equal(960, _context.EmbeddingSize);
Assert.Equal(49152, _context.Vocab.Count);
}
diff --git a/LLama.Unittest/LLamaRerankerTests.cs b/LLama.Unittest/LLamaRerankerTests.cs
index b8dfcfa8d..9ba0a31c0 100644
--- a/LLama.Unittest/LLamaRerankerTests.cs
+++ b/LLama.Unittest/LLamaRerankerTests.cs
@@ -20,7 +20,6 @@ public LLamaRerankerTests(ITestOutputHelper testOutputHelper)
ContextSize = 0,
PoolingType = LLamaPoolingType.Rank,
GpuLayerCount = Constants.CIGpuLayerCount,
-
};
using var weights = LLamaWeights.LoadFromFile(@params);
_reranker = new LLamaReranker(weights, @params);
diff --git a/LLama.Unittest/SamplingTests.cs b/LLama.Unittest/SamplingTests.cs
index 615a7c79e..297641df3 100644
--- a/LLama.Unittest/SamplingTests.cs
+++ b/LLama.Unittest/SamplingTests.cs
@@ -104,7 +104,7 @@ public void BatchedSampling()
}
}
- // Add " repeat" and test whether next tokens will be "this phrase forever.".
+ // Add " repeat" and test whether next tokens will be "this phrase forever."
for (int i = 0; i < 4; i++)
{
for (int b = 0; b < batch_count; b++)
diff --git a/LLama.Web/Common/ModelOptions.cs b/LLama.Web/Common/ModelOptions.cs
index c453aeddf..586db5611 100644
--- a/LLama.Web/Common/ModelOptions.cs
+++ b/LLama.Web/Common/ModelOptions.cs
@@ -102,7 +102,7 @@ public class ModelOptions
public bool NoKqvOffload { get; set; }
///
- public bool FlashAttention { get; set; }
+ public bool? FlashAttention { get; set; }
///
public Encoding Encoding { get; set; } = Encoding.UTF8;
diff --git a/LLama/Abstractions/IContextParams.cs b/LLama/Abstractions/IContextParams.cs
index f80759c8a..b7abed5ed 100644
--- a/LLama/Abstractions/IContextParams.cs
+++ b/LLama/Abstractions/IContextParams.cs
@@ -106,8 +106,8 @@ public interface IContextParams
///
/// Whether to use flash attention
///
- bool FlashAttention { get; }
-
+ bool? FlashAttention { get; }
+
///
/// defragment the KV cache if holes/size > defrag_threshold, Set to <= 0 to disable (default)
///
diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs
index 89737faa7..1b7e44308 100644
--- a/LLama/Common/ModelParams.cs
+++ b/LLama/Common/ModelParams.cs
@@ -1,3 +1,4 @@
+using System;
using LLama.Abstractions;
using System.Text;
using System.Text.Json.Serialization;
@@ -95,12 +96,12 @@ public record ModelParams
///
public bool NoKqvOffload { get; set; }
-
+
///
-
- public bool FlashAttention { get; set; }
+ public bool? FlashAttention { get; set; }
///
+ [Obsolete]
public float? DefragThreshold { get; set; }
///
diff --git a/LLama/Extensions/IContextParamsExtensions.cs b/LLama/Extensions/IContextParamsExtensions.cs
index 85e40f7ad..bfef9b9c1 100644
--- a/LLama/Extensions/IContextParamsExtensions.cs
+++ b/LLama/Extensions/IContextParamsExtensions.cs
@@ -37,7 +37,7 @@ public static void ToLlamaContextParams(this IContextParams @params, out LLamaCo
result.yarn_beta_slow = @params.YarnBetaSlow ?? 1f;
result.yarn_orig_ctx = @params.YarnOriginalContext ?? 0;
result.rope_scaling_type = @params.YarnScalingType ?? RopeScalingType.Unspecified;
-
+
result.defrag_threshold = @params.DefragThreshold ?? -1;
result.cb_eval = IntPtr.Zero;
@@ -49,9 +49,16 @@ public static void ToLlamaContextParams(this IContextParams @params, out LLamaCo
result.type_k = @params.TypeK ?? GGMLType.GGML_TYPE_F16;
result.type_v = @params.TypeV ?? GGMLType.GGML_TYPE_F16;
result.offload_kqv = !@params.NoKqvOffload;
- result.flash_attention = @params.FlashAttention;
result.llama_pooling_type = @params.PoolingType;
result.attention_type = @params.AttentionType;
+ result.llama_flash_attn_type = @params.FlashAttention switch
+ {
+ true => LLamaFlashAttentionType.LLAMA_FLASH_ATTENTION_TYPE_ENABLED,
+ false => LLamaFlashAttentionType.LLAMA_FLASH_ATTENTION_TYPE_DISABLED,
+ null => LLamaFlashAttentionType.LLAMA_FLASH_ATTENTION_TYPE_AUTO
+ };
+ result.kv_unified = true;
+ result.n_seq_max = (uint)Math.Min(Math.Max(10,result.n_ctx/8),256);
result.n_threads = Threads(@params.Threads);
result.n_threads_batch = Threads(@params.BatchThreads);
diff --git a/LLama/LLamaSharp.csproj b/LLama/LLamaSharp.csproj
index 629d10447..be5d09da3 100644
--- a/LLama/LLamaSharp.csproj
+++ b/LLama/LLamaSharp.csproj
@@ -57,7 +57,7 @@
- 11dd5a44eb180e
+ 86587da
diff --git a/LLama/Native/LLamaContextParams.cs b/LLama/Native/LLamaContextParams.cs
index 76f5d6c77..6dea4de47 100644
--- a/LLama/Native/LLamaContextParams.cs
+++ b/LLama/Native/LLamaContextParams.cs
@@ -64,6 +64,11 @@ public struct LLamaContextParams
/// Attention type to use for embeddings
///
public LLamaAttentionType attention_type;
+
+ ///
+ /// when to enable Flash Attention
+ ///
+ public LLamaFlashAttentionType llama_flash_attn_type;
///
/// RoPE base frequency, 0 = from model
diff --git a/LLama/Native/LLamaFlashAttentionType.cs b/LLama/Native/LLamaFlashAttentionType.cs
new file mode 100644
index 000000000..7138dea93
--- /dev/null
+++ b/LLama/Native/LLamaFlashAttentionType.cs
@@ -0,0 +1,19 @@
+namespace LLama.Native;
+///
+/// flash_attn_type
+///
+public enum LLamaFlashAttentionType
+{
+ ///
+ /// attention type auto
+ ///
+ LLAMA_FLASH_ATTENTION_TYPE_AUTO = -1,
+ ///
+ /// attention disabled
+ ///
+ LLAMA_FLASH_ATTENTION_TYPE_DISABLED = 0,
+ ///
+ /// attention enabled
+ ///
+ LLAMA_FLASH_ATTENTION_TYPE_ENABLED = 1,
+}
\ No newline at end of file
diff --git a/LLama/Native/LLamaFtype.cs b/LLama/Native/LLamaFtype.cs
index 705f8032e..813bad1ae 100644
--- a/LLama/Native/LLamaFtype.cs
+++ b/LLama/Native/LLamaFtype.cs
@@ -201,7 +201,12 @@ public enum LLamaFtype
/// except 1d tensors
///
LLAMA_FTYPE_MOSTLY_TQ2_0 = 37,
-
+
+ ///
+ /// except 1d tensors
+ ///
+ LLAMA_FTYPE_MOSTLY_MXFP4_MOE = 38,
+
///
/// File type was not specified
///
diff --git a/LLama/Native/LLamaModelParams.cs b/LLama/Native/LLamaModelParams.cs
index acb024852..4826c96b7 100644
--- a/LLama/Native/LLamaModelParams.cs
+++ b/LLama/Native/LLamaModelParams.cs
@@ -100,7 +100,16 @@ public bool check_tensors
set => _check_tensors = Convert.ToSByte(value);
}
private sbyte _check_tensors;
-
+
+ ///
+ /// use extra buffer types (used for weight repacking)
+ ///
+ public bool use_extra_bufts
+ {
+ readonly get => Convert.ToBoolean(_use_extra_bufts);
+ set => _use_extra_bufts = Convert.ToSByte(value);
+ }
+ private sbyte _use_extra_bufts;
///
/// Create a LLamaModelParams with default values
///
diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs
index db9e928bd..4aefc8810 100644
--- a/LLama/Native/NativeApi.cs
+++ b/LLama/Native/NativeApi.cs
@@ -179,7 +179,7 @@ public static unsafe int llama_chat_apply_template(byte* tmpl, LLamaChatMessage*
{
return internal_llama_chat_apply_template(tmpl, chat, n_msg, add_ass, buf, length);
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_chat_apply_template")]
+ [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl,EntryPoint = "llama_chat_apply_template")]
static extern int internal_llama_chat_apply_template(byte* tmpl, LLamaChatMessage* chat, nuint n_msg, [MarshalAs(UnmanagedType.U1)] bool add_ass, byte* buf, int length);
}
@@ -215,7 +215,8 @@ public static unsafe int llama_chat_apply_template(byte* tmpl, LLamaChatMessage*
/// User can skip up to 'lstrip' leading spaces before copying (useful when encoding/decoding multiple tokens with 'add_space_prefix')
/// If true, special tokens are rendered in the output
/// The length written, or if the buffer is too small a negative that indicates the length required
- public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LLamaToken llamaToken, Span buffer, int lstrip, bool special)
+ public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LLamaToken llamaToken,
+ Span buffer, int lstrip, bool special)
{
// Handle invalid tokens
if ((int)llamaToken < 0)
@@ -225,12 +226,14 @@ public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LL
{
fixed (byte* bufferPtr = buffer)
{
- return llama_token_to_piece_native(vocab.VocabNative, llamaToken, bufferPtr, buffer.Length, lstrip, special);
+ return llama_token_to_piece_native(vocab.VocabNative, llamaToken, bufferPtr, buffer.Length, lstrip,
+ special);
}
}
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_token_to_piece")]
- static extern unsafe int llama_token_to_piece_native(LLamaVocabNative* model, LLamaToken llamaToken, byte* buffer, int length, int lstrip, [MarshalAs(UnmanagedType.U1)] bool special);
+ static extern unsafe int llama_token_to_piece_native(LLamaVocabNative* model, LLamaToken llamaToken,
+ byte* buffer, int length, int lstrip, [MarshalAs(UnmanagedType.U1)] bool special);
}
///
@@ -247,7 +250,9 @@ public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LL
/// Returns a negative number on failure - the number of tokens that would have been returned. Returns INT32_MIN on overflow (e.g., tokenization result size exceeds int32_t limit)
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- internal static extern unsafe int llama_tokenize(LLamaVocabNative* model, byte* text, int text_len, LLamaToken* tokens, int n_max_tokens, [MarshalAs(UnmanagedType.U1)] bool add_special, [MarshalAs(UnmanagedType.U1)] bool parse_special);
+ internal static extern unsafe int llama_tokenize(LLamaVocabNative* model, byte* text, int text_len,
+ LLamaToken* tokens, int n_max_tokens, [MarshalAs(UnmanagedType.U1)] bool add_special,
+ [MarshalAs(UnmanagedType.U1)] bool parse_special);
///
/// Convert the provided tokens into text (inverse of llama_tokenize()).
@@ -261,7 +266,8 @@ public static int llama_token_to_piece(SafeLlamaModelHandle.Vocabulary vocab, LL
/// unparse_special If true, special tokens are rendered in the output.
/// Returns the number of chars/bytes on success, no more than textLengthMax. Returns a negative number on failure - the number of chars/bytes that would have been returned.
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- internal static extern unsafe int llama_detokenize(LLamaVocabNative* model, LLamaToken* tokens, int nTokens, byte* textOut, int textLengthMax, bool removeSpecial, bool unparseSpecial);
+ internal static extern unsafe int llama_detokenize(LLamaVocabNative* model, LLamaToken* tokens, int nTokens,
+ byte* textOut, int textLengthMax, bool removeSpecial, bool unparseSpecial);
///
/// Register a callback to receive llama log messages
@@ -272,7 +278,7 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
{
NativeLogConfig.llama_log_set(logCallback);
}
-
+
///
/// Allocates a batch of tokens on the heap
/// Each token can be assigned up to n_seq_max sequence ids
@@ -311,7 +317,8 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern unsafe int llama_apply_adapter_cvec(SafeLLamaContextHandle ctx, float* data, nuint len, int n_embd, int il_start, int il_end);
+ public static extern unsafe int llama_apply_adapter_cvec(SafeLLamaContextHandle ctx, float* data, nuint len,
+ int n_embd, int il_start, int il_end);
///
/// Build a split GGUF final path for this chunk.
@@ -324,7 +331,8 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
///
/// Returns the split_path length.
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern int llama_split_path(string split_path, nuint maxlen, string path_prefix, int split_no, int split_count);
+ public static extern int llama_split_path(string split_path, nuint maxlen, string path_prefix, int split_no,
+ int split_count);
///
/// Extract the path prefix from the split_path if and only if the split_no and split_count match.
@@ -337,7 +345,8 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
///
/// Returns the split_prefix length.
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern int llama_split_prefix(string split_prefix, nuint maxlen, string split_path, int split_no, int split_count);
+ public static extern int llama_split_prefix(string split_prefix, nuint maxlen, string split_path, int split_no,
+ int split_count);
//[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
//todo: public static void llama_attach_threadpool(SafeLLamaContextHandle ctx, ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch);
@@ -380,5 +389,41 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
/// Name of the buffer type
[DllImport(ggmlBaseLibraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr ggml_backend_buft_name(IntPtr buft);
+
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ public static extern UIntPtr llama_state_seq_get_size_ext(IntPtr ctx, int seq_id, uint flags);
+
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ public static extern UIntPtr llama_state_seq_get_data_ext(IntPtr ctx, [Out] byte[] dst, UIntPtr size,
+ int seq_id, uint flags);
+
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ public static extern UIntPtr llama_state_seq_set_data_ext(IntPtr ctx, byte[] src, UIntPtr size, int dest_seq_id,
+ uint flags);
}
-}
+}
\ No newline at end of file
diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs
index e26619b26..f48e818b7 100644
--- a/LLama/Native/SafeLLamaContextHandle.cs
+++ b/LLama/Native/SafeLLamaContextHandle.cs
@@ -341,6 +341,47 @@ static SafeLLamaContextHandle()
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern int llama_set_adapter_lora(SafeLLamaContextHandle context, IntPtr adapter, float scale);
+ ///
+ /// Get metadata value as a string by key name
+ ///
+ ///
+ ///
+ ///
+ ///
+ /// The length of the value string (on success) -1 otherwise
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern int llama_adapter_meta_val_str(IntPtr adapter, string key, StringBuilder buf, UIntPtr buf_size);
+
+ ///
+ /// Get the number of metadata key value pairs
+ ///
+ ///
+ /// The count of meta key value pairs
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern int llama_adapter_meta_count(IntPtr adapter);
+
+ ///
+ /// Get metadata key name by index
+ ///
+ ///
+ ///
+ ///
+ ///
+ /// The length of string i.e meta key (on success) -1 otherwise
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern int llama_adapter_meta_key_by_index(IntPtr adapter, int i, StringBuilder buf, UIntPtr buf_size);
+
+ ///
+ /// Get metadata key value by index
+ ///
+ ///
+ ///
+ ///
+ ///
+ /// The length of value string (on success) -1 otherwise
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern int llama_adapter_meta_val_by_index(IntPtr adapter, int i, StringBuilder buf, UIntPtr buf_size);
+
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern int llama_rm_adapter_lora(SafeLLamaContextHandle context, IntPtr adapter);
diff --git a/LLama/Native/SafeLLamaSamplerHandle.cs b/LLama/Native/SafeLLamaSamplerHandle.cs
index bad1a1974..a113e1694 100644
--- a/LLama/Native/SafeLLamaSamplerHandle.cs
+++ b/LLama/Native/SafeLLamaSamplerHandle.cs
@@ -616,7 +616,7 @@ static extern unsafe IntPtr llama_sampler_init_logit_bias(
// This is a tricky method to work with!
// It can't return a handle, because that would create a second handle to these resources.
- // Instead It returns the raw pointer, and that can be looked up in the _samplers dictionary.
+ // Instead , It returns the raw pointer, and that can be looked up in the _samplers dictionary.
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern IntPtr llama_sampler_chain_get(SafeLLamaSamplerChainHandle chain, int i);
// ReSharper restore InconsistentNaming
diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs
index d335a1209..196bb1763 100644
--- a/LLama/Native/SafeLlamaModelHandle.cs
+++ b/LLama/Native/SafeLlamaModelHandle.cs
@@ -80,7 +80,12 @@ public sealed class SafeLlamaModelHandle
/// Returns true if the model is recurrent (like Mamba, RWKV, etc.)
///
public bool IsRecurrent => llama_model_is_recurrent(this);
-
+
+ ///
+ /// Returns true if the model is diffusion based (like LLaDA , Dream etc )
+ ///
+ public bool IsDiffusion => llama_model_is_diffusion(this);
+
///
/// Get a description of this model
///
@@ -424,6 +429,10 @@ private static int llama_model_meta_val_str(SafeLlamaModelHandle model, string k
[return: MarshalAs(UnmanagedType.U1)]
private static extern bool llama_model_is_recurrent(SafeLlamaModelHandle model);
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ [return: MarshalAs(UnmanagedType.U1)]
+ private static extern bool llama_model_is_diffusion(SafeLlamaModelHandle model);
+
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern unsafe LLamaVocabNative* llama_model_get_vocab(SafeLlamaModelHandle model);
diff --git a/llama.cpp b/llama.cpp
index 11dd5a44e..86587da03 160000
--- a/llama.cpp
+++ b/llama.cpp
@@ -1 +1 @@
-Subproject commit 11dd5a44eb180e1d69fac24d3852b5222d66fb7f
+Subproject commit 86587da03bd78df8f4e7d8b111a0c1d2494d6ed0