Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions LLama.Web/Common/ModelOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -103,5 +103,11 @@ public class ModelOptions

/// <inheritdoc />
public bool VocabOnly { get; set; }

/// <inheritdoc />
public float DefragThreshold { get; set; }

/// <inheritdoc />
public bool DoPooling { get; set; }
}
}
10 changes: 10 additions & 0 deletions LLama/Abstractions/IContextParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,14 @@ public interface IContextParams
/// Whether to disable offloading the KQV cache to the GPU
/// </summary>
bool NoKqvOffload { get; }

/// <summary>
/// defragment the KV cache if holes/size &gt; defrag_threshold, Set to &lt; 0 to disable (default)
/// </summary>
float DefragThreshold { get; }

/// <summary>
/// Whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
/// </summary>
bool DoPooling { get; }
}
24 changes: 12 additions & 12 deletions LLama/Abstractions/IModelParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ public MetadataOverride(string key, int value)
{
Key = key;
_valueInt = value;
Type = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT;
Type = LLamaModelKvOverrideType.Int;
}

/// <summary>
Expand All @@ -263,7 +263,7 @@ public MetadataOverride(string key, float value)
{
Key = key;
_valueFloat = value;
Type = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT;
Type = LLamaModelKvOverrideType.Float;
}

/// <summary>
Expand All @@ -275,20 +275,20 @@ public MetadataOverride(string key, bool value)
{
Key = key;
_valueBool = value;
Type = LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL;
Type = LLamaModelKvOverrideType.Bool;
}

