@@ -956,15 +956,20 @@ impl<'a, 'b> Pattern<'a> for &'b str {
956956
957957 match self . len ( ) . cmp ( & haystack. len ( ) ) {
958958 Ordering :: Less => {
959+ if self . len ( ) == 1 {
960+ return haystack. as_bytes ( ) . contains ( & self . as_bytes ( ) [ 0 ] ) ;
961+ }
962+
959963 #[ cfg( all( target_arch = "x86_64" , target_feature = "sse2" ) ) ]
960- if self . as_bytes ( ) . len ( ) <= 8 {
961- return simd_contains ( self , haystack) ;
964+ if self . len ( ) <= 32 {
965+ if let Some ( result) = simd_contains ( self , haystack) {
966+ return result;
967+ }
962968 }
963969
964970 self . into_searcher ( haystack) . next_match ( ) . is_some ( )
965971 }
966- Ordering :: Equal => self == haystack,
967- Ordering :: Greater => false ,
972+ _ => self == haystack,
968973 }
969974 }
970975
@@ -1707,82 +1712,207 @@ impl TwoWayStrategy for RejectAndMatch {
17071712 }
17081713}
17091714
1715+ /// SIMD search for short needles based on
1716+ /// Wojciech Muła's "SIMD-friendly algorithms for substring searching"[0]
1717+ ///
1718+ /// It skips ahead by the vector width on each iteration (rather than the needle length as two-way
1719+ /// does) by probing the first and last byte of the needle for the whole vector width
1720+ /// and only doing full needle comparisons when the vectorized probe indicated potential matches.
1721+ ///
1722+ /// Since the x86_64 baseline only offers SSE2 we only use u8x16 here.
1723+ /// If we ever ship std with for x86-64-v3 or adapt this for other platforms then wider vectors
1724+ /// should be evaluated.
1725+ ///
1726+ /// For haystacks smaller than vector-size + needle length it falls back to
1727+ /// a naive O(n*m) search so this implementation should not be called on larger needles.
1728+ ///
1729+ /// [0]: http://0x80.pl/articles/simd-strfind.html#sse-avx2
17101730#[ cfg( all( target_arch = "x86_64" , target_feature = "sse2" ) ) ]
17111731#[ inline]
1712- fn simd_contains ( needle : & str , haystack : & str ) -> bool {
1732+ fn simd_contains ( needle : & str , haystack : & str ) -> Option < bool > {
17131733 let needle = needle. as_bytes ( ) ;
17141734 let haystack = haystack. as_bytes ( ) ;
17151735
1716- if needle. len ( ) == 1 {
1717- return haystack. contains ( & needle[ 0 ] ) ;
1718- }
1719-
1720- const CHUNK : usize = 16 ;
1736+ debug_assert ! ( needle. len( ) > 1 ) ;
1737+
1738+ use crate :: ops:: BitAnd ;
1739+ use crate :: simd:: mask8x16 as Mask ;
1740+ use crate :: simd:: u8x16 as Block ;
1741+ use crate :: simd:: { SimdPartialEq , ToBitMask } ;
1742+
1743+ let first_probe = needle[ 0 ] ;
1744+
1745+ // the offset used for the 2nd vector
1746+ let second_probe_offset = if needle. len ( ) == 2 {
1747+ // never bail out on len=2 needles because the probes will fully cover them and have
1748+ // no degenerate cases.
1749+ 1
1750+ } else {
1751+ // try a few bytes in case first and last byte of the needle are the same
1752+ let Some ( second_probe_offset) = ( needle. len ( ) . saturating_sub ( 4 ) ..needle. len ( ) ) . rfind ( |& idx| needle[ idx] != first_probe) else {
1753+ // fall back to other search methods if we can't find any different bytes
1754+ // since we could otherwise hit some degenerate cases
1755+ return None ;
1756+ } ;
1757+ second_probe_offset
1758+ } ;
17211759
1722- // do a naive search if if the haystack is too small to fit
1723- if haystack. len ( ) < CHUNK + needle . len ( ) - 1 {
1724- return haystack. windows ( needle. len ( ) ) . any ( |c| c == needle) ;
1760+ // do a naive search if the haystack is too small to fit
1761+ if haystack. len ( ) < Block :: LANES + second_probe_offset {
1762+ return Some ( haystack. windows ( needle. len ( ) ) . any ( |c| c == needle) ) ;
17251763 }
17261764
1727- use crate :: arch:: x86_64:: {
1728- __m128i, _mm_and_si128, _mm_cmpeq_epi8, _mm_loadu_si128, _mm_movemask_epi8, _mm_set1_epi8,
1729- } ;
1730-
1731- // SAFETY: no preconditions other than sse2 being available
1732- let first: __m128i = unsafe { _mm_set1_epi8 ( needle[ 0 ] as i8 ) } ;
1733- // SAFETY: no preconditions other than sse2 being available
1734- let last: __m128i = unsafe { _mm_set1_epi8 ( * needle. last ( ) . unwrap ( ) as i8 ) } ;
1765+ let first_probe: Block = Block :: splat ( first_probe) ;
1766+ let second_probe: Block = Block :: splat ( needle[ second_probe_offset] ) ;
1767+ // first byte are already checked by the outer loop. to verify a match only the
1768+ // remainder has to be compared.
1769+ let trimmed_needle = & needle[ 1 ..] ;
17351770
1771+ // this #[cold] is load-bearing, benchmark before removing it...
17361772 let check_mask = #[ cold]
1737- |idx, mut mask: u32| -> bool {
1773+ |idx, mask: u16, skip: bool| -> bool {
1774+ if skip {
1775+ return false ;
1776+ }
1777+
1778+ // and so is this. optimizations are weird.
1779+ let mut mask = mask;
1780+
17381781 while mask != 0 {
17391782 let trailing = mask. trailing_zeros ( ) ;
17401783 let offset = idx + trailing as usize + 1 ;
1741- let sub = & haystack[ offset..] [ ..needle. len ( ) - 2 ] ;
1742- let trimmed_needle = & needle[ 1 ..needle. len ( ) - 1 ] ;
1743-
1744- if sub == trimmed_needle {
1745- return true ;
1784+ // SAFETY: mask is between 0 and 15 trailing zeroes, we skip one additional byte that was already compared
1785+ // and then take trimmed_needle.len() bytes. This is within the bounds defined by the outer loop
1786+ unsafe {
1787+ let sub = haystack. get_unchecked ( offset..) . get_unchecked ( ..trimmed_needle. len ( ) ) ;
1788+ if small_slice_eq ( sub, trimmed_needle) {
1789+ return true ;
1790+ }
17461791 }
17471792 mask &= !( 1 << trailing) ;
17481793 }
17491794 return false ;
17501795 } ;
17511796
1752- let test_chunk = |i | -> bool {
1753- // SAFETY: this requires at least CHUNK bytes being readable at offset i
1797+ let test_chunk = |idx | -> u16 {
1798+ // SAFETY: this requires at least LANES bytes being readable at idx
17541799 // that is ensured by the loop ranges (see comments below)
1755- let a: __m128i = unsafe { _mm_loadu_si128 ( haystack. as_ptr ( ) . add ( i) as * const _ ) } ;
1756- let b: __m128i =
1757- // SAFETY: this requires CHUNK + needle.len() - 1 bytes being readable at offset i
1758- unsafe { _mm_loadu_si128 ( haystack. as_ptr ( ) . add ( i + needle. len ( ) - 1 ) as * const _ ) } ;
1759-
1760- // SAFETY: no preconditions other than sse2 being available
1761- let eq_first: __m128i = unsafe { _mm_cmpeq_epi8 ( first, a) } ;
1762- // SAFETY: no preconditions other than sse2 being available
1763- let eq_last: __m128i = unsafe { _mm_cmpeq_epi8 ( last, b) } ;
1764-
1765- // SAFETY: no preconditions other than sse2 being available
1766- let mask: u32 = unsafe { _mm_movemask_epi8 ( _mm_and_si128 ( eq_first, eq_last) ) } as u32 ;
1800+ let a: Block = unsafe { haystack. as_ptr ( ) . add ( idx) . cast :: < Block > ( ) . read_unaligned ( ) } ;
1801+ // SAFETY: this requires LANES + block_offset bytes being readable at idx
1802+ let b: Block = unsafe {
1803+ haystack. as_ptr ( ) . add ( idx) . add ( second_probe_offset) . cast :: < Block > ( ) . read_unaligned ( )
1804+ } ;
1805+ let eq_first: Mask = a. simd_eq ( first_probe) ;
1806+ let eq_last: Mask = b. simd_eq ( second_probe) ;
1807+ let both = eq_first. bitand ( eq_last) ;
1808+ let mask = both. to_bitmask ( ) ;
17671809
1768- if mask != 0 {
1769- return check_mask ( i, mask) ;
1770- }
1771- return false ;
1810+ return mask;
17721811 } ;
17731812
17741813 let mut i = 0 ;
17751814 let mut result = false ;
1776- while !result && i + CHUNK + needle. len ( ) <= haystack. len ( ) {
1777- result |= test_chunk ( i) ;
1778- i += CHUNK ;
1815+ // The loop condition must ensure that there's enough headroom to read LANE bytes,
1816+ // and not only at the current index but also at the index shifted by block_offset
1817+ const UNROLL : usize = 4 ;
1818+ while i + second_probe_offset + UNROLL * Block :: LANES < haystack. len ( ) && !result {
1819+ let mut masks = [ 0u16 ; UNROLL ] ;
1820+ for j in 0 ..UNROLL {
1821+ masks[ j] = test_chunk ( i + j * Block :: LANES ) ;
1822+ }
1823+ for j in 0 ..UNROLL {
1824+ let mask = masks[ j] ;
1825+ if mask != 0 {
1826+ result |= check_mask ( i + j * Block :: LANES , mask, result) ;
1827+ }
1828+ }
1829+ i += UNROLL * Block :: LANES ;
1830+ }
1831+ while i + second_probe_offset + Block :: LANES < haystack. len ( ) && !result {
1832+ let mask = test_chunk ( i) ;
1833+ if mask != 0 {
1834+ result |= check_mask ( i, mask, result) ;
1835+ }
1836+ i += Block :: LANES ;
17791837 }
17801838
1781- // process the tail that didn't fit into CHUNK -sized steps
1782- // this simply repeats the same procedure but as right-aligned chunk instead
1839+ // Process the tail that didn't fit into LANES -sized steps.
1840+ // This simply repeats the same procedure but as right-aligned chunk instead
17831841 // of a left-aligned one. The last byte must be exactly flush with the string end so
17841842 // we don't miss a single byte or read out of bounds.
1785- result |= test_chunk ( haystack. len ( ) + 1 - needle. len ( ) - CHUNK ) ;
1843+ let i = haystack. len ( ) - second_probe_offset - Block :: LANES ;
1844+ let mask = test_chunk ( i) ;
1845+ if mask != 0 {
1846+ result |= check_mask ( i, mask, result) ;
1847+ }
1848+
1849+ Some ( result)
1850+ }
1851+
1852+ /// Compares short slices for equality.
1853+ ///
1854+ /// It avoids a call to libc's memcmp which is faster on long slices
1855+ /// due to SIMD optimizations but it incurs a function call overhead.
1856+ ///
1857+ /// # Safety
1858+ ///
1859+ /// Both slices must have the same length.
1860+ #[ cfg( all( target_arch = "x86_64" , target_feature = "sse2" ) ) ] // only called on x86
1861+ #[ inline]
1862+ unsafe fn small_slice_eq ( x : & [ u8 ] , y : & [ u8 ] ) -> bool {
1863+ // This function is adapted from
1864+ // https://github.com/BurntSushi/memchr/blob/8037d11b4357b0f07be2bb66dc2659d9cf28ad32/src/memmem/util.rs#L32
17861865
1787- return result;
1866+ // If we don't have enough bytes to do 4-byte at a time loads, then
1867+ // fall back to the naive slow version.
1868+ //
1869+ // Potential alternative: We could do a copy_nonoverlapping combined with a mask instead
1870+ // of a loop. Benchmark it.
1871+ if x. len ( ) < 4 {
1872+ for ( & b1, & b2) in x. iter ( ) . zip ( y) {
1873+ if b1 != b2 {
1874+ return false ;
1875+ }
1876+ }
1877+ return true ;
1878+ }
1879+ // When we have 4 or more bytes to compare, then proceed in chunks of 4 at
1880+ // a time using unaligned loads.
1881+ //
1882+ // Also, why do 4 byte loads instead of, say, 8 byte loads? The reason is
1883+ // that this particular version of memcmp is likely to be called with tiny
1884+ // needles. That means that if we do 8 byte loads, then a higher proportion
1885+ // of memcmp calls will use the slower variant above. With that said, this
1886+ // is a hypothesis and is only loosely supported by benchmarks. There's
1887+ // likely some improvement that could be made here. The main thing here
1888+ // though is to optimize for latency, not throughput.
1889+
1890+ // SAFETY: Via the conditional above, we know that both `px` and `py`
1891+ // have the same length, so `px < pxend` implies that `py < pyend`.
1892+ // Thus, derefencing both `px` and `py` in the loop below is safe.
1893+ //
1894+ // Moreover, we set `pxend` and `pyend` to be 4 bytes before the actual
1895+ // end of of `px` and `py`. Thus, the final dereference outside of the
1896+ // loop is guaranteed to be valid. (The final comparison will overlap with
1897+ // the last comparison done in the loop for lengths that aren't multiples
1898+ // of four.)
1899+ //
1900+ // Finally, we needn't worry about alignment here, since we do unaligned
1901+ // loads.
1902+ unsafe {
1903+ let ( mut px, mut py) = ( x. as_ptr ( ) , y. as_ptr ( ) ) ;
1904+ let ( pxend, pyend) = ( px. add ( x. len ( ) - 4 ) , py. add ( y. len ( ) - 4 ) ) ;
1905+ while px < pxend {
1906+ let vx = ( px as * const u32 ) . read_unaligned ( ) ;
1907+ let vy = ( py as * const u32 ) . read_unaligned ( ) ;
1908+ if vx != vy {
1909+ return false ;
1910+ }
1911+ px = px. add ( 4 ) ;
1912+ py = py. add ( 4 ) ;
1913+ }
1914+ let vx = ( pxend as * const u32 ) . read_unaligned ( ) ;
1915+ let vy = ( pyend as * const u32 ) . read_unaligned ( ) ;
1916+ vx == vy
1917+ }
17881918}
0 commit comments