diff --git a/CMakeLists.txt b/CMakeLists.txt index 36a2078e4c9fa..c51807dac4722 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,7 @@ cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target directories. project("llama.cpp" C CXX) include(CheckIncludeFileCXX) +include(FetchContent) #set(CMAKE_WARN_DEPRECATED YES) set(CMAKE_WARN_UNUSED_CLI YES) @@ -87,6 +88,30 @@ option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE}) option(LLAMA_CURL "llama: use libcurl to download model from an URL" ON) option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF) +# Add yaml-cpp dependency +FetchContent_Declare( + yaml-cpp + GIT_REPOSITORY https://github.com/jbeder/yaml-cpp.git + GIT_TAG yaml-cpp-0.7.0 +) + +# Configure yaml-cpp for platform compatibility +if(WIN32 AND (CMAKE_GENERATOR MATCHES "MSYS Makefiles" OR MSYS)) + set(YAML_BUILD_SHARED_LIBS ON CACHE BOOL "Build yaml-cpp as shared library for MSYS2") + set(YAML_CPP_BUILD_CONTRIB OFF CACHE BOOL "Disable yaml-cpp contrib for MSYS2") + set(CMAKE_POLICY_DEFAULT_CMP0077 NEW CACHE STRING "Use NEW policy for option() behavior") +endif() + +# Set CMake policy version minimum for yaml-cpp compatibility +set(CMAKE_POLICY_VERSION_MINIMUM 3.5 CACHE STRING "Minimum CMake policy version for yaml-cpp compatibility") + +# Disable yaml-cpp tests and tools to avoid build issues +set(YAML_CPP_BUILD_TESTS OFF CACHE BOOL "Disable yaml-cpp tests") +set(YAML_CPP_BUILD_TOOLS OFF CACHE BOOL "Disable yaml-cpp tools") +set(YAML_CPP_BUILD_CONTRIB OFF CACHE BOOL "Disable yaml-cpp contrib") + +FetchContent_MakeAvailable(yaml-cpp) + # Required for relocatable CMake package include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake) include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/common.cmake) diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 0ae4d698f080c..90743d1b4f176 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -135,7 +135,7 @@ endif () target_include_directories(${TARGET} PUBLIC . ../vendor) target_compile_features (${TARGET} PUBLIC cxx_std_17) -target_link_libraries (${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads) +target_link_libraries (${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} yaml-cpp PUBLIC llama Threads::Threads) # diff --git a/common/arg.cpp b/common/arg.cpp index fcee0c4470077..ce4ded15cad2a 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -7,6 +7,8 @@ #include "log.h" #include "sampling.h" +#include + // fix problem with std::min and std::max #if defined(_WIN32) #define WIN32_LEAN_AND_MEAN @@ -41,6 +43,289 @@ using json = nlohmann::ordered_json; +// YAML configuration parsing functions +static void parse_yaml_sampling(const YAML::Node& node, common_params_sampling& sampling) { + if (node["seed"] && node["seed"].IsScalar()) sampling.seed = node["seed"].as(); + if (node["n_prev"] && node["n_prev"].IsScalar()) sampling.n_prev = node["n_prev"].as(); + if (node["n_probs"] && node["n_probs"].IsScalar()) sampling.n_probs = node["n_probs"].as(); + if (node["min_keep"] && node["min_keep"].IsScalar()) sampling.min_keep = node["min_keep"].as(); + if (node["top_k"] && node["top_k"].IsScalar()) sampling.top_k = node["top_k"].as(); + if (node["top_p"] && node["top_p"].IsScalar()) sampling.top_p = node["top_p"].as(); + if (node["min_p"] && node["min_p"].IsScalar()) sampling.min_p = node["min_p"].as(); + if (node["xtc_probability"] && node["xtc_probability"].IsScalar()) sampling.xtc_probability = node["xtc_probability"].as(); + if (node["xtc_threshold"] && node["xtc_threshold"].IsScalar()) sampling.xtc_threshold = node["xtc_threshold"].as(); + if (node["typ_p"] && node["typ_p"].IsScalar()) sampling.typ_p = node["typ_p"].as(); + if (node["temp"] && node["temp"].IsScalar()) sampling.temp = node["temp"].as(); + if (node["dynatemp_range"] && node["dynatemp_range"].IsScalar()) sampling.dynatemp_range = node["dynatemp_range"].as(); + if (node["dynatemp_exponent"] && node["dynatemp_exponent"].IsScalar()) sampling.dynatemp_exponent = node["dynatemp_exponent"].as(); + if (node["penalty_last_n"] && node["penalty_last_n"].IsScalar()) sampling.penalty_last_n = node["penalty_last_n"].as(); + if (node["penalty_repeat"] && node["penalty_repeat"].IsScalar()) sampling.penalty_repeat = node["penalty_repeat"].as(); + if (node["penalty_freq"] && node["penalty_freq"].IsScalar()) sampling.penalty_freq = node["penalty_freq"].as(); + if (node["penalty_present"] && node["penalty_present"].IsScalar()) sampling.penalty_present = node["penalty_present"].as(); + if (node["dry_multiplier"] && node["dry_multiplier"].IsScalar()) sampling.dry_multiplier = node["dry_multiplier"].as(); + if (node["dry_base"] && node["dry_base"].IsScalar()) sampling.dry_base = node["dry_base"].as(); + if (node["dry_allowed_length"] && node["dry_allowed_length"].IsScalar()) sampling.dry_allowed_length = node["dry_allowed_length"].as(); + if (node["dry_penalty_last_n"] && node["dry_penalty_last_n"].IsScalar()) sampling.dry_penalty_last_n = node["dry_penalty_last_n"].as(); + if (node["mirostat"] && node["mirostat"].IsScalar()) sampling.mirostat = node["mirostat"].as(); + if (node["top_n_sigma"] && node["top_n_sigma"].IsScalar()) sampling.top_n_sigma = node["top_n_sigma"].as(); + if (node["mirostat_tau"] && node["mirostat_tau"].IsScalar()) sampling.mirostat_tau = node["mirostat_tau"].as(); + if (node["mirostat_eta"] && node["mirostat_eta"].IsScalar()) sampling.mirostat_eta = node["mirostat_eta"].as(); + if (node["ignore_eos"] && node["ignore_eos"].IsScalar()) sampling.ignore_eos = node["ignore_eos"].as(); + if (node["no_perf"] && node["no_perf"].IsScalar()) sampling.no_perf = node["no_perf"].as(); + if (node["timing_per_token"] && node["timing_per_token"].IsScalar()) sampling.timing_per_token = node["timing_per_token"].as(); + if (node["grammar"] && node["grammar"].IsScalar()) sampling.grammar = node["grammar"].as(); + if (node["grammar_lazy"] && node["grammar_lazy"].IsScalar()) sampling.grammar_lazy = node["grammar_lazy"].as(); + + if (node["dry_sequence_breakers"] && node["dry_sequence_breakers"].IsSequence()) { + sampling.dry_sequence_breakers.clear(); + const auto& breakers = node["dry_sequence_breakers"]; + sampling.dry_sequence_breakers.reserve(breakers.size()); + for (const auto& breaker : breakers) { + if (breaker && breaker.IsScalar()) { + sampling.dry_sequence_breakers.push_back(breaker.as()); + } + } + } +} + +static void parse_yaml_model(const YAML::Node& node, common_params_model& model) { + if (node["path"] && node["path"].IsScalar()) model.path = node["path"].as(); + if (node["url"] && node["url"].IsScalar()) model.url = node["url"].as(); + if (node["hf_repo"] && node["hf_repo"].IsScalar()) model.hf_repo = node["hf_repo"].as(); + if (node["hf_file"] && node["hf_file"].IsScalar()) model.hf_file = node["hf_file"].as(); +} + +static void parse_yaml_speculative(const YAML::Node& node, common_params_speculative& spec) { + if (node["n_ctx"] && node["n_ctx"].IsScalar()) spec.n_ctx = node["n_ctx"].as(); + if (node["n_max"] && node["n_max"].IsScalar()) spec.n_max = node["n_max"].as(); + if (node["n_min"] && node["n_min"].IsScalar()) spec.n_min = node["n_min"].as(); + if (node["n_gpu_layers"] && node["n_gpu_layers"].IsScalar()) spec.n_gpu_layers = node["n_gpu_layers"].as(); + if (node["p_split"] && node["p_split"].IsScalar()) spec.p_split = node["p_split"].as(); + if (node["p_min"] && node["p_min"].IsScalar()) spec.p_min = node["p_min"].as(); + if (node["cache_type_k"] && node["cache_type_k"].IsScalar()) { + std::string cache_type = node["cache_type_k"].as(); + if (cache_type == "f16") spec.cache_type_k = GGML_TYPE_F16; + else if (cache_type == "f32") spec.cache_type_k = GGML_TYPE_F32; + else if (cache_type == "q4_0") spec.cache_type_k = GGML_TYPE_Q4_0; + else if (cache_type == "q4_1") spec.cache_type_k = GGML_TYPE_Q4_1; + else if (cache_type == "q5_0") spec.cache_type_k = GGML_TYPE_Q5_0; + else if (cache_type == "q5_1") spec.cache_type_k = GGML_TYPE_Q5_1; + else if (cache_type == "q8_0") spec.cache_type_k = GGML_TYPE_Q8_0; + } + if (node["cache_type_v"] && node["cache_type_v"].IsScalar()) { + std::string cache_type = node["cache_type_v"].as(); + if (cache_type == "f16") spec.cache_type_v = GGML_TYPE_F16; + else if (cache_type == "f32") spec.cache_type_v = GGML_TYPE_F32; + else if (cache_type == "q4_0") spec.cache_type_v = GGML_TYPE_Q4_0; + else if (cache_type == "q4_1") spec.cache_type_v = GGML_TYPE_Q4_1; + else if (cache_type == "q5_0") spec.cache_type_v = GGML_TYPE_Q5_0; + else if (cache_type == "q5_1") spec.cache_type_v = GGML_TYPE_Q5_1; + else if (cache_type == "q8_0") spec.cache_type_v = GGML_TYPE_Q8_0; + } + if (node["model"] && node["model"].IsMap()) { + parse_yaml_model(node["model"], spec.model); + } +} + +static void parse_yaml_vocoder(const YAML::Node& node, common_params_vocoder& vocoder) { + if (node["speaker_file"] && node["speaker_file"].IsScalar()) vocoder.speaker_file = node["speaker_file"].as(); + if (node["use_guide_tokens"] && node["use_guide_tokens"].IsScalar()) vocoder.use_guide_tokens = node["use_guide_tokens"].as(); + if (node["model"] && node["model"].IsMap()) { + parse_yaml_model(node["model"], vocoder.model); + } +} + +static void parse_yaml_diffusion(const YAML::Node& node, common_params_diffusion& diffusion) { + if (node["steps"] && node["steps"].IsScalar()) diffusion.steps = node["steps"].as(); + if (node["visual_mode"] && node["visual_mode"].IsScalar()) diffusion.visual_mode = node["visual_mode"].as(); + if (node["eps"] && node["eps"].IsScalar()) diffusion.eps = node["eps"].as(); + if (node["block_length"] && node["block_length"].IsScalar()) diffusion.block_length = node["block_length"].as(); + if (node["algorithm"] && node["algorithm"].IsScalar()) diffusion.algorithm = node["algorithm"].as(); + if (node["alg_temp"] && node["alg_temp"].IsScalar()) diffusion.alg_temp = node["alg_temp"].as(); + if (node["cfg_scale"] && node["cfg_scale"].IsScalar()) diffusion.cfg_scale = node["cfg_scale"].as(); + if (node["add_gumbel_noise"] && node["add_gumbel_noise"].IsScalar()) diffusion.add_gumbel_noise = node["add_gumbel_noise"].as(); +} + +static bool load_yaml_config(const std::string& config_path, common_params& params) { + try { + YAML::Node config = YAML::LoadFile(config_path); + + // Parse main parameters with bounds checking + if (config["n_predict"] && config["n_predict"].IsScalar()) { + params.n_predict = config["n_predict"].as(); + } + if (config["n_ctx"] && config["n_ctx"].IsScalar()) { + params.n_ctx = config["n_ctx"].as(); + } + if (config["n_batch"] && config["n_batch"].IsScalar()) { + params.n_batch = config["n_batch"].as(); + } + if (config["n_ubatch"] && config["n_ubatch"].IsScalar()) { + params.n_ubatch = config["n_ubatch"].as(); + } + if (config["n_keep"] && config["n_keep"].IsScalar()) { + params.n_keep = config["n_keep"].as(); + } + if (config["n_chunks"] && config["n_chunks"].IsScalar()) { + params.n_chunks = config["n_chunks"].as(); + } + if (config["n_parallel"] && config["n_parallel"].IsScalar()) { + params.n_parallel = config["n_parallel"].as(); + } + if (config["n_sequences"] && config["n_sequences"].IsScalar()) { + params.n_sequences = config["n_sequences"].as(); + } + if (config["grp_attn_n"] && config["grp_attn_n"].IsScalar()) { + params.grp_attn_n = config["grp_attn_n"].as(); + } + if (config["grp_attn_w"] && config["grp_attn_w"].IsScalar()) { + params.grp_attn_w = config["grp_attn_w"].as(); + } + if (config["n_print"] && config["n_print"].IsScalar()) { + params.n_print = config["n_print"].as(); + } + if (config["rope_freq_base"] && config["rope_freq_base"].IsScalar()) { + params.rope_freq_base = config["rope_freq_base"].as(); + } + if (config["rope_freq_scale"] && config["rope_freq_scale"].IsScalar()) { + params.rope_freq_scale = config["rope_freq_scale"].as(); + } + if (config["yarn_ext_factor"] && config["yarn_ext_factor"].IsScalar()) { + params.yarn_ext_factor = config["yarn_ext_factor"].as(); + } + if (config["yarn_attn_factor"] && config["yarn_attn_factor"].IsScalar()) { + params.yarn_attn_factor = config["yarn_attn_factor"].as(); + } + if (config["yarn_beta_fast"] && config["yarn_beta_fast"].IsScalar()) { + params.yarn_beta_fast = config["yarn_beta_fast"].as(); + } + if (config["yarn_beta_slow"] && config["yarn_beta_slow"].IsScalar()) { + params.yarn_beta_slow = config["yarn_beta_slow"].as(); + } + if (config["yarn_orig_ctx"] && config["yarn_orig_ctx"].IsScalar()) { + params.yarn_orig_ctx = config["yarn_orig_ctx"].as(); + } + if (config["n_gpu_layers"] && config["n_gpu_layers"].IsScalar()) { + params.n_gpu_layers = config["n_gpu_layers"].as(); + } + if (config["main_gpu"] && config["main_gpu"].IsScalar()) { + params.main_gpu = config["main_gpu"].as(); + } + + // Parse string parameters with type checking + if (config["model_alias"] && config["model_alias"].IsScalar()) { + params.model_alias = config["model_alias"].as(); + } + if (config["hf_token"] && config["hf_token"].IsScalar()) { + params.hf_token = config["hf_token"].as(); + } + if (config["prompt"] && config["prompt"].IsScalar()) { + params.prompt = config["prompt"].as(); + } + if (config["system_prompt"] && config["system_prompt"].IsScalar()) { + params.system_prompt = config["system_prompt"].as(); + } + if (config["prompt_file"] && config["prompt_file"].IsScalar()) { + params.prompt_file = config["prompt_file"].as(); + } + if (config["path_prompt_cache"] && config["path_prompt_cache"].IsScalar()) { + params.path_prompt_cache = config["path_prompt_cache"].as(); + } + if (config["input_prefix"] && config["input_prefix"].IsScalar()) { + params.input_prefix = config["input_prefix"].as(); + } + if (config["input_suffix"] && config["input_suffix"].IsScalar()) { + params.input_suffix = config["input_suffix"].as(); + } + if (config["lookup_cache_static"] && config["lookup_cache_static"].IsScalar()) { + params.lookup_cache_static = config["lookup_cache_static"].as(); + } + if (config["lookup_cache_dynamic"] && config["lookup_cache_dynamic"].IsScalar()) { + params.lookup_cache_dynamic = config["lookup_cache_dynamic"].as(); + } + if (config["logits_file"] && config["logits_file"].IsScalar()) { + params.logits_file = config["logits_file"].as(); + } + + // Parse boolean parameters with type checking + if (config["lora_init_without_apply"] && config["lora_init_without_apply"].IsScalar()) { + params.lora_init_without_apply = config["lora_init_without_apply"].as(); + } + if (config["offline"] && config["offline"].IsScalar()) { + params.offline = config["offline"].as(); + } + + // Parse integer parameters with type checking + if (config["verbosity"] && config["verbosity"].IsScalar()) { + params.verbosity = config["verbosity"].as(); + } + if (config["control_vector_layer_start"] && config["control_vector_layer_start"].IsScalar()) { + params.control_vector_layer_start = config["control_vector_layer_start"].as(); + } + if (config["control_vector_layer_end"] && config["control_vector_layer_end"].IsScalar()) { + params.control_vector_layer_end = config["control_vector_layer_end"].as(); + } + if (config["ppl_stride"] && config["ppl_stride"].IsScalar()) { + params.ppl_stride = config["ppl_stride"].as(); + } + if (config["ppl_output_type"] && config["ppl_output_type"].IsScalar()) { + params.ppl_output_type = config["ppl_output_type"].as(); + } + + // Parse array parameters with proper bounds checking + if (config["in_files"] && config["in_files"].IsSequence()) { + params.in_files.clear(); + params.in_files.reserve(config["in_files"].size()); + for (const auto& file : config["in_files"]) { + if (file.IsScalar()) { + params.in_files.push_back(file.as()); + } + } + } + + if (config["antiprompt"] && config["antiprompt"].IsSequence()) { + params.antiprompt.clear(); + params.antiprompt.reserve(config["antiprompt"].size()); + for (const auto& prompt : config["antiprompt"]) { + if (prompt.IsScalar()) { + params.antiprompt.push_back(prompt.as()); + } + } + } + + if (config["sampling"] && config["sampling"].IsMap()) { + parse_yaml_sampling(config["sampling"], params.sampling); + } + + if (config["model"] && config["model"].IsMap()) { + parse_yaml_model(config["model"], params.model); + } + + if (config["speculative"] && config["speculative"].IsMap()) { + parse_yaml_speculative(config["speculative"], params.speculative); + } + + if (config["vocoder"] && config["vocoder"].IsMap()) { + parse_yaml_vocoder(config["vocoder"], params.vocoder); + } + + if (config["diffusion"] && config["diffusion"].IsMap()) { + parse_yaml_diffusion(config["diffusion"], params.diffusion); + } + + return true; + } catch (const YAML::Exception& e) { + fprintf(stderr, "YAML parsing error: %s\n", e.what()); + return false; + } catch (const std::exception& e) { + fprintf(stderr, "Error loading YAML config: %s\n", e.what()); + return false; + } catch (...) { + fprintf(stderr, "Unknown error loading YAML config\n"); + return false; + } +} + std::initializer_list mmproj_examples = { LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_SERVER, @@ -1223,9 +1508,40 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e const common_params params_org = ctx_arg.params; // the example can modify the default params try { - if (!common_params_parse_ex(argc, argv, ctx_arg)) { - ctx_arg.params = params_org; - return false; + bool has_config = false; + for (int i = 1; i < argc; i++) { + if (strcmp(argv[i], "--config") == 0 && i + 1 < argc) { + if (!load_yaml_config(argv[i + 1], ctx_arg.params)) { + fprintf(stderr, "Failed to load YAML config: %s\n", argv[i + 1]); + ctx_arg.params = params_org; + return false; + } + has_config = true; + break; // Only process first --config for now + } + } + + if (has_config) { + std::vector filtered_argv; + filtered_argv.push_back(argv[0]); // Keep program name + + for (int i = 1; i < argc; i++) { + if (strcmp(argv[i], "--config") == 0 && i + 1 < argc) { + i++; // Skip both --config and filename + } else { + filtered_argv.push_back(argv[i]); + } + } + + if (!common_params_parse_ex(filtered_argv.size(), filtered_argv.data(), ctx_arg)) { + ctx_arg.params = params_org; + return false; + } + } else { + if (!common_params_parse_ex(argc, argv, ctx_arg)) { + ctx_arg.params = params_org; + return false; + } } if (ctx_arg.params.usage) { common_params_print_usage(ctx_arg); @@ -1294,6 +1610,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex }; + add_opt(common_arg( + {"--config"}, + "FNAME", + "path to YAML configuration file", + [](common_params & params, const std::string & value) { + params.config_file = value; + } + )); + add_opt(common_arg( {"-h", "--help", "--usage"}, "print usage and exit", diff --git a/common/common.h b/common/common.h index 85b3b879d4536..a902dbc8d6e7e 100644 --- a/common/common.h +++ b/common/common.h @@ -348,6 +348,8 @@ struct common_params { int32_t control_vector_layer_end = -1; // layer range for control vector bool offline = false; + std::string config_file = ""; // path to YAML configuration file + int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line // (which is more convenient to use for plotting) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 91719577564a9..869d8cabbd43a 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -190,6 +190,10 @@ llama_build_and_test(test-thread-safety.cpp ARGS -hf ggml-org/models -hff tinyll # this fails on windows (github hosted runner) due to curl DLL not found (exit code 0xc0000135) if (NOT WIN32) llama_build_and_test(test-arg-parser.cpp) + +# YAML configuration tests +llama_build_and_test(test-yaml-config.cpp) +llama_build_and_test(test-yaml-backward-compat.cpp) endif() if (NOT LLAMA_SANITIZE_ADDRESS) diff --git a/tests/test-yaml-backward-compat.cpp b/tests/test-yaml-backward-compat.cpp new file mode 100644 index 0000000000000..84e6e0e71c903 --- /dev/null +++ b/tests/test-yaml-backward-compat.cpp @@ -0,0 +1,179 @@ +#include "common.h" +#include "arg.h" +#include +#include +#include +#include +#include +#include +#include + +struct TestCase { + std::vector args; + std::string description; +}; + +static void test_cli_args_without_yaml() { + std::cout << "Testing CLI arguments without YAML..." << std::endl; + + std::vector test_cases = { + {{"test", "-n", "100"}, "Basic n_predict"}, + {{"test", "-p", "Hello world"}, "Basic prompt"}, + {{"test", "--temp", "0.8"}, "Temperature setting"}, + {{"test", "-c", "2048"}, "Context size"}, + {{"test", "-b", "512"}, "Batch size"}, + {{"test", "--top-k", "40"}, "Top-k sampling"}, + {{"test", "--top-p", "0.9"}, "Top-p sampling"}, + {{"test", "-s", "42"}, "Random seed"}, + {{"test", "-n", "50", "-p", "Test", "--temp", "0.7"}, "Multiple arguments"}, + {{"test", "--help"}, "Help flag (should exit)"}, + }; + + for (const auto& test_case : test_cases) { + if (test_case.description == "Help flag (should exit)") { + continue; + } + + std::cout << " Testing: " << test_case.description << std::endl; + + common_params params; + std::vector argv; + for (const auto& arg : test_case.args) { + argv.push_back(const_cast(arg.c_str())); + } + + bool result = common_params_parse(argv.size(), argv.data(), params, LLAMA_EXAMPLE_COMMON); + + if (!result && test_case.description != "Help flag (should exit)") { + std::cout << " Warning: " << test_case.description << " failed to parse" << std::endl; + } + } + + std::cout << "CLI arguments without YAML test completed!" << std::endl; +} + +static void test_equivalent_yaml_and_cli() { + std::cout << "Testing equivalent YAML and CLI produce same results..." << std::endl; + + std::ofstream yaml_file("equivalent_test.yaml"); + yaml_file << R"( +n_predict: 100 +n_ctx: 2048 +n_batch: 512 +prompt: "Test prompt" +sampling: + seed: 42 + temp: 0.8 + top_k: 40 + top_p: 0.9 + penalty_repeat: 1.1 +)"; + yaml_file.close(); + + common_params yaml_params; + const char* yaml_argv[] = {"test", "--config", "equivalent_test.yaml"}; + bool yaml_result = common_params_parse(3, const_cast(yaml_argv), yaml_params, LLAMA_EXAMPLE_COMMON); + + common_params cli_params; + const char* cli_argv[] = { + "test", + "-n", "100", + "-c", "2048", + "-b", "512", + "-p", "Test prompt", + "-s", "42", + "--temp", "0.8", + "--top-k", "40", + "--top-p", "0.9", + "--repeat-penalty", "1.1" + }; + const int cli_argc = sizeof(cli_argv) / sizeof(cli_argv[0]); + + bool cli_result = common_params_parse(cli_argc, const_cast(cli_argv), cli_params, LLAMA_EXAMPLE_COMMON); + + assert(yaml_result == true); + assert(cli_result == true); + (void)yaml_result; // Suppress unused variable warning + (void)cli_result; // Suppress unused variable warning + + assert(yaml_params.n_predict == cli_params.n_predict); + assert(yaml_params.n_ctx == cli_params.n_ctx); + assert(yaml_params.n_batch == cli_params.n_batch); + assert(yaml_params.prompt == cli_params.prompt); + assert(yaml_params.sampling.seed == cli_params.sampling.seed); + assert(yaml_params.sampling.temp == cli_params.sampling.temp); + assert(yaml_params.sampling.top_k == cli_params.sampling.top_k); + assert(yaml_params.sampling.top_p == cli_params.sampling.top_p); + + const float epsilon = 1e-6f; + assert(std::abs(yaml_params.sampling.penalty_repeat - cli_params.sampling.penalty_repeat) < epsilon); + + std::filesystem::remove("equivalent_test.yaml"); + std::cout << "Equivalent YAML and CLI test passed!" << std::endl; +} + +static void test_all_major_cli_options() { + std::cout << "Testing all major CLI options still work..." << std::endl; + + struct CliTest { + std::vector args; + std::string param_name; + bool should_succeed; + }; + + std::vector cli_tests = { + {{"test", "-m", "model.gguf"}, "model path", true}, + {{"test", "-n", "200"}, "n_predict", true}, + {{"test", "-c", "4096"}, "context size", true}, + {{"test", "-b", "1024"}, "batch size", true}, + {{"test", "-p", "Hello"}, "prompt", true}, + {{"test", "-s", "123"}, "seed", true}, + {{"test", "--temp", "0.7"}, "temperature", true}, + {{"test", "--top-k", "50"}, "top_k", true}, + {{"test", "--top-p", "0.95"}, "top_p", true}, + {{"test", "--repeat-penalty", "1.05"}, "repeat penalty", true}, + {{"test", "-t", "4"}, "threads", true}, + {{"test", "-ngl", "32"}, "gpu layers", true}, + {{"test", "--interactive"}, "interactive mode", true}, + {{"test", "--color"}, "color output", true}, + {{"test", "--verbose"}, "verbose mode", true}, + }; + + for (const auto& test : cli_tests) { + std::cout << " Testing: " << test.param_name << std::endl; + + common_params params; + std::vector argv; + for (const auto& arg : test.args) { + argv.push_back(const_cast(arg.c_str())); + } + + bool result = common_params_parse(argv.size(), argv.data(), params, LLAMA_EXAMPLE_COMMON); + + if (result != test.should_succeed) { + std::cout << " Unexpected result for " << test.param_name + << ": expected " << test.should_succeed << ", got " << result << std::endl; + } + } + + std::cout << "Major CLI options test completed!" << std::endl; +} + +int main() { + std::cout << "Running backward compatibility tests..." << std::endl; + + try { + test_cli_args_without_yaml(); + test_equivalent_yaml_and_cli(); + test_all_major_cli_options(); + + std::cout << "All backward compatibility tests completed!" << std::endl; + return 0; + } catch (const std::exception& e) { + std::cerr << "Test failed with exception: " << e.what() << std::endl; + return 1; + } catch (...) { + std::cerr << "Test failed with unknown exception" << std::endl; + return 1; + } +} diff --git a/tests/test-yaml-config.cpp b/tests/test-yaml-config.cpp new file mode 100644 index 0000000000000..c3618ff303ea1 --- /dev/null +++ b/tests/test-yaml-config.cpp @@ -0,0 +1,216 @@ +#include "common.h" +#include "arg.h" +#include +#include +#include +#include + +static void write_test_yaml(const std::string& filename, const std::string& content) { + std::ofstream file(filename); + file << content; + file.close(); +} + +static void test_basic_yaml_parsing() { + std::cout << "Testing basic YAML parsing..." << std::endl; + + const std::string yaml_content = R"( +n_predict: 100 +n_ctx: 2048 +n_batch: 512 +prompt: "Hello, world!" +model: + path: "test-model.gguf" +sampling: + seed: 42 + temp: 0.7 + top_k: 50 + top_p: 0.9 +)"; + + write_test_yaml("test_basic.yaml", yaml_content); + + common_params params; + const char* argv[] = {"test", "--config", "test_basic.yaml"}; + int argc = 3; + + bool result = common_params_parse(argc, const_cast(argv), params, LLAMA_EXAMPLE_COMMON); + assert(result == true); + (void)result; // Suppress unused variable warning + assert(params.n_predict == 100); + assert(params.n_ctx == 2048); + assert(params.n_batch == 512); + assert(params.prompt == "Hello, world!"); + assert(params.model.path == "test-model.gguf"); + assert(params.sampling.seed == 42); + assert(params.sampling.temp == 0.7f); + assert(params.sampling.top_k == 50); + assert(params.sampling.top_p == 0.9f); + + std::filesystem::remove("test_basic.yaml"); + std::cout << "Basic YAML parsing test passed!" << std::endl; +} + +static void test_cli_override_yaml() { + std::cout << "Testing CLI override of YAML values..." << std::endl; + + const std::string yaml_content = R"( +n_predict: 100 +n_ctx: 2048 +prompt: "YAML prompt" +sampling: + temp: 0.7 +)"; + + write_test_yaml("test_override.yaml", yaml_content); + + common_params params; + const char* argv[] = {"test", "--config", "test_override.yaml", "-n", "200", "-p", "CLI prompt"}; + int argc = 7; + + bool result = common_params_parse(argc, const_cast(argv), params, LLAMA_EXAMPLE_COMMON); + assert(result == true); + (void)result; + assert(params.n_predict == 200); + assert(params.n_ctx == 2048); + assert(params.prompt == "CLI prompt"); + assert(params.sampling.temp == 0.7f); + + std::filesystem::remove("test_override.yaml"); + std::cout << "CLI override test passed!" << std::endl; +} + +static void test_invalid_yaml() { + std::cout << "Testing invalid YAML handling..." << std::endl; + + const std::string invalid_yaml = R"( +n_predict: 100 +invalid_yaml: [unclosed array +)"; + + write_test_yaml("test_invalid.yaml", invalid_yaml); + + common_params params; + const char* argv[] = {"test", "--config", "test_invalid.yaml"}; + int argc = 3; + + bool result = common_params_parse(argc, const_cast(argv), params, LLAMA_EXAMPLE_COMMON); + assert(result == false); // Should fail with invalid YAML + (void)result; // Suppress unused variable warning + + std::filesystem::remove("test_invalid.yaml"); + std::cout << "Invalid YAML test passed!" << std::endl; +} + +static void test_missing_config_file() { + std::cout << "Testing missing config file handling..." << std::endl; + + common_params params; + const char* argv[] = {"test", "--config", "nonexistent.yaml"}; + int argc = 3; + + bool result = common_params_parse(argc, const_cast(argv), params, LLAMA_EXAMPLE_COMMON); + assert(result == false); // Should fail with missing file + (void)result; // Suppress unused variable warning + + std::cout << "Missing config file test passed!" << std::endl; +} + +static void test_backward_compatibility() { + std::cout << "Testing backward compatibility..." << std::endl; + + common_params params; + const char* argv[] = {"test", "-n", "150", "-p", "Test prompt"}; + int argc = 5; + + bool result = common_params_parse(argc, const_cast(argv), params, LLAMA_EXAMPLE_COMMON); + assert(result == true); + (void)result; // Suppress unused variable warning + assert(params.n_predict == 150); + assert(params.prompt == "Test prompt"); + + std::cout << "Backward compatibility test passed!" << std::endl; +} + +static void test_complex_yaml_structure() { + std::cout << "Testing complex YAML structure..." << std::endl; + + const std::string complex_yaml = R"( +n_predict: 200 +n_ctx: 4096 +model: + path: "complex-model.gguf" +sampling: + seed: 123 + temp: 0.6 + top_k: 40 + top_p: 0.95 + penalty_repeat: 1.1 + dry_sequence_breakers: + - "\n" + - ":" + - ";" +speculative: + n_max: 16 + p_split: 0.1 +in_files: + - "file1.txt" + - "file2.txt" +antiprompt: + - "User:" + - "Assistant:" +)"; + + write_test_yaml("test_complex.yaml", complex_yaml); + + common_params params; + const char* argv[] = {"test", "--config", "test_complex.yaml"}; + int argc = 3; + + bool result = common_params_parse(argc, const_cast(argv), params, LLAMA_EXAMPLE_COMMON); + assert(result == true); + (void)result; // Suppress unused variable warning + assert(params.n_predict == 200); + assert(params.n_ctx == 4096); + assert(params.model.path == "complex-model.gguf"); + assert(params.sampling.seed == 123); + assert(params.sampling.temp == 0.6f); + assert(params.sampling.penalty_repeat == 1.1f); + assert(params.sampling.dry_sequence_breakers.size() == 3); + assert(params.sampling.dry_sequence_breakers[0] == "\n"); + assert(params.sampling.dry_sequence_breakers[1] == ":"); + assert(params.sampling.dry_sequence_breakers[2] == ";"); + assert(params.speculative.n_max == 16); + assert(params.speculative.p_split == 0.1f); + assert(params.in_files.size() == 2); + assert(params.in_files[0] == "file1.txt"); + assert(params.in_files[1] == "file2.txt"); + assert(params.antiprompt.size() == 2); + assert(params.antiprompt[0] == "User:"); + assert(params.antiprompt[1] == "Assistant:"); + + std::filesystem::remove("test_complex.yaml"); + std::cout << "Complex YAML structure test passed!" << std::endl; +} + +int main() { + std::cout << "Running YAML configuration tests..." << std::endl; + + try { + test_basic_yaml_parsing(); + test_cli_override_yaml(); + test_invalid_yaml(); + test_missing_config_file(); + test_backward_compatibility(); + test_complex_yaml_structure(); + + std::cout << "All YAML configuration tests passed!" << std::endl; + return 0; + } catch (const std::exception& e) { + std::cerr << "Test failed with exception: " << e.what() << std::endl; + return 1; + } catch (...) { + std::cerr << "Test failed with unknown exception" << std::endl; + return 1; + } +}