-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Improve performance of set_bits by avoiding to set individual bits #6288
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6e8c864
a81ba56
57634f2
32ff203
f94f312
e9cd77a
842c2b1
06de184
03b0db8
7faa5f3
e3d812d
13dec63
68cdaf2
1e9de38
f1e1bbd
f294663
39719c4
9fbb87d
74b9d80
25c309e
7905330
6dd9771
0e956cc
272ecbb
08ebf20
d751a7f
ef2864f
e69cf9a
b5f8bca
9c15417
533381a
dca9ab8
7f3c3fb
6ccedd2
ff2f3ca
fb46cb0
3fd5e3e
be3076e
58868c1
a15db14
fefafa7
d8c3f08
cc5ec2b
4c39dc8
f4789be
59fd805
7d81076
f185a19
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,48 +17,144 @@ | |
|
|
||
| //! Utils for working with packed bit masks | ||
|
|
||
| use crate::bit_chunk_iterator::BitChunks; | ||
| use crate::bit_util::{ceil, get_bit, set_bit}; | ||
| use crate::bit_util::ceil; | ||
|
|
||
| /// Sets all bits on `write_data` in the range `[offset_write..offset_write+len]` to be equal to the | ||
| /// bits in `data` in the range `[offset_read..offset_read+len]` | ||
| /// returns the number of `0` bits `data[offset_read..offset_read+len]` | ||
| /// `offset_write`, `offset_read`, and `len` are in terms of bits | ||
| pub fn set_bits( | ||
| write_data: &mut [u8], | ||
| data: &[u8], | ||
| offset_write: usize, | ||
| offset_read: usize, | ||
| len: usize, | ||
| ) -> usize { | ||
| assert!(offset_write + len <= write_data.len() * 8); | ||
| assert!(offset_read + len <= data.len() * 8); | ||
| let mut null_count = 0; | ||
|
|
||
| let mut bits_to_align = offset_write % 8; | ||
| if bits_to_align > 0 { | ||
| bits_to_align = std::cmp::min(len, 8 - bits_to_align); | ||
| let mut acc = 0; | ||
| while len > acc { | ||
| // SAFETY: the arguments to `set_upto_64bits` are within the valid range because | ||
| // (offset_write + acc) + (len - acc) == offset_write + len <= write_data.len() * 8 | ||
| // (offset_read + acc) + (len - acc) == offset_read + len <= data.len() * 8 | ||
| let (n, len_set) = unsafe { | ||
| set_upto_64bits( | ||
| write_data, | ||
| data, | ||
| offset_write + acc, | ||
| offset_read + acc, | ||
| len - acc, | ||
| ) | ||
| }; | ||
| null_count += n; | ||
| acc += len_set; | ||
| } | ||
| let mut write_byte_index = ceil(offset_write + bits_to_align, 8); | ||
|
|
||
| // Set full bytes provided by bit chunk iterator (which iterates in 64 bits at a time) | ||
| let chunks = BitChunks::new(data, offset_read + bits_to_align, len - bits_to_align); | ||
| chunks.iter().for_each(|chunk| { | ||
| null_count += chunk.count_zeros(); | ||
| write_data[write_byte_index..write_byte_index + 8].copy_from_slice(&chunk.to_le_bytes()); | ||
| write_byte_index += 8; | ||
| }); | ||
|
|
||
| // Set individual bits both to align write_data to a byte offset and the remainder bits not covered by the bit chunk iterator | ||
| let remainder_offset = len - chunks.remainder_len(); | ||
| (0..bits_to_align) | ||
| .chain(remainder_offset..len) | ||
| .for_each(|i| { | ||
| if get_bit(data, offset_read + i) { | ||
| set_bit(write_data, offset_write + i); | ||
|
|
||
| null_count | ||
| } | ||
|
|
||
| /// Similar to `set_bits` but sets only upto 64 bits, actual number of bits set may vary. | ||
| /// Returns a pair of the number of `0` bits and the number of bits set | ||
| /// | ||
| /// # Safety | ||
kazuyukitanimura marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| /// The caller must ensure all arguments are within the valid range. | ||
| #[inline] | ||
| unsafe fn set_upto_64bits( | ||
| write_data: &mut [u8], | ||
| data: &[u8], | ||
| offset_write: usize, | ||
| offset_read: usize, | ||
| len: usize, | ||
kazuyukitanimura marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) -> (usize, usize) { | ||
| let read_byte = offset_read / 8; | ||
| let read_shift = offset_read % 8; | ||
| let write_byte = offset_write / 8; | ||
| let write_shift = offset_write % 8; | ||
|
|
||
| if len >= 64 { | ||
| let chunk = unsafe { (data.as_ptr().add(read_byte) as *const u64).read_unaligned() }; | ||
| if read_shift == 0 { | ||
| if write_shift == 0 { | ||
kazuyukitanimura marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| // no shifting necessary | ||
| let len = 64; | ||
| let null_count = chunk.count_zeros() as usize; | ||
| unsafe { write_u64_bytes(write_data, write_byte, chunk) }; | ||
| (null_count, len) | ||
| } else { | ||
| null_count += 1; | ||
| // only write shifting necessary | ||
| let len = 64 - write_shift; | ||
| let chunk = chunk << write_shift; | ||
| let null_count = len - chunk.count_ones() as usize; | ||
| unsafe { or_write_u64_bytes(write_data, write_byte, chunk) }; | ||
| (null_count, len) | ||
| } | ||
| }); | ||
| } else if write_shift == 0 { | ||
| // only read shifting necessary | ||
| let len = 64 - 8; // 56 bits so the next set_upto_64bits call will see write_shift == 0 | ||
| let chunk = (chunk >> read_shift) & 0x00FFFFFFFFFFFFFF; // 56 bits mask | ||
| let null_count = len - chunk.count_ones() as usize; | ||
| unsafe { write_u64_bytes(write_data, write_byte, chunk) }; | ||
| (null_count, len) | ||
| } else { | ||
| let len = 64 - std::cmp::max(read_shift, write_shift); | ||
| let chunk = (chunk >> read_shift) << write_shift; | ||
| let null_count = len - chunk.count_ones() as usize; | ||
| unsafe { or_write_u64_bytes(write_data, write_byte, chunk) }; | ||
| (null_count, len) | ||
| } | ||
| } else if len == 1 { | ||
| let byte_chunk = (unsafe { data.get_unchecked(read_byte) } >> read_shift) & 1; | ||
| unsafe { *write_data.get_unchecked_mut(write_byte) |= byte_chunk << write_shift }; | ||
| ((byte_chunk ^ 1) as usize, 1) | ||
| } else { | ||
| let len = std::cmp::min(len, 64 - std::cmp::max(read_shift, write_shift)); | ||
| let bytes = ceil(len + read_shift, 8); | ||
| // SAFETY: the args of `read_bytes_to_u64` are valid as read_byte + bytes <= data.len() | ||
| let chunk = unsafe { read_bytes_to_u64(data, read_byte, bytes) }; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add some
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated |
||
| let mask = u64::MAX >> (64 - len); | ||
| let chunk = (chunk >> read_shift) & mask; // masking to read `len` bits only | ||
| let chunk = chunk << write_shift; // shifting back to align with `write_data` | ||
| let null_count = len - chunk.count_ones() as usize; | ||
| let bytes = ceil(len + write_shift, 8); | ||
| for (i, c) in chunk.to_le_bytes().iter().enumerate().take(bytes) { | ||
| unsafe { *write_data.get_unchecked_mut(write_byte + i) |= c }; | ||
| } | ||
| (null_count, len) | ||
| } | ||
| } | ||
|
|
||
| null_count as usize | ||
| /// # Safety | ||
| /// The caller must ensure all arguments are within the valid range. | ||
| #[inline] | ||
| unsafe fn read_bytes_to_u64(data: &[u8], offset: usize, count: usize) -> u64 { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function doesn't limit reading bytes to be up 8 bytes. Do you want to add an assert?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @viirya updated |
||
| debug_assert!(count <= 8); | ||
| let mut tmp = std::mem::MaybeUninit::<u64>::new(0); | ||
| let src = data.as_ptr().add(offset); | ||
| unsafe { | ||
| std::ptr::copy_nonoverlapping(src, tmp.as_mut_ptr() as *mut u8, count); | ||
| tmp.assume_init() | ||
| } | ||
| } | ||
|
|
||
| /// # Safety | ||
| /// The caller must ensure `data` has `offset..(offset + 8)` range | ||
| #[inline] | ||
kazuyukitanimura marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| unsafe fn write_u64_bytes(data: &mut [u8], offset: usize, chunk: u64) { | ||
| let ptr = data.as_mut_ptr().add(offset) as *mut u64; | ||
| ptr.write_unaligned(chunk); | ||
| } | ||
|
|
||
| /// Similar to `write_u64_bytes`, but this method ORs the offset addressed `data` and `chunk` | ||
| /// instead of overwriting | ||
| /// | ||
| /// # Safety | ||
| /// The caller must ensure `data` has `offset..(offset + 8)` range | ||
| #[inline] | ||
kazuyukitanimura marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| unsafe fn or_write_u64_bytes(data: &mut [u8], offset: usize, chunk: u64) { | ||
| let ptr = data.as_mut_ptr().add(offset); | ||
| let chunk = chunk | (*ptr) as u64; | ||
| (ptr as *mut u64).write_unaligned(chunk); | ||
| } | ||
|
|
||
| #[cfg(test)] | ||
|
|
@@ -185,4 +281,40 @@ mod tests { | |
| assert_eq!(destination, expected_data); | ||
| assert_eq!(result, expected_null_count); | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_set_upto_64bits() { | ||
| // len >= 64 | ||
| let write_data: &mut [u8] = &mut [0; 9]; | ||
| let data: &[u8] = &[ | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please also add a test that is greater than 64 bits (not just = 64 bits)?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, I will add it. BTW there is an existing test https://github.com/apache/arrow-rs/blob/master/arrow-buffer/src/util/bit_mask.rs#L170
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am working on some more tests too. Stay tuned... |
||
| 0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001, | ||
| 0b00000001, 0b00000001, | ||
| ]; | ||
| let offset_write = 1; | ||
| let offset_read = 0; | ||
| let len = 65; | ||
| let (n, len_set) = | ||
| unsafe { set_upto_64bits(write_data, data, offset_write, offset_read, len) }; | ||
| assert_eq!(n, 55); | ||
| assert_eq!(len_set, 63); | ||
| assert_eq!( | ||
| write_data, | ||
| &[ | ||
| 0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010, | ||
| 0b00000010, 0b00000000 | ||
| ] | ||
| ); | ||
|
|
||
| // len = 1 | ||
| let write_data: &mut [u8] = &mut [0b00000000]; | ||
| let data: &[u8] = &[0b00000001]; | ||
| let offset_write = 1; | ||
| let offset_read = 0; | ||
| let len = 1; | ||
| let (n, len_set) = | ||
| unsafe { set_upto_64bits(write_data, data, offset_write, offset_read, len) }; | ||
| assert_eq!(n, 0); | ||
| assert_eq!(len_set, 1); | ||
| assert_eq!(write_data, &[0b00000010]); | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.