Skip to content

Commit ff9d5e4

Browse files
authored
Support DraftRetriever datastore read/write for large vocab sizes (i.e. llama3+) and REST inference for llama3 (#24)
* support large vocab sizes (i.e. llama3) for DraftRetriever datastore * updates comments to explain implementation changes * modify modeling_llama_kv for llama3 compatibility
1 parent 1cd2772 commit ff9d5e4

File tree

2 files changed

+441
-567
lines changed

2 files changed

+441
-567
lines changed

DraftRetriever/src/lib.rs

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// The code for retrival is adapted from https://github.com/Intsights/PySubstringSearch;
1+
// The code for retrival is adapted from https://github.com/Intsights/PySubstringSearch;
22
// The code for drafft buffer is adapted from https://github.com/FasterDecoding/Medusa/blob/main/medusa/model/utils.py#L31-L124
33
use ahash::AHashSet;
44
use byteorder::{ReadBytesExt, WriteBytesExt, ByteOrder, LittleEndian};
@@ -67,6 +67,7 @@ impl Writer {
6767
let index_file = File::create(index_file_path)?;
6868
let index_file = BufWriter::new(index_file);
6969

70+
// max_chunk_len can be whatever we want it to be, but the draftretriever Reader() seems to work fastest when we choose something large (i.e. 2e27)
7071
let max_chunk_len = max_chunk_len.unwrap_or(512 * 1024 * 1024);
7172
let vocab_size = vocab_size.unwrap_or(35000);
7273

@@ -111,10 +112,13 @@ impl Writer {
111112
return Ok(());
112113
}
113114

114-
self.index_file.write_u32::<LittleEndian>((self.buffer.len() * 2) as u32)?;
115+
self.index_file.write_u32::<LittleEndian>((self.buffer.len() * 4) as u32)?; // self.buffer.len() is the length of the buffer (in # of integers). This is variable because sometimes we dump_data() early, its not always self.buffer.capacity().
116+
// * 4 because this value will actually tell us how much space is needed for this buffer in file, and we store each as 4 bytes
115117

118+
// For larger vocabularies (ie > 65,535), we should write the integers as i32 instead of u16
119+
// Keeping i32 instead of u32 so negative values can be used as pad tokens (i.e. pad_path(path, max_length, -2))
116120
for &item in &self.buffer {
117-
self.index_file.write_u16::<LittleEndian>(item as u16)?;
121+
self.index_file.write_i32::<LittleEndian>(item as i32)?;
118122
}
119123

120124
let suffix_array = construct_suffix_array(&self.buffer, self.vocab_size);
@@ -188,8 +192,9 @@ impl Reader {
188192

189193
let mut data: Vec<i32> = Vec::new();
190194

191-
for i in (0..data_u8.len()).step_by(2) {
192-
let int = LittleEndian::read_u16(&data_u8[i..i+2]) as i32;
195+
// Step by 4 to read in each 4-byte int (i32) from index file
196+
for i in (0..data_u8.len()).step_by(4) {
197+
let int = LittleEndian::read_i32(&data_u8[i..i+4]) as i32;
193198
data.push(int);
194199
}
195200

@@ -259,7 +264,7 @@ impl Reader {
259264
if start_of_indices.is_none() {
260265
return;
261266
}
262-
267+
263268
// this binary search finds the end of the matching suffixes
264269
let mut right_anchor = sub_index.suffixes_file_end - 4;
265270
while left_anchor <= right_anchor {
@@ -300,7 +305,7 @@ impl Reader {
300305
let data_index = LittleEndian::read_i32(suffix);
301306
if matches_ranges.insert(data_index) {
302307
let sub_string_plus = &sub_index.data[data_index as usize + substring_i32.len() ..std::cmp::min(data_index as usize + substring_i32.len() + long as usize, sub_index.data.len())];
303-
308+
304309
local_results.push(sub_string_plus.to_vec());
305310
cnt += 1;
306311
if cnt >= k as usize {
@@ -328,7 +333,7 @@ impl Reader {
328333
*counter += 1;
329334
}
330335
}
331-
336+
332337
let choices = choices.unwrap_or(64);
333338
// The items in the heap must be a Trie.
334339
let mut heap = BinaryHeap::new();
@@ -348,7 +353,7 @@ impl Reader {
348353
let verified: Vec<_> = verified.into_iter().collect();
349354

350355
// Because multiple nodes in the Trie may have same weights around the threshold, the number of draft tokens may exceed choices
351-
// We roughly cut nodes to be less than choices in most cases.
356+
// We roughly cut nodes to be less than choices in most cases.
352357
let paths = cut_to_choices(verified, choices);
353358

354359
let (draft_choices, max_branch) = get_draft_choices(paths.clone());
@@ -562,4 +567,3 @@ fn draftretriever(
562567

563568
Ok(())
564569
}
565-

0 commit comments

Comments
 (0)