Skip to content

Commit c933206

Browse files
committed
- Added low level wrapper methods for new per-sequence state load/save in SafeLLamaContextHandle
- Added high level wrapper methods (save/load with `State` object or memory mapped file) in `LLamaContext` - Moved native methods for per-sequence state load/save into `SafeLLamaContextHandle`
1 parent ae5ad71 commit c933206

File tree

3 files changed

+282
-85
lines changed

3 files changed

+282
-85
lines changed

LLama/LLamaContext.cs

Lines changed: 196 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ public string DeTokenize(IReadOnlyList<LLamaToken> tokens)
152152
return decoder.Read();
153153
}
154154

155+
#region state load/save
155156
/// <summary>
156157
/// Save the state to specified path.
157158
/// </summary>
@@ -163,7 +164,7 @@ public void SaveState(string filename)
163164
File.Delete(filename);
164165

165166
// Estimate size of state to write to disk, this is always equal to or greater than the actual size
166-
var estimatedStateSize = (long)NativeApi.llama_state_get_size(NativeHandle);
167+
var estimatedStateSize = checked((long)NativeHandle.GetStateSize());
167168

168169
// Map the file and write the bytes directly to it. This saves copying the bytes into a C# array
169170
long writtenBytes;
@@ -174,8 +175,53 @@ public void SaveState(string filename)
174175
{
175176
byte* ptr = null;
176177
view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr);
177-
writtenBytes = (long)NativeApi.llama_state_get_data(NativeHandle, ptr);
178-
view.SafeMemoryMappedViewHandle.ReleasePointer();
178+
try
179+
{
180+
writtenBytes = (long)NativeHandle.GetState(ptr, (ulong)estimatedStateSize);
181+
}
182+
finally
183+
{
184+
view.SafeMemoryMappedViewHandle.ReleasePointer();
185+
}
186+
}
187+
}
188+
189+
// Truncate the file to the actual size of data that was written
190+
using (var fileStream = new FileStream(filename, FileMode.Open))
191+
fileStream.SetLength(writtenBytes);
192+
}
193+
194+
/// <summary>
195+
/// Save the state of a particular sequence to specified path.
196+
/// </summary>
197+
/// <param name="filename"></param>
198+
/// <param name="sequence"></param>
199+
public void SaveState(string filename, LLamaSeqId sequence)
200+
{
201+
// Delete that file before overwriting it
202+
if (File.Exists(filename))
203+
File.Delete(filename);
204+
205+
// Estimate size of state to write to disk, this is always equal to or greater than the actual size
206+
var estimatedStateSize = checked((long)NativeHandle.GetStateSize(sequence));
207+
208+
// Map the file and write the bytes directly to it. This saves copying the bytes into a C# array
209+
long writtenBytes;
210+
using (var file = MemoryMappedFile.CreateFromFile(filename, FileMode.Create, null, estimatedStateSize))
211+
using (var view = file.CreateViewAccessor(0, estimatedStateSize))
212+
{
213+
unsafe
214+
{
215+
byte* ptr = null;
216+
view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr);
217+
try
218+
{
219+
writtenBytes = (long)NativeHandle.GetState(ptr, (ulong)estimatedStateSize);
220+
}
221+
finally
222+
{
223+
view.SafeMemoryMappedViewHandle.ReleasePointer();
224+
}
179225
}
180226
}
181227

@@ -187,7 +233,7 @@ public void SaveState(string filename)
187233
/// <summary>
188234
/// Get the state data as an opaque handle, which can be loaded later using <see cref="LoadState(State)"/>
189235
/// </summary>
190-
/// <remarks>Use <see cref="SaveState"/> if you intend to save this state to disk.</remarks>
236+
/// <remarks>Use <see cref="SaveState(string)"/> if you intend to save this state to disk.</remarks>
191237
/// <returns></returns>
192238
public State GetState()
193239
{
@@ -198,7 +244,11 @@ public State GetState()
198244
try
199245
{
200246
// Copy the state data into memory, discover the actual size required
201-
var actualSize = NativeHandle.GetState(memory, stateSize);
247+
ulong actualSize;
248+
unsafe
249+
{
250+
actualSize = NativeHandle.GetState((byte*)memory, stateSize);
251+
}
202252

203253
// Shrink to size
204254
memory = Marshal.ReAllocHGlobal(memory, (nint)actualSize);
@@ -218,11 +268,48 @@ public State GetState()
218268
}
219269
}
220270

