diff --git a/.github/workflows/coverage.yaml b/.github/workflows/coverage.yaml index ccf2be332..0b178a15e 100644 --- a/.github/workflows/coverage.yaml +++ b/.github/workflows/coverage.yaml @@ -62,9 +62,35 @@ jobs: exit 1 fi + - name: Find and fix librdkafka CMakeLists.txt + run: | + # Download the package first so it's in the registry + cargo fetch + + # Find the rdkafka-sys package directory + RDKAFKA_SYS_DIR=$(find ~/.cargo/registry/src -name "rdkafka-sys-*" -type d | head -n 1) + echo "Found rdkafka-sys at: $RDKAFKA_SYS_DIR" + + # Find the librdkafka CMakeLists.txt file + CMAKE_FILE="$RDKAFKA_SYS_DIR/librdkafka/CMakeLists.txt" + + if [ -f "$CMAKE_FILE" ]; then + echo "Found CMakeLists.txt at: $CMAKE_FILE" + + # Make a backup of the original file + cp "$CMAKE_FILE" "$CMAKE_FILE.bak" + + # Replace the minimum required version + sed -i 's/cmake_minimum_required(VERSION 3.2)/cmake_minimum_required(VERSION 3.5)/' "$CMAKE_FILE" + + echo "Modified CMakeLists.txt - before and after comparison:" + diff "$CMAKE_FILE.bak" "$CMAKE_FILE" || true + else + echo "Could not find librdkafka CMakeLists.txt file!" + exit 1 + fi + - name: Check with clippy - env: - CMAKE_FLAGS: "-DCMAKE_POLICY_VERSION_MINIMUM=3.5" run: cargo hack clippy --verbose --each-feature --no-dev-deps -- -D warnings - name: Test default feature set diff --git a/src/alerts/alerts_utils.rs b/src/alerts/alerts_utils.rs index 878f78644..55a2d6146 100644 --- a/src/alerts/alerts_utils.rs +++ b/src/alerts/alerts_utils.rs @@ -25,12 +25,13 @@ use datafusion::{ min_max::{max, min}, sum::sum, }, + logical_expr::{BinaryExpr, Literal, Operator}, prelude::{col, lit, DataFrame, Expr}, }; use tracing::trace; use crate::{ - alerts::AggregateCondition, + alerts::LogicalOperator, parseable::PARSEABLE, query::{TableScanVisitor, QUERY_SESSION}, rbac::{ @@ -42,8 +43,8 @@ use crate::{ }; use super::{ - AggregateConfig, AggregateOperation, AggregateResult, Aggregations, AlertConfig, AlertError, - AlertOperator, AlertState, ConditionConfig, Conditions, ALERTS, + AggregateConfig, AggregateFunction, AggregateResult, Aggregates, AlertConfig, AlertError, + AlertOperator, AlertState, ConditionConfig, Conditions, WhereConfigOperator, ALERTS, }; async fn get_tables_from_query(query: &str) -> Result { @@ -102,23 +103,23 @@ pub async fn evaluate_alert(alert: &AlertConfig) -> Result<(), AlertError> { trace!("RUNNING EVAL TASK FOR- {alert:?}"); let query = prepare_query(alert).await?; - let base_df = execute_base_query(&query, &alert.query).await?; - let agg_results = evaluate_aggregates(&alert.aggregate_config, &base_df).await?; - let final_res = calculate_final_result(&alert.aggregate_config, &agg_results); + let select_query = alert.get_base_query(); + let base_df = execute_base_query(&query, &select_query).await?; + let agg_results = evaluate_aggregates(&alert.aggregates, &base_df).await?; + let final_res = calculate_final_result(&alert.aggregates, &agg_results); update_alert_state(alert, final_res, &agg_results).await?; Ok(()) } async fn prepare_query(alert: &AlertConfig) -> Result { - let (start_time, end_time) = match &alert.eval_type { - super::EvalConfig::RollingWindow(rolling_window) => { - (&rolling_window.eval_start, &rolling_window.eval_end) - } + let (start_time, end_time) = match &alert.eval_config { + super::EvalConfig::RollingWindow(rolling_window) => (&rolling_window.eval_start, "now"), }; let session_state = QUERY_SESSION.state(); - let raw_logical_plan = session_state.create_logical_plan(&alert.query).await?; + let select_query = alert.get_base_query(); + let raw_logical_plan = session_state.create_logical_plan(&select_query).await?; let time_range = TimeRange::parse_human_time(start_time, end_time) .map_err(|err| AlertError::CustomError(err.to_string()))?; @@ -146,15 +147,15 @@ async fn execute_base_query( } async fn evaluate_aggregates( - agg_config: &Aggregations, + agg_config: &Aggregates, base_df: &DataFrame, ) -> Result, AlertError> { let agg_filter_exprs = get_exprs(agg_config); let mut results = Vec::new(); let conditions = match &agg_config.operator { - Some(_) => &agg_config.aggregate_conditions[0..2], - None => &agg_config.aggregate_conditions[0..1], + Some(_) => &agg_config.aggregate_config[0..2], + None => &agg_config.aggregate_config[0..1], }; for ((agg_expr, filter), agg) in agg_filter_exprs.into_iter().zip(conditions) { @@ -186,10 +187,10 @@ async fn evaluate_single_aggregate( let result = evaluate_condition(&agg.operator, final_value, agg.value); let message = if result { - agg.condition_config + agg.conditions .as_ref() .map(|config| config.generate_filter_message()) - .or(Some(String::default())) + .or(None) } else { None }; @@ -206,18 +207,17 @@ fn evaluate_condition(operator: &AlertOperator, actual: f64, expected: f64) -> b match operator { AlertOperator::GreaterThan => actual > expected, AlertOperator::LessThan => actual < expected, - AlertOperator::EqualTo => actual == expected, - AlertOperator::NotEqualTo => actual != expected, - AlertOperator::GreaterThanEqualTo => actual >= expected, - AlertOperator::LessThanEqualTo => actual <= expected, - _ => unreachable!(), + AlertOperator::Equal => actual == expected, + AlertOperator::NotEqual => actual != expected, + AlertOperator::GreaterThanOrEqual => actual >= expected, + AlertOperator::LessThanOrEqual => actual <= expected, } } -fn calculate_final_result(agg_config: &Aggregations, results: &[AggregateResult]) -> bool { +fn calculate_final_result(agg_config: &Aggregates, results: &[AggregateResult]) -> bool { match &agg_config.operator { - Some(AggregateCondition::And) => results.iter().all(|r| r.result), - Some(AggregateCondition::Or) => results.iter().any(|r| r.result), + Some(LogicalOperator::And) => results.iter().all(|r| r.result), + Some(LogicalOperator::Or) => results.iter().any(|r| r.result), None => results.first().is_some_and(|r| r.result), } } @@ -228,8 +228,12 @@ async fn update_alert_state( agg_results: &[AggregateResult], ) -> Result<(), AlertError> { if final_res { - trace!("ALERT!!!!!!"); let message = format_alert_message(agg_results); + let message = format!( + "{message}\nEvaluation Window: {}\nEvaluation Frequency: {}m", + alert.get_eval_window(), + alert.get_eval_frequency() + ); ALERTS .update_state(alert.id, AlertState::Triggered, Some(message)) .await @@ -249,8 +253,8 @@ fn format_alert_message(agg_results: &[AggregateResult]) -> String { for result in agg_results { if let Some(msg) = &result.message { message.extend([format!( - "|{}({}) WHERE ({}) {} {} (ActualValue: {})|", - result.config.agg, + "\nCondition: {}({}) WHERE ({}) {} {}\nActualValue: {}\n", + result.config.aggregate_function, result.config.column, msg, result.config.operator, @@ -259,8 +263,8 @@ fn format_alert_message(agg_results: &[AggregateResult]) -> String { )]); } else { message.extend([format!( - "|{}({}) {} {} (ActualValue: {})", - result.config.agg, + "\nCondition: {}({}) {} {}\nActualValue: {}\n", + result.config.aggregate_function, result.config.column, result.config.operator, result.config.value, @@ -305,17 +309,17 @@ fn get_final_value(aggregated_rows: Vec) -> f64 { /// returns a tuple of (aggregate expressions, filter expressions) /// /// It calls get_filter_expr() to get filter expressions -fn get_exprs(aggregate_config: &Aggregations) -> Vec<(Expr, Option)> { +fn get_exprs(aggregate_config: &Aggregates) -> Vec<(Expr, Option)> { let mut agg_expr = Vec::new(); match &aggregate_config.operator { Some(op) => match op { - AggregateCondition::And | AggregateCondition::Or => { - let agg1 = &aggregate_config.aggregate_conditions[0]; - let agg2 = &aggregate_config.aggregate_conditions[1]; + LogicalOperator::And | LogicalOperator::Or => { + let agg1 = &aggregate_config.aggregate_config[0]; + let agg2 = &aggregate_config.aggregate_config[1]; for agg in [agg1, agg2] { - let filter_expr = if let Some(where_clause) = &agg.condition_config { + let filter_expr = if let Some(where_clause) = &agg.conditions { let fe = get_filter_expr(where_clause); trace!("filter_expr-\n{fe:?}"); @@ -331,9 +335,9 @@ fn get_exprs(aggregate_config: &Aggregations) -> Vec<(Expr, Option)> { } }, None => { - let agg = &aggregate_config.aggregate_conditions[0]; + let agg = &aggregate_config.aggregate_config[0]; - let filter_expr = if let Some(where_clause) = &agg.condition_config { + let filter_expr = if let Some(where_clause) = &agg.conditions { let fe = get_filter_expr(where_clause); trace!("filter_expr-\n{fe:?}"); @@ -353,11 +357,11 @@ fn get_exprs(aggregate_config: &Aggregations) -> Vec<(Expr, Option)> { fn get_filter_expr(where_clause: &Conditions) -> Expr { match &where_clause.operator { Some(op) => match op { - AggregateCondition::And => { + LogicalOperator::And => { let mut expr = Expr::Literal(datafusion::scalar::ScalarValue::Boolean(Some(true))); - let expr1 = &where_clause.conditions[0]; - let expr2 = &where_clause.conditions[1]; + let expr1 = &where_clause.condition_config[0]; + let expr2 = &where_clause.condition_config[1]; for e in [expr1, expr2] { let ex = match_alert_operator(e); @@ -365,11 +369,11 @@ fn get_filter_expr(where_clause: &Conditions) -> Expr { } expr } - AggregateCondition::Or => { + LogicalOperator::Or => { let mut expr = Expr::Literal(datafusion::scalar::ScalarValue::Boolean(Some(false))); - let expr1 = &where_clause.conditions[0]; - let expr2 = &where_clause.conditions[1]; + let expr1 = &where_clause.condition_config[0]; + let expr2 = &where_clause.condition_config[1]; for e in [expr1, expr2] { let ex = match_alert_operator(e); @@ -379,30 +383,86 @@ fn get_filter_expr(where_clause: &Conditions) -> Expr { } }, None => { - let expr = &where_clause.conditions[0]; + let expr = &where_clause.condition_config[0]; match_alert_operator(expr) } } } fn match_alert_operator(expr: &ConditionConfig) -> Expr { + // the form accepts value as a string + // if it can be parsed as a number, then parse it + // else keep it as a string + let value = NumberOrString::from_string(expr.value.clone()); + + // for maintaining column case + let column = format!(r#""{}""#, expr.column); match expr.operator { - AlertOperator::GreaterThan => col(&expr.column).gt(lit(&expr.value)), - AlertOperator::LessThan => col(&expr.column).lt(lit(&expr.value)), - AlertOperator::EqualTo => col(&expr.column).eq(lit(&expr.value)), - AlertOperator::NotEqualTo => col(&expr.column).not_eq(lit(&expr.value)), - AlertOperator::GreaterThanEqualTo => col(&expr.column).gt_eq(lit(&expr.value)), - AlertOperator::LessThanEqualTo => col(&expr.column).lt_eq(lit(&expr.value)), - AlertOperator::Like => col(&expr.column).like(lit(&expr.value)), - AlertOperator::NotLike => col(&expr.column).not_like(lit(&expr.value)), + WhereConfigOperator::Equal => col(column).eq(lit(value)), + WhereConfigOperator::NotEqual => col(column).not_eq(lit(value)), + WhereConfigOperator::LessThan => col(column).lt(lit(value)), + WhereConfigOperator::GreaterThan => col(column).gt(lit(value)), + WhereConfigOperator::LessThanOrEqual => col(column).lt_eq(lit(value)), + WhereConfigOperator::GreaterThanOrEqual => col(column).gt_eq(lit(value)), + WhereConfigOperator::IsNull => col(column).is_null(), + WhereConfigOperator::IsNotNull => col(column).is_not_null(), + WhereConfigOperator::ILike => col(column).ilike(lit(&expr.value)), + WhereConfigOperator::Contains => col(column).like(lit(&expr.value)), + WhereConfigOperator::BeginsWith => Expr::BinaryExpr(BinaryExpr::new( + Box::new(col(column)), + Operator::RegexIMatch, + Box::new(lit(format!("^{}", expr.value))), + )), + WhereConfigOperator::EndsWith => Expr::BinaryExpr(BinaryExpr::new( + Box::new(col(column)), + Operator::RegexIMatch, + Box::new(lit(format!("{}$", expr.value))), + )), + WhereConfigOperator::DoesNotContain => col(column).not_ilike(lit(&expr.value)), + WhereConfigOperator::DoesNotBeginWith => Expr::BinaryExpr(BinaryExpr::new( + Box::new(col(column)), + Operator::RegexNotIMatch, + Box::new(lit(format!("^{}", expr.value))), + )), + WhereConfigOperator::DoesNotEndWith => Expr::BinaryExpr(BinaryExpr::new( + Box::new(col(column)), + Operator::RegexNotIMatch, + Box::new(lit(format!("{}$", expr.value))), + )), } } + fn match_aggregate_operation(agg: &AggregateConfig) -> Expr { - match agg.agg { - AggregateOperation::Avg => avg(col(&agg.column)), - AggregateOperation::Count => count(col(&agg.column)), - AggregateOperation::Min => min(col(&agg.column)), - AggregateOperation::Max => max(col(&agg.column)), - AggregateOperation::Sum => sum(col(&agg.column)), + // for maintaining column case + let column = format!(r#""{}""#, agg.column); + match agg.aggregate_function { + AggregateFunction::Avg => avg(col(column)), + AggregateFunction::Count => count(col(column)), + AggregateFunction::Min => min(col(column)), + AggregateFunction::Max => max(col(column)), + AggregateFunction::Sum => sum(col(column)), + } +} + +enum NumberOrString { + Number(f64), + String(String), +} + +impl Literal for NumberOrString { + fn lit(&self) -> Expr { + match self { + NumberOrString::Number(expr) => lit(*expr), + NumberOrString::String(expr) => lit(expr.clone()), + } + } +} +impl NumberOrString { + fn from_string(value: String) -> Self { + if let Ok(num) = value.parse::() { + NumberOrString::Number(num) + } else { + NumberOrString::String(value) + } } } diff --git a/src/alerts/mod.rs b/src/alerts/mod.rs index 737ba940a..5f5a75087 100644 --- a/src/alerts/mod.rs +++ b/src/alerts/mod.rs @@ -20,30 +20,29 @@ use actix_web::http::header::ContentType; use alerts_utils::user_auth_for_query; use async_trait::async_trait; use chrono::Utc; -use datafusion::common::tree_node::TreeNode; +use derive_more::derive::FromStr; +use derive_more::FromStrError; use http::StatusCode; -use itertools::Itertools; use once_cell::sync::Lazy; use serde::Serialize; use serde_json::Error as SerdeError; use std::collections::{HashMap, HashSet}; use std::fmt::{self, Display}; -use tokio::sync::oneshot::{self, Receiver, Sender}; -use tokio::sync::RwLock; +use std::thread; +use tokio::sync::oneshot::{Receiver, Sender}; +use tokio::sync::{mpsc, RwLock}; use tokio::task::JoinHandle; -use tracing::{trace, warn}; +use tracing::{error, trace, warn}; use ulid::Ulid; pub mod alerts_utils; pub mod target; use crate::parseable::{StreamNotFound, PARSEABLE}; -use crate::query::{TableScanVisitor, QUERY_SESSION}; use crate::rbac::map::SessionKey; use crate::storage; use crate::storage::ObjectStorageError; -use crate::sync::schedule_alert_task; -use crate::utils::time::TimeRange; +use crate::sync::alert_runtime; use self::target::Target; @@ -52,12 +51,27 @@ pub type ScheduledTaskHandlers = (JoinHandle<()>, Receiver<()>, Sender<()>); pub const CURRENT_ALERTS_VERSION: &str = "v1"; -pub static ALERTS: Lazy = Lazy::new(Alerts::default); +pub static ALERTS: Lazy = Lazy::new(|| { + let (tx, rx) = mpsc::channel::(10); + let alerts = Alerts { + alerts: RwLock::new(HashMap::new()), + sender: tx, + }; -#[derive(Debug, Default)] + thread::spawn(|| alert_runtime(rx)); + + alerts +}); + +#[derive(Debug)] pub struct Alerts { pub alerts: RwLock>, - pub scheduled_tasks: RwLock>, + pub sender: mpsc::Sender, +} + +pub enum AlertTask { + Create(AlertConfig), + Delete(Ulid), } #[derive(Default, Debug, serde::Serialize, serde::Deserialize, Clone)] @@ -76,6 +90,7 @@ impl From<&str> for AlertVerison { } } +#[derive(Debug)] pub struct AggregateResult { result: bool, message: Option, @@ -106,7 +121,7 @@ impl Context { fn default_alert_string(&self) -> String { format!( - "AlertName: {}, Triggered TimeStamp: {}, Severity: {}, Message: {}", + "AlertName: {}\nTriggered TimeStamp: {}\nSeverity: {}\n{}", self.alert_info.alert_name, Utc::now().to_rfc3339(), self.alert_info.severity, @@ -183,17 +198,13 @@ pub enum AlertOperator { #[serde(rename = "<")] LessThan, #[serde(rename = "=")] - EqualTo, - #[serde(rename = "<>")] - NotEqualTo, + Equal, + #[serde(rename = "!=")] + NotEqual, #[serde(rename = ">=")] - GreaterThanEqualTo, + GreaterThanOrEqual, #[serde(rename = "<=")] - LessThanEqualTo, - #[serde(rename = "like")] - Like, - #[serde(rename = "not like")] - NotLike, + LessThanOrEqual, } impl Display for AlertOperator { @@ -201,19 +212,82 @@ impl Display for AlertOperator { match self { AlertOperator::GreaterThan => write!(f, ">"), AlertOperator::LessThan => write!(f, "<"), - AlertOperator::EqualTo => write!(f, "="), - AlertOperator::NotEqualTo => write!(f, "<>"), - AlertOperator::GreaterThanEqualTo => write!(f, ">="), - AlertOperator::LessThanEqualTo => write!(f, "<="), - AlertOperator::Like => write!(f, "like"), - AlertOperator::NotLike => write!(f, "not like"), + AlertOperator::Equal => write!(f, "="), + AlertOperator::NotEqual => write!(f, "!="), + AlertOperator::GreaterThanOrEqual => write!(f, ">="), + AlertOperator::LessThanOrEqual => write!(f, "<="), + } + } +} + +#[derive(Debug, serde::Serialize, serde::Deserialize, Clone, FromStr)] +#[serde(rename_all = "camelCase")] +pub enum WhereConfigOperator { + #[serde(rename = "=")] + Equal, + #[serde(rename = "!=")] + NotEqual, + #[serde(rename = "<")] + LessThan, + #[serde(rename = ">")] + GreaterThan, + #[serde(rename = "<=")] + LessThanOrEqual, + #[serde(rename = ">=")] + GreaterThanOrEqual, + #[serde(rename = "is null")] + IsNull, + #[serde(rename = "is not null")] + IsNotNull, + #[serde(rename = "ilike")] + ILike, + #[serde(rename = "contains")] + Contains, + #[serde(rename = "begins with")] + BeginsWith, + #[serde(rename = "ends with")] + EndsWith, + #[serde(rename = "does not contain")] + DoesNotContain, + #[serde(rename = "does not begin with")] + DoesNotBeginWith, + #[serde(rename = "does not end with")] + DoesNotEndWith, +} + +impl WhereConfigOperator { + /// Convert the enum value to its string representation + pub fn as_str(&self) -> &'static str { + match self { + Self::Equal => "=", + Self::NotEqual => "!=", + Self::LessThan => "<", + Self::GreaterThan => ">", + Self::LessThanOrEqual => "<=", + Self::GreaterThanOrEqual => ">=", + Self::IsNull => "is null", + Self::IsNotNull => "is not null", + Self::ILike => "ilike", + Self::Contains => "contains", + Self::BeginsWith => "begins with", + Self::EndsWith => "ends with", + Self::DoesNotContain => "does not contain", + Self::DoesNotBeginWith => "does not begin with", + Self::DoesNotEndWith => "does not end with", } } } +impl Display for WhereConfigOperator { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // We can reuse our as_str method to get the string representation + write!(f, "{}", self.as_str()) + } +} + #[derive(Debug, serde::Serialize, serde::Deserialize, Clone)] #[serde(rename_all = "camelCase")] -pub enum AggregateOperation { +pub enum AggregateFunction { Avg, Count, Min, @@ -221,14 +295,14 @@ pub enum AggregateOperation { Sum, } -impl Display for AggregateOperation { +impl Display for AggregateFunction { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - AggregateOperation::Avg => write!(f, "Avg"), - AggregateOperation::Count => write!(f, "Count"), - AggregateOperation::Min => write!(f, "Min"), - AggregateOperation::Max => write!(f, "Max"), - AggregateOperation::Sum => write!(f, "Sum"), + AggregateFunction::Avg => write!(f, "Avg"), + AggregateFunction::Count => write!(f, "Count"), + AggregateFunction::Min => write!(f, "Min"), + AggregateFunction::Max => write!(f, "Max"), + AggregateFunction::Sum => write!(f, "Sum"), } } } @@ -249,33 +323,26 @@ pub struct FilterConfig { #[derive(Debug, serde::Serialize, serde::Deserialize, Clone)] pub struct ConditionConfig { pub column: String, - pub operator: AlertOperator, + pub operator: WhereConfigOperator, pub value: String, } #[derive(Debug, serde::Serialize, serde::Deserialize, Clone)] #[serde(rename_all = "camelCase")] pub struct Conditions { - pub operator: Option, - pub conditions: Vec, + pub operator: Option, + pub condition_config: Vec, } -// #[derive(Debug, serde::Serialize, serde::Deserialize, Clone)] -// pub enum Conditions { -// AND((ConditionConfig, ConditionConfig)), -// OR((ConditionConfig, ConditionConfig)), -// Condition(ConditionConfig), -// } - impl Conditions { pub fn generate_filter_message(&self) -> String { match &self.operator { Some(op) => match op { - AggregateCondition::And | AggregateCondition::Or => { - let expr1 = &self.conditions[0]; - let expr2 = &self.conditions[1]; + LogicalOperator::And | LogicalOperator::Or => { + let expr1 = &self.condition_config[0]; + let expr2 = &self.condition_config[1]; format!( - "[{} {} {} AND {} {} {}]", + "[{} {} {} {op} {} {} {}]", expr1.column, expr1.operator, expr1.value, @@ -286,43 +353,53 @@ impl Conditions { } }, None => { - let expr = &self.conditions[0]; + let expr = &self.condition_config[0]; format!("[{} {} {}]", expr.column, expr.operator, expr.value) } } } } +#[derive(Debug, serde::Serialize, serde::Deserialize, Clone)] +#[serde(rename_all = "camelCase")] +pub struct GroupBy { + pub columns: Vec, +} + #[derive(Debug, serde::Serialize, serde::Deserialize, Clone)] #[serde(rename_all = "camelCase")] pub struct AggregateConfig { - pub agg: AggregateOperation, - pub condition_config: Option, + pub aggregate_function: AggregateFunction, + pub conditions: Option, + pub group_by: Option, pub column: String, pub operator: AlertOperator, pub value: f64, } -// #[derive(Debug, serde::Serialize, serde::Deserialize, Clone)] -// pub enum Aggregations { -// AND((AggregateConfig, AggregateConfig)), -// OR((AggregateConfig, AggregateConfig)), -// Single(AggregateConfig), -// } - #[derive(Debug, serde::Serialize, serde::Deserialize, Clone)] #[serde(rename_all = "camelCase")] -pub struct Aggregations { - pub operator: Option, - pub aggregate_conditions: Vec, +pub struct Aggregates { + pub operator: Option, + pub aggregate_config: Vec, } #[derive(Debug, serde::Serialize, serde::Deserialize, Clone)] -pub enum AggregateCondition { +#[serde(rename_all = "camelCase")] +pub enum LogicalOperator { And, Or, } +impl Display for LogicalOperator { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + LogicalOperator::And => write!(f, "AND"), + LogicalOperator::Or => write!(f, "OR"), + } + } +} + #[derive(Debug, serde::Serialize, serde::Deserialize, Clone)] #[serde(rename_all = "camelCase")] pub struct RollingWindow { @@ -344,7 +421,7 @@ pub enum EvalConfig { #[serde(rename_all = "camelCase")] pub struct AlertEval {} -#[derive(Debug, serde::Serialize, serde::Deserialize, Clone, Copy, PartialEq, Default)] +#[derive(Debug, serde::Serialize, serde::Deserialize, Clone, Copy, PartialEq, Default, FromStr)] #[serde(rename_all = "camelCase")] pub enum AlertState { Triggered, @@ -363,6 +440,46 @@ impl Display for AlertState { } } +impl AlertState { + pub async fn update_state( + &self, + new_state: AlertState, + alert_id: Ulid, + ) -> Result<(), AlertError> { + match self { + AlertState::Triggered => { + if new_state == AlertState::Triggered { + let msg = format!("Not allowed to manually go from Triggered to {new_state}"); + return Err(AlertError::InvalidStateChange(msg)); + } else { + // update state on disk and in memory + ALERTS + .update_state(alert_id, new_state, Some("".into())) + .await?; + } + } + AlertState::Silenced => { + // from here, the user can only go to Resolved + if new_state == AlertState::Resolved { + // update state on disk and in memory + ALERTS + .update_state(alert_id, new_state, Some("".into())) + .await?; + } else { + let msg = format!("Not allowed to manually go from Silenced to {new_state}"); + return Err(AlertError::InvalidStateChange(msg)); + } + } + AlertState::Resolved => { + // user shouldn't logically be changing states if current state is Resolved + let msg = format!("Not allowed to go manually from Resolved to {new_state}"); + return Err(AlertError::InvalidStateChange(msg)); + } + } + Ok(()) + } +} + #[derive(Debug, serde::Serialize, serde::Deserialize, Clone, Default)] #[serde(rename_all = "camelCase")] pub enum Severity { @@ -390,10 +507,10 @@ pub struct AlertRequest { #[serde(default = "Severity::default")] pub severity: Severity, pub title: String, - pub query: String, + pub stream: String, pub alert_type: AlertType, - pub aggregate_config: Aggregations, - pub eval_type: EvalConfig, + pub aggregates: Aggregates, + pub eval_config: EvalConfig, pub targets: Vec, } @@ -404,10 +521,10 @@ impl From for AlertConfig { id: Ulid::new(), severity: val.severity, title: val.title, - query: val.query, + stream: val.stream, alert_type: val.alert_type, - aggregate_config: val.aggregate_config, - eval_type: val.eval_type, + aggregates: val.aggregates, + eval_config: val.eval_config, targets: val.targets, state: AlertState::default(), } @@ -422,10 +539,10 @@ pub struct AlertConfig { pub id: Ulid, pub severity: Severity, pub title: String, - pub query: String, + pub stream: String, pub alert_type: AlertType, - pub aggregate_config: Aggregations, - pub eval_type: EvalConfig, + pub aggregates: Aggregates, + pub eval_config: EvalConfig, pub targets: Vec, // for new alerts, state should be resolved #[serde(default)] @@ -435,23 +552,23 @@ pub struct AlertConfig { impl AlertConfig { pub fn modify(&mut self, alert: AlertRequest) { self.title = alert.title; - self.query = alert.query; + self.stream = alert.stream; self.alert_type = alert.alert_type; - self.aggregate_config = alert.aggregate_config; - self.eval_type = alert.eval_type; + self.aggregates = alert.aggregates; + self.eval_config = alert.eval_config; self.targets = alert.targets; self.state = AlertState::default(); } + pub fn get_base_query(&self) -> String { + format!("SELECT * FROM \"{}\"", self.stream) + } + /// Validations pub async fn validate(&self) -> Result<(), AlertError> { // validate evalType - let eval_frequency = match &self.eval_type { + let eval_frequency = match &self.eval_config { EvalConfig::RollingWindow(rolling_window) => { - if rolling_window.eval_end != "now" { - return Err(AlertError::Metadata("evalEnd should be now")); - } - if humantime::parse_duration(&rolling_window.eval_start).is_err() { return Err(AlertError::Metadata( "evalStart should be of type humantime", @@ -476,62 +593,29 @@ impl AlertConfig { } } - // validate aggregateCnnfig and conditionConfig + // validate aggregateConfig and conditionConfig self.validate_configs()?; - let session_state = QUERY_SESSION.state(); - let raw_logical_plan = session_state.create_logical_plan(&self.query).await?; - - // create a visitor to extract the table names present in query - let mut visitor = TableScanVisitor::default(); - let _ = raw_logical_plan.visit(&mut visitor); + // validate the presence of columns + let columns = self.get_agg_config_cols(); - let table = visitor.into_inner().first().unwrap().to_owned(); + let schema = PARSEABLE.get_stream(&self.stream)?.get_schema(); - let lowercase = self.query.split(&table).collect_vec()[0].to_lowercase(); + let schema_columns = schema + .fields() + .iter() + .map(|f| f.name()) + .collect::>(); - if lowercase - .strip_prefix(" ") - .unwrap_or(&lowercase) - .strip_suffix(" ") - .unwrap_or(&lowercase) - .ne("select * from") - { - return Err(AlertError::Metadata( - "Query needs to be select * from ", - )); + for col in columns { + if !schema_columns.contains(col) { + return Err(AlertError::CustomError(format!( + "Column {} not found in stream {}", + col, self.stream + ))); + } } - // TODO: Filter tags should be taken care of!!! - let time_range = TimeRange::parse_human_time("1m", "now") - .map_err(|err| AlertError::CustomError(err.to_string()))?; - - let query = crate::query::Query { - raw_logical_plan, - time_range, - filter_tag: None, - }; - - // for now proceed in a similar fashion as we do in query - // TODO: in case of multiple table query does the selection of time partition make a difference? (especially when the tables don't have overlapping data) - let Some(stream_name) = query.first_table_name() else { - return Err(AlertError::CustomError(format!( - "Table name not found in query- {}", - self.query - ))); - }; - - let time_partition = PARSEABLE.get_stream(&stream_name)?.get_time_partition(); - let base_df = query - .get_dataframe(time_partition.as_ref()) - .await - .map_err(|err| AlertError::CustomError(err.to_string()))?; - - // now that we have base_df, verify that it has - // columns from aggregate config - let columns = self.get_agg_config_cols(); - - base_df.select_columns(columns.iter().map(|c| c.as_str()).collect_vec().as_slice())?; Ok(()) } @@ -544,17 +628,17 @@ impl AlertConfig { match &config.operator { Some(_) => { // only two aggregate conditions should be present - if config.conditions.len() != 2 { + if config.condition_config.len() != 2 { return Err(AlertError::CustomError( - "While using AND/OR, two conditions must be used".to_string(), + "While using AND/OR, only two conditions must be used".to_string(), )); } } None => { // only one aggregate condition should be present - if config.conditions.len() != 1 { + if config.condition_config.len() != 1 { return Err(AlertError::CustomError( - "While not using AND/OR, one conditions must be used".to_string(), + "While not using AND/OR, only one condition must be used".to_string(), )); } } @@ -563,32 +647,34 @@ impl AlertConfig { } // validate aggregate config(s) - match &self.aggregate_config.operator { + match &self.aggregates.operator { Some(_) => { // only two aggregate conditions should be present - if self.aggregate_config.aggregate_conditions.len() != 2 { + if self.aggregates.aggregate_config.len() != 2 { return Err(AlertError::CustomError( - "While using AND/OR, two aggregateConditions must be used".to_string(), + "While using AND/OR, only two aggregate conditions must be used" + .to_string(), )); } // validate condition config - let agg1 = &self.aggregate_config.aggregate_conditions[0]; - let agg2 = &self.aggregate_config.aggregate_conditions[0]; + let agg1 = &self.aggregates.aggregate_config[0]; + let agg2 = &self.aggregates.aggregate_config[1]; - validate_condition_config(&agg1.condition_config)?; - validate_condition_config(&agg2.condition_config)?; + validate_condition_config(&agg1.conditions)?; + validate_condition_config(&agg2.conditions)?; } None => { // only one aggregate condition should be present - if self.aggregate_config.aggregate_conditions.len() != 1 { + if self.aggregates.aggregate_config.len() != 1 { return Err(AlertError::CustomError( - "While not using AND/OR, one aggregateConditions must be used".to_string(), + "While not using AND/OR, only one aggregate condition must be used" + .to_string(), )); } - let agg = &self.aggregate_config.aggregate_conditions[0]; - validate_condition_config(&agg.condition_config)?; + let agg = &self.aggregates.aggregate_config[0]; + validate_condition_config(&agg.conditions)?; } } Ok(()) @@ -596,25 +682,25 @@ impl AlertConfig { fn get_agg_config_cols(&self) -> HashSet<&String> { let mut columns: HashSet<&String> = HashSet::new(); - match &self.aggregate_config.operator { + match &self.aggregates.operator { Some(op) => match op { - AggregateCondition::And | AggregateCondition::Or => { - let agg1 = &self.aggregate_config.aggregate_conditions[0]; - let agg2 = &self.aggregate_config.aggregate_conditions[1]; + LogicalOperator::And | LogicalOperator::Or => { + let agg1 = &self.aggregates.aggregate_config[0]; + let agg2 = &self.aggregates.aggregate_config[1]; columns.insert(&agg1.column); columns.insert(&agg2.column); - if let Some(condition) = &agg1.condition_config { + if let Some(condition) = &agg1.conditions { columns.extend(self.get_condition_cols(condition)); } } }, None => { - let agg = &self.aggregate_config.aggregate_conditions[0]; + let agg = &self.aggregates.aggregate_config[0]; columns.insert(&agg.column); - if let Some(condition) = &agg.condition_config { + if let Some(condition) = &agg.conditions { columns.extend(self.get_condition_cols(condition)); } } @@ -626,15 +712,15 @@ impl AlertConfig { let mut columns: HashSet<&String> = HashSet::new(); match &condition.operator { Some(op) => match op { - AggregateCondition::And | AggregateCondition::Or => { - let c1 = &condition.conditions[0]; - let c2 = &condition.conditions[1]; + LogicalOperator::And | LogicalOperator::Or => { + let c1 = &condition.condition_config[0]; + let c2 = &condition.condition_config[1]; columns.insert(&c1.column); columns.insert(&c2.column); } }, None => { - let c = &condition.conditions[0]; + let c = &condition.condition_config[0]; columns.insert(&c.column); } } @@ -642,10 +728,18 @@ impl AlertConfig { } pub fn get_eval_frequency(&self) -> u64 { - match &self.eval_type { + match &self.eval_config { EvalConfig::RollingWindow(rolling_window) => rolling_window.eval_frequency, } } + pub fn get_eval_window(&self) -> String { + match &self.eval_config { + EvalConfig::RollingWindow(rolling_window) => format!( + "Start={}\tEnd={}", + rolling_window.eval_start, rolling_window.eval_end + ), + } + } fn get_context(&self) -> Context { let deployment_instance = format!( @@ -656,12 +750,6 @@ impl AlertConfig { let deployment_id = storage::StorageMetadata::global().deployment_id; let deployment_mode = storage::StorageMetadata::global().mode.to_string(); - // let additional_labels = - // serde_json::to_value(rule).expect("rule is perfectly deserializable"); - // let flatten_additional_labels = - // utils::json::flatten::flatten_with_parent_prefix(additional_labels, "rule", "_") - // .expect("can be flattened"); - Context::new( AlertInfo::new( self.id, @@ -707,6 +795,10 @@ pub enum AlertError { StreamNotFound(#[from] StreamNotFound), #[error("{0}")] Anyhow(#[from] anyhow::Error), + #[error("No alert request body provided")] + InvalidAlertModifyRequest, + #[error("{0}")] + FromStrError(#[from] FromStrError), } impl actix_web::ResponseError for AlertError { @@ -722,6 +814,8 @@ impl actix_web::ResponseError for AlertError { Self::InvalidStateChange(_) => StatusCode::BAD_REQUEST, Self::StreamNotFound(_) => StatusCode::NOT_FOUND, Self::Anyhow(_) => StatusCode::INTERNAL_SERVER_ERROR, + Self::InvalidAlertModifyRequest => StatusCode::BAD_REQUEST, + Self::FromStrError(_) => StatusCode::BAD_REQUEST, } } @@ -739,19 +833,20 @@ impl Alerts { let store = PARSEABLE.storage.get_object_store(); for alert in store.get_alerts().await.unwrap_or_default() { - let (outbox_tx, outbox_rx) = oneshot::channel::<()>(); - let (inbox_tx, inbox_rx) = oneshot::channel::<()>(); - let handle = schedule_alert_task( - alert.get_eval_frequency(), - alert.clone(), - inbox_rx, - outbox_tx, - ) - .map_err(|e| anyhow::Error::msg(e.to_string()))?; - - self.update_task(alert.id, handle, outbox_rx, inbox_tx) - .await; - + match self.sender.send(AlertTask::Create(alert.clone())).await { + Ok(_) => {} + Err(e) => { + warn!("Failed to create alert task: {e}\nRetrying..."); + // Retry sending the task + match self.sender.send(AlertTask::Create(alert.clone())).await { + Ok(_) => {} + Err(e) => { + error!("Failed to create alert task: {e}"); + continue; + } + } + } + }; map.insert(alert.id, alert); } @@ -766,8 +861,8 @@ impl Alerts { let mut alerts: Vec = Vec::new(); for (_, alert) in self.alerts.read().await.iter() { // filter based on whether the user can execute this query or not - let query = &alert.query; - if user_auth_for_query(&session, query).await.is_ok() { + let query = alert.get_base_query(); + if user_auth_for_query(&session, &query).await.is_ok() { alerts.push(alert.to_owned()); } } @@ -851,31 +946,21 @@ impl Alerts { } } - /// Update the scheduled alert tasks in-memory map - pub async fn update_task( - &self, - id: Ulid, - handle: JoinHandle<()>, - rx: Receiver<()>, - tx: Sender<()>, - ) { - self.scheduled_tasks - .write() + /// Start a scheduled alert task + pub async fn start_task(&self, alert: AlertConfig) -> Result<(), AlertError> { + self.sender + .send(AlertTask::Create(alert)) .await - .insert(id, (handle, rx, tx)); + .map_err(|e| AlertError::CustomError(e.to_string()))?; + Ok(()) } /// Remove a scheduled alert task pub async fn delete_task(&self, alert_id: Ulid) -> Result<(), AlertError> { - if self - .scheduled_tasks - .write() + self.sender + .send(AlertTask::Delete(alert_id)) .await - .remove(&alert_id) - .is_none() - { - trace!("Alert task {alert_id} not found in hashmap"); - } + .map_err(|e| AlertError::CustomError(e.to_string()))?; Ok(()) } diff --git a/src/alerts/target.rs b/src/alerts/target.rs index b92784cc4..3c8939cca 100644 --- a/src/alerts/target.rs +++ b/src/alerts/target.rs @@ -29,13 +29,14 @@ use http::{header::AUTHORIZATION, HeaderMap, HeaderValue}; use humantime_serde::re::humantime; use reqwest::ClientBuilder; use tracing::{error, trace, warn}; +use url::Url; use super::ALERTS; use super::{AlertState, CallableTarget, Context}; #[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)] -#[serde(rename_all = "lowercase")] +#[serde(rename_all = "camelCase")] #[serde(untagged)] pub enum Retry { Infinite, @@ -49,7 +50,7 @@ impl Default for Retry { } #[derive(Debug, serde::Serialize, serde::Deserialize, Clone)] -#[serde(rename_all = "lowercase")] +#[serde(rename_all = "camelCase")] #[serde(try_from = "TargetVerifier")] pub struct Target { #[serde(flatten)] @@ -186,7 +187,7 @@ pub struct RepeatVerifier { } #[derive(Debug, serde::Deserialize)] -#[serde(rename_all = "lowercase")] +#[serde(rename_all = "camelCase")] pub struct TargetVerifier { #[serde(flatten)] pub target: TargetType, @@ -229,13 +230,15 @@ impl TryFrom for Target { } #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] -#[serde(rename_all = "lowercase")] +#[serde(rename_all = "camelCase")] #[serde(tag = "type")] #[serde(deny_unknown_fields)] pub enum TargetType { + #[serde(rename = "slack")] Slack(SlackWebHook), #[serde(rename = "webhook")] Other(OtherWebHook), + #[serde(rename = "alertManager")] AlertManager(AlertManager), } @@ -253,9 +256,9 @@ fn default_client_builder() -> ClientBuilder { ClientBuilder::new() } -#[derive(Default, Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub struct SlackWebHook { - endpoint: String, + endpoint: Url, } #[async_trait] @@ -277,16 +280,16 @@ impl CallableTarget for SlackWebHook { } }; - if let Err(e) = client.post(&self.endpoint).json(&alert).send().await { + if let Err(e) = client.post(self.endpoint.clone()).json(&alert).send().await { error!("Couldn't make call to webhook, error: {}", e) } } } #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] -#[serde(rename_all = "snake_case")] +#[serde(rename_all = "camelCase")] pub struct OtherWebHook { - endpoint: String, + endpoint: Url, #[serde(default)] headers: HashMap, #[serde(default)] @@ -312,7 +315,7 @@ impl CallableTarget for OtherWebHook { }; let request = client - .post(&self.endpoint) + .post(self.endpoint.clone()) .headers((&self.headers).try_into().expect("valid_headers")); if let Err(e) = request.body(alert).send().await { @@ -322,8 +325,9 @@ impl CallableTarget for OtherWebHook { } #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[serde(rename_all = "camelCase")] pub struct AlertManager { - endpoint: String, + endpoint: Url, #[serde(default)] skip_tls_check: bool, #[serde(flatten)] @@ -401,7 +405,12 @@ impl CallableTarget for AlertManager { } }; - if let Err(e) = client.post(&self.endpoint).json(&alerts).send().await { + if let Err(e) = client + .post(self.endpoint.clone()) + .json(&alerts) + .send() + .await + { error!("Couldn't make call to alertmanager, error: {}", e) } } diff --git a/src/handlers/airplane.rs b/src/handlers/airplane.rs index 1a72c7470..efb95fef8 100644 --- a/src/handlers/airplane.rs +++ b/src/handlers/airplane.rs @@ -169,7 +169,7 @@ impl FlightService for AirServiceImpl { query.time_range.start.timestamp_millis(), query.time_range.end.timestamp_millis(), ) { - let sql = format!("select * from {}", &stream_name); + let sql = format!("select * from \"{}\"", &stream_name); let start_time = ticket.start_time.clone(); let end_time = ticket.end_time.clone(); let out_ticket = json!({ diff --git a/src/handlers/http/alerts.rs b/src/handlers/http/alerts.rs index 357cf4b87..28e5c80e9 100644 --- a/src/handlers/http/alerts.rs +++ b/src/handlers/http/alerts.rs @@ -16,8 +16,12 @@ * */ +use std::str::FromStr; + use crate::{ - parseable::PARSEABLE, storage::object_storage::alert_json_path, sync::schedule_alert_task, + parseable::PARSEABLE, + storage::object_storage::alert_json_path, + // sync::schedule_alert_task, utils::actix::extract_session_key_from_req, }; use actix_web::{ @@ -25,7 +29,6 @@ use actix_web::{ HttpRequest, Responder, }; use bytes::Bytes; -use tokio::sync::oneshot; use ulid::Ulid; use crate::alerts::{ @@ -53,17 +56,8 @@ pub async fn post( // validate the incoming alert query // does the user have access to these tables or not? let session_key = extract_session_key_from_req(&req)?; - user_auth_for_query(&session_key, &alert.query).await?; - - // create scheduled tasks - let (outbox_tx, outbox_rx) = oneshot::channel::<()>(); - let (inbox_tx, inbox_rx) = oneshot::channel::<()>(); - let handle = schedule_alert_task( - alert.get_eval_frequency(), - alert.clone(), - inbox_rx, - outbox_tx, - )?; + let query = alert.get_base_query(); + user_auth_for_query(&session_key, &query).await?; // now that we've validated that the user can run this query // move on to saving the alert in ObjectStore @@ -75,9 +69,8 @@ pub async fn post( let alert_bytes = serde_json::to_vec(&alert)?; store.put_object(&path, Bytes::from(alert_bytes)).await?; - ALERTS - .update_task(alert.id, handle, outbox_rx, inbox_tx) - .await; + // start the task + ALERTS.start_task(alert.clone()).await?; Ok(web::Json(alert)) } @@ -89,7 +82,8 @@ pub async fn get(req: HttpRequest, alert_id: Path) -> Result) -> Result) -> Result, - Json(alert_request): Json, ) -> Result { let session_key = extract_session_key_from_req(&req)?; let alert_id = alert_id.into_inner(); // check if alert id exists in map - let mut alert = ALERTS.get_alert_by_id(alert_id).await?; - + let alert = ALERTS.get_alert_by_id(alert_id).await?; // validate that the user has access to the tables mentioned - // in the old as well as the modified alert - user_auth_for_query(&session_key, &alert.query).await?; - user_auth_for_query(&session_key, &alert_request.query).await?; + let query = alert.get_base_query(); + user_auth_for_query(&session_key, &query).await?; - alert.modify(alert_request); - alert.validate().await?; - - // modify task - let (outbox_tx, outbox_rx) = oneshot::channel::<()>(); - let (inbox_tx, inbox_rx) = oneshot::channel::<()>(); - let handle = schedule_alert_task( - alert.get_eval_frequency(), - alert.clone(), - inbox_rx, - outbox_tx, - )?; - - // modify on disk - PARSEABLE - .storage - .get_object_store() - .put_alert(alert.id, &alert) - .await?; - - // modify in memory - ALERTS.update(&alert).await; - - ALERTS - .update_task(alert.id, handle, outbox_rx, inbox_tx) - .await; - - Ok(web::Json(alert)) -} - -// PUT /alerts/{alert_id}/update_state -pub async fn update_state( - req: HttpRequest, - alert_id: Path, - state: String, -) -> Result { - let session_key = extract_session_key_from_req(&req)?; - let alert_id = alert_id.into_inner(); + let query_string = req.query_string(); - // check if alert id exists in map - let alert = ALERTS.get_alert_by_id(alert_id).await?; + if query_string.is_empty() { + return Err(AlertError::InvalidStateChange( + "No query string provided".to_string(), + )); + } - // validate that the user has access to the tables mentioned - user_auth_for_query(&session_key, &alert.query).await?; + let tokens = query_string.split('=').collect::>(); + let state_key = tokens[0]; + let state_value = tokens[1]; + if state_key != "state" { + return Err(AlertError::InvalidStateChange( + "Invalid query parameter".to_string(), + )); + } // get current state let current_state = ALERTS.get_state(alert_id).await?; - let new_state: AlertState = serde_json::from_str(&state)?; - - match current_state { - AlertState::Triggered => { - if new_state == AlertState::Triggered { - let msg = format!("Not allowed to manually go from Triggered to {new_state}"); - return Err(AlertError::InvalidStateChange(msg)); - } else { - // update state on disk and in memory - ALERTS - .update_state(alert_id, new_state, Some("".into())) - .await?; - } - } - AlertState::Silenced => { - // from here, the user can only go to Resolved - if new_state == AlertState::Resolved { - // update state on disk and in memory - ALERTS - .update_state(alert_id, new_state, Some("".into())) - .await?; - } else { - let msg = format!("Not allowed to manually go from Silenced to {new_state}"); - return Err(AlertError::InvalidStateChange(msg)); - } - } - AlertState::Resolved => { - // user shouldn't logically be changing states if current state is Resolved - let msg = format!("Not allowed to go manually from Resolved to {new_state}"); - return Err(AlertError::InvalidStateChange(msg)); - } - } + let new_state = AlertState::from_str(state_value)?; - Ok("") + current_state.update_state(new_state, alert_id).await?; + let alert = ALERTS.get_alert_by_id(alert_id).await?; + Ok(web::Json(alert)) } diff --git a/src/handlers/http/modal/server.rs b/src/handlers/http/modal/server.rs index 002c102b3..1674c3ee2 100644 --- a/src/handlers/http/modal/server.rs +++ b/src/handlers/http/modal/server.rs @@ -228,20 +228,17 @@ impl Server { .service( web::resource("/{alert_id}") .route(web::get().to(alerts::get).authorize(Action::GetAlert)) - .route(web::put().to(alerts::modify).authorize(Action::PutAlert)) + .route( + web::put() + .to(alerts::update_state) + .authorize(Action::PutAlert), + ) .route( web::delete() .to(alerts::delete) .authorize(Action::DeleteAlert), ), ) - .service( - web::resource("/{alert_id}/update_state").route( - web::put() - .to(alerts::update_state) - .authorize(Action::PutAlert), - ), - ) } // get the dashboards web scope diff --git a/src/rbac/role.rs b/src/rbac/role.rs index 20bb31025..863c984d2 100644 --- a/src/rbac/role.rs +++ b/src/rbac/role.rs @@ -340,4 +340,4 @@ pub mod model { tag: None, } } -} \ No newline at end of file +} diff --git a/src/sync.rs b/src/sync.rs index 86b489893..6fa568331 100644 --- a/src/sync.rs +++ b/src/sync.rs @@ -17,15 +17,16 @@ */ use chrono::{TimeDelta, Timelike}; +use std::collections::HashMap; use std::future::Future; use std::panic::AssertUnwindSafe; -use tokio::sync::oneshot; +use tokio::sync::{mpsc, oneshot}; use tokio::task::JoinSet; use tokio::time::{interval_at, sleep, Duration, Instant}; use tokio::{select, task}; use tracing::{error, info, trace, warn}; -use crate::alerts::{alerts_utils, AlertConfig, AlertError}; +use crate::alerts::{alerts_utils, AlertTask}; use crate::parseable::PARSEABLE; use crate::{LOCAL_SYNC_INTERVAL, STORAGE_UPLOAD_INTERVAL}; @@ -227,49 +228,60 @@ pub fn local_sync() -> ( (handle, outbox_rx, inbox_tx) } -pub fn schedule_alert_task( - eval_frequency: u64, - alert: AlertConfig, - inbox_rx: oneshot::Receiver<()>, - outbox_tx: oneshot::Sender<()>, -) -> Result, AlertError> { - let handle = tokio::task::spawn(async move { - info!("new alert task started for {alert:?}"); +/// A separate runtime for running all alert tasks +#[tokio::main(flavor = "multi_thread")] +pub async fn alert_runtime(mut rx: mpsc::Receiver) -> Result<(), anyhow::Error> { + let mut alert_tasks = HashMap::new(); - let result = std::panic::catch_unwind(AssertUnwindSafe(|| async move { - let mut sync_interval = - interval_at(next_minute(), Duration::from_secs(eval_frequency * 60)); - let mut inbox_rx = AssertUnwindSafe(inbox_rx); + // this is the select! loop which will keep waiting for the alert task to finish or get cancelled + while let Some(task) = rx.recv().await { + match task { + AlertTask::Create(alert) => { + // check if the alert already exists + if alert_tasks.contains_key(&alert.id) { + error!("Alert with id {} already exists", alert.id); + continue; + } - loop { - select! { - _ = sync_interval.tick() => { - trace!("Flushing stage to disk..."); + let alert = alert.clone(); + let id = alert.id; + let handle = tokio::spawn(async move { + let mut retry_counter = 0; + let mut sleep_duration = alert.get_eval_frequency(); + loop { match alerts_utils::evaluate_alert(&alert).await { - Ok(_) => {} - Err(err) => error!("Error while evaluation- {err}"), + Ok(_) => { + retry_counter = 0; + } + Err(err) => { + warn!("Error while evaluation- {}\nRetrying after sleeping for 1 minute", err); + sleep_duration = 1; + retry_counter += 1; + + if retry_counter > 3 { + error!("Alert with id {} failed to evaluate after 3 retries with err- {}", id, err); + break; + } + } } - }, - res = &mut inbox_rx => {match res{ - Ok(_) => break, - Err(_) => { - warn!("Inbox channel closed unexpectedly"); - break; - }} + tokio::time::sleep(Duration::from_secs(sleep_duration * 60)).await; } - } - } - })); + }); - match result { - Ok(future) => { - future.await; + // store the handle in the map, since it is not awaited, it will keep on running + alert_tasks.insert(id, handle); } - Err(panic_error) => { - error!("Panic in scheduled alert task: {panic_error:?}"); - let _ = outbox_tx.send(()); + AlertTask::Delete(ulid) => { + // check if the alert exists + if let Some(handle) = alert_tasks.remove(&ulid) { + // cancel the task + handle.abort(); + warn!("Alert with id {} deleted", ulid); + } else { + error!("Alert with id {} does not exist", ulid); + } } } - }); - Ok(handle) + } + Ok(()) }