internal void WriteValue(ref LLamaModelMetadataOverride dest)
{
switch (Type)
{
case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT:
case LLamaModelKvOverrideType.Int:
dest.IntValue = _valueInt;
break;
case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT:
case LLamaModelKvOverrideType.Float:
dest.FloatValue = _valueFloat;
break;
case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL:
case LLamaModelKvOverrideType.Bool:
dest.BoolValue = _valueBool ? -1L : 0;
break;
default:
Expand All @@ -300,13 +300,13 @@ internal void WriteValue(Utf8JsonWriter writer)
{
switch (Type)
{
case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT:
case LLamaModelKvOverrideType.Int:
writer.WriteNumberValue(_valueInt);
break;
case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT:
case LLamaModelKvOverrideType.Float:
writer.WriteNumberValue(_valueFloat);
break;
case LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL:
case LLamaModelKvOverrideType.Bool:
writer.WriteBooleanValue(_valueBool);
break;
default:
Expand All @@ -328,9 +328,9 @@ public override MetadataOverride Read(ref Utf8JsonReader reader, Type typeToConv

return ((LLamaModelKvOverrideType)ktv.Type) switch
{
LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_INT => new MetadataOverride(ktv.Key, ktv.Value.GetInt32()),
LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_FLOAT => new MetadataOverride(ktv.Key, ktv.Value.GetSingle()),
LLamaModelKvOverrideType.LLAMA_KV_OVERRIDE_BOOL => new MetadataOverride(ktv.Key, ktv.Value.GetBoolean()),
LLamaModelKvOverrideType.Int => new MetadataOverride(ktv.Key, ktv.Value.GetInt32()),
LLamaModelKvOverrideType.Float => new MetadataOverride(ktv.Key, ktv.Value.GetSingle()),
LLamaModelKvOverrideType.Bool => new MetadataOverride(ktv.Key, ktv.Value.GetBoolean()),
_ => throw new JsonException(),
};
}
Expand Down
4 changes: 2 additions & 2 deletions LLama/Batched/Conversation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,9 @@ public void Remove(LLamaPos start, int count)
/// <param name="start">Start position (inclusive)</param>
/// <param name="end">End position (exclusive)</param>
/// <param name="delta">Amount to add on to each token position</param>
public void Shift(LLamaPos start, LLamaPos end, int delta)
public void Add(LLamaPos start, LLamaPos end, int delta)
{
_conversation.Executor.Context.NativeHandle.KvCacheSequenceShift(_conversation.ConversationId, start, end, delta);
_conversation.Executor.Context.NativeHandle.KvCacheSequenceAdd(_conversation.ConversationId, start, end, delta);
}
#endregion

Expand Down
2 changes: 1 addition & 1 deletion LLama/Batched/ConversationExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public static void ShiftLeft(this Conversation conversation, int count, int keep
kv.Remove(keep, count);

// Shift the C's
kv.Shift(keep + count, end, -count);
kv.Add(keep + count, end, -count);

// Update total count
return end.Value - count;
Expand Down
6 changes: 6 additions & 0 deletions LLama/Common/ModelParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ public record ModelParams
/// <inheritdoc />
public bool NoKqvOffload { get; set; }

/// <inheritdoc />
public float DefragThreshold { get; set; }

/// <inheritdoc />
public bool DoPooling { get; set; }

/// <inheritdoc />
public bool VocabOnly { get; set; }

Expand Down
5 changes: 4 additions & 1 deletion LLama/Extensions/IContextParamsExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,17 @@ public static void ToLlamaContextParams(this IContextParams @params, out LLamaCo
result.yarn_beta_fast = @params.YarnBetaFast ?? 32f;
result.yarn_beta_slow = @params.YarnBetaSlow ?? 1f;
result.yarn_orig_ctx = @params.YarnOriginalContext ?? 0;
result.rope_scaling_type = @params.YarnScalingType ?? RopeScalingType.LLAMA_ROPE_SCALING_UNSPECIFIED;
result.rope_scaling_type = @params.YarnScalingType ?? RopeScalingType.Unspecified;

result.defrag_threshold = @params.DefragThreshold;

result.cb_eval = IntPtr.Zero;
result.cb_eval_user_data = IntPtr.Zero;

result.type_k = @params.TypeK ?? GGMLType.GGML_TYPE_F16;
result.type_k = @params.TypeV ?? GGMLType.GGML_TYPE_F16;
result.offload_kqv = [email protected];
result.do_pooling = @params.DoPooling;

result.n_threads = Threads(@params.Threads);
result.n_threads_batch = Threads(@params.BatchThreads);
Expand Down
33 changes: 33 additions & 0 deletions LLama/LLamaContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,35 @@ public sealed class LLamaContext
/// </summary>
public Encoding Encoding { get; }

private uint _generationThreads;
private uint _batchThreads;

/// <summary>
/// Get or set the number of threads to use for generation
/// </summary>
public uint GenerationThreads
{
get => _generationThreads;
set
{
_generationThreads = value;
NativeHandle.SetThreads(_generationThreads, _batchThreads);
}
}

/// <summary>
/// Get or set the number of threads to use for batch processing
/// </summary>
public uint BatchThreads
{
get => _batchThreads;
set
{
_batchThreads = value;
NativeHandle.SetThreads(_generationThreads, _batchThreads);
}
}

/// <summary>
/// Create a new LLamaContext for the given LLamaWeights
/// </summary>
Expand All @@ -75,6 +104,10 @@ public LLamaContext(LLamaWeights model, IContextParams @params, ILogger? logger

@params.ToLlamaContextParams(out var lparams);
NativeHandle = SafeLLamaContextHandle.Create(model.NativeHandle, lparams);

// It's not possible to get these values from llama.cpp, store a copy of them here.
_generationThreads = lparams.n_threads;
_batchThreads = lparams.n_threads_batch;
}

/// <summary>
Expand Down
14 changes: 12 additions & 2 deletions LLama/LLamaQuantizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public static bool Quantize(string srcFileName, string dstFilename, string ftype
private static bool ValidateFtype(LLamaFtype ftype)
{
// Validation copies from here:
// https://github.com/ggerganov/llama.cpp/blob/d71ac90985854b0905e1abba778e407e17f9f887/llama.cpp#L9613
// https://github.com/ggerganov/llama.cpp/blob/3ab8b3a92ede46df88bc5a2dfca3777de4a2b2b6/llama.cpp#L10965

switch (ftype)
{
Expand All @@ -74,7 +74,7 @@ private static bool ValidateFtype(LLamaFtype ftype)
case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q2_K_S:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q2_K:

case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q3_K_XS:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ3_K_XS:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q3_K_S:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q3_K_M:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q3_K_L:
Expand All @@ -89,8 +89,18 @@ private static bool ValidateFtype(LLamaFtype ftype)

case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ2_XXS:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ2_XS:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ2_S:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ2_M:

case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ3_XXS:

case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ1_S:

case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ4_NL:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ4_XS:

case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ3_S:
case LLamaFtype.LLAMA_FTYPE_MOSTLY_IQ3_M:
return true;

case LLamaFtype.LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16:
Expand Down
2 changes: 1 addition & 1 deletion LLama/LLamaStatelessExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams
var n_discard = n_left / 2;

NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1);
NativeApi.llama_kv_cache_seq_shift(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard);
NativeApi.llama_kv_cache_seq_add(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard);

n_past -= n_discard;
}
Expand Down
11 changes: 11 additions & 0 deletions LLama/Native/LLamaChatMessage.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
namespace LLama.Native;

/// <summary>
///
/// </summary>
/// <remarks>llama_chat_message</remarks>
public unsafe struct LLamaChatMessage
{
public byte* role;
public byte* content;
}
35 changes: 25 additions & 10 deletions LLama/Native/LLamaContextParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,32 +51,42 @@ public struct LLamaContextParams
/// RoPE base frequency, 0 = from model
/// </summary>
public float rope_freq_base;

/// <summary>
/// RoPE frequency scaling factor, 0 = from model
/// </summary>
public float rope_freq_scale;
public float rope_freq_scale;

/// <summary>
/// YaRN extrapolation mix factor, negative = from model
/// </summary>
public float yarn_ext_factor;
public float yarn_ext_factor;

/// <summary>
/// YaRN magnitude scaling factor
/// </summary>
public float yarn_attn_factor;
public float yarn_attn_factor;

/// <summary>
/// YaRN low correction dim
/// </summary>
public float yarn_beta_fast;
public float yarn_beta_fast;

/// <summary>
/// YaRN high correction dim
/// </summary>
public float yarn_beta_slow;
public float yarn_beta_slow;

/// <summary>
/// YaRN original context size
/// </summary>
public uint yarn_orig_ctx;

/// <summary>
/// defragment the KV cache if holes/size &gt; defrag_threshold, Set to &lt; 0 to disable (default)
/// </summary>
public float defrag_threshold;

/// <summary>
/// ggml_backend_sched_eval_callback
/// </summary>
Expand All @@ -97,11 +107,6 @@ public struct LLamaContextParams
/// </summary>
public GGMLType type_v;

/// <summary>
/// Deprecated!
/// </summary>
private sbyte _mul_mat_q;

/// <summary>
/// Deprecated!
/// </summary>
Expand All @@ -126,6 +131,16 @@ public bool offload_kqv
set => _offload_kqv = Convert.ToSByte(value);
}
private sbyte _offload_kqv;

/// <summary>
/// Whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
/// </summary>
public bool do_pooling
{
readonly get => Convert.ToBoolean(_do_pooling);
set => _do_pooling = Convert.ToSByte(value);
}
private sbyte _do_pooling;
}
}

37 changes: 36 additions & 1 deletion LLama/Native/LLamaFtype.cs
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,48 @@ public enum LLamaFtype
/// <summary>
/// except 1d tensors
/// </summary>
LLAMA_FTYPE_MOSTLY_Q3_K_XS = 22,
LLAMA_FTYPE_MOSTLY_IQ3_K_XS = 22,

/// <summary>
/// except 1d tensors
/// </summary>
LLAMA_FTYPE_MOSTLY_IQ3_XXS = 23,

/// <summary>
/// except 1d tensors
/// </summary>
LLAMA_FTYPE_MOSTLY_IQ1_S = 24,

/// <summary>
/// except 1d tensors
/// </summary>
LLAMA_FTYPE_MOSTLY_IQ4_NL = 25,

/// <summary>
/// except 1d tensors
/// </summary>
LLAMA_FTYPE_MOSTLY_IQ3_S = 26,

/// <summary>
/// except 1d tensors
/// </summary>
LLAMA_FTYPE_MOSTLY_IQ3_M = 27,

/// <summary>
/// except 1d tensors
/// </summary>
LLAMA_FTYPE_MOSTLY_IQ2_S = 28,

/// <summary>
/// except 1d tensors
/// </summary>
LLAMA_FTYPE_MOSTLY_IQ2_M = 29,

/// <summary>
/// except 1d tensors
/// </summary>
LLAMA_FTYPE_MOSTLY_IQ4_XS = 30,

/// <summary>
/// File type was not specified
/// </summary>
Expand Down
Loading