271+
/// <summary>
272+
/// Get the state data as an opaque handle, which can be loaded later using <see cref="LoadState(State)"/>
273+
/// </summary>
274+
/// <remarks>Use <see cref="SaveState(string, LLamaSeqId)"/> if you intend to save this state to disk.</remarks>
275+
/// <returns></returns>
276+
public SequenceState GetState(LLamaSeqId sequence)
277+
{
278+
var stateSize = NativeHandle.GetStateSize(sequence);
279+
280+
// Allocate a chunk of memory large enough to hold the entire state
281+
var memory = Marshal.AllocHGlobal((nint)stateSize);
282+
try
283+
{
284+
// Copy the state data into memory, discover the actual size required
285+
ulong actualSize;
286+
unsafe
287+
{
288+
actualSize = NativeHandle.GetState((byte*)memory, stateSize, sequence);
289+
}
290+
291+
// Shrink to size
292+
memory = Marshal.ReAllocHGlobal(memory, (nint)actualSize);
293+
294+
// Wrap memory in a "state"
295+
var state = new SequenceState(memory, actualSize);
296+
297+
// Set memory to zero, to prevent it being freed in finally block
298+
memory = IntPtr.Zero;
299+
300+
return state;
301+
}
302+
finally
303+
{
304+
if (memory != IntPtr.Zero)
305+
Marshal.FreeHGlobal(memory);
306+
}
307+
}
308+
221309
/// <summary>
222310
/// Load the state from specified path.
223311
/// </summary>
224312
/// <param name="filename"></param>
225-
/// <exception cref="RuntimeError"></exception>
226313
public void LoadState(string filename)
227314
{
228315
// Map state file into memory and pass that pointer directly to `llama_set_state_data` to load from
@@ -233,8 +320,41 @@ public void LoadState(string filename)
233320
{
234321
byte* ptr = null;
235322
view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr);
236-
NativeApi.llama_state_set_data(NativeHandle, ptr);
237-
view.SafeMemoryMappedViewHandle.ReleasePointer();
323+
try
324+
{
325+
NativeHandle.SetState(ptr);
326+
}
327+
finally
328+
{
329+
view.SafeMemoryMappedViewHandle.ReleasePointer();
330+
}
331+
}
332+
}
333+
}
334+
335+
/// <summary>
336+
/// Load the state from specified path into a particular sequence
337+
/// </summary>
338+
/// <param name="filename"></param>
339+
/// <param name="sequence"></param>
340+
public void LoadState(string filename, LLamaSeqId sequence)
341+
{
342+
// Map state file into memory and pass that pointer directly to `llama_set_state_data` to load from
343+
using (var file = MemoryMappedFile.CreateFromFile(filename, FileMode.Open, null))
344+
using (var view = file.CreateViewAccessor())
345+
{
346+
unsafe
347+
{
348+
byte* ptr = null;
349+
view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr);
350+
try
351+
{
352+
NativeHandle.SetState(ptr, sequence);
353+
}
354+
finally
355+
{
356+
view.SafeMemoryMappedViewHandle.ReleasePointer();
357+
}
238358
}
239359
}
240360
}
@@ -248,10 +368,25 @@ public void LoadState(State state)
248368
{
249369
unsafe
250370
{
251-
NativeHandle.SetState((byte*)state.DangerousGetHandle().ToPointer());
371+
NativeHandle.SetState((byte*)state.DangerousGetHandle());
252372
}
253373
}
254374

375+
/// <summary>
376+
/// Load the state from memory into a particular sequence
377+
/// </summary>
378+
/// <param name="state"></param>
379+
/// <param name="sequence"></param>
380+
/// <exception cref="RuntimeError"></exception>
381+
public void LoadState(SequenceState state, LLamaSeqId sequence)
382+
{
383+
unsafe
384+
{
385+
NativeHandle.SetState((byte*)state.DangerousGetHandle(), sequence);
386+
}
387+
}
388+
#endregion
389+
255390
/// <summary>
256391
/// Sample a single token from this context, using the given sampling pipeline
257392
/// </summary>
@@ -417,12 +552,16 @@ public void Dispose()
417552
}
418553

