From 0e8cfcf5aed46f13a245efe122b471e5fcd336c4 Mon Sep 17 00:00:00 2001 From: Mohammad Hadi Noori Vareno Date: Sun, 6 Jul 2025 11:27:10 +0200 Subject: [PATCH] dynamont_NT modified / canonical states --- dynamont-polya/include/argparse.hpp | 2482 +++++++++++++++++++++++++++ dynamont-polya/include/polyA.hpp | 0 dynamont-polya/include/utils.hpp | 213 +++ dynamont-polya/main.cpp | 0 dynamont-polya/src/polyA.cpp | 596 +++++++ dynamont-polya/src/polyA.py | 160 ++ dynamont-polya/src/utils.cpp | 197 +++ 7 files changed, 3648 insertions(+) create mode 100644 dynamont-polya/include/argparse.hpp create mode 100644 dynamont-polya/include/polyA.hpp create mode 100644 dynamont-polya/include/utils.hpp create mode 100644 dynamont-polya/main.cpp create mode 100644 dynamont-polya/src/polyA.cpp create mode 100755 dynamont-polya/src/polyA.py create mode 100644 dynamont-polya/src/utils.cpp diff --git a/dynamont-polya/include/argparse.hpp b/dynamont-polya/include/argparse.hpp new file mode 100644 index 0000000..c684892 --- /dev/null +++ b/dynamont-polya/include/argparse.hpp @@ -0,0 +1,2482 @@ +/* + __ _ _ __ __ _ _ __ __ _ _ __ ___ ___ + / _` | '__/ _` | '_ \ / _` | '__/ __|/ _ \ Argument Parser for Modern C++ +| (_| | | | (_| | |_) | (_| | | \__ \ __/ http://github.com/p-ranav/argparse + \__,_|_| \__, | .__/ \__,_|_| |___/\___| + |___/|_| + +Licensed under the MIT License . +SPDX-License-Identifier: MIT +Copyright (c) 2019-2022 Pranav Srinivas Kumar +and other contributors. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ +#pragma once + +#include + +#ifndef ARGPARSE_MODULE_USE_STD_MODULE +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + +#ifndef ARGPARSE_CUSTOM_STRTOF +#define ARGPARSE_CUSTOM_STRTOF strtof +#endif + +#ifndef ARGPARSE_CUSTOM_STRTOD +#define ARGPARSE_CUSTOM_STRTOD strtod +#endif + +#ifndef ARGPARSE_CUSTOM_STRTOLD +#define ARGPARSE_CUSTOM_STRTOLD strtold +#endif + +namespace argparse { + +namespace details { // namespace for helper methods + +template +struct HasContainerTraits : std::false_type {}; + +template <> struct HasContainerTraits : std::false_type {}; + +template <> struct HasContainerTraits : std::false_type {}; + +template +struct HasContainerTraits< + T, std::void_t().begin()), + decltype(std::declval().end()), + decltype(std::declval().size())>> : std::true_type {}; + +template +inline constexpr bool IsContainer = HasContainerTraits::value; + +template +struct HasStreamableTraits : std::false_type {}; + +template +struct HasStreamableTraits< + T, + std::void_t() << std::declval())>> + : std::true_type {}; + +template +inline constexpr bool IsStreamable = HasStreamableTraits::value; + +constexpr std::size_t repr_max_container_size = 5; + +template std::string repr(T const &val) { + if constexpr (std::is_same_v) { + return val ? "true" : "false"; + } else if constexpr (std::is_convertible_v) { + return '"' + std::string{std::string_view{val}} + '"'; + } else if constexpr (IsContainer) { + std::stringstream out; + out << "{"; + const auto size = val.size(); + if (size > 1) { + out << repr(*val.begin()); + std::for_each( + std::next(val.begin()), + std::next( + val.begin(), + static_cast( + std::min(size, repr_max_container_size) - 1)), + [&out](const auto &v) { out << " " << repr(v); }); + if (size <= repr_max_container_size) { + out << " "; + } else { + out << "..."; + } + } + if (size > 0) { + out << repr(*std::prev(val.end())); + } + out << "}"; + return out.str(); + } else if constexpr (IsStreamable) { + std::stringstream out; + out << val; + return out.str(); + } else { + return ""; + } +} + +namespace { + +template constexpr bool standard_signed_integer = false; +template <> constexpr bool standard_signed_integer = true; +template <> constexpr bool standard_signed_integer = true; +template <> constexpr bool standard_signed_integer = true; +template <> constexpr bool standard_signed_integer = true; +template <> constexpr bool standard_signed_integer = true; + +template constexpr bool standard_unsigned_integer = false; +template <> constexpr bool standard_unsigned_integer = true; +template <> constexpr bool standard_unsigned_integer = true; +template <> constexpr bool standard_unsigned_integer = true; +template <> constexpr bool standard_unsigned_integer = true; +template <> +constexpr bool standard_unsigned_integer = true; + +} // namespace + +constexpr int radix_2 = 2; +constexpr int radix_8 = 8; +constexpr int radix_10 = 10; +constexpr int radix_16 = 16; + +template +constexpr bool standard_integer = + standard_signed_integer || standard_unsigned_integer; + +template +constexpr decltype(auto) +apply_plus_one_impl(F &&f, Tuple &&t, Extra &&x, + std::index_sequence /*unused*/) { + return std::invoke(std::forward(f), std::get(std::forward(t))..., + std::forward(x)); +} + +template +constexpr decltype(auto) apply_plus_one(F &&f, Tuple &&t, Extra &&x) { + return details::apply_plus_one_impl( + std::forward(f), std::forward(t), std::forward(x), + std::make_index_sequence< + std::tuple_size_v>>{}); +} + +constexpr auto pointer_range(std::string_view s) noexcept { + return std::tuple(s.data(), s.data() + s.size()); +} + +template +constexpr bool starts_with(std::basic_string_view prefix, + std::basic_string_view s) noexcept { + return s.substr(0, prefix.size()) == prefix; +} + +enum class chars_format { + scientific = 0xf1, + fixed = 0xf2, + hex = 0xf4, + binary = 0xf8, + general = fixed | scientific +}; + +struct ConsumeBinaryPrefixResult { + bool is_binary; + std::string_view rest; +}; + +constexpr auto consume_binary_prefix(std::string_view s) + -> ConsumeBinaryPrefixResult { + if (starts_with(std::string_view{"0b"}, s) || + starts_with(std::string_view{"0B"}, s)) { + s.remove_prefix(2); + return {true, s}; + } + return {false, s}; +} + +struct ConsumeHexPrefixResult { + bool is_hexadecimal; + std::string_view rest; +}; + +using namespace std::literals; + +constexpr auto consume_hex_prefix(std::string_view s) + -> ConsumeHexPrefixResult { + if (starts_with("0x"sv, s) || starts_with("0X"sv, s)) { + s.remove_prefix(2); + return {true, s}; + } + return {false, s}; +} + +template +inline auto do_from_chars(std::string_view s) -> T { + T x; + auto [first, last] = pointer_range(s); + auto [ptr, ec] = std::from_chars(first, last, x, Param); + if (ec == std::errc()) { + if (ptr == last) { + return x; + } + throw std::invalid_argument{"pattern '" + std::string(s) + + "' does not match to the end"}; + } + if (ec == std::errc::invalid_argument) { + throw std::invalid_argument{"pattern '" + std::string(s) + "' not found"}; + } + if (ec == std::errc::result_out_of_range) { + throw std::range_error{"'" + std::string(s) + "' not representable"}; + } + return x; // unreachable +} + +template struct parse_number { + auto operator()(std::string_view s) -> T { + return do_from_chars(s); + } +}; + +template struct parse_number { + auto operator()(std::string_view s) -> T { + if (auto [ok, rest] = consume_binary_prefix(s); ok) { + return do_from_chars(rest); + } + throw std::invalid_argument{"pattern not found"}; + } +}; + +template struct parse_number { + auto operator()(std::string_view s) -> T { + if (starts_with("0x"sv, s) || starts_with("0X"sv, s)) { + if (auto [ok, rest] = consume_hex_prefix(s); ok) { + try { + return do_from_chars(rest); + } catch (const std::invalid_argument &err) { + throw std::invalid_argument("Failed to parse '" + std::string(s) + + "' as hexadecimal: " + err.what()); + } catch (const std::range_error &err) { + throw std::range_error("Failed to parse '" + std::string(s) + + "' as hexadecimal: " + err.what()); + } + } + } else { + // Allow passing hex numbers without prefix + // Shape 'x' already has to be specified + try { + return do_from_chars(s); + } catch (const std::invalid_argument &err) { + throw std::invalid_argument("Failed to parse '" + std::string(s) + + "' as hexadecimal: " + err.what()); + } catch (const std::range_error &err) { + throw std::range_error("Failed to parse '" + std::string(s) + + "' as hexadecimal: " + err.what()); + } + } + + throw std::invalid_argument{"pattern '" + std::string(s) + + "' not identified as hexadecimal"}; + } +}; + +template struct parse_number { + auto operator()(std::string_view s) -> T { + auto [ok, rest] = consume_hex_prefix(s); + if (ok) { + try { + return do_from_chars(rest); + } catch (const std::invalid_argument &err) { + throw std::invalid_argument("Failed to parse '" + std::string(s) + + "' as hexadecimal: " + err.what()); + } catch (const std::range_error &err) { + throw std::range_error("Failed to parse '" + std::string(s) + + "' as hexadecimal: " + err.what()); + } + } + + auto [ok_binary, rest_binary] = consume_binary_prefix(s); + if (ok_binary) { + try { + return do_from_chars(rest_binary); + } catch (const std::invalid_argument &err) { + throw std::invalid_argument("Failed to parse '" + std::string(s) + + "' as binary: " + err.what()); + } catch (const std::range_error &err) { + throw std::range_error("Failed to parse '" + std::string(s) + + "' as binary: " + err.what()); + } + } + + if (starts_with("0"sv, s)) { + try { + return do_from_chars(rest); + } catch (const std::invalid_argument &err) { + throw std::invalid_argument("Failed to parse '" + std::string(s) + + "' as octal: " + err.what()); + } catch (const std::range_error &err) { + throw std::range_error("Failed to parse '" + std::string(s) + + "' as octal: " + err.what()); + } + } + + try { + return do_from_chars(rest); + } catch (const std::invalid_argument &err) { + throw std::invalid_argument("Failed to parse '" + std::string(s) + + "' as decimal integer: " + err.what()); + } catch (const std::range_error &err) { + throw std::range_error("Failed to parse '" + std::string(s) + + "' as decimal integer: " + err.what()); + } + } +}; + +namespace { + +template inline const auto generic_strtod = nullptr; +template <> inline const auto generic_strtod = ARGPARSE_CUSTOM_STRTOF; +template <> inline const auto generic_strtod = ARGPARSE_CUSTOM_STRTOD; +template <> +inline const auto generic_strtod = ARGPARSE_CUSTOM_STRTOLD; + +} // namespace + +template inline auto do_strtod(std::string const &s) -> T { + if (isspace(static_cast(s[0])) || s[0] == '+') { + throw std::invalid_argument{"pattern '" + s + "' not found"}; + } + + auto [first, last] = pointer_range(s); + char *ptr; + + errno = 0; + auto x = generic_strtod(first, &ptr); + if (errno == 0) { + if (ptr == last) { + return x; + } + throw std::invalid_argument{"pattern '" + s + + "' does not match to the end"}; + } + if (errno == ERANGE) { + throw std::range_error{"'" + s + "' not representable"}; + } + return x; // unreachable +} + +template struct parse_number { + auto operator()(std::string const &s) -> T { + if (auto r = consume_hex_prefix(s); r.is_hexadecimal) { + throw std::invalid_argument{ + "chars_format::general does not parse hexfloat"}; + } + if (auto r = consume_binary_prefix(s); r.is_binary) { + throw std::invalid_argument{ + "chars_format::general does not parse binfloat"}; + } + + try { + return do_strtod(s); + } catch (const std::invalid_argument &err) { + throw std::invalid_argument("Failed to parse '" + s + + "' as number: " + err.what()); + } catch (const std::range_error &err) { + throw std::range_error("Failed to parse '" + s + + "' as number: " + err.what()); + } + } +}; + +template struct parse_number { + auto operator()(std::string const &s) -> T { + if (auto r = consume_hex_prefix(s); !r.is_hexadecimal) { + throw std::invalid_argument{"chars_format::hex parses hexfloat"}; + } + if (auto r = consume_binary_prefix(s); r.is_binary) { + throw std::invalid_argument{"chars_format::hex does not parse binfloat"}; + } + + try { + return do_strtod(s); + } catch (const std::invalid_argument &err) { + throw std::invalid_argument("Failed to parse '" + s + + "' as hexadecimal: " + err.what()); + } catch (const std::range_error &err) { + throw std::range_error("Failed to parse '" + s + + "' as hexadecimal: " + err.what()); + } + } +}; + +template struct parse_number { + auto operator()(std::string const &s) -> T { + if (auto r = consume_hex_prefix(s); r.is_hexadecimal) { + throw std::invalid_argument{ + "chars_format::binary does not parse hexfloat"}; + } + if (auto r = consume_binary_prefix(s); !r.is_binary) { + throw std::invalid_argument{"chars_format::binary parses binfloat"}; + } + + return do_strtod(s); + } +}; + +template struct parse_number { + auto operator()(std::string const &s) -> T { + if (auto r = consume_hex_prefix(s); r.is_hexadecimal) { + throw std::invalid_argument{ + "chars_format::scientific does not parse hexfloat"}; + } + if (auto r = consume_binary_prefix(s); r.is_binary) { + throw std::invalid_argument{ + "chars_format::scientific does not parse binfloat"}; + } + if (s.find_first_of("eE") == std::string::npos) { + throw std::invalid_argument{ + "chars_format::scientific requires exponent part"}; + } + + try { + return do_strtod(s); + } catch (const std::invalid_argument &err) { + throw std::invalid_argument("Failed to parse '" + s + + "' as scientific notation: " + err.what()); + } catch (const std::range_error &err) { + throw std::range_error("Failed to parse '" + s + + "' as scientific notation: " + err.what()); + } + } +}; + +template struct parse_number { + auto operator()(std::string const &s) -> T { + if (auto r = consume_hex_prefix(s); r.is_hexadecimal) { + throw std::invalid_argument{ + "chars_format::fixed does not parse hexfloat"}; + } + if (auto r = consume_binary_prefix(s); r.is_binary) { + throw std::invalid_argument{ + "chars_format::fixed does not parse binfloat"}; + } + if (s.find_first_of("eE") != std::string::npos) { + throw std::invalid_argument{ + "chars_format::fixed does not parse exponent part"}; + } + + try { + return do_strtod(s); + } catch (const std::invalid_argument &err) { + throw std::invalid_argument("Failed to parse '" + s + + "' as fixed notation: " + err.what()); + } catch (const std::range_error &err) { + throw std::range_error("Failed to parse '" + s + + "' as fixed notation: " + err.what()); + } + } +}; + +template +std::string join(StrIt first, StrIt last, const std::string &separator) { + if (first == last) { + return ""; + } + std::stringstream value; + value << *first; + ++first; + while (first != last) { + value << separator << *first; + ++first; + } + return value.str(); +} + +template struct can_invoke_to_string { + template + static auto test(int) + -> decltype(std::to_string(std::declval()), std::true_type{}); + + template static auto test(...) -> std::false_type; + + static constexpr bool value = decltype(test(0))::value; +}; + +template struct IsChoiceTypeSupported { + using CleanType = typename std::decay::type; + static const bool value = std::is_integral::value || + std::is_same::value || + std::is_same::value || + std::is_same::value; +}; + +template +std::size_t get_levenshtein_distance(const StringType &s1, + const StringType &s2) { + std::vector> dp( + s1.size() + 1, std::vector(s2.size() + 1, 0)); + + for (std::size_t i = 0; i <= s1.size(); ++i) { + for (std::size_t j = 0; j <= s2.size(); ++j) { + if (i == 0) { + dp[i][j] = j; + } else if (j == 0) { + dp[i][j] = i; + } else if (s1[i - 1] == s2[j - 1]) { + dp[i][j] = dp[i - 1][j - 1]; + } else { + dp[i][j] = 1 + std::min({dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]}); + } + } + } + + return dp[s1.size()][s2.size()]; +} + +template +std::string get_most_similar_string(const std::map &map, + const std::string &input) { + std::string most_similar{}; + std::size_t min_distance = std::numeric_limits::max(); + + for (const auto &entry : map) { + std::size_t distance = get_levenshtein_distance(entry.first, input); + if (distance < min_distance) { + min_distance = distance; + most_similar = entry.first; + } + } + + return most_similar; +} + +} // namespace details + +enum class nargs_pattern { optional, any, at_least_one }; + +enum class default_arguments : unsigned int { + none = 0, + help = 1, + version = 2, + all = help | version, +}; + +inline default_arguments operator&(const default_arguments &a, + const default_arguments &b) { + return static_cast( + static_cast::type>(a) & + static_cast::type>(b)); +} + +class ArgumentParser; + +class Argument { + friend class ArgumentParser; + friend auto operator<<(std::ostream &stream, const ArgumentParser &parser) + -> std::ostream &; + + template + explicit Argument(std::string_view prefix_chars, + std::array &&a, + std::index_sequence /*unused*/) + : m_accepts_optional_like_value(false), + m_is_optional((is_optional(a[I], prefix_chars) || ...)), + m_is_required(false), m_is_repeatable(false), m_is_used(false), + m_is_hidden(false), m_prefix_chars(prefix_chars) { + ((void)m_names.emplace_back(a[I]), ...); + std::sort( + m_names.begin(), m_names.end(), [](const auto &lhs, const auto &rhs) { + return lhs.size() == rhs.size() ? lhs < rhs : lhs.size() < rhs.size(); + }); + } + +public: + template + explicit Argument(std::string_view prefix_chars, + std::array &&a) + : Argument(prefix_chars, std::move(a), std::make_index_sequence{}) {} + + Argument &help(std::string help_text) { + m_help = std::move(help_text); + return *this; + } + + Argument &metavar(std::string metavar) { + m_metavar = std::move(metavar); + return *this; + } + + template Argument &default_value(T &&value) { + m_num_args_range = NArgsRange{0, m_num_args_range.get_max()}; + m_default_value_repr = details::repr(value); + + if constexpr (std::is_convertible_v) { + m_default_value_str = std::string{std::string_view{value}}; + } else if constexpr (details::can_invoke_to_string::value) { + m_default_value_str = std::to_string(value); + } + + m_default_value = std::forward(value); + return *this; + } + + Argument &default_value(const char *value) { + return default_value(std::string(value)); + } + + Argument &required() { + m_is_required = true; + return *this; + } + + Argument &implicit_value(std::any value) { + m_implicit_value = std::move(value); + m_num_args_range = NArgsRange{0, 0}; + return *this; + } + + // This is shorthand for: + // program.add_argument("foo") + // .default_value(false) + // .implicit_value(true) + Argument &flag() { + default_value(false); + implicit_value(true); + return *this; + } + + template + auto action(F &&callable, Args &&... bound_args) + -> std::enable_if_t, + Argument &> { + using action_type = std::conditional_t< + std::is_void_v>, + void_action, valued_action>; + if constexpr (sizeof...(Args) == 0) { + m_action.emplace(std::forward(callable)); + } else { + m_action.emplace( + [f = std::forward(callable), + tup = std::make_tuple(std::forward(bound_args)...)]( + std::string const &opt) mutable { + return details::apply_plus_one(f, tup, opt); + }); + } + return *this; + } + + auto &store_into(bool &var) { + flag(); + if (m_default_value.has_value()) { + var = std::any_cast(m_default_value); + } + action([&var](const auto & /*unused*/) { var = true; }); + return *this; + } + + auto &store_into(int &var) { + if (m_default_value.has_value()) { + var = std::any_cast(m_default_value); + } + action([&var](const auto &s) { + var = details::parse_number()(s); + }); + return *this; + } + + auto &store_into(double &var) { + if (m_default_value.has_value()) { + var = std::any_cast(m_default_value); + } + action([&var](const auto &s) { + var = details::parse_number()(s); + }); + return *this; + } + + auto &store_into(std::string &var) { + if (m_default_value.has_value()) { + var = std::any_cast(m_default_value); + } + action([&var](const std::string &s) { var = s; }); + return *this; + } + + auto &store_into(std::vector &var) { + if (m_default_value.has_value()) { + var = std::any_cast>(m_default_value); + } + action([this, &var](const std::string &s) { + if (!m_is_used) { + var.clear(); + } + m_is_used = true; + var.push_back(s); + }); + return *this; + } + + auto &append() { + m_is_repeatable = true; + return *this; + } + + // Cause the argument to be invisible in usage and help + auto &hidden() { + m_is_hidden = true; + return *this; + } + + template + auto scan() -> std::enable_if_t, Argument &> { + static_assert(!(std::is_const_v || std::is_volatile_v), + "T should not be cv-qualified"); + auto is_one_of = [](char c, auto... x) constexpr { + return ((c == x) || ...); + }; + + if constexpr (is_one_of(Shape, 'd') && details::standard_integer) { + action(details::parse_number()); + } else if constexpr (is_one_of(Shape, 'i') && + details::standard_integer) { + action(details::parse_number()); + } else if constexpr (is_one_of(Shape, 'u') && + details::standard_unsigned_integer) { + action(details::parse_number()); + } else if constexpr (is_one_of(Shape, 'b') && + details::standard_unsigned_integer) { + action(details::parse_number()); + } else if constexpr (is_one_of(Shape, 'o') && + details::standard_unsigned_integer) { + action(details::parse_number()); + } else if constexpr (is_one_of(Shape, 'x', 'X') && + details::standard_unsigned_integer) { + action(details::parse_number()); + } else if constexpr (is_one_of(Shape, 'a', 'A') && + std::is_floating_point_v) { + action(details::parse_number()); + } else if constexpr (is_one_of(Shape, 'e', 'E') && + std::is_floating_point_v) { + action(details::parse_number()); + } else if constexpr (is_one_of(Shape, 'f', 'F') && + std::is_floating_point_v) { + action(details::parse_number()); + } else if constexpr (is_one_of(Shape, 'g', 'G') && + std::is_floating_point_v) { + action(details::parse_number()); + } else { + static_assert(alignof(T) == 0, "No scan specification for T"); + } + + return *this; + } + + Argument &nargs(std::size_t num_args) { + m_num_args_range = NArgsRange{num_args, num_args}; + return *this; + } + + Argument &nargs(std::size_t num_args_min, std::size_t num_args_max) { + m_num_args_range = NArgsRange{num_args_min, num_args_max}; + return *this; + } + + Argument &nargs(nargs_pattern pattern) { + switch (pattern) { + case nargs_pattern::optional: + m_num_args_range = NArgsRange{0, 1}; + break; + case nargs_pattern::any: + m_num_args_range = + NArgsRange{0, (std::numeric_limits::max)()}; + break; + case nargs_pattern::at_least_one: + m_num_args_range = + NArgsRange{1, (std::numeric_limits::max)()}; + break; + } + return *this; + } + + Argument &remaining() { + m_accepts_optional_like_value = true; + return nargs(nargs_pattern::any); + } + + template void add_choice(T &&choice) { + static_assert(details::IsChoiceTypeSupported::value, + "Only string or integer type supported for choice"); + static_assert(std::is_convertible_v || + details::can_invoke_to_string::value, + "Choice is not convertible to string_type"); + if (!m_choices.has_value()) { + m_choices = std::vector{}; + } + + if constexpr (std::is_convertible_v) { + m_choices.value().push_back( + std::string{std::string_view{std::forward(choice)}}); + } else if constexpr (details::can_invoke_to_string::value) { + m_choices.value().push_back(std::to_string(std::forward(choice))); + } + } + + Argument &choices() { + if (!m_choices.has_value()) { + throw std::runtime_error("Zero choices provided"); + } + return *this; + } + + template + Argument &choices(T &&first, U &&... rest) { + add_choice(std::forward(first)); + choices(std::forward(rest)...); + return *this; + } + + void find_default_value_in_choices_or_throw() const { + + const auto &choices = m_choices.value(); + + if (m_default_value.has_value()) { + if (std::find(choices.begin(), choices.end(), m_default_value_str) == + choices.end()) { + // provided arg not in list of allowed choices + // report error + + std::string choices_as_csv = + std::accumulate(choices.begin(), choices.end(), std::string(), + [](const std::string &a, const std::string &b) { + return a + (a.empty() ? "" : ", ") + b; + }); + + throw std::runtime_error( + std::string{"Invalid default value "} + m_default_value_repr + + " - allowed options: {" + choices_as_csv + "}"); + } + } + } + + template + void find_value_in_choices_or_throw(Iterator it) const { + + const auto &choices = m_choices.value(); + + if (std::find(choices.begin(), choices.end(), *it) == choices.end()) { + // provided arg not in list of allowed choices + // report error + + std::string choices_as_csv = + std::accumulate(choices.begin(), choices.end(), std::string(), + [](const std::string &a, const std::string &b) { + return a + (a.empty() ? "" : ", ") + b; + }); + + throw std::runtime_error(std::string{"Invalid argument "} + + details::repr(*it) + " - allowed options: {" + + choices_as_csv + "}"); + } + } + + /* The dry_run parameter can be set to true to avoid running the actions, + * and setting m_is_used. This may be used by a pre-processing step to do + * a first iteration over arguments. + */ + template + Iterator consume(Iterator start, Iterator end, + std::string_view used_name = {}, bool dry_run = false) { + if (!m_is_repeatable && m_is_used) { + throw std::runtime_error( + std::string("Duplicate argument ").append(used_name)); + } + m_used_name = used_name; + + if (m_choices.has_value()) { + // Check each value in (start, end) and make sure + // it is in the list of allowed choices/options + std::size_t i = 0; + auto max_number_of_args = m_num_args_range.get_max(); + for (auto it = start; it != end; ++it) { + if (i == max_number_of_args) { + break; + } + find_value_in_choices_or_throw(it); + i += 1; + } + } + + const auto num_args_max = m_num_args_range.get_max(); + const auto num_args_min = m_num_args_range.get_min(); + std::size_t dist = 0; + if (num_args_max == 0) { + if (!dry_run) { + m_values.emplace_back(m_implicit_value); + std::visit([](const auto &f) { f({}); }, m_action); + m_is_used = true; + } + return start; + } + if ((dist = static_cast(std::distance(start, end))) >= + num_args_min) { + if (num_args_max < dist) { + end = std::next(start, static_cast( + num_args_max)); + } + if (!m_accepts_optional_like_value) { + end = std::find_if( + start, end, + std::bind(is_optional, std::placeholders::_1, m_prefix_chars)); + dist = static_cast(std::distance(start, end)); + if (dist < num_args_min) { + throw std::runtime_error("Too few arguments"); + } + } + + struct ActionApply { + void operator()(valued_action &f) { + std::transform(first, last, std::back_inserter(self.m_values), f); + } + + void operator()(void_action &f) { + std::for_each(first, last, f); + if (!self.m_default_value.has_value()) { + if (!self.m_accepts_optional_like_value) { + self.m_values.resize( + static_cast(std::distance(first, last))); + } + } + } + + Iterator first, last; + Argument &self; + }; + if (!dry_run) { + std::visit(ActionApply{start, end, *this}, m_action); + m_is_used = true; + } + return end; + } + if (m_default_value.has_value()) { + if (!dry_run) { + m_is_used = true; + } + return start; + } + throw std::runtime_error("Too few arguments for '" + + std::string(m_used_name) + "'."); + } + + /* + * @throws std::runtime_error if argument values are not valid + */ + void validate() const { + if (m_is_optional) { + // TODO: check if an implicit value was programmed for this argument + if (!m_is_used && !m_default_value.has_value() && m_is_required) { + throw_required_arg_not_used_error(); + } + if (m_is_used && m_is_required && m_values.empty()) { + throw_required_arg_no_value_provided_error(); + } + } else { + if (!m_num_args_range.contains(m_values.size()) && + !m_default_value.has_value()) { + throw_nargs_range_validation_error(); + } + } + + if (m_choices.has_value()) { + // Make sure the default value (if provided) + // is in the list of choices + find_default_value_in_choices_or_throw(); + } + } + + std::string get_names_csv(char separator = ',') const { + return std::accumulate( + m_names.begin(), m_names.end(), std::string{""}, + [&](const std::string &result, const std::string &name) { + return result.empty() ? name : result + separator + name; + }); + } + + std::string get_usage_full() const { + std::stringstream usage; + + usage << get_names_csv('/'); + const std::string metavar = !m_metavar.empty() ? m_metavar : "VAR"; + if (m_num_args_range.get_max() > 0) { + usage << " " << metavar; + if (m_num_args_range.get_max() > 1) { + usage << "..."; + } + } + return usage.str(); + } + + std::string get_inline_usage() const { + std::stringstream usage; + // Find the longest variant to show in the usage string + std::string longest_name = m_names.front(); + for (const auto &s : m_names) { + if (s.size() > longest_name.size()) { + longest_name = s; + } + } + if (!m_is_required) { + usage << "["; + } + usage << longest_name; + const std::string metavar = !m_metavar.empty() ? m_metavar : "VAR"; + if (m_num_args_range.get_max() > 0) { + usage << " " << metavar; + if (m_num_args_range.get_max() > 1 && + m_metavar.find("> <") == std::string::npos) { + usage << "..."; + } + } + if (!m_is_required) { + usage << "]"; + } + if (m_is_repeatable) { + usage << "..."; + } + return usage.str(); + } + + std::size_t get_arguments_length() const { + + std::size_t names_size = std::accumulate( + std::begin(m_names), std::end(m_names), std::size_t(0), + [](const auto &sum, const auto &s) { return sum + s.size(); }); + + if (is_positional(m_names.front(), m_prefix_chars)) { + // A set metavar means this replaces the names + if (!m_metavar.empty()) { + // Indent and metavar + return 2 + m_metavar.size(); + } + + // Indent and space-separated + return 2 + names_size + (m_names.size() - 1); + } + // Is an option - include both names _and_ metavar + // size = text + (", " between names) + std::size_t size = names_size + 2 * (m_names.size() - 1); + if (!m_metavar.empty() && m_num_args_range == NArgsRange{1, 1}) { + size += m_metavar.size() + 1; + } + return size + 2; // indent + } + + friend std::ostream &operator<<(std::ostream &stream, + const Argument &argument) { + std::stringstream name_stream; + name_stream << " "; // indent + if (argument.is_positional(argument.m_names.front(), + argument.m_prefix_chars)) { + if (!argument.m_metavar.empty()) { + name_stream << argument.m_metavar; + } else { + name_stream << details::join(argument.m_names.begin(), + argument.m_names.end(), " "); + } + } else { + name_stream << details::join(argument.m_names.begin(), + argument.m_names.end(), ", "); + // If we have a metavar, and one narg - print the metavar + if (!argument.m_metavar.empty() && + argument.m_num_args_range == NArgsRange{1, 1}) { + name_stream << " " << argument.m_metavar; + } + else if (!argument.m_metavar.empty() && + argument.m_num_args_range.get_min() == argument.m_num_args_range.get_max() && + argument.m_metavar.find("> <") != std::string::npos) { + name_stream << " " << argument.m_metavar; + } + } + + // align multiline help message + auto stream_width = stream.width(); + auto name_padding = std::string(name_stream.str().size(), ' '); + auto pos = std::string::size_type{}; + auto prev = std::string::size_type{}; + auto first_line = true; + auto hspace = " "; // minimal space between name and help message + stream << name_stream.str(); + std::string_view help_view(argument.m_help); + while ((pos = argument.m_help.find('\n', prev)) != std::string::npos) { + auto line = help_view.substr(prev, pos - prev + 1); + if (first_line) { + stream << hspace << line; + first_line = false; + } else { + stream.width(stream_width); + stream << name_padding << hspace << line; + } + prev += pos - prev + 1; + } + if (first_line) { + stream << hspace << argument.m_help; + } else { + auto leftover = help_view.substr(prev, argument.m_help.size() - prev); + if (!leftover.empty()) { + stream.width(stream_width); + stream << name_padding << hspace << leftover; + } + } + + // print nargs spec + if (!argument.m_help.empty()) { + stream << " "; + } + stream << argument.m_num_args_range; + + bool add_space = false; + if (argument.m_default_value.has_value() && + argument.m_num_args_range != NArgsRange{0, 0}) { + stream << "[default: " << argument.m_default_value_repr << "]"; + add_space = true; + } else if (argument.m_is_required) { + stream << "[required]"; + add_space = true; + } + if (argument.m_is_repeatable) { + if (add_space) { + stream << " "; + } + stream << "[may be repeated]"; + } + stream << "\n"; + return stream; + } + + template bool operator!=(const T &rhs) const { + return !(*this == rhs); + } + + /* + * Compare to an argument value of known type + * @throws std::logic_error in case of incompatible types + */ + template bool operator==(const T &rhs) const { + if constexpr (!details::IsContainer) { + return get() == rhs; + } else { + using ValueType = typename T::value_type; + auto lhs = get(); + return std::equal(std::begin(lhs), std::end(lhs), std::begin(rhs), + std::end(rhs), [](const auto &a, const auto &b) { + return std::any_cast(a) == b; + }); + } + } + + /* + * positional: + * _empty_ + * '-' + * '-' decimal-literal + * !'-' anything + */ + static bool is_positional(std::string_view name, + std::string_view prefix_chars) { + auto first = lookahead(name); + + if (first == eof) { + return true; + } + if (prefix_chars.find(static_cast(first)) != + std::string_view::npos) { + name.remove_prefix(1); + if (name.empty()) { + return true; + } + return is_decimal_literal(name); + } + return true; + } + +private: + class NArgsRange { + std::size_t m_min; + std::size_t m_max; + + public: + NArgsRange(std::size_t minimum, std::size_t maximum) + : m_min(minimum), m_max(maximum) { + if (minimum > maximum) { + throw std::logic_error("Range of number of arguments is invalid"); + } + } + + bool contains(std::size_t value) const { + return value >= m_min && value <= m_max; + } + + bool is_exact() const { return m_min == m_max; } + + bool is_right_bounded() const { + return m_max < (std::numeric_limits::max)(); + } + + std::size_t get_min() const { return m_min; } + + std::size_t get_max() const { return m_max; } + + // Print help message + friend auto operator<<(std::ostream &stream, const NArgsRange &range) + -> std::ostream & { + if (range.m_min == range.m_max) { + if (range.m_min != 0 && range.m_min != 1) { + stream << "[nargs: " << range.m_min << "] "; + } + } else { + if (range.m_max == (std::numeric_limits::max)()) { + stream << "[nargs: " << range.m_min << " or more] "; + } else { + stream << "[nargs=" << range.m_min << ".." << range.m_max << "] "; + } + } + return stream; + } + + bool operator==(const NArgsRange &rhs) const { + return rhs.m_min == m_min && rhs.m_max == m_max; + } + + bool operator!=(const NArgsRange &rhs) const { return !(*this == rhs); } + }; + + void throw_nargs_range_validation_error() const { + std::stringstream stream; + if (!m_used_name.empty()) { + stream << m_used_name << ": "; + } else { + stream << m_names.front() << ": "; + } + if (m_num_args_range.is_exact()) { + stream << m_num_args_range.get_min(); + } else if (m_num_args_range.is_right_bounded()) { + stream << m_num_args_range.get_min() << " to " + << m_num_args_range.get_max(); + } else { + stream << m_num_args_range.get_min() << " or more"; + } + stream << " argument(s) expected. " << m_values.size() << " provided."; + throw std::runtime_error(stream.str()); + } + + void throw_required_arg_not_used_error() const { + std::stringstream stream; + stream << m_names.front() << ": required."; + throw std::runtime_error(stream.str()); + } + + void throw_required_arg_no_value_provided_error() const { + std::stringstream stream; + stream << m_used_name << ": no value provided."; + throw std::runtime_error(stream.str()); + } + + static constexpr int eof = std::char_traits::eof(); + + static auto lookahead(std::string_view s) -> int { + if (s.empty()) { + return eof; + } + return static_cast(static_cast(s[0])); + } + + /* + * decimal-literal: + * '0' + * nonzero-digit digit-sequence_opt + * integer-part fractional-part + * fractional-part + * integer-part '.' exponent-part_opt + * integer-part exponent-part + * + * integer-part: + * digit-sequence + * + * fractional-part: + * '.' post-decimal-point + * + * post-decimal-point: + * digit-sequence exponent-part_opt + * + * exponent-part: + * 'e' post-e + * 'E' post-e + * + * post-e: + * sign_opt digit-sequence + * + * sign: one of + * '+' '-' + */ + static bool is_decimal_literal(std::string_view s) { + auto is_digit = [](auto c) constexpr { + switch (c) { + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + return true; + default: + return false; + } + }; + + // precondition: we have consumed or will consume at least one digit + auto consume_digits = [=](std::string_view sd) { + // NOLINTNEXTLINE(readability-qualified-auto) + auto it = std::find_if_not(std::begin(sd), std::end(sd), is_digit); + return sd.substr(static_cast(it - std::begin(sd))); + }; + + switch (lookahead(s)) { + case '0': { + s.remove_prefix(1); + if (s.empty()) { + return true; + } + goto integer_part; + } + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': { + s = consume_digits(s); + if (s.empty()) { + return true; + } + goto integer_part_consumed; + } + case '.': { + s.remove_prefix(1); + goto post_decimal_point; + } + default: + return false; + } + + integer_part: + s = consume_digits(s); + integer_part_consumed: + switch (lookahead(s)) { + case '.': { + s.remove_prefix(1); + if (is_digit(lookahead(s))) { + goto post_decimal_point; + } else { + goto exponent_part_opt; + } + } + case 'e': + case 'E': { + s.remove_prefix(1); + goto post_e; + } + default: + return false; + } + + post_decimal_point: + if (is_digit(lookahead(s))) { + s = consume_digits(s); + goto exponent_part_opt; + } + return false; + + exponent_part_opt: + switch (lookahead(s)) { + case eof: + return true; + case 'e': + case 'E': { + s.remove_prefix(1); + goto post_e; + } + default: + return false; + } + + post_e: + switch (lookahead(s)) { + case '-': + case '+': + s.remove_prefix(1); + } + if (is_digit(lookahead(s))) { + s = consume_digits(s); + return s.empty(); + } + return false; + } + + static bool is_optional(std::string_view name, + std::string_view prefix_chars) { + return !is_positional(name, prefix_chars); + } + + /* + * Get argument value given a type + * @throws std::logic_error in case of incompatible types + */ + template T get() const { + if (!m_values.empty()) { + if constexpr (details::IsContainer) { + return any_cast_container(m_values); + } else { + return std::any_cast(m_values.front()); + } + } + if (m_default_value.has_value()) { + return std::any_cast(m_default_value); + } + if constexpr (details::IsContainer) { + if (!m_accepts_optional_like_value) { + return any_cast_container(m_values); + } + } + + throw std::logic_error("No value provided for '" + m_names.back() + "'."); + } + + /* + * Get argument value given a type. + * @pre The object has no default value. + * @returns The stored value if any, std::nullopt otherwise. + */ + template auto present() const -> std::optional { + if (m_default_value.has_value()) { + throw std::logic_error("Argument with default value always presents"); + } + if (m_values.empty()) { + return std::nullopt; + } + if constexpr (details::IsContainer) { + return any_cast_container(m_values); + } + return std::any_cast(m_values.front()); + } + + template + static auto any_cast_container(const std::vector &operand) -> T { + using ValueType = typename T::value_type; + + T result; + std::transform( + std::begin(operand), std::end(operand), std::back_inserter(result), + [](const auto &value) { return std::any_cast(value); }); + return result; + } + + void set_usage_newline_counter(int i) { m_usage_newline_counter = i; } + + void set_group_idx(std::size_t i) { m_group_idx = i; } + + std::vector m_names; + std::string_view m_used_name; + std::string m_help; + std::string m_metavar; + std::any m_default_value; + std::string m_default_value_repr; + std::optional + m_default_value_str; // used for checking default_value against choices + std::any m_implicit_value; + std::optional> m_choices{std::nullopt}; + using valued_action = std::function; + using void_action = std::function; + std::variant m_action{ + std::in_place_type, + [](const std::string &value) { return value; }}; + std::vector m_values; + NArgsRange m_num_args_range{1, 1}; + // Bit field of bool values. Set default value in ctor. + bool m_accepts_optional_like_value : 1; + bool m_is_optional : 1; + bool m_is_required : 1; + bool m_is_repeatable : 1; + bool m_is_used : 1; + bool m_is_hidden : 1; // if set, does not appear in usage or help + std::string_view m_prefix_chars; // ArgumentParser has the prefix_chars + int m_usage_newline_counter = 0; + std::size_t m_group_idx = 0; +}; + +class ArgumentParser { +public: + explicit ArgumentParser(std::string program_name = {}, + std::string version = "1.0", + default_arguments add_args = default_arguments::all, + bool exit_on_default_arguments = true, + std::ostream &os = std::cout) + : m_program_name(std::move(program_name)), m_version(std::move(version)), + m_exit_on_default_arguments(exit_on_default_arguments), + m_parser_path(m_program_name) { + if ((add_args & default_arguments::help) == default_arguments::help) { + add_argument("-h", "--help") + .action([&](const auto & /*unused*/) { + os << help().str(); + if (m_exit_on_default_arguments) { + std::exit(0); + } + }) + .default_value(false) + .help("shows help message and exits") + .implicit_value(true) + .nargs(0); + } + if ((add_args & default_arguments::version) == default_arguments::version) { + add_argument("-v", "--version") + .action([&](const auto & /*unused*/) { + os << m_version << std::endl; + if (m_exit_on_default_arguments) { + std::exit(0); + } + }) + .default_value(false) + .help("prints version information and exits") + .implicit_value(true) + .nargs(0); + } + } + + ~ArgumentParser() = default; + + // ArgumentParser is meant to be used in a single function. + // Setup everything and parse arguments in one place. + // + // ArgumentParser internally uses std::string_views, + // references, iterators, etc. + // Many of these elements become invalidated after a copy or move. + ArgumentParser(const ArgumentParser &other) = delete; + ArgumentParser &operator=(const ArgumentParser &other) = delete; + ArgumentParser(ArgumentParser &&) noexcept = delete; + ArgumentParser &operator=(ArgumentParser &&) = delete; + + explicit operator bool() const { + auto arg_used = std::any_of(m_argument_map.cbegin(), m_argument_map.cend(), + [](auto &it) { return it.second->m_is_used; }); + auto subparser_used = + std::any_of(m_subparser_used.cbegin(), m_subparser_used.cend(), + [](auto &it) { return it.second; }); + + return m_is_parsed && (arg_used || subparser_used); + } + + // Parameter packing + // Call add_argument with variadic number of string arguments + template Argument &add_argument(Targs... f_args) { + using array_of_sv = std::array; + auto argument = + m_optional_arguments.emplace(std::cend(m_optional_arguments), + m_prefix_chars, array_of_sv{f_args...}); + + if (!argument->m_is_optional) { + m_positional_arguments.splice(std::cend(m_positional_arguments), + m_optional_arguments, argument); + } + argument->set_usage_newline_counter(m_usage_newline_counter); + argument->set_group_idx(m_group_names.size()); + + index_argument(argument); + return *argument; + } + + class MutuallyExclusiveGroup { + friend class ArgumentParser; + + public: + MutuallyExclusiveGroup() = delete; + + explicit MutuallyExclusiveGroup(ArgumentParser &parent, + bool required = false) + : m_parent(parent), m_required(required), m_elements({}) {} + + MutuallyExclusiveGroup(const MutuallyExclusiveGroup &other) = delete; + MutuallyExclusiveGroup & + operator=(const MutuallyExclusiveGroup &other) = delete; + + MutuallyExclusiveGroup(MutuallyExclusiveGroup &&other) noexcept + : m_parent(other.m_parent), m_required(other.m_required), + m_elements(std::move(other.m_elements)) { + other.m_elements.clear(); + } + + template Argument &add_argument(Targs... f_args) { + auto &argument = m_parent.add_argument(std::forward(f_args)...); + m_elements.push_back(&argument); + argument.set_usage_newline_counter(m_parent.m_usage_newline_counter); + argument.set_group_idx(m_parent.m_group_names.size()); + return argument; + } + + private: + ArgumentParser &m_parent; + bool m_required{false}; + std::vector m_elements{}; + }; + + MutuallyExclusiveGroup &add_mutually_exclusive_group(bool required = false) { + m_mutually_exclusive_groups.emplace_back(*this, required); + return m_mutually_exclusive_groups.back(); + } + + // Parameter packed add_parents method + // Accepts a variadic number of ArgumentParser objects + template + ArgumentParser &add_parents(const Targs &... f_args) { + for (const ArgumentParser &parent_parser : {std::ref(f_args)...}) { + for (const auto &argument : parent_parser.m_positional_arguments) { + auto it = m_positional_arguments.insert( + std::cend(m_positional_arguments), argument); + index_argument(it); + } + for (const auto &argument : parent_parser.m_optional_arguments) { + auto it = m_optional_arguments.insert(std::cend(m_optional_arguments), + argument); + index_argument(it); + } + } + return *this; + } + + // Ask for the next optional arguments to be displayed on a separate + // line in usage() output. Only effective if set_usage_max_line_width() is + // also used. + ArgumentParser &add_usage_newline() { + ++m_usage_newline_counter; + return *this; + } + + // Ask for the next optional arguments to be displayed in a separate section + // in usage() and help (<< *this) output. + // For usage(), this is only effective if set_usage_max_line_width() is + // also used. + ArgumentParser &add_group(std::string group_name) { + m_group_names.emplace_back(std::move(group_name)); + return *this; + } + + ArgumentParser &add_description(std::string description) { + m_description = std::move(description); + return *this; + } + + ArgumentParser &add_epilog(std::string epilog) { + m_epilog = std::move(epilog); + return *this; + } + + // Add a un-documented/hidden alias for an argument. + // Ideally we'd want this to be a method of Argument, but Argument + // does not own its owing ArgumentParser. + ArgumentParser &add_hidden_alias_for(Argument &arg, std::string_view alias) { + for (auto it = m_optional_arguments.begin(); + it != m_optional_arguments.end(); ++it) { + if (&(*it) == &arg) { + m_argument_map.insert_or_assign(std::string(alias), it); + return *this; + } + } + throw std::logic_error( + "Argument is not an optional argument of this parser"); + } + + /* Getter for arguments and subparsers. + * @throws std::logic_error in case of an invalid argument or subparser name + */ + template T &at(std::string_view name) { + if constexpr (std::is_same_v) { + return (*this)[name]; + } else { + std::string str_name(name); + auto subparser_it = m_subparser_map.find(str_name); + if (subparser_it != m_subparser_map.end()) { + return subparser_it->second->get(); + } + throw std::logic_error("No such subparser: " + str_name); + } + } + + ArgumentParser &set_prefix_chars(std::string prefix_chars) { + m_prefix_chars = std::move(prefix_chars); + return *this; + } + + ArgumentParser &set_assign_chars(std::string assign_chars) { + m_assign_chars = std::move(assign_chars); + return *this; + } + + /* Call parse_args_internal - which does all the work + * Then, validate the parsed arguments + * This variant is used mainly for testing + * @throws std::runtime_error in case of any invalid argument + */ + void parse_args(const std::vector &arguments) { + parse_args_internal(arguments); + // Check if all arguments are parsed + for ([[maybe_unused]] const auto &[unused, argument] : m_argument_map) { + argument->validate(); + } + + // Check each mutually exclusive group and make sure + // there are no constraint violations + for (const auto &group : m_mutually_exclusive_groups) { + auto mutex_argument_used{false}; + Argument *mutex_argument_it{nullptr}; + for (Argument *arg : group.m_elements) { + if (!mutex_argument_used && arg->m_is_used) { + mutex_argument_used = true; + mutex_argument_it = arg; + } else if (mutex_argument_used && arg->m_is_used) { + // Violation + throw std::runtime_error("Argument '" + arg->get_usage_full() + + "' not allowed with '" + + mutex_argument_it->get_usage_full() + "'"); + } + } + + if (!mutex_argument_used && group.m_required) { + // at least one argument from the group is + // required + std::string argument_names{}; + std::size_t i = 0; + std::size_t size = group.m_elements.size(); + for (Argument *arg : group.m_elements) { + if (i + 1 == size) { + // last + argument_names += "'" + arg->get_usage_full() + "' "; + } else { + argument_names += "'" + arg->get_usage_full() + "' or "; + } + i += 1; + } + throw std::runtime_error("One of the arguments " + argument_names + + "is required"); + } + } + } + + /* Call parse_known_args_internal - which does all the work + * Then, validate the parsed arguments + * This variant is used mainly for testing + * @throws std::runtime_error in case of any invalid argument + */ + std::vector + parse_known_args(const std::vector &arguments) { + auto unknown_arguments = parse_known_args_internal(arguments); + // Check if all arguments are parsed + for ([[maybe_unused]] const auto &[unused, argument] : m_argument_map) { + argument->validate(); + } + return unknown_arguments; + } + + /* Main entry point for parsing command-line arguments using this + * ArgumentParser + * @throws std::runtime_error in case of any invalid argument + */ + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays) + void parse_args(int argc, const char *const argv[]) { + parse_args({argv, argv + argc}); + } + + /* Main entry point for parsing command-line arguments using this + * ArgumentParser + * @throws std::runtime_error in case of any invalid argument + */ + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays) + auto parse_known_args(int argc, const char *const argv[]) { + return parse_known_args({argv, argv + argc}); + } + + /* Getter for options with default values. + * @throws std::logic_error if parse_args() has not been previously called + * @throws std::logic_error if there is no such option + * @throws std::logic_error if the option has no value + * @throws std::bad_any_cast if the option is not of type T + */ + template T get(std::string_view arg_name) const { + if (!m_is_parsed) { + throw std::logic_error("Nothing parsed, no arguments are available."); + } + return (*this)[arg_name].get(); + } + + /* Getter for options without default values. + * @pre The option has no default value. + * @throws std::logic_error if there is no such option + * @throws std::bad_any_cast if the option is not of type T + */ + template + auto present(std::string_view arg_name) const -> std::optional { + return (*this)[arg_name].present(); + } + + /* Getter that returns true for user-supplied options. Returns false if not + * user-supplied, even with a default value. + */ + auto is_used(std::string_view arg_name) const { + return (*this)[arg_name].m_is_used; + } + + /* Getter that returns true if a subcommand is used. + */ + auto is_subcommand_used(std::string_view subcommand_name) const { + return m_subparser_used.at(std::string(subcommand_name)); + } + + /* Getter that returns true if a subcommand is used. + */ + auto is_subcommand_used(const ArgumentParser &subparser) const { + return is_subcommand_used(subparser.m_program_name); + } + + /* Indexing operator. Return a reference to an Argument object + * Used in conjunction with Argument.operator== e.g., parser["foo"] == true + * @throws std::logic_error in case of an invalid argument name + */ + Argument &operator[](std::string_view arg_name) const { + std::string name(arg_name); + auto it = m_argument_map.find(name); + if (it != m_argument_map.end()) { + return *(it->second); + } + if (!is_valid_prefix_char(arg_name.front())) { + const auto legal_prefix_char = get_any_valid_prefix_char(); + const auto prefix = std::string(1, legal_prefix_char); + + // "-" + arg_name + name = prefix + name; + it = m_argument_map.find(name); + if (it != m_argument_map.end()) { + return *(it->second); + } + // "--" + arg_name + name = prefix + name; + it = m_argument_map.find(name); + if (it != m_argument_map.end()) { + return *(it->second); + } + } + throw std::logic_error("No such argument: " + std::string(arg_name)); + } + + // Print help message + friend auto operator<<(std::ostream &stream, const ArgumentParser &parser) + -> std::ostream & { + stream.setf(std::ios_base::left); + + auto longest_arg_length = parser.get_length_of_longest_argument(); + + stream << parser.usage() << "\n\n"; + + if (!parser.m_description.empty()) { + stream << parser.m_description << "\n\n"; + } + + const bool has_visible_positional_args = std::find_if( + parser.m_positional_arguments.begin(), + parser.m_positional_arguments.end(), + [](const auto &argument) { + return !argument.m_is_hidden; }) != + parser.m_positional_arguments.end(); + if (has_visible_positional_args) { + stream << "Positional arguments:\n"; + } + + for (const auto &argument : parser.m_positional_arguments) { + if (!argument.m_is_hidden) { + stream.width(static_cast(longest_arg_length)); + stream << argument; + } + } + + if (!parser.m_optional_arguments.empty()) { + stream << (!has_visible_positional_args ? "" : "\n") + << "Optional arguments:\n"; + } + + for (const auto &argument : parser.m_optional_arguments) { + if (argument.m_group_idx == 0 && !argument.m_is_hidden) { + stream.width(static_cast(longest_arg_length)); + stream << argument; + } + } + + for (size_t i_group = 0; i_group < parser.m_group_names.size(); ++i_group) { + stream << "\n" << parser.m_group_names[i_group] << " (detailed usage):\n"; + for (const auto &argument : parser.m_optional_arguments) { + if (argument.m_group_idx == i_group + 1 && !argument.m_is_hidden) { + stream.width(static_cast(longest_arg_length)); + stream << argument; + } + } + } + + bool has_visible_subcommands = std::any_of( + parser.m_subparser_map.begin(), parser.m_subparser_map.end(), + [](auto &p) { return !p.second->get().m_suppress; }); + + if (has_visible_subcommands) { + stream << (parser.m_positional_arguments.empty() + ? (parser.m_optional_arguments.empty() ? "" : "\n") + : "\n") + << "Subcommands:\n"; + for (const auto &[command, subparser] : parser.m_subparser_map) { + if (subparser->get().m_suppress) { + continue; + } + + stream << std::setw(2) << " "; + stream << std::setw(static_cast(longest_arg_length - 2)) + << command; + stream << " " << subparser->get().m_description << "\n"; + } + } + + if (!parser.m_epilog.empty()) { + stream << '\n'; + stream << parser.m_epilog << "\n\n"; + } + + return stream; + } + + // Format help message + auto help() const -> std::stringstream { + std::stringstream out; + out << *this; + return out; + } + + // Sets the maximum width for a line of the Usage message + ArgumentParser &set_usage_max_line_width(size_t w) { + this->m_usage_max_line_width = w; + return *this; + } + + // Asks to display arguments of mutually exclusive group on separate lines in + // the Usage message + ArgumentParser &set_usage_break_on_mutex() { + this->m_usage_break_on_mutex = true; + return *this; + } + + // Format usage part of help only + auto usage() const -> std::string { + std::stringstream stream; + + std::string curline("Usage: "); + curline += this->m_program_name; + const bool multiline_usage = + this->m_usage_max_line_width < std::numeric_limits::max(); + const size_t indent_size = curline.size(); + + const auto deal_with_options_of_group = [&](std::size_t group_idx) { + bool found_options = false; + // Add any options inline here + const MutuallyExclusiveGroup *cur_mutex = nullptr; + int usage_newline_counter = -1; + for (const auto &argument : this->m_optional_arguments) { + if (argument.m_is_hidden) { + continue; + } + if (multiline_usage) { + if (argument.m_group_idx != group_idx) { + continue; + } + if (usage_newline_counter != argument.m_usage_newline_counter) { + if (usage_newline_counter >= 0) { + if (curline.size() > indent_size) { + stream << curline << std::endl; + curline = std::string(indent_size, ' '); + } + } + usage_newline_counter = argument.m_usage_newline_counter; + } + } + found_options = true; + const std::string arg_inline_usage = argument.get_inline_usage(); + const MutuallyExclusiveGroup *arg_mutex = + get_belonging_mutex(&argument); + if ((cur_mutex != nullptr) && (arg_mutex == nullptr)) { + curline += ']'; + if (this->m_usage_break_on_mutex) { + stream << curline << std::endl; + curline = std::string(indent_size, ' '); + } + } else if ((cur_mutex == nullptr) && (arg_mutex != nullptr)) { + if ((this->m_usage_break_on_mutex && curline.size() > indent_size) || + curline.size() + 3 + arg_inline_usage.size() > + this->m_usage_max_line_width) { + stream << curline << std::endl; + curline = std::string(indent_size, ' '); + } + curline += " ["; + } else if ((cur_mutex != nullptr) && (arg_mutex != nullptr)) { + if (cur_mutex != arg_mutex) { + curline += ']'; + if (this->m_usage_break_on_mutex || + curline.size() + 3 + arg_inline_usage.size() > + this->m_usage_max_line_width) { + stream << curline << std::endl; + curline = std::string(indent_size, ' '); + } + curline += " ["; + } else { + curline += '|'; + } + } + cur_mutex = arg_mutex; + if (curline.size() + 1 + arg_inline_usage.size() > + this->m_usage_max_line_width) { + stream << curline << std::endl; + curline = std::string(indent_size, ' '); + curline += " "; + } else if (cur_mutex == nullptr) { + curline += " "; + } + curline += arg_inline_usage; + } + if (cur_mutex != nullptr) { + curline += ']'; + } + return found_options; + }; + + const bool found_options = deal_with_options_of_group(0); + + if (found_options && multiline_usage && + !this->m_positional_arguments.empty()) { + stream << curline << std::endl; + curline = std::string(indent_size, ' '); + } + // Put positional arguments after the optionals + for (const auto &argument : this->m_positional_arguments) { + if (argument.m_is_hidden) { + continue; + } + const std::string pos_arg = !argument.m_metavar.empty() + ? argument.m_metavar + : argument.m_names.front(); + if (curline.size() + 1 + pos_arg.size() > this->m_usage_max_line_width) { + stream << curline << std::endl; + curline = std::string(indent_size, ' '); + } + curline += " "; + if (argument.m_num_args_range.get_min() == 0 && + !argument.m_num_args_range.is_right_bounded()) { + curline += "["; + curline += pos_arg; + curline += "]..."; + } else if (argument.m_num_args_range.get_min() == 1 && + !argument.m_num_args_range.is_right_bounded()) { + curline += pos_arg; + curline += "..."; + } else { + curline += pos_arg; + } + } + + if (multiline_usage) { + // Display options of other groups + for (std::size_t i = 0; i < m_group_names.size(); ++i) { + stream << curline << std::endl << std::endl; + stream << m_group_names[i] << ":" << std::endl; + curline = std::string(indent_size, ' '); + deal_with_options_of_group(i + 1); + } + } + + stream << curline; + + // Put subcommands after positional arguments + if (!m_subparser_map.empty()) { + stream << " {"; + std::size_t i{0}; + for (const auto &[command, subparser] : m_subparser_map) { + if (subparser->get().m_suppress) { + continue; + } + + if (i == 0) { + stream << command; + } else { + stream << "," << command; + } + ++i; + } + stream << "}"; + } + + return stream.str(); + } + + // Printing the one and only help message + // I've stuck with a simple message format, nothing fancy. + [[deprecated("Use cout << program; instead. See also help().")]] std::string + print_help() const { + auto out = help(); + std::cout << out.rdbuf(); + return out.str(); + } + + void add_subparser(ArgumentParser &parser) { + parser.m_parser_path = m_program_name + " " + parser.m_program_name; + auto it = m_subparsers.emplace(std::cend(m_subparsers), parser); + m_subparser_map.insert_or_assign(parser.m_program_name, it); + m_subparser_used.insert_or_assign(parser.m_program_name, false); + } + + void set_suppress(bool suppress) { m_suppress = suppress; } + +protected: + const MutuallyExclusiveGroup *get_belonging_mutex(const Argument *arg) const { + for (const auto &mutex : m_mutually_exclusive_groups) { + if (std::find(mutex.m_elements.begin(), mutex.m_elements.end(), arg) != + mutex.m_elements.end()) { + return &mutex; + } + } + return nullptr; + } + + bool is_valid_prefix_char(char c) const { + return m_prefix_chars.find(c) != std::string::npos; + } + + char get_any_valid_prefix_char() const { return m_prefix_chars[0]; } + + /* + * Pre-process this argument list. Anything starting with "--", that + * contains an =, where the prefix before the = has an entry in the + * options table, should be split. + */ + std::vector + preprocess_arguments(const std::vector &raw_arguments) const { + std::vector arguments{}; + for (const auto &arg : raw_arguments) { + + const auto argument_starts_with_prefix_chars = + [this](const std::string &a) -> bool { + if (!a.empty()) { + + const auto legal_prefix = [this](char c) -> bool { + return m_prefix_chars.find(c) != std::string::npos; + }; + + // Windows-style + // if '/' is a legal prefix char + // then allow single '/' followed by argument name, followed by an + // assign char, e.g., ':' e.g., 'test.exe /A:Foo' + const auto windows_style = legal_prefix('/'); + + if (windows_style) { + if (legal_prefix(a[0])) { + return true; + } + } else { + // Slash '/' is not a legal prefix char + // For all other characters, only support long arguments + // i.e., the argument must start with 2 prefix chars, e.g, + // '--foo' e,g, './test --foo=Bar -DARG=yes' + if (a.size() > 1) { + return (legal_prefix(a[0]) && legal_prefix(a[1])); + } + } + } + return false; + }; + + // Check that: + // - We don't have an argument named exactly this + // - The argument starts with a prefix char, e.g., "--" + // - The argument contains an assign char, e.g., "=" + auto assign_char_pos = arg.find_first_of(m_assign_chars); + + if (m_argument_map.find(arg) == m_argument_map.end() && + argument_starts_with_prefix_chars(arg) && + assign_char_pos != std::string::npos) { + // Get the name of the potential option, and check it exists + std::string opt_name = arg.substr(0, assign_char_pos); + if (m_argument_map.find(opt_name) != m_argument_map.end()) { + // This is the name of an option! Split it into two parts + arguments.push_back(std::move(opt_name)); + arguments.push_back(arg.substr(assign_char_pos + 1)); + continue; + } + } + // If we've fallen through to here, then it's a standard argument + arguments.push_back(arg); + } + return arguments; + } + + /* + * @throws std::runtime_error in case of any invalid argument + */ + void parse_args_internal(const std::vector &raw_arguments) { + auto arguments = preprocess_arguments(raw_arguments); + if (m_program_name.empty() && !arguments.empty()) { + m_program_name = arguments.front(); + } + auto end = std::end(arguments); + auto positional_argument_it = std::begin(m_positional_arguments); + for (auto it = std::next(std::begin(arguments)); it != end;) { + const auto ¤t_argument = *it; + if (Argument::is_positional(current_argument, m_prefix_chars)) { + if (positional_argument_it == std::end(m_positional_arguments)) { + + // Check sub-parsers + auto subparser_it = m_subparser_map.find(current_argument); + if (subparser_it != m_subparser_map.end()) { + + // build list of remaining args + const auto unprocessed_arguments = + std::vector(it, end); + + // invoke subparser + m_is_parsed = true; + m_subparser_used[current_argument] = true; + return subparser_it->second->get().parse_args( + unprocessed_arguments); + } + + if (m_positional_arguments.empty()) { + + // Ask the user if they argument they provided was a typo + // for some sub-parser, + // e.g., user provided `git totes` instead of `git notes` + if (!m_subparser_map.empty()) { + throw std::runtime_error( + "Failed to parse '" + current_argument + "', did you mean '" + + std::string{details::get_most_similar_string( + m_subparser_map, current_argument)} + + "'"); + } + + // Ask the user if they meant to use a specific optional argument + if (!m_optional_arguments.empty()) { + for (const auto &opt : m_optional_arguments) { + if (!opt.m_implicit_value.has_value()) { + // not a flag, requires a value + if (!opt.m_is_used) { + throw std::runtime_error( + "Zero positional arguments expected, did you mean " + + opt.get_usage_full()); + } + } + } + + throw std::runtime_error("Zero positional arguments expected"); + } else { + throw std::runtime_error("Zero positional arguments expected"); + } + } else { + throw std::runtime_error("Maximum number of positional arguments " + "exceeded, failed to parse '" + + current_argument + "'"); + } + } + auto argument = positional_argument_it++; + it = argument->consume(it, end); + continue; + } + + auto arg_map_it = m_argument_map.find(current_argument); + if (arg_map_it != m_argument_map.end()) { + auto argument = arg_map_it->second; + it = argument->consume(std::next(it), end, arg_map_it->first); + } else if (const auto &compound_arg = current_argument; + compound_arg.size() > 1 && + is_valid_prefix_char(compound_arg[0]) && + !is_valid_prefix_char(compound_arg[1])) { + ++it; + for (std::size_t j = 1; j < compound_arg.size(); j++) { + auto hypothetical_arg = std::string{'-', compound_arg[j]}; + auto arg_map_it2 = m_argument_map.find(hypothetical_arg); + if (arg_map_it2 != m_argument_map.end()) { + auto argument = arg_map_it2->second; + it = argument->consume(it, end, arg_map_it2->first); + } else { + throw std::runtime_error("Unknown argument: " + current_argument); + } + } + } else { + throw std::runtime_error("Unknown argument: " + current_argument); + } + } + m_is_parsed = true; + } + + /* + * Like parse_args_internal but collects unused args into a vector + */ + std::vector + parse_known_args_internal(const std::vector &raw_arguments) { + auto arguments = preprocess_arguments(raw_arguments); + + std::vector unknown_arguments{}; + + if (m_program_name.empty() && !arguments.empty()) { + m_program_name = arguments.front(); + } + auto end = std::end(arguments); + auto positional_argument_it = std::begin(m_positional_arguments); + for (auto it = std::next(std::begin(arguments)); it != end;) { + const auto ¤t_argument = *it; + if (Argument::is_positional(current_argument, m_prefix_chars)) { + if (positional_argument_it == std::end(m_positional_arguments)) { + + // Check sub-parsers + auto subparser_it = m_subparser_map.find(current_argument); + if (subparser_it != m_subparser_map.end()) { + + // build list of remaining args + const auto unprocessed_arguments = + std::vector(it, end); + + // invoke subparser + m_is_parsed = true; + m_subparser_used[current_argument] = true; + return subparser_it->second->get().parse_known_args_internal( + unprocessed_arguments); + } + + // save current argument as unknown and go to next argument + unknown_arguments.push_back(current_argument); + ++it; + } else { + // current argument is the value of a positional argument + // consume it + auto argument = positional_argument_it++; + it = argument->consume(it, end); + } + continue; + } + + auto arg_map_it = m_argument_map.find(current_argument); + if (arg_map_it != m_argument_map.end()) { + auto argument = arg_map_it->second; + it = argument->consume(std::next(it), end, arg_map_it->first); + } else if (const auto &compound_arg = current_argument; + compound_arg.size() > 1 && + is_valid_prefix_char(compound_arg[0]) && + !is_valid_prefix_char(compound_arg[1])) { + ++it; + for (std::size_t j = 1; j < compound_arg.size(); j++) { + auto hypothetical_arg = std::string{'-', compound_arg[j]}; + auto arg_map_it2 = m_argument_map.find(hypothetical_arg); + if (arg_map_it2 != m_argument_map.end()) { + auto argument = arg_map_it2->second; + it = argument->consume(it, end, arg_map_it2->first); + } else { + unknown_arguments.push_back(current_argument); + break; + } + } + } else { + // current argument is an optional-like argument that is unknown + // save it and move to next argument + unknown_arguments.push_back(current_argument); + ++it; + } + } + m_is_parsed = true; + return unknown_arguments; + } + + // Used by print_help. + std::size_t get_length_of_longest_argument() const { + if (m_argument_map.empty()) { + return 0; + } + std::size_t max_size = 0; + for ([[maybe_unused]] const auto &[unused, argument] : m_argument_map) { + max_size = + std::max(max_size, argument->get_arguments_length()); + } + for ([[maybe_unused]] const auto &[command, unused] : m_subparser_map) { + max_size = std::max(max_size, command.size()); + } + return max_size; + } + + using argument_it = std::list::iterator; + using mutex_group_it = std::vector::iterator; + using argument_parser_it = + std::list>::iterator; + + void index_argument(argument_it it) { + for (const auto &name : std::as_const(it->m_names)) { + m_argument_map.insert_or_assign(name, it); + } + } + + std::string m_program_name; + std::string m_version; + std::string m_description; + std::string m_epilog; + bool m_exit_on_default_arguments = true; + std::string m_prefix_chars{"-"}; + std::string m_assign_chars{"="}; + bool m_is_parsed = false; + std::list m_positional_arguments; + std::list m_optional_arguments; + std::map m_argument_map; + std::string m_parser_path; + std::list> m_subparsers; + std::map m_subparser_map; + std::map m_subparser_used; + std::vector m_mutually_exclusive_groups; + bool m_suppress = false; + std::size_t m_usage_max_line_width = std::numeric_limits::max(); + bool m_usage_break_on_mutex = false; + int m_usage_newline_counter = 0; + std::vector m_group_names; +}; + +} // namespace argparse \ No newline at end of file diff --git a/dynamont-polya/include/polyA.hpp b/dynamont-polya/include/polyA.hpp new file mode 100644 index 0000000..e69de29 diff --git a/dynamont-polya/include/utils.hpp b/dynamont-polya/include/utils.hpp new file mode 100644 index 0000000..22876fd --- /dev/null +++ b/dynamont-polya/include/utils.hpp @@ -0,0 +1,213 @@ +// =============================================================== +// =============================================================== +// =========================== Utility =========================== +// =============================================================== +// =============================================================== + +#pragma once + +#include +#include +#include // file io +#include // file io +#include //log1p +#include //stable_sort +#include //iota + +using namespace std; + +inline constexpr int ALPHABET_SIZE = 5; +const unordered_map BASE2ID = { + {'A', 0}, + {'C', 1}, + {'G', 2}, + {'T', 3}, + {'U', 3}, + {'N', 4}, + {'a', 0}, + {'c', 1}, + {'g', 2}, + {'t', 3}, + {'u', 3}, + {'n', 4} +}; // Nucleotide : Token map +const unordered_map ID2BASE = { + {'0', 'A'}, + {'1', 'C'}, + {'2', 'G'}, + {'3', 'T'}, + {'4', 'N'} +}; // Token : Nucleotide map + +/** + * Sorts the column indices of a row-major-indexed double matrix. + * Complexity is O(C * log(C)), see https://en.cppreference.com/w/cpp/algorithm/stable_sort. + * + * @param matrix a double matrix in row major order + * @param C column size + * @param t the column to sort for + * + * @return size_t vector with the sorted index of column in descending order + */ +vector column_argsort(const double* matrix, const size_t C, const size_t t); + +/** + * C++ version 0.4 std::string style "itoa": + * Contributions from Stuart Lowe, Ray-Yuan Sheu, + * Rodrigo de Salvo Braz, Luc Gallant, John Maloney + * and Brian Hunt + * + * Converts a decimal to number to a number of base ALPHABET_SIZE. + * TODO Works for base between 2 and 16 (included) + * + * Returns kmer in reversed direction! + * + * @param value input number in decimal to convert to base + * @param kmerSize kmer size + * @returns kmer as reversed string, should be 5' - 3' direction +*/ +string itoa(const size_t value, const int kmerSize); + +/** + * C++ version 0.4 std::string style "itoa": + * Contributions from Stuart Lowe, Ray-Yuan Sheu, + * Rodrigo de Salvo Braz, Luc Gallant, John Maloney + * and Brian Hunt + * + * Converts a decimal to number to a number of base ALPHABET_SIZE. + * TODO Works for base between 2 and 16 (included) + * + * Returns kmer in reversed direction! + * + * @param value input number in decimal to convert to base + * @param kmerSize kmer size + * @returns kmer as reversed string, should be 5' - 3' direction +*/ +string itoa(const int value, const int kmerSize); + +/** + * Converts a number of base ALPHABET_SIZE to a decimal number. + * Works ONLY if ALPHABET_SIZE is smaller or equal to 10! + * + * @param i input number in the given base as an array + * @param kmerSize kmer size + * @returns Decimal number representation of given token array +*/ +int toDeci(const int* i, const int kmerSize); + +/** + * Converts the kmers of the model file to the integer representation using the BASE2ID map + * + * @param s kmer containing nucleotides + * @param BASE2ID base to id map + * @param kmerSize kmer size + * @returns integer representation of the given kmer + */ +int kmer2int(const string &s, const int kmerSize); + +/** + * Convert the read sequence to a kmer sequence which is represented by integers. + * + * @param seq read sequence + * @param N length of the read sequence, number of nucleotides + * @param kmerSize kmer size + * @return kmer sequence in integer representation +*/ +int* seq2kmer(const int* seq, const size_t N, const int kmerSize); + +/** + * Read the normal distribution parameters from a given TSV file + * + * @param file path to the TSV file containing the parameters + * @param model kmer model to fill + */ +void readKmerModel(const string &file, vector> &model, const int kmerSize); + +// https://en.wikipedia.org/wiki/Log_probability +/** + * Calculate addition of a+b in log space as efficiently as possible + * with x + log1p(exp(y-x)) : x>y + * + * @param a first value + * @param b second value + * @return log(exp(a) + exp(b)) + */ +inline double logPlus(const double x, const double y) { + if (isinf(x) && isinf(y)) { + return x; + } + if (x>=y){ + return x + log1p(exp(y-x)); + } + return y + log1p(exp(x-y)); +} + +/** + * Calculates the integer representation of the successing kmer given the current kmer and the upcoming nucleotide + * k_i+1 = (k_i mod base^(kmerSize-1)) * base + value(nextNt, base) + * + * @param currentKmer current kmer in decimal representation + * @param nextNt successing nucleotide as a token + * @param ALPHABET_SIZE number of accepted characters + * @param stepSize equals ALPHABET_SIZE ^ (kmerSize - 1) + * @return successing Kmer as integer representation in the current base + */ +inline size_t successingKmer(const size_t currentKmer, const int nextNt, const int stepSize) { + return (currentKmer % stepSize) * ALPHABET_SIZE + nextNt; +} + +/** + * Calculates the integer representation of the precessor kmer given the current kmer and the precessing nucleotide + * k_i-1 = int(k_i/base) + value(priorNt, base) * base^(kmerSize-1) + * + * @param currentKmer current kmer in decimal representation + * @param priorNt precessing nucleotide as a token + * @param ALPHABET_SIZE number of accepted characters + * @param stepSize equals ALPHABET_SIZE ^ (kmerSize - 1) + * @return precessing Kmer as integer representation in the current base + */ +inline size_t precessingKmer(const size_t currentKmer, const int priorNt, const int stepSize) { + return (currentKmer/ALPHABET_SIZE) + (priorNt * stepSize); +} + +// =============================================================== +// =============================================================== +// ===================== Scoring calculations ==================== +// =============================================================== +// =============================================================== + +inline constexpr double log2Pi = 1.8378770664093453; // Precomputed log(2 * M_PI) + +// https://ethz.ch/content/dam/ethz/special-interest/mavt/dynamic-systems-n-control/idsc-dam/Lectures/Stochastic-Systems/Statistical_Methods.pdf +/** + * Calculate log pdf for a given x, mean and standard deviation + * + * @param x value + * @param m mean + * @param s standard deviation + * @return probabily density at position x for N~(m, s²) +*/ +inline double log_normal_pdf(const double x, const double m, const double s) { + if (s == 0.0) { + return -INFINITY; // Handling edge case where standard deviation is 0 + } + + const double variance = s * s; + const double diff = x - m; + + return -0.5 * (log2Pi + log(variance) + (diff * diff) / variance); +} + +/** + * Return log probability density for a given value and a given normal distribution + * + * @param signal point to calculate probability density + * @param kmer key for the model kmer:(mean, stdev) map + * @param model map containing kmers as keys and (mean, stdev) tuples as values + * @return log probability density value for x in the given normal distribution + */ +inline double scoreKmer(const double signal, const size_t kmer, const vector> &model) { + // Access elements of the model tuple directly to avoid redundant tuple creation and overhead + const auto &[mean, stddev] = model[kmer]; + return log_normal_pdf(signal, mean, stddev); +} \ No newline at end of file diff --git a/dynamont-polya/main.cpp b/dynamont-polya/main.cpp new file mode 100644 index 0000000..e69de29 diff --git a/dynamont-polya/src/polyA.cpp b/dynamont-polya/src/polyA.cpp new file mode 100644 index 0000000..4e8b22b --- /dev/null +++ b/dynamont-polya/src/polyA.cpp @@ -0,0 +1,596 @@ +// author: Jannes Spangenberg +// e-mail: jannes.spangenberg@uni-jena.de +// github: https://github.com/JannesSP +// website: https://jannessp.github.io + +#include +#include +#include // file io +#include // file io +#include +#include // dictionary +#include +#include +#include // exp +#include +#include +#include +#include +#include "../include/argparse.hpp" +#include "../include/utils.hpp" + +// TODO split main and rest of functions +// TODO create and use polyA.hpp +// TODO for later, runtime improvement possible by writing explicit emission functions with loc, scale and df as constexpr, only sig_val is variable as parameter + +// TODO do not use this, write explicit std::cin, std::cout, std::exp, std::... -> DONE +//using namespace std; + +inline constexpr double EPSILON = 1e-5; // chose by eye just to distinguish real errors from numeric errors + +// Asserts doubleing point compatibility at compile time // ? +// necessary for INFINITY usage +static_assert(std::numeric_limits::is_iec559, "IEEE 754 required"); + +//! ------------------------------------------ PDFs, Forward, Backward & Posterior Probability ---------------------------------------------- + +/** + * DIST & PARAM IN -> 60 READS : + * adapter t, df: 5.612094 loc: -0.759701 scale: 0.535895 + * polyA t, df: 6.022091, loc: 0.839093, scale: 0.217290 + * leader gumbel l, loc: 0.927918 , scale: 0.398849 + * transcript gumbel r, loc: -0.341699 , scale: 0.890093 + * start gumbel r, loc: -1.552134, scale: 0.415937 + */ + +// TODO make function inline +/** + * logarithm t distribution PDF : checked the correctness with scipy.stats.t + */ +double log_t_pdf(const double sig_val, const double loc, const double scale, const double df) +{ + + const double pi = 3.14159265358979323846; + const double diff = (sig_val - loc) / scale; + const double logGammaNuPlusOneHalf = lgamma((df + 1.0) / 2.0); + const double logGammaNuHalf = lgamma(df / 2.0); + + return logGammaNuPlusOneHalf - logGammaNuHalf - 0.5 * log(df * pi * scale * scale) - (df + 1.0) / 2.0 * log(1.0 + (diff * diff) / df); +} + +// TODO make function inline +/** + * logarithm gumbel left skewed PDF : checked the correctness with scipy.stats.gumbel_l + */ +double log_gumbel_l_pdf(const double sig_val, const double loc, const double scale) +{ + // if (scale == 0.0) { + // return -INFINITY; // Handling edge case where beta (scale) is 0 + // } + + const double z = -(sig_val - loc) / scale; + + return -z - exp(-z); +} + +// TODO make function inline +/** + * logarithm gumbel right skewed PDF : checked with scipy.stats.gumbel_r, -> //! around 0.92 different with scipy.stat.gumbel_r + */ +double log_gumbel_r_pdf(const double sig_val, const double loc, const double scale) +{ + // if (scale == 0.0) { + // return -INFINITY; // Handling edge case where beta (scale) is 0 + // } + + const double z = (sig_val - loc) / scale; + + return -z - exp(-z); +} + +/** + * Calculate forward matrices using logarithmic values + * 1D array for each state : 5 1D arrays + * S L A PA TR : initialized matrices for each state + */ +void logF(double *sig, double *S, double *L, double *A, double *PA, double *TR, size_t T, + double s, double l1, double l2, double a1, double a2, double pa1, double pa2, double tr1, double tr2) +{ + double start, leader, adapter, polya, transcript; + + S[0] = 0; + + for (size_t t = 1; t < T; ++t) + { + // init state accumulators + start = -INFINITY; + leader = -INFINITY; + adapter = -INFINITY; + polya = -INFINITY; + transcript = -INFINITY; + + // calculate probabilities + // accumulator + (prevV * emission * transition) + start = logPlus(start, S[t - 1] + log_gumbel_r_pdf(sig[t - 1], -1.552134, 0.415937) + s); + + leader = logPlus(leader, S[t - 1] + log_gumbel_l_pdf(sig[t - 1], 0.927918, 0.398849) + l1); + leader = logPlus(leader, L[t - 1] + log_gumbel_l_pdf(sig[t - 1], 0.927918, 0.398849) + l2); + + adapter = logPlus(adapter, L[t - 1] + log_t_pdf(sig[t - 1], -0.759701, 0.535895, 5.612094) + a1); + adapter = logPlus(adapter, A[t - 1] + log_t_pdf(sig[t - 1], -0.759701, 0.535895, 5.612094) + a2); + + polya = logPlus(polya, A[t - 1] + log_t_pdf(sig[t - 1], 0.839093, 0.217290, 6.022091) + pa1); + polya = logPlus(polya, PA[t - 1] + log_t_pdf(sig[t - 1], 0.839093, 0.217290, 6.022091) + pa2); + + transcript = logPlus(transcript, PA[t - 1] + log_gumbel_r_pdf(sig[t - 1], -0.341699, 0.890093) + tr1); + transcript = logPlus(transcript, TR[t - 1] + log_gumbel_r_pdf(sig[t - 1], -0.341699, 0.890093) + tr2); + + // start = logPlus(start, TR[t-1] + log_gumbel_r_pdf(sig[t-1], -1.552134, 0.415937) + s0); + + // update state matrices + S[t] = start; + L[t] = leader; + A[t] = adapter; + PA[t] = polya; + TR[t] = transcript; + + // TODO compress code + // S[t] = S[t - 1] + log_gumbel_r_pdf(sig[t - 1], -1.552134, 0.415937) + s; + // L[t] = logPlus(S[t - 1] + log_gumbel_l_pdf(sig[t - 1], 0.927918, 0.398849) + l1, L[t - 1] + log_gumbel_l_pdf(sig[t - 1], 0.927918, 0.398849) + l2); + // ... + } +} + +/** + * Calculate backward matrices using logarithmic values + */ +void logB(double *sig, double *S, double *L, double *A, double *PA, double *TR, size_t T, + double s, double l1, double l2, double a1, double a2, double pa1, double pa2, double tr1, double tr2) +{ + + double start, leader, adapter, polya, transcript; + + TR[T - 1] = 0; + + for (size_t t = T - 1; t-- > 0;) + { // T-2, ..., 1, 0 + // init state accumulators + start = -INFINITY; + leader = -INFINITY; + adapter = -INFINITY; + polya = -INFINITY; + transcript = -INFINITY; + + // calculate probabilities + // accumulator + (prevV * emission(t) * transition) + start = logPlus(start, S[t + 1] + log_gumbel_r_pdf(sig[t], -1.552134, 0.415937) + s); + start = logPlus(start, L[t + 1] + log_gumbel_l_pdf(sig[t], 0.927918, 0.398849) + l1); + + leader = logPlus(leader, L[t + 1] + log_gumbel_l_pdf(sig[t], 0.927918, 0.398849) + l2); + leader = logPlus(leader, A[t + 1] + log_t_pdf(sig[t], -0.759701, 0.535895, 5.612094) + a1); + + adapter = logPlus(adapter, A[t + 1] + log_t_pdf(sig[t], -0.759701, 0.535895, 5.612094) + a2); + adapter = logPlus(adapter, PA[t + 1] + log_t_pdf(sig[t], 0.839093, 0.217290, 6.022091) + pa1); + + polya = logPlus(polya, PA[t + 1] + log_t_pdf(sig[t], 0.839093, 0.217290, 6.022091) + pa2); + polya = logPlus(polya, TR[t + 1] + log_gumbel_r_pdf(sig[t], -0.341699, 0.890093) + tr1); + + transcript = logPlus(transcript, TR[t + 1] + log_gumbel_r_pdf(sig[t], -0.341699, 0.890093) + tr2); + + // transcript = logPlus(transcript, S[t+1] + log_gumbel_r_pdf(sig[t], -1.552134, 0.415937) + s0); + + // update state matrices + S[t] = start; + L[t] = leader; + A[t] = adapter; + PA[t] = polya; + TR[t] = transcript; + + // TODO compress code + } +} + +/** + * Calculate the logarithmic probability matrix - posterior probability + */ +double *logP(const double *F, const double *B, const double Z, const size_t T) +{ + double *LP = new double[T]; + for (size_t t = 0; t < T; ++t) + { + LP[t] = F[t] + B[t] - Z; + } + return LP; +} + +//! --------------------------------------------------------- BACKTRACING SECTION ------------------------------------------------------ + +/** + * define backtracing function after each state + */ + +// Backtracking Funcs Declaration +void funcTR(const size_t t, const double *S, const double *L, const double *A, const double *PA, const double *TR, + const double *LPS, const double *LPL, const double *LPA, const double *LPPA, const double *LPTR, + std::list &segString, std::vector &borders, std::string prevState); + +void funcS(const size_t t, const double *S, const double *L, const double *A, const double *PA, const double *TR, + const double *LPS, const double *LPL, const double *LPA, const double *LPPA, const double *LPTR, + std::list &segString, std::vector &borders, std::string prevState); + +void funcL(const size_t t, const double *S, const double *L, const double *A, const double *PA, const double *TR, + const double *LPS, const double *LPL, const double *LPA, const double *LPPA, const double *LPTR, + std::list &segString, std::vector &borders, std::string prevState); + +void funcA(const size_t t, const double *S, const double *L, const double *A, const double *PA, const double *TR, + const double *LPS, const double *LPL, const double *LPA, const double *LPPA, const double *LPTR, + std::list &segString, std::vector &borders, std::string prevState); + +void funcPA(const size_t t, const double *S, const double *L, const double *A, const double *PA, const double *TR, + const double *LPS, const double *LPL, const double *LPA, const double *LPPA, const double *LPTR, + std::list &segString, std::vector &borders, std::string prevState); + +void funcS(const size_t t, const double *S, const double *L, const double *A, const double *PA, const double *TR, + const double *LPS, const double *LPL, const double *LPA, const double *LPPA, const double *LPTR, std::list &segString, std::vector &borders, std::string prevState) +{ + + // base case only in S as last region + if (t == 0) + { + return; + } + + if (S[t] == S[t - 1] + LPS[t]) + { + prevState = "START"; + segString.push_back(prevState); + funcS(t - 1, S, L, A, PA, TR, LPS, LPL, LPA, LPPA, LPTR, segString, borders, prevState); + } + + /* + */ + if (S[t] == TR[t - 1] + LPS[t]) + { + const size_t border_start = t; + borders.push_back(border_start); + prevState = "TRANSCRIPT"; + segString.push_back(prevState); + funcTR(t - 1, S, L, A, PA, TR, LPS, LPL, LPA, LPPA, LPTR, segString, borders, prevState); + } +} + +void funcL(const size_t t, const double *S, const double *L, const double *A, const double *PA, const double *TR, + const double *LPS, const double *LPL, const double *LPA, const double *LPPA, const double *LPTR, std::list &segString, std::vector &borders, std::string prevState) +{ + + if (L[t] == S[t - 1] + LPL[t]) + { + const size_t border_start = t; + borders.push_back(border_start); + prevState = "START"; + segString.push_back(prevState); + funcS(t - 1, S, L, A, PA, TR, LPS, LPL, LPA, LPPA, LPTR, segString, borders, prevState); + } + + if (L[t] == L[t - 1] + LPL[t]) + { + prevState = "LEADER"; + segString.push_back(prevState); + funcL(t - 1, S, L, A, PA, TR, LPS, LPL, LPA, LPPA, LPTR, segString, borders, prevState); + } +} + +void funcA(const size_t t, const double *S, const double *L, const double *A, const double *PA, const double *TR, + const double *LPS, const double *LPL, const double *LPA, const double *LPPA, const double *LPTR, std::list &segString, std::vector &borders, std::string prevState) +{ + + if (A[t] == L[t - 1] + LPA[t]) + { + const size_t border_leader = t; + borders.push_back(border_leader); + prevState = "LEADER"; + segString.push_back(prevState); + funcL(t - 1, S, L, A, PA, TR, LPS, LPL, LPA, LPPA, LPTR, segString, borders, prevState); + } + + if (A[t] == A[t - 1] + LPA[t]) + { + prevState = "ADAPTOR"; + segString.push_back(prevState); + funcA(t - 1, S, L, A, PA, TR, LPS, LPL, LPA, LPPA, LPTR, segString, borders, prevState); + } +} + +void funcPA(const size_t t, const double *S, const double *L, const double *A, const double *PA, const double *TR, + const double *LPS, const double *LPL, const double *LPA, const double *LPPA, const double *LPTR, std::list &segString, std::vector &borders, std::string prevState) +{ + + if (PA[t] == A[t - 1] + LPPA[t]) + { + const size_t border_adaptor = t; + borders.push_back(border_adaptor); + prevState = "ADAPTOR"; + segString.push_back(prevState); + funcA(t - 1, S, L, A, PA, TR, LPS, LPL, LPA, LPPA, LPTR, segString, borders, prevState); + } + + if (PA[t] == PA[t - 1] + LPPA[t]) + { + prevState = "POLYA"; + segString.push_back(prevState); + funcPA(t - 1, S, L, A, PA, TR, LPS, LPL, LPA, LPPA, LPTR, segString, borders, prevState); + } +} + +void funcTR(const size_t t, const double *S, const double *L, const double *A, const double *PA, const double *TR, + const double *LPS, const double *LPL, const double *LPA, const double *LPPA, const double *LPTR, std::list &segString, std::vector &borders, std::string prevState) +{ + if (TR[t] == PA[t - 1] + LPTR[t]) + { + + const size_t border_polyA = t; + borders.push_back(border_polyA); + prevState = "POLYA"; + segString.push_back(prevState); + funcPA(t - 1, S, L, A, PA, TR, LPS, LPL, LPA, LPPA, LPTR, segString, borders, prevState); + } + + if (TR[t] == TR[t - 1] + LPTR[t]) + { + prevState = "TRANSCRIPT"; + segString.push_back(prevState); + funcTR(t - 1, S, L, A, PA, TR, LPS, LPL, LPA, LPPA, LPTR, segString, borders, prevState); + } +} + +/** + * Calculate the maximum a posteriori path (backtracing) - posterioir decoding + */ +std::string getBorders(const double *LPS, const double *LPL, const double *LPA, const double *LPPA, const double *LPTR, const size_t T) +{ + + double *S = new double[T]; + double *L = new double[T]; + double *A = new double[T]; + double *PA = new double[T]; + double *TR = new double[T]; + + // Initialize M and E in one step, no need for fill_n + for (size_t t = 0; t < T; ++t) + { + S[t] = -INFINITY; + L[t] = -INFINITY; + A[t] = -INFINITY; + PA[t] = -INFINITY; + TR[t] = -INFINITY; + } + + double start, leader, adapter, polya, transcript; + S[0] = 0; + + for (size_t t = 1; t < T; ++t) + { + + // TODO compress code -> compressed code (?) : + start, leader, adapter, polya, transcript = -INFINITY; + + start = std::max(start, S[t - 1] + LPS[t]); // s + leader = std::max(leader, S[t - 1] + LPL[t]); // l1 : leave start + leader = std::max(leader, L[t - 1] + LPL[t]); // l2 : stay in leader + adapter = std::max(adapter, L[t - 1] + LPA[t]); // a1 : leave leader + adapter = std::max(adapter, A[t - 1] + LPA[t]); // a2 : stay in adapter + polya = std::max(polya, A[t - 1] + LPPA[t]); // pa1 : leader adapter + polya = std::max(polya, PA[t - 1] + LPPA[t]); // pa2 : stay in polyA + transcript = std::max(transcript, PA[t - 1] + LPTR[t]); // tr1 : leave polyA + transcript = std::max(transcript, TR[t - 1] + LPTR[t]); // tr2 : stay in trancript + + S[t] = start; + L[t] = leader; + A[t] = adapter; + PA[t] = polya; + TR[t] = transcript; + } + + std::list segString; // define string of most probabale states at T-1 backward + std::vector borders; + segString.push_back("TRANSCRIPT"); // signal value at T - 1 pos. 100% in transcript region -> beginn recursion T - 2 onward + + funcTR(T - 1, S, L, A, PA, TR, LPS, LPL, LPA, LPPA, LPTR, segString, borders, "TRANSCRIPT"); + + std::ostringstream oss; + for (size_t i = 0; i < borders.size(); ++i) + { + oss << borders[i]; + if (i < borders.size() - 1) + { + oss << ","; + } + } + + delete[] S; + delete[] L; + delete[] A; + delete[] PA; + delete[] TR; + + return oss.str(); +} + +template // TODO is this necessary? +void writeBorders(const std::string &save_file, const std::string &read_id, const std::vector &borders) +{ + + ofstream output_file(save_file, ios::app); + + if (!output_file.is_open()) + { + cerr << "Error: Unable to open file"; + exit(EXIT_FAILURE); + } + + output_file << read_id << ","; + + for (size_t i = 0; i < borders.size(); ++i) + { + + output_file << borders[i]; + + if (i < borders.size() - 1) + { + output_file << ","; + } + else + { + output_file << "\n"; + } + } + output_file.close(); +} + +/** + * Get the signal Value from python script + */ +int main() +{ + + std::cout << std::fixed << std::showpoint; + std::cout << std::setprecision(20); + + // TODO use argparse, see dynamont + + // transition parameters : + double s = log(0.996943171897388); + double l1 = log(0.0030568281026119044); + double l2 = log(0.9963280807270234); + double a1 = log(0.003671919272976708); + double a2 = log(0.99980542449089); + double pa1 = log(0.0001945755091038867); + double pa2 = log(0.9996311333837735); + double tr1 = log(0.0003688666162265902); + double tr2 = log(1.0); + + std::string signal_values; + + std::getline(std::cin, signal_values); + + // due to buffer error while piping + if (signal_values.empty()) + { + //! diff ? : std:c* vs. printf + std::cerr << "no signal value are provided!"; + //printf("Error: no signal value provided."); + return 1; // non-zero value to indicate something is wrong! + } + + // PROCESS SIGNAL : convert string to double array + // How many signal values are there ? T values + const size_t T = std::count(signal_values.begin(), signal_values.end(), ',') + 2; // len(sig) + 1 + + // init a double array of T-1 elements for signal values + double *sig = new double[T - 1]; + + // put each signal value in i-position of sig + std::string value; + std::stringstream ss(signal_values); + int i = 0; + + while (getline(ss, value, ',')) + { + sig[i++] = std::stod(value); + } + + // so far we have the signal as an array of double values in sig variable + // initialize Forward Backward algorithm calculation + double *forS = new double[T]; + double *forL = new double[T]; + double *forA = new double[T]; + double *forPA = new double[T]; + double *forTR = new double[T]; + double *backS = new double[T]; + double *backL = new double[T]; + double *backA = new double[T]; + double *backPA = new double[T]; + double *backTR = new double[T]; + + for (size_t t = 0; t < T; ++t) + { + forS[t] = -INFINITY; + backS[t] = -INFINITY; + forL[t] = -INFINITY; + backL[t] = -INFINITY; + forA[t] = -INFINITY; + backA[t] = -INFINITY; + forPA[t] = -INFINITY; + backPA[t] = -INFINITY; + forTR[t] = -INFINITY; + backTR[t] = -INFINITY; + } + + // calculate segmentation probabilities, fill forward matrices + logF(sig, forS, forL, forA, forPA, forTR, T, s, l1, l2, a1, a2, pa1, pa2, tr1, tr2); + + // calculate segmentation probabilities, fill backward matrices + logB(sig, backS, backL, backA, backPA, backTR, T, s, l1, l2, a1, a2, pa1, pa2, tr1, tr2); + + // where both values should meet each other + const double Zf = forTR[T - 1]; // end of trancript for Forward + const double Zb = backS[0]; // is same as beginning of start for Backward + + //! ----------------------------------------------- THE START OF MAIN CALCULATATION ----------------------------------------------- + + const double *LPS = logP(forS, backS, Zf, T); + const double *LPL = logP(forL, backL, Zf, T); + const double *LPA = logP(forA, backA, Zf, T); + const double *LPPA = logP(forPA, backPA, Zf, T); + const double *LPTR = logP(forTR, backTR, Zf, T); + + std::string borders = getBorders(LPS, LPL, LPA, LPPA, LPTR, T); + + if (borders.empty()) + { + printf("segmentation failed!"); + // always clean up before return!! + delete[] LPS; + delete[] LPL; + delete[] LPA; + delete[] LPPA; + delete[] LPTR; + delete[] forS; + delete[] forL; + delete[] forA; + delete[] forPA; + delete[] forTR; + delete[] backS; + delete[] backL; + delete[] backA; + delete[] backPA; + delete[] backTR; + delete[] sig; + + return 1; + } + + // writes for python + std::cout << borders << std::endl; + + // Clean up + delete[] LPS; + delete[] LPL; + delete[] LPA; + delete[] LPPA; + delete[] LPTR; + + delete[] forS; + delete[] forL; + delete[] forA; + delete[] forPA; + delete[] forTR; + delete[] backS; + delete[] backL; + delete[] backA; + delete[] backPA; + delete[] backTR; + delete[] sig; + + return 0; // if prints the border then exits with 0 ! +} diff --git a/dynamont-polya/src/polyA.py b/dynamont-polya/src/polyA.py new file mode 100755 index 0000000..c61126c --- /dev/null +++ b/dynamont-polya/src/polyA.py @@ -0,0 +1,160 @@ + +""" +author: Hadi Vareno +e-mail: mohammad.noori.vareno@uni-jena.de +github: https://github.com/TheVareno +""" + +from read5.Reader import read # type: ignore +# import ont_fast5_api # type: ignore +from ont_fast5_api.conversion_tools.fast5_subset import Fast5Filter # type: ignore +import argparse +# TODO use hampelFilter from FileIO.py +from hampel import hampel # type: ignore +import subprocess as sp +import os +import multiprocessing as mp +import queue + + + +def setup_working_directory(working_dir=None)-> None: + if working_dir is None: + working_dir = os.getcwd() + if not os.path.exists(working_dir): + raise FileNotFoundError(f"The working directory: '{working_dir}' does not exist.") + + os.chdir(working_dir) + + + +def find_polya(task_queue: mp.Queue, result_queue: mp.Queue, read_object: str): + + while not task_queue.empty(): + try: + read_id = task_queue.get_nowait() + z_normalized_signal_values = read_object.getZNormSignal(read_id, mode='mean') + filter_object = hampel(z_normalized_signal_values, window_size=5, n_sigma=6.0) + filtered_signal_values = filter_object.filtered_data + + if len(filtered_signal_values) == 0: + print(f"the array of signal values empty for read id : {read_id}") + + polyA_app_call = './polyA' + + sig_vals_str = ','.join(map(str, filtered_signal_values)) + + process = sp.Popen(polyA_app_call, stdin=sp.PIPE, stdout=sp.PIPE, stderr=sp.PIPE, text=True) + + if not sig_vals_str: + print(f"Empty signal values for read {read_id}") + + process.stdin.write(f"{sig_vals_str}\n") + process.stdin.flush() + stdout, stderr = process.communicate() + rc = process.returncode # returns int + + if rc == 0: + borders = stdout.strip() + result_queue.put((read_id, borders)) + else: + pass + + if stderr: + print(f"Error for {read_id}: {stderr}") + continue + + except queue.Empty: + break + + + +def split_segment_input(input_read_data: str, output_path: str, summary_file_path: str): + + raw_signal_path = input_read_data + + if not os.path.exists(output_path): # dir exist. check + os.makedirs(output_path) + + save_file = os.path.join(output_path, 'region_borders.csv') + with open(save_file, 'w') as f: # file exist. check + f.write("Read ID,Poly(A) end,Adapter end,Leader end,Start end\n") + + splitter = Fast5Filter(input_folder=raw_signal_path, + output_folder=output_path, + read_list_file=summary_file_path, + filename_base="subset", + batch_size=500, + threads=1, + recursive=False, + file_list_file=None, + follow_symlinks=False, + target_compression=None) + + splitter.run_batch() # raw signal splitting into 8 files of 500 reads + + for file in os.listdir(output_path): + filename = os.fsencode(file) + + if filename.endswith(b".fast5") or filename.endswith(b".pod5") or filename.endswith(b".slow5"): + + file = os.path.join(output_path, file) + read_object = read(file) # needs file path ends with .fast5 / .pod5 / .slow5 + + all_read_ids = read_object.getReads() # 500 each time + + task_queue = mp.Queue() + result_queue = mp.Queue() + + for read_id in all_read_ids: + task_queue.put(read_id) + + number_of_processes = os.cpu_count() + + processes = [ mp.Process(target=find_polya, args=(task_queue, result_queue, read_object)) for _ in range(number_of_processes) ] + + for proc in processes: + proc.start() + + for proc in processes: + proc.join() + + while not result_queue.empty(): + read_id, borders = result_queue.get() + with open (save_file, 'a') as f: + f.write(f"{read_id},{borders}\n") + else: + continue + + +def main()-> None: + # TODO use same parameters/namings as dynamont + parser = argparse.ArgumentParser(description="Process and Save output file.") + parser.add_argument('-i', "--fast5_path", type=str, required=True, help="Path to input ONT read data in FAST5, POD5, or SLOW5 format.") + parser.add_argument('-o', "--output_dir", type=str, required=True, help="Directory to save output files.") + parser.add_argument('-s', "--summary_file", type=str, required=True, help="Path to the sequence summary file.") + parser.add_argument('-w', "--working_dir", type=str, default=os.getcwd(), help="Working directory for the program (default is the current directory).") + + args = parser.parse_args() + + try: + setup_working_directory(working_dir=args.working_dir) + + except FileNotFoundError as e: + print(f"Error: {e}") + return + + split_segment_input(input_read_data=args.fast5_path, + output_path=args.output_dir, + summary_file_path=args.summary_file) + + + +if __name__ == '__main__' : + main() + + + + + + diff --git a/dynamont-polya/src/utils.cpp b/dynamont-polya/src/utils.cpp new file mode 100644 index 0000000..cdd85bb --- /dev/null +++ b/dynamont-polya/src/utils.cpp @@ -0,0 +1,197 @@ +// =============================================================== +// =============================================================== +// =========================== Utility =========================== +// =============================================================== +// =============================================================== + +#include "utils.hpp" + +/** + * Sorts the column indices of a row-major-indexed double matrix. + * Complexity is O(C * log(C)), see https://en.cppreference.com/w/cpp/algorithm/stable_sort. + * + * @param matrix a double matrix in row major order + * @param C column size + * @param t the column to sort for + * + * @return size_t vector with the sorted index of column in descending order + */ +vector column_argsort(const double* matrix, const size_t C, const size_t t) { + // Initialize original index locations (indices correspond to C) + vector idx(C); + iota(idx.begin(), idx.end(), 0); + + // Sort indexes based on comparing values in the given column 'c' + stable_sort(idx.begin(), idx.end(), + [matrix, C, t](size_t i1, size_t i2) { + return matrix[t * C + i1] > matrix[t * C + i2]; + }); + + return idx; +} + +/** + * C++ version 0.4 std::string style "itoa": + * Contributions from Stuart Lowe, Ray-Yuan Sheu, + * Rodrigo de Salvo Braz, Luc Gallant, John Maloney + * and Brian Hunt + * + * Converts a decimal to number to a number of base ALPHABET_SIZE. + * TODO Works for base between 2 and 16 (included) + * + * Returns kmer in reversed direction! + * + * @param value input number in decimal to convert to base + * @returns kmer as reversed string, should be 5' - 3' direction +*/ +string itoa(const size_t value, const int kmerSize) { + string buf; + int base = kmerSize; + + // check that the base if valid + if (base < 2 || base > 16) return to_string(value); + + enum { kMaxDigits = 35 }; + buf.reserve( kMaxDigits ); // Pre-allocate enough space. + int quotient = value; + + // Translating number to string with base: + do { + buf += ID2BASE.at("0123456789abcdef"[ abs( quotient % base ) ]); + quotient /= base; + } while ( quotient ); + + // Append the negative sign + // if ( value < 0) buf += '-'; + + while ((int) buf.length() < base) { + buf += ID2BASE.at('0'); + } + + // skip this so kmer is in 5' - 3' direction for output + // reverse( buf.begin(), buf.end() ); + return buf; +} + +/** + * C++ version 0.4 std::string style "itoa": + * Contributions from Stuart Lowe, Ray-Yuan Sheu, + * Rodrigo de Salvo Braz, Luc Gallant, John Maloney + * and Brian Hunt + * + * Converts a decimal to number to a number of base ALPHABET_SIZE. + * TODO Works for base between 2 and 16 (included) + * + * Returns kmer in reversed direction! + * + * @param value input number in decimal to convert to base + * @returns kmer as reversed string, should be 5' - 3' direction +*/ +string itoa(const int value, const int kmerSize) { + string buf; + int base = kmerSize; + + // check that the base if valid + if (base < 2 || base > 16) return to_string(value); + + enum { kMaxDigits = 35 }; + buf.reserve( kMaxDigits ); // Pre-allocate enough space. + int quotient = value; + + // Translating number to string with base: + do { + buf += ID2BASE.at("0123456789abcdef"[ abs( quotient % base ) ]); + quotient /= base; + } while ( quotient ); + + // Append the negative sign + if ( value < 0) buf += '-'; + + while ((int) buf.length() < base) { + buf += ID2BASE.at('0'); + } + + // skip this so kmer is in 5' - 3' direction for output + // reverse( buf.begin(), buf.end() ); + return buf; +} + +/** + * Converts a number of base ALPHABET_SIZE to a decimal number. + * Works ONLY if ALPHABET_SIZE is smaller or equal to 10! + * + * @param i input number in the given base as an array +*/ +int toDeci(const int* i, const int kmerSize) { + int ret = 0; + int m = 1; + for(int r = kmerSize - 1; r >= 0; r--) { + ret += m*i[r]; + m *= ALPHABET_SIZE; + } + return ret; +} + +/** + * Converts the kmers of the model file to the integer representation using the BASE2ID map + * + * @param s kmer containing nucleotides + * @param BASE2ID base to id map + * @param kmerSize kmer size + * @returns integer representation of the given kmer + */ +int kmer2int(const string &s, const int kmerSize) { + int ret = 0; + for(char const &c:s){ + // assert (BASE2ID.at(c)>=0); // check if nucleotide is known + ret*=kmerSize; // move the number in base to the left + ret+=BASE2ID.at(c); + } + return ret; +} + +/** + * Convert the read sequence to a kmer sequence which is represented by integers. + * + * @param seq read sequence + * @param N length of the read sequence, number of nucleotides + * @param kmerSize kmer size + * @return kmer sequence in integer representation +*/ +int* seq2kmer(const int* seq, const size_t N, const int kmerSize) { + int* kmer_seq = new int[N]; + int* tempKmer = new int[kmerSize]; + for(size_t n=0; n> &model, const int kmerSize) { + ifstream inputFile(file); + string line, kmer, tmp; + double mean, stdev; + getline(inputFile, line); + while(getline(inputFile, line)) { // read line + stringstream buffer(line); // parse line to stringstream for getline + getline(buffer, kmer, '\t'); + // legacy models are stored from 3' - 5' + // https://github.com/nanoporetech/kmer_models + // new models are stored in 5' - 3' + reverse(kmer.begin(), kmer.end()); // 5-3 -> 3-5 orientation + getline(buffer, tmp, '\t'); // level_mean + mean = atof(tmp.c_str()); + getline(buffer, tmp, '\t'); // level_stdv + stdev = atof(tmp.c_str()); + model[kmer2int(kmer, kmerSize)]=make_tuple(mean, stdev); + } + inputFile.close(); +} \ No newline at end of file