Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions DraftRetriever/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// The code for retrival is adapted from https://github.com/Intsights/PySubstringSearch;
// The code for retrival is adapted from https://github.com/Intsights/PySubstringSearch;
// The code for drafft buffer is adapted from https://github.com/FasterDecoding/Medusa/blob/main/medusa/model/utils.py#L31-L124
use ahash::AHashSet;
use byteorder::{ReadBytesExt, WriteBytesExt, ByteOrder, LittleEndian};
Expand Down Expand Up @@ -67,6 +67,7 @@ impl Writer {
let index_file = File::create(index_file_path)?;
let index_file = BufWriter::new(index_file);

// max_chunk_len can be whatever we want it to be, but the draftretriever Reader() seems to work fastest when we choose something large (i.e. 2e27)
let max_chunk_len = max_chunk_len.unwrap_or(512 * 1024 * 1024);
let vocab_size = vocab_size.unwrap_or(35000);

Expand Down Expand Up @@ -111,10 +112,13 @@ impl Writer {
return Ok(());
}

self.index_file.write_u32::<LittleEndian>((self.buffer.len() * 2) as u32)?;
self.index_file.write_u32::<LittleEndian>((self.buffer.len() * 4) as u32)?; // self.buffer.len() is the length of the buffer (in # of integers). This is variable because sometimes we dump_data() early, its not always self.buffer.capacity().
// * 4 because this value will actually tell us how much space is needed for this buffer in file, and we store each as 4 bytes

// For larger vocabularies (ie > 65,535), we should write the integers as i32 instead of u16
// Keeping i32 instead of u32 so negative values can be used as pad tokens (i.e. pad_path(path, max_length, -2))
for &item in &self.buffer {
self.index_file.write_u16::<LittleEndian>(item as u16)?;
self.index_file.write_i32::<LittleEndian>(item as i32)?;
}

let suffix_array = construct_suffix_array(&self.buffer, self.vocab_size);
Expand Down Expand Up @@ -188,8 +192,9 @@ impl Reader {

let mut data: Vec<i32> = Vec::new();

for i in (0..data_u8.len()).step_by(2) {
let int = LittleEndian::read_u16(&data_u8[i..i+2]) as i32;
// Step by 4 to read in each 4-byte int (i32) from index file
for i in (0..data_u8.len()).step_by(4) {
let int = LittleEndian::read_i32(&data_u8[i..i+4]) as i32;
data.push(int);
}

Expand Down Expand Up @@ -259,7 +264,7 @@ impl Reader {
if start_of_indices.is_none() {
return;
}

// this binary search finds the end of the matching suffixes
let mut right_anchor = sub_index.suffixes_file_end - 4;
while left_anchor <= right_anchor {
Expand Down Expand Up @@ -300,7 +305,7 @@ impl Reader {
let data_index = LittleEndian::read_i32(suffix);
if matches_ranges.insert(data_index) {
let sub_string_plus = &sub_index.data[data_index as usize + substring_i32.len() ..std::cmp::min(data_index as usize + substring_i32.len() + long as usize, sub_index.data.len())];

local_results.push(sub_string_plus.to_vec());
cnt += 1;
if cnt >= k as usize {
Expand Down Expand Up @@ -328,7 +333,7 @@ impl Reader {
*counter += 1;
}
}

let choices = choices.unwrap_or(64);
// The items in the heap must be a Trie.
let mut heap = BinaryHeap::new();
Expand All @@ -348,7 +353,7 @@ impl Reader {
let verified: Vec<_> = verified.into_iter().collect();

// Because multiple nodes in the Trie may have same weights around the threshold, the number of draft tokens may exceed choices
// We roughly cut nodes to be less than choices in most cases.
// We roughly cut nodes to be less than choices in most cases.
let paths = cut_to_choices(verified, choices);

let (draft_choices, max_branch) = get_draft_choices(paths.clone());
Expand Down Expand Up @@ -562,4 +567,3 @@ fn draftretriever(

Ok(())
}

Loading