419554
/// <summary>
420-
/// The state of this model, which can be reloaded later
555+
/// The state of this context, which can be reloaded later
421556
/// </summary>
422557
public class State
423558
: SafeLLamaHandleBase
424559
{
425-
private ulong _size;
560+
private readonly ulong _size;
561+
/// <summary>
562+
/// Get the size in bytes of this state object
563+
/// </summary>
564+
public ulong Size => _size;
426565

427566
internal State(IntPtr memory, ulong size)
428567
: base(memory, true)
@@ -441,6 +580,7 @@ protected override bool ReleaseHandle()
441580
/// Convert this state to a byte array
442581
/// </summary>
443582
/// <returns></returns>
583+
[Obsolete("It is not generally safe to convert a state into a byte array - it will fail if the state is very large")]
444584
public byte[] ToByteArray()
445585
{
446586
var bytes = new byte[_size];
@@ -453,12 +593,57 @@ public byte[] ToByteArray()
453593
/// </summary>
454594
/// <param name="bytes"></param>
455595
/// <returns></returns>
596+
[Obsolete("It is not generally safe to convert a state into a byte array - it will fail if the state is very large")]
456597
public static State FromByteArray(byte[] bytes)
457598
{
458599
var memory = Marshal.AllocHGlobal(bytes.Length);
459600
Marshal.Copy(bytes, 0, memory, bytes.Length);
460601
return new State(memory, (ulong)bytes.Length);
461602
}
462603
}
604+
605+
/// <summary>
606+
/// The state of a single sequence, which can be reloaded later
607+
/// </summary>
608+
public class SequenceState
609+
: SafeLLamaHandleBase
610+
{
611+
private readonly ulong _size;
612+
/// <summary>
613+
/// Get the size in bytes of this state object
614+
/// </summary>
615+
public ulong Size => _size;
616+
617+
internal SequenceState(IntPtr memory, ulong size)
618+
: base(memory, true)
619+
{
620+
_size = size;
621+
}
622+
623+
/// <inheritdoc />
624+
protected override bool ReleaseHandle()
625+
{
626+
Marshal.FreeHGlobal(handle);
627+
return true;
628+
}
629+
630+
/// <summary>
631+
/// Copy bytes to a desintation pointer.
632+
/// </summary>
633+
/// <param name="dst">Destination to write to</param>
634+
/// <param name="length">Length of the destination buffer</param>
635+
/// <param name="offset">Offset from start of src to start copying from</param>
636+
/// <returns>Number of bytes written to destination</returns>
637+
public unsafe ulong CopyTo(byte* dst, ulong length, ulong offset = 0)
638+
{
639+
var copy = Math.Min(length, _size - offset);
640+
641+
var src = (byte*)DangerousGetHandle();
642+
src += offset;
643+
644+
Buffer.MemoryCopy(src, dst, length, copy);
645+
return copy;
646+
}
647+
}
463648
}
464649
}

LLama/Native/NativeApi.cs

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -67,34 +67,6 @@ public static void llama_empty_call()
6767
//[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
6868
//public static extern void llama_numa_init(ggml_numa_strategy numa);
6969

70-
/// <summary>
71-
/// Returns the maximum size in bytes of the state (rng, logits, embedding
72-
/// and kv_cache) - will often be smaller after compacting tokens
73-
/// </summary>
74-
/// <param name="ctx"></param>
75-
/// <returns></returns>
76-
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
77-
public static extern ulong llama_state_get_size(SafeLLamaContextHandle ctx);
78-
79-
/// <summary>
80-
/// Copies the state to the specified destination address.
81-
/// Destination needs to have allocated enough memory.
82-
/// </summary>
83-
/// <param name="ctx"></param>
84-
/// <param name="dest"></param>
85-
/// <returns>the number of bytes copied</returns>
86-
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
87-
public static extern unsafe ulong llama_state_get_data(SafeLLamaContextHandle ctx, byte* dest);
88-
89-
/// <summary>
90-
/// Set the state reading from the specified address
91-
/// </summary>
92-
/// <param name="ctx"></param>
93-
/// <param name="src"></param>
94-
/// <returns>the number of bytes read</returns>
95-
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
96-
public static extern unsafe ulong llama_state_set_data(SafeLLamaContextHandle ctx, byte* src);
97-
9870
/// <summary>
9971
/// Load session file
10072
/// </summary>
@@ -118,35 +90,6 @@ public static void llama_empty_call()
11890
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
11991
public static extern bool llama_state_save_file(SafeLLamaContextHandle ctx, string path_session, LLamaToken[] tokens, ulong n_token_count);
12092

121-
/// <summary>
122-
/// Get the exact size needed to copy the KV cache of a single sequence
123-
/// </summary>
124-
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
125-
public static extern nuint llama_state_seq_get_size(SafeLLamaContextHandle ctx, LLamaSeqId seq_id);
126-
127-
/// <summary>
128-
/// Copy the KV cache of a single sequence into the specified buffer
129-
/// </summary>
130-
/// <param name="ctx"></param>
131-
/// <param name="dst"></param>
132-
/// <param name="seq_id"></param>
133-
/// <returns></returns>
134-
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
135-
public static extern unsafe nuint llama_state_seq_get_data(SafeLLamaContextHandle ctx, byte* dst, LLamaSeqId seq_id);
136-
137-
/// <summary>
138-
/// Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence
139-
/// </summary>
140-
/// <param name="ctx"></param>
141-
/// <param name="src"></param>
142-
/// <param name="dest_seq_id"></param>
143-
/// <returns>
144-
/// - Positive: Ok
145-
/// - Zero: Failed to load
146-
/// </returns>
147-
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
148-
public static extern unsafe nuint llama_state_seq_set_data(SafeLLamaContextHandle ctx, byte* src, LLamaSeqId dest_seq_id);
149-
15093
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
15194
public static extern unsafe nuint llama_state_seq_save_file(SafeLLamaContextHandle ctx, string filepath, LLamaSeqId seq_id, LLamaToken* tokens, nuint n_token_count);
15295

0 commit comments

Comments
 (0)