@@ -152,6 +152,7 @@ public string DeTokenize(IReadOnlyList<LLamaToken> tokens)
152152 return decoder . Read ( ) ;
153153 }
154154
155+ #region state load/save
155156 /// <summary>
156157 /// Save the state to specified path.
157158 /// </summary>
@@ -163,7 +164,7 @@ public void SaveState(string filename)
163164 File . Delete ( filename ) ;
164165
165166 // Estimate size of state to write to disk, this is always equal to or greater than the actual size
166- var estimatedStateSize = ( long ) NativeApi . llama_state_get_size ( NativeHandle ) ;
167+ var estimatedStateSize = checked ( ( long ) NativeHandle . GetStateSize ( ) ) ;
167168
168169 // Map the file and write the bytes directly to it. This saves copying the bytes into a C# array
169170 long writtenBytes ;
@@ -174,8 +175,53 @@ public void SaveState(string filename)
174175 {
175176 byte * ptr = null ;
176177 view . SafeMemoryMappedViewHandle . AcquirePointer ( ref ptr ) ;
177- writtenBytes = ( long ) NativeApi . llama_state_get_data ( NativeHandle , ptr ) ;
178- view . SafeMemoryMappedViewHandle . ReleasePointer ( ) ;
178+ try
179+ {
180+ writtenBytes = ( long ) NativeHandle . GetState ( ptr , ( ulong ) estimatedStateSize ) ;
181+ }
182+ finally
183+ {
184+ view . SafeMemoryMappedViewHandle . ReleasePointer ( ) ;
185+ }
186+ }
187+ }
188+
189+ // Truncate the file to the actual size of data that was written
190+ using ( var fileStream = new FileStream ( filename , FileMode . Open ) )
191+ fileStream . SetLength ( writtenBytes ) ;
192+ }
193+
194+ /// <summary>
195+ /// Save the state of a particular sequence to specified path.
196+ /// </summary>
197+ /// <param name="filename"></param>
198+ /// <param name="sequence"></param>
199+ public void SaveState ( string filename , LLamaSeqId sequence )
200+ {
201+ // Delete that file before overwriting it
202+ if ( File . Exists ( filename ) )
203+ File . Delete ( filename ) ;
204+
205+ // Estimate size of state to write to disk, this is always equal to or greater than the actual size
206+ var estimatedStateSize = checked ( ( long ) NativeHandle . GetStateSize ( sequence ) ) ;
207+
208+ // Map the file and write the bytes directly to it. This saves copying the bytes into a C# array
209+ long writtenBytes ;
210+ using ( var file = MemoryMappedFile . CreateFromFile ( filename , FileMode . Create , null , estimatedStateSize ) )
211+ using ( var view = file . CreateViewAccessor ( 0 , estimatedStateSize ) )
212+ {
213+ unsafe
214+ {
215+ byte * ptr = null ;
216+ view . SafeMemoryMappedViewHandle . AcquirePointer ( ref ptr ) ;
217+ try
218+ {
219+ writtenBytes = ( long ) NativeHandle . GetState ( ptr , ( ulong ) estimatedStateSize ) ;
220+ }
221+ finally
222+ {
223+ view . SafeMemoryMappedViewHandle . ReleasePointer ( ) ;
224+ }
179225 }
180226 }
181227
@@ -187,7 +233,7 @@ public void SaveState(string filename)
187233 /// <summary>
188234 /// Get the state data as an opaque handle, which can be loaded later using <see cref="LoadState(State)"/>
189235 /// </summary>
190- /// <remarks>Use <see cref="SaveState"/> if you intend to save this state to disk.</remarks>
236+ /// <remarks>Use <see cref="SaveState(string) "/> if you intend to save this state to disk.</remarks>
191237 /// <returns></returns>
192238 public State GetState ( )
193239 {
@@ -198,7 +244,11 @@ public State GetState()
198244 try
199245 {
200246 // Copy the state data into memory, discover the actual size required
201- var actualSize = NativeHandle . GetState ( memory , stateSize ) ;
247+ ulong actualSize ;
248+ unsafe
249+ {
250+ actualSize = NativeHandle . GetState ( ( byte * ) memory , stateSize ) ;
251+ }
202252
203253 // Shrink to size
204254 memory = Marshal . ReAllocHGlobal ( memory , ( nint ) actualSize ) ;
@@ -218,11 +268,48 @@ public State GetState()
218268 }
219269 }
220270
271+ /// <summary>
272+ /// Get the state data as an opaque handle, which can be loaded later using <see cref="LoadState(State)"/>
273+ /// </summary>
274+ /// <remarks>Use <see cref="SaveState(string, LLamaSeqId)"/> if you intend to save this state to disk.</remarks>
275+ /// <returns></returns>
276+ public SequenceState GetState ( LLamaSeqId sequence )
277+ {
278+ var stateSize = NativeHandle . GetStateSize ( sequence ) ;
279+
280+ // Allocate a chunk of memory large enough to hold the entire state
281+ var memory = Marshal . AllocHGlobal ( ( nint ) stateSize ) ;
282+ try
283+ {
284+ // Copy the state data into memory, discover the actual size required
285+ ulong actualSize ;
286+ unsafe
287+ {
288+ actualSize = NativeHandle . GetState ( ( byte * ) memory , stateSize , sequence ) ;
289+ }
290+
291+ // Shrink to size
292+ memory = Marshal . ReAllocHGlobal ( memory , ( nint ) actualSize ) ;
293+
294+ // Wrap memory in a "state"
295+ var state = new SequenceState ( memory , actualSize ) ;
296+
297+ // Set memory to zero, to prevent it being freed in finally block
298+ memory = IntPtr . Zero ;
299+
300+ return state ;
301+ }
302+ finally
303+ {
304+ if ( memory != IntPtr . Zero )
305+ Marshal . FreeHGlobal ( memory ) ;
306+ }
307+ }
308+
221309 /// <summary>
222310 /// Load the state from specified path.
223311 /// </summary>
224312 /// <param name="filename"></param>
225- /// <exception cref="RuntimeError"></exception>
226313 public void LoadState ( string filename )
227314 {
228315 // Map state file into memory and pass that pointer directly to `llama_set_state_data` to load from
@@ -233,8 +320,41 @@ public void LoadState(string filename)
233320 {
234321 byte * ptr = null ;
235322 view . SafeMemoryMappedViewHandle . AcquirePointer ( ref ptr ) ;
236- NativeApi . llama_state_set_data ( NativeHandle , ptr ) ;
237- view . SafeMemoryMappedViewHandle . ReleasePointer ( ) ;
323+ try
324+ {
325+ NativeHandle . SetState ( ptr ) ;
326+ }
327+ finally
328+ {
329+ view . SafeMemoryMappedViewHandle . ReleasePointer ( ) ;
330+ }
331+ }
332+ }
333+ }
334+
335+ /// <summary>
336+ /// Load the state from specified path into a particular sequence
337+ /// </summary>
338+ /// <param name="filename"></param>
339+ /// <param name="sequence"></param>
340+ public void LoadState ( string filename , LLamaSeqId sequence )
341+ {
342+ // Map state file into memory and pass that pointer directly to `llama_set_state_data` to load from
343+ using ( var file = MemoryMappedFile . CreateFromFile ( filename , FileMode . Open , null ) )
344+ using ( var view = file . CreateViewAccessor ( ) )
345+ {
346+ unsafe
347+ {
348+ byte * ptr = null ;
349+ view . SafeMemoryMappedViewHandle . AcquirePointer ( ref ptr ) ;
350+ try
351+ {
352+ NativeHandle . SetState ( ptr , sequence ) ;
353+ }
354+ finally
355+ {
356+ view . SafeMemoryMappedViewHandle . ReleasePointer ( ) ;
357+ }
238358 }
239359 }
240360 }
@@ -248,10 +368,25 @@ public void LoadState(State state)
248368 {
249369 unsafe
250370 {
251- NativeHandle . SetState ( ( byte * ) state . DangerousGetHandle ( ) . ToPointer ( ) ) ;
371+ NativeHandle . SetState ( ( byte * ) state . DangerousGetHandle ( ) ) ;
252372 }
253373 }
254374
375+ /// <summary>
376+ /// Load the state from memory into a particular sequence
377+ /// </summary>
378+ /// <param name="state"></param>
379+ /// <param name="sequence"></param>
380+ /// <exception cref="RuntimeError"></exception>
381+ public void LoadState ( SequenceState state , LLamaSeqId sequence )
382+ {
383+ unsafe
384+ {
385+ NativeHandle . SetState ( ( byte * ) state . DangerousGetHandle ( ) , sequence ) ;
386+ }
387+ }
388+ #endregion
389+
255390 /// <summary>
256391 /// Sample a single token from this context, using the given sampling pipeline
257392 /// </summary>
@@ -417,12 +552,16 @@ public void Dispose()
417552 }
418553
419554 /// <summary>
420- /// The state of this model , which can be reloaded later
555+ /// The state of this context , which can be reloaded later
421556 /// </summary>
422557 public class State
423558 : SafeLLamaHandleBase
424559 {
425- private ulong _size ;
560+ private readonly ulong _size ;
561+ /// <summary>
562+ /// Get the size in bytes of this state object
563+ /// </summary>
564+ public ulong Size => _size ;
426565
427566 internal State ( IntPtr memory , ulong size )
428567 : base ( memory , true )
@@ -441,6 +580,7 @@ protected override bool ReleaseHandle()
441580 /// Convert this state to a byte array
442581 /// </summary>
443582 /// <returns></returns>
583+ [ Obsolete ( "It is not generally safe to convert a state into a byte array - it will fail if the state is very large" ) ]
444584 public byte [ ] ToByteArray ( )
445585 {
446586 var bytes = new byte [ _size ] ;
@@ -453,12 +593,57 @@ public byte[] ToByteArray()
453593 /// </summary>
454594 /// <param name="bytes"></param>
455595 /// <returns></returns>
596+ [ Obsolete ( "It is not generally safe to convert a state into a byte array - it will fail if the state is very large" ) ]
456597 public static State FromByteArray ( byte [ ] bytes )
457598 {
458599 var memory = Marshal . AllocHGlobal ( bytes . Length ) ;
459600 Marshal . Copy ( bytes , 0 , memory , bytes . Length ) ;
460601 return new State ( memory , ( ulong ) bytes . Length ) ;
461602 }
462603 }
604+
605+ /// <summary>
606+ /// The state of a single sequence, which can be reloaded later
607+ /// </summary>
608+ public class SequenceState
609+ : SafeLLamaHandleBase
610+ {
611+ private readonly ulong _size ;
612+ /// <summary>
613+ /// Get the size in bytes of this state object
614+ /// </summary>
615+ public ulong Size => _size ;
616+
617+ internal SequenceState ( IntPtr memory , ulong size )
618+ : base ( memory , true )
619+ {
620+ _size = size ;
621+ }
622+
623+ /// <inheritdoc />
624+ protected override bool ReleaseHandle ( )
625+ {
626+ Marshal . FreeHGlobal ( handle ) ;
627+ return true ;
628+ }
629+
630+ /// <summary>
631+ /// Copy bytes to a desintation pointer.
632+ /// </summary>
633+ /// <param name="dst">Destination to write to</param>
634+ /// <param name="length">Length of the destination buffer</param>
635+ /// <param name="offset">Offset from start of src to start copying from</param>
636+ /// <returns>Number of bytes written to destination</returns>
637+ public unsafe ulong CopyTo ( byte * dst , ulong length , ulong offset = 0 )
638+ {
639+ var copy = Math . Min ( length , _size - offset ) ;
640+
641+ var src = ( byte * ) DangerousGetHandle ( ) ;
642+ src += offset ;
643+
644+ Buffer . MemoryCopy ( src , dst , length , copy ) ;
645+ return copy ;
646+ }
647+ }
463648 }
464649}
0 commit comments