@@ -95,6 +95,7 @@ public Bpe(string vocabFile, string? mergesFile, string? unknownToken = null, st
9595
9696 ( Dictionary < string , int > ? vocab1 , Vec < ( string , string ) > merges ) = ReadFile ( vocabFile , mergesFile ) ;
9797 Vocab = vocab1 ?? new Dictionary < string , int > ( ) ;
98+ Cache = new Cache < string , Word > ( ) ;
9899
99100 VocabReverse = new ( ) ;
100101
@@ -146,23 +147,33 @@ public Bpe(string vocabFile, string? mergesFile, string? unknownToken = null, st
146147 /// Tokenize a sequence string to a list of tokens.
147148 /// </summary>
148149 /// <param name="sequence">The sequence to tokenize.</param>
150+ /// <param name="isSpecialToken">Indicate if the token is a special token.</param>
149151 /// <returns>The list of tokens generated from the sequence tokenization.</returns>
150- public override IReadOnlyList < Token > Tokenize ( string sequence )
152+ public override IReadOnlyList < Token > Tokenize ( string sequence , bool isSpecialToken = false )
151153 {
152154 if ( sequence . Length == 0 )
153155 {
154156 return EmptyTokensList ;
155157 }
156158
157- if ( ! Dropout . HasValue )
158- {
159- return TokenizeWithCache ( sequence ) ;
160- }
159+ return TokenizeWithCache ( sequence ) ;
160+ }
161161
162- Word word = MergeWord ( sequence ) ;
162+ /// <summary>
163+ /// Tokenize a split sequence string to a list of Ids and add them to the accumulatedIds list.
164+ /// </summary>
165+ /// <param name="sequence">The sequence to split.</param>
166+ /// <param name="isSpecialToken">Indicate if the token is a special token.</param>
167+ /// <param name="accumulatedIds">The list of accumulated tokenized Ids.</param>
168+ public override void TokenizeToIds ( string sequence , bool isSpecialToken , IList < int > accumulatedIds ) => TokenizeToIdsWithCache ( sequence , accumulatedIds ) ;
163169
164- return WordToTokens ( ref word ) ;
165- }
170+ /// <summary>
171+ /// Get the number of tokens that the input sequence will be encoded to.
172+ /// </summary>
173+ /// <param name="sequence">The text to tokenize.</param>
174+ /// <param name="isSpecialToken">Indicate if the token is special token.</param>
175+ /// <returns>The number of tokens that the input sequence will be encoded to.</returns>
176+ public override int CountTokens ( string sequence , bool isSpecialToken ) => TokenizeToIdsWithCache ( sequence , null ) ;
166177
167178 /// <summary>
168179 /// Map the token to tokenized Id.
@@ -195,14 +206,6 @@ public override IReadOnlyList<Token> Tokenize(string sequence)
195206 return null ;
196207 }
197208
198- /// <summary>
199- /// Map the tokenized Id to the token.
200- /// </summary>
201- /// <param name="id">The Id to map to the token.</param>
202- /// <param name="skipSpecialTokens">Indicate if want to skip the special tokens during the decoding.</param>
203- /// <returns>The mapped token of the Id.</returns>
204- public override string ? IdToString ( int id , bool skipSpecialTokens = false ) => throw new NotImplementedException ( ) ;
205-
206209 /// <summary>
207210 /// Gets the dictionary mapping tokens to Ids.
208211 /// </summary>
@@ -332,7 +335,7 @@ internal string CharToString(char c)
332335
333336 internal Word MergeWord ( string w )
334337 {
335- Word word = Word . WithCapacity ( ( int ) w . Length ) ;
338+ Word word = Word . WithCapacity ( w . Length ) ;
336339 ( int Id , int Len ) ? unk = null ;
337340 int i = 0 ;
338341
@@ -344,7 +347,7 @@ internal Word MergeWord(string w)
344347 if ( Char . IsHighSurrogate ( w [ i ] ) && i < w . Length - 1 && Char . IsLowSurrogate ( w [ i + 1 ] ) )
345348 {
346349 length = 2 ;
347- s = w . Substring ( i , ( int ) length ) ;
350+ s = w . Substring ( i , length ) ;
348351 }
349352 else
350353 {
@@ -403,7 +406,7 @@ internal Word MergeWord(string w)
403406 }
404407 }
405408
406- i += ( int ) length ;
409+ i += length ;
407410 }
408411
409412 if ( unk . HasValue )
@@ -415,45 +418,59 @@ internal Word MergeWord(string w)
415418 return word ;
416419 }
417420
418- // internal Word.Enumerator WordToTokens(Word word) => word.GetIterator(VocabReverse);
419- internal List < Token > WordToTokens ( ref Word word )
421+ internal List < Token > WordToTokens ( ref Word word ) => word . ToTokens ( VocabReverse ) ;
422+
423+ internal List < Token > TokenizeWithCache ( string sequence )
420424 {
421- List < Token > tokens = new ( word . SymbolsCount ) ;
425+ Word word ;
426+ if ( Cache is not null )
427+ {
428+ if ( Cache . TryGet ( sequence , out word ) )
429+ {
430+ return WordToTokens ( ref word ) ;
431+ }
422432
423- foreach ( Token token in word . GetIterator ( VocabReverse ) )
433+ word = MergeWord ( sequence ) ;
434+ Cache . Set ( sequence , word ) ;
435+ }
436+ else
424437 {
425- tokens . Add ( token ) ;
438+ word = MergeWord ( sequence ) ;
426439 }
427440
428- return tokens ;
441+ return WordToTokens ( ref word ) ;
429442 }
430443
431- internal List < Token > TokenizeWithCache ( string sequence )
444+ internal int WordToIds ( ref Word word , IList < int > ? accumulatedIds )
432445 {
433- if ( Cache is not null )
446+ if ( accumulatedIds is not null )
434447 {
435- Word ? hit = Cache . Get ( sequence ) ;
436- if ( hit . HasValue )
437- {
438- Word w = hit . Value ;
439- return WordToTokens ( ref w ) ;
440- }
448+ word . PopulateIds ( accumulatedIds ) ;
441449 }
442450
443- Word word = MergeWord ( sequence ) ;
444- List < Token > tokens = WordToTokens ( ref word ) ;
451+ return word . SymbolsCount ;
452+ }
453+
454+ internal int TokenizeToIdsWithCache ( string sequence , IList < int > ? accumulatedIds )
455+ {
456+ Word word ;
445457
446458 if ( Cache is not null )
447459 {
460+ if ( Cache . TryGet ( sequence , out Word hit ) )
461+ {
462+ return WordToIds ( ref hit , accumulatedIds ) ;
463+ }
464+
465+ word = MergeWord ( sequence ) ;
448466 Cache . Set ( sequence , word ) ;
449467 }
468+ else
469+ {
470+ word = MergeWord ( sequence ) ;
471+ }
450472
451- return tokens ;
452- }
453-
454- public override bool IsValidChar ( char ch )
455- {
456- throw new NotImplementedException ( ) ;
473+ return WordToIds ( ref word , accumulatedIds ) ;
457474 }
458475
459476 internal static readonly List < Token > EmptyTokensList = new ( ) ;
0 commit comments