@@ -656,21 +656,29 @@ async def all_embeddings_are_generated(context):
656
656
assert_embeddings (context .tasks_result .pop ().pop ())
657
657
658
658
659
+ @step ('adding special tokens' )
660
+ def step_tokenize_set_add_special (context ):
661
+ context .tokenize_add_special = True
662
+
663
+
659
664
@step ('tokenizing' )
660
665
@async_run_until_complete
661
666
async def step_tokenize (context ):
662
667
context .tokenized_text = context_text (context )
663
668
async with aiohttp .ClientSession () as session :
669
+ tokenize_args = {
670
+ "content" : context .tokenized_text ,
671
+ }
672
+ if getattr (context , 'tokenize_add_special' , None ) is not None :
673
+ tokenize_args ['add_special' ] = context .tokenize_add_special
664
674
async with session .post (f'{ context .base_url } /tokenize' ,
665
- json = {
666
- "content" : context .tokenized_text ,
667
- }) as response :
675
+ json = tokenize_args ) as response :
668
676
assert response .status == 200
669
677
tokenize_json = await response .json ()
670
678
context .tokens = tokenize_json ['tokens' ]
671
679
672
680
673
- @step ('tokens can be detokenize ' )
681
+ @step ('tokens can be detokenized ' )
674
682
@async_run_until_complete
675
683
async def step_detokenize (context ):
676
684
assert len (context .tokens ) > 0
@@ -685,6 +693,21 @@ async def step_detokenize(context):
685
693
assert context .tokenized_text == detokenize_json ['content' ].strip ()
686
694
687
695
696
+ @step ('tokens begin with BOS' )
697
+ def step_strings_for_tokenization (context ):
698
+ assert context .tokens [0 ] == context .bos
699
+
700
+
701
+ @step ('tokens do not begin with BOS' )
702
+ def step_strings_for_tokenization (context ):
703
+ assert context .tokens [0 ] != context .bos
704
+
705
+
706
+ @step ('first token is removed' )
707
+ def step_strings_for_tokenization (context ):
708
+ context .tokens = context .tokens [1 :]
709
+
710
+
688
711
@step ('an OPTIONS request is sent from {origin}' )
689
712
@async_run_until_complete
690
713
async def step_options_request (context , origin ):
@@ -1289,4 +1312,6 @@ def server_log(in_stream, out_stream):
1289
1312
thread_stderr = threading .Thread (target = server_log , args = (context .server_process .stderr , sys .stderr ))
1290
1313
thread_stderr .start ()
1291
1314
1315
+ context .bos = 1
1316
+
1292
1317
print (f"server pid={ context .server_process .pid } , behave pid={ os .getpid ()} " )
0 commit comments