From d8de79d697a19744681b2f97fd703f575238586d Mon Sep 17 00:00:00 2001 From: brian khuu Date: Sun, 19 May 2024 00:18:44 +1000 Subject: [PATCH] common.cpp: add --enable-special-out and --disable-special-out for override default special token handling behavior --- common/common.cpp | 22 ++++++++++++++++++++-- common/common.h | 2 ++ examples/main/main.cpp | 12 +++++++++--- 3 files changed, 31 insertions(+), 5 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index e624fc7f35352..038db8a4becd0 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -901,6 +901,22 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.interactive = true; return true; } + if (arg == "-eso" || arg == "--enable-special-out") { + params.enable_special_token_rendering = true; + if (params.disable_special_token_rendering) { + invalid_param = true; + return true; + } + return true; + } + if (arg == "-dso" ||arg == "--disable-special-out") { + params.disable_special_token_rendering = true; + if (params.enable_special_token_rendering) { + invalid_param = true; + return true; + } + return true; + } if (arg == "--interactive-specials") { params.interactive_specials = true; return true; @@ -1432,6 +1448,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -h, --help show this help message and exit\n"); printf(" --version show version and build info\n"); printf(" -i, --interactive run in interactive mode\n"); + printf(" -eso --enable-special-out enable special tokens print (overrides default behaviour)\n"); + printf(" -dso --disable-special-out disable special tokens print (overrides default behaviour)\n"); printf(" --interactive-specials allow special tokens in user text, in interactive mode\n"); printf(" --interactive-first run in interactive mode and wait for input right away\n"); printf(" -cnv, --conversation run in conversation mode (does not print special tokens and suffix/prefix)\n"); @@ -1493,8 +1511,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" modifies the likelihood of token appearing in the completion,\n"); printf(" i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"); printf(" or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n"); - printf(" --grammar GRAMMAR BNF-like grammar to constrain generations (see samples in grammars/ dir)\n"); - printf(" --grammar-file FNAME file to read grammar from\n"); + printf(" --grammar GRAMMAR BNF-like grammar to constrain generations (see samples in grammars/ dir) (special token disabled by default)\n"); + printf(" --grammar-file FNAME file to read grammar from (special token disabled by default)\n"); printf(" -j SCHEMA, --json-schema SCHEMA\n"); printf(" JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object.\n"); printf(" For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead\n"); diff --git a/common/common.h b/common/common.h index 566490e2f881a..498397566aab3 100644 --- a/common/common.h +++ b/common/common.h @@ -141,6 +141,8 @@ struct gpt_params { bool random_prompt = false; // do not randomize prompt if none provided bool use_color = false; // use color to distinguish generations and inputs bool interactive = false; // interactive mode + bool enable_special_token_rendering = false; // override special token rendering to enabled mode regardless of default (useful for debugging) + bool disable_special_token_rendering = false; // override special token rendering to disabled mode regardless of default (useful for scripting) bool interactive_specials = false; // whether to allow special tokens from user, during interactive mode bool conversation = false; // conversation mode (does not print special tokens and suffix/prefix) bool chatml = false; // chatml mode (used for models trained on chatml syntax) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 8153a71fb5791..ff7fb05e065bd 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -528,7 +528,13 @@ int main(int argc, char ** argv) { exit(1); } - bool should_show_special_tokens = sparams.grammar.empty(); + const bool special_token_render_override = params.enable_special_token_rendering || params.disable_special_token_rendering; + bool special_token_render = sparams.grammar.empty(); + if (params.enable_special_token_rendering) { + special_token_render = true; + } else if (params.disable_special_token_rendering) { + special_token_render = false; + } while ((n_remain != 0 && !is_antiprompt) || params.interactive) { // predict @@ -742,7 +748,7 @@ int main(int argc, char ** argv) { // display text if (input_echo && display) { for (auto id : embd) { - const std::string token_str = llama_token_to_piece(ctx, id, !params.conversation && should_show_special_tokens); + const std::string token_str = llama_token_to_piece(ctx, id, special_token_render_override ? special_token_render : !params.conversation && special_token_render); printf("%s", token_str.c_str()); if (embd.size() > 1) { @@ -908,7 +914,7 @@ int main(int argc, char ** argv) { for (size_t i = original_size; i < embd_inp.size(); ++i) { const llama_token token = embd_inp[i]; output_tokens.push_back(token); - output_ss << llama_token_to_piece(ctx, token, should_show_special_tokens); + output_ss << llama_token_to_piece(ctx, token, special_token_render); } n_remain -= line_inp.size();