33// See the LICENSE file in the project root for more information.
44
55using System ;
6+ using System . Buffers ;
67using System . Collections . Generic ;
78
89namespace Microsoft . ML . Tokenizers
@@ -19,63 +20,82 @@ public static int[] BytePairEncode(ReadOnlyMemory<byte> mergingBytes, Dictionary
1920 return [ ranks [ mergingBytes ] ] ;
2021 }
2122
22- var byteIndicesAndRanks = new List < ( int Index , int Rank ) > ( ) ;
23- for ( int i = 0 ; i < mergingBytes . Length + 1 ; i ++ )
23+ ( int Index , int Rank ) [ ] ? arrayPoolArray = null ;
24+ int requiredLength = mergingBytes . Length + 1 ;
25+ Span < ( int Index , int Rank ) > byteIndicesAndRanks = requiredLength <= 64 ?
26+ stackalloc ( int , int ) [ 64 ] :
27+ ( arrayPoolArray = ArrayPool < ( int , int ) > . Shared . Rent ( requiredLength ) ) ;
28+ byteIndicesAndRanks = byteIndicesAndRanks . Slice ( 0 , requiredLength ) ;
29+
30+ for ( int i = 0 ; i < byteIndicesAndRanks . Length ; i ++ )
2431 {
25- byteIndicesAndRanks . Add ( ( i , int . MaxValue ) ) ;
32+ byteIndicesAndRanks [ i ] = ( i , int . MaxValue ) ;
2633 }
27- int GetRank ( int startIndex , int skip = 0 )
34+
35+ int GetRank ( Span < ( int Index , int Rank ) > byteIndicesAndRanks , int startIndex , int skip = 0 )
2836 {
29- if ( startIndex + skip + 2 < byteIndicesAndRanks . Count )
37+ if ( startIndex + skip + 2 < byteIndicesAndRanks . Length )
3038 {
3139 var slice = mergingBytes . SliceStartEnd ( byteIndicesAndRanks [ startIndex ] . Index , byteIndicesAndRanks [ startIndex + skip + 2 ] . Index ) ;
3240 if ( ranks . TryGetValue ( slice , out var rank ) )
3341 {
3442 return rank ;
3543 }
3644 }
45+
3746 return int . MaxValue ;
3847 }
39- for ( int i = 0 ; i < byteIndicesAndRanks . Count - 2 ; i ++ )
48+
49+ for ( int i = 0 ; i < byteIndicesAndRanks . Length - 2 ; i ++ )
4050 {
41- var rank = GetRank ( i ) ;
51+ int rank = GetRank ( byteIndicesAndRanks , i ) ;
4252 if ( rank != int . MaxValue )
4353 {
44- byteIndicesAndRanks [ i ] = ( byteIndicesAndRanks [ i ] . Index , rank ) ;
54+ byteIndicesAndRanks [ i ] . Rank = rank ;
4555 }
4656 }
47- while ( byteIndicesAndRanks . Count > 1 )
57+
58+ while ( byteIndicesAndRanks . Length > 1 )
4859 {
4960 var minRank = ( Index : 0 , Rank : int . MaxValue ) ;
50- for ( int i = 0 ; i < byteIndicesAndRanks . Count - 1 ; i ++ )
61+ for ( int i = 0 ; i < byteIndicesAndRanks . Length - 1 ; i ++ )
5162 {
5263 if ( byteIndicesAndRanks [ i ] . Rank < minRank . Rank )
5364 {
5465 minRank = ( i , byteIndicesAndRanks [ i ] . Rank ) ;
5566 }
5667 }
68+
5769 if ( minRank . Rank != int . MaxValue )
5870 {
5971 int j = minRank . Index ;
60- byteIndicesAndRanks [ j ] = ( byteIndicesAndRanks [ j ] . Index , GetRank ( j , 1 ) ) ;
72+ byteIndicesAndRanks [ j ] . Rank = GetRank ( byteIndicesAndRanks , j , 1 ) ;
6173 if ( j > 0 )
6274 {
63- byteIndicesAndRanks [ j - 1 ] = ( byteIndicesAndRanks [ j - 1 ] . Index , GetRank ( j - 1 , 1 ) ) ;
75+ byteIndicesAndRanks [ j - 1 ] . Rank = GetRank ( byteIndicesAndRanks , j - 1 , 1 ) ;
6476 }
65- byteIndicesAndRanks . RemoveAt ( j + 1 ) ;
77+
78+ byteIndicesAndRanks . Slice ( j + 2 ) . CopyTo ( byteIndicesAndRanks . Slice ( j + 1 ) ) ;
79+ byteIndicesAndRanks = byteIndicesAndRanks . Slice ( 0 , byteIndicesAndRanks . Length - 1 ) ;
6680 }
6781 else
6882 {
6983 break ;
7084 }
7185 }
7286
73- var outList = new int [ byteIndicesAndRanks . Count - 1 ] ;
74- for ( int i = 0 ; i < byteIndicesAndRanks . Count - 1 ; i ++ )
87+ var result = new int [ byteIndicesAndRanks . Length - 1 ] ;
88+ for ( int i = 0 ; i < result . Length ; i ++ )
7589 {
76- outList [ i ] = ranks [ mergingBytes . SliceStartEnd ( byteIndicesAndRanks [ i ] . Index , byteIndicesAndRanks [ i + 1 ] . Index ) ] ;
90+ result [ i ] = ranks [ mergingBytes . SliceStartEnd ( byteIndicesAndRanks [ i ] . Index , byteIndicesAndRanks [ i + 1 ] . Index ) ] ;
7791 }
78- return outList ;
92+
93+ if ( arrayPoolArray is not null )
94+ {
95+ ArrayPool < ( int , int ) > . Shared . Return ( arrayPoolArray ) ;
96+ }
97+
98+ return result ;
7999 }
80100
81101 private static ReadOnlyMemory < byte > SliceStartEnd ( this ReadOnlyMemory < byte > memory , int start , int end ) => memory . Slice ( start , end - start ) ;
0 commit comments