@@ -6,16 +6,34 @@ enum LlamaError: Error {
66 case couldNotInitializeContext
77}
88
9+ func llama_batch_clear( _ batch: inout llama_batch ) {
10+ batch. n_tokens = 0
11+ }
12+
13+ func llama_batch_add( _ batch: inout llama_batch , _ id: llama_token , _ pos: llama_pos , _ seq_ids: [ llama_seq_id ] , _ logits: Bool ) {
14+ batch. token [ Int ( batch. n_tokens) ] = id
15+ batch. pos [ Int ( batch. n_tokens) ] = pos
16+ batch. n_seq_id [ Int ( batch. n_tokens) ] = Int32 ( seq_ids. count)
17+ for i in 0 ..< seq_ids. count {
18+ batch. seq_id [ Int ( batch. n_tokens) ] ![ Int ( i) ] = seq_ids [ i]
19+ }
20+ batch. logits [ Int ( batch. n_tokens) ] = logits ? 1 : 0
21+
22+ batch. n_tokens += 1
23+ }
24+
925actor LlamaContext {
1026 private var model : OpaquePointer
1127 private var context : OpaquePointer
1228 private var batch : llama_batch
1329 private var tokens_list : [ llama_token ]
30+
1431 /// This variable is used to store temporarily invalid cchars
1532 private var temporary_invalid_cchars : [ CChar ]
1633
17- var n_len : Int32 = 512
34+ var n_len : Int32 = 64
1835 var n_cur : Int32 = 0
36+
1937 var n_decode : Int32 = 0
2038
2139 init ( model: OpaquePointer , context: OpaquePointer ) {
@@ -27,25 +45,34 @@ actor LlamaContext {
2745 }
2846
2947 deinit {
48+ llama_batch_free ( batch)
3049 llama_free ( context)
3150 llama_free_model ( model)
3251 llama_backend_free ( )
3352 }
3453
35- static func createContext ( path: String ) throws -> LlamaContext {
54+ static func create_context ( path: String ) throws -> LlamaContext {
3655 llama_backend_init ( false )
37- let model_params = llama_model_default_params ( )
56+ var model_params = llama_model_default_params ( )
3857
58+ #if targetEnvironment(simulator)
59+ model_params. n_gpu_layers = 0
60+ print ( " Running on simulator, force use n_gpu_layers = 0 " )
61+ #endif
3962 let model = llama_load_model_from_file ( path, model_params)
4063 guard let model else {
4164 print ( " Could not load model at \( path) " )
4265 throw LlamaError . couldNotInitializeContext
4366 }
67+
68+ let n_threads = max ( 1 , min ( 8 , ProcessInfo . processInfo. processorCount - 2 ) )
69+ print ( " Using \( n_threads) threads " )
70+
4471 var ctx_params = llama_context_default_params ( )
45- ctx_params. seed = 1234
72+ ctx_params. seed = 1234
4673 ctx_params. n_ctx = 2048
47- ctx_params. n_threads = 8
48- ctx_params. n_threads_batch = 8
74+ ctx_params. n_threads = UInt32 ( n_threads )
75+ ctx_params. n_threads_batch = UInt32 ( n_threads )
4976
5077 let context = llama_new_context_with_model ( model, ctx_params)
5178 guard let context else {
@@ -56,6 +83,26 @@ actor LlamaContext {
5683 return LlamaContext ( model: model, context: context)
5784 }
5885
86+ func model_info( ) -> String {
87+ let result = UnsafeMutablePointer< Int8> . allocate( capacity: 256 )
88+ result. initialize ( repeating: Int8 ( 0 ) , count: 256 )
89+ defer {
90+ result. deallocate ( )
91+ }
92+
93+ // TODO: this is probably very stupid way to get the string from C
94+
95+ let nChars = llama_model_desc ( model, result, 256 )
96+ let bufferPointer = UnsafeBufferPointer ( start: result, count: Int ( nChars) )
97+
98+ var SwiftString = " "
99+ for char in bufferPointer {
100+ SwiftString . append ( Character ( UnicodeScalar ( UInt8 ( char) ) ) )
101+ }
102+
103+ return SwiftString
104+ }
105+
59106 func get_n_tokens( ) -> Int32 {
60107 return batch. n_tokens;
61108 }
@@ -79,16 +126,11 @@ actor LlamaContext {
79126 print ( String ( cString: token_to_piece ( token: id) + [ 0 ] ) )
80127 }
81128
82- // batch = llama_batch_init(512, 0) // done in init()
83- batch. n_tokens = Int32 ( tokens_list. count)
129+ llama_batch_clear ( & batch)
84130
85- for i1 in 0 ..< batch . n_tokens {
131+ for i1 in 0 ..< tokens_list . count {
86132 let i = Int ( i1)
87- batch. token [ i] = tokens_list [ i]
88- batch. pos [ i] = i1
89- batch. n_seq_id [ Int ( i) ] = 1
90- batch. seq_id [ Int ( i) ] ![ 0 ] = 0
91- batch. logits [ i] = 0
133+ llama_batch_add ( & batch, tokens_list [ i] , Int32 ( i) , [ 0 ] , false )
92134 }
93135 batch. logits [ Int ( batch. n_tokens) - 1 ] = 1 // true
94136
@@ -141,18 +183,11 @@ actor LlamaContext {
141183 print ( new_token_str)
142184 // tokens_list.append(new_token_id)
143185
144- batch. n_tokens = 0
145-
146- batch. token [ Int ( batch. n_tokens) ] = new_token_id
147- batch. pos [ Int ( batch. n_tokens) ] = n_cur
148- batch. n_seq_id [ Int ( batch. n_tokens) ] = 1
149- batch. seq_id [ Int ( batch. n_tokens) ] ![ 0 ] = 0
150- batch. logits [ Int ( batch. n_tokens) ] = 1 // true
151- batch. n_tokens += 1
186+ llama_batch_clear ( & batch)
187+ llama_batch_add ( & batch, new_token_id, n_cur, [ 0 ] , true )
152188
153189 n_decode += 1
154-
155- n_cur += 1
190+ n_cur += 1
156191
157192 if llama_decode ( context, batch) != 0 {
158193 print ( " failed to evaluate llama! " )
@@ -161,14 +196,111 @@ actor LlamaContext {
161196 return new_token_str
162197 }
163198
199+ func bench( pp: Int , tg: Int , pl: Int , nr: Int = 1 ) -> String {
200+ var pp_avg : Double = 0
201+ var tg_avg : Double = 0
202+
203+ var pp_std : Double = 0
204+ var tg_std : Double = 0
205+
206+ for r in 0 ..< nr {
207+ // bench prompt processing
208+
209+ llama_batch_clear ( & batch)
210+
211+ let n_tokens = pp
212+
213+ for i in 0 ..< n_tokens {
214+ llama_batch_add ( & batch, 0 , Int32 ( i) , [ 0 ] , false )
215+ }
216+ batch. logits [ Int ( batch. n_tokens) - 1 ] = 1 // true
217+
218+ llama_kv_cache_clear ( context)
219+
220+ let t_pp_start = ggml_time_us ( )
221+
222+ if llama_decode ( context, batch) != 0 {
223+ print ( " llama_decode() failed during prompt " )
224+ }
225+
226+ let t_pp_end = ggml_time_us ( )
227+
228+ // bench text generation
229+
230+ llama_kv_cache_clear ( context)
231+
232+ let t_tg_start = ggml_time_us ( )
233+
234+ for i in 0 ..< tg {
235+ llama_batch_clear ( & batch)
236+
237+ for j in 0 ..< pl {
238+ llama_batch_add ( & batch, 0 , Int32 ( i) , [ Int32 ( j) ] , true )
239+ }
240+
241+ if llama_decode ( context, batch) != 0 {
242+ print ( " llama_decode() failed during text generation " )
243+ }
244+ }
245+
246+ let t_tg_end = ggml_time_us ( )
247+
248+ llama_kv_cache_clear ( context)
249+
250+ let t_pp = Double ( t_pp_end - t_pp_start) / 1000000.0
251+ let t_tg = Double ( t_tg_end - t_tg_start) / 1000000.0
252+
253+ let speed_pp = Double ( pp) / t_pp
254+ let speed_tg = Double ( pl*tg) / t_tg
255+
256+ pp_avg += speed_pp
257+ tg_avg += speed_tg
258+
259+ pp_std += speed_pp * speed_pp
260+ tg_std += speed_tg * speed_tg
261+
262+ print ( " pp \( speed_pp) t/s, tg \( speed_tg) t/s " )
263+ }
264+
265+ pp_avg /= Double ( nr)
266+ tg_avg /= Double ( nr)
267+
268+ if nr > 1 {
269+ pp_std = sqrt ( pp_std / Double( nr - 1 ) - pp_avg * pp_avg * Double( nr) / Double( nr - 1 ) )
270+ tg_std = sqrt ( tg_std / Double( nr - 1 ) - tg_avg * tg_avg * Double( nr) / Double( nr - 1 ) )
271+ } else {
272+ pp_std = 0
273+ tg_std = 0
274+ }
275+
276+ let model_desc = model_info ( ) ;
277+ let model_size = String ( format: " %.2f GiB " , Double ( llama_model_size ( model) ) / 1024.0 / 1024.0 / 1024.0 ) ;
278+ let model_n_params = String ( format: " %.2f B " , Double ( llama_model_n_params ( model) ) / 1e9 ) ;
279+ let backend = " Metal " ;
280+ let pp_avg_str = String ( format: " %.2f " , pp_avg) ;
281+ let tg_avg_str = String ( format: " %.2f " , tg_avg) ;
282+ let pp_std_str = String ( format: " %.2f " , pp_std) ;
283+ let tg_std_str = String ( format: " %.2f " , tg_std) ;
284+
285+ var result = " "
286+
287+ result += String ( " | model | size | params | backend | test | t/s | \n " )
288+ result += String ( " | --- | --- | --- | --- | --- | --- | \n " )
289+ result += String ( " | \( model_desc) | \( model_size) | \( model_n_params) | \( backend) | pp \( pp) | \( pp_avg_str) ± \( pp_std_str) | \n " )
290+ result += String ( " | \( model_desc) | \( model_size) | \( model_n_params) | \( backend) | tg \( tg) | \( tg_avg_str) ± \( tg_std_str) | \n " )
291+
292+ return result;
293+ }
294+
164295 func clear( ) {
165296 tokens_list. removeAll ( )
166297 temporary_invalid_cchars. removeAll ( )
298+ llama_kv_cache_clear ( context)
167299 }
168300
169301 private func tokenize( text: String , add_bos: Bool ) -> [ llama_token ] {
170302 let utf8Count = text. utf8. count
171- let n_tokens = utf8Count + ( add_bos ? 1 : 0 )
303+ let n_tokens = utf8Count + ( add_bos ? 1 : 0 ) + 1
172304 let tokens = UnsafeMutablePointer< llama_token> . allocate( capacity: n_tokens)
173305 let tokenCount = llama_tokenize ( model, text, Int32 ( utf8Count) , tokens, Int32 ( n_tokens) , add_bos, false )
174306
0 commit comments