|
| 1 | +mod context; |
| 2 | + |
1 | 3 | use cel_interpreter::objects::{Key, TryIntoValue}; |
2 | | -use cel_interpreter::{Context, Program, Value}; |
| 4 | +use cel_interpreter::{ExecutionError, Program, Value}; |
3 | 5 | use log::{debug, info, warn}; |
4 | 6 | use pyo3::exceptions::PyValueError; |
5 | 7 | use pyo3::prelude::*; |
6 | 8 |
|
7 | 9 | use chrono::{DateTime, Duration as ChronoDuration, Offset, TimeZone, Utc}; |
8 | | -use pyo3::types::PyDelta; |
9 | 10 | use pyo3::types::{PyBytes, PyDateTime, PyDict, PyList, PyTuple}; |
| 11 | +use pyo3::types::{PyDelta, PyFunction}; |
10 | 12 | use std::time::{Duration, SystemTime, UNIX_EPOCH}; |
11 | 13 |
|
12 | 14 | use std::collections::HashMap; |
13 | 15 | use std::error::Error; |
14 | 16 | use std::fmt; |
| 17 | +use std::ops::Deref; |
15 | 18 | use std::sync::Arc; |
16 | 19 |
|
17 | 20 | #[derive(Debug)] |
@@ -183,79 +186,109 @@ impl TryIntoValue for RustyPyType<'_> { |
183 | 186 | fn evaluate(src: String, evaluation_context: Option<&PyAny>) -> PyResult<RustyCelType> { |
184 | 187 | debug!("Evaluating CEL expression: {}", src); |
185 | 188 |
|
186 | | - let context: Option<&PyDict> = evaluation_context.map(|context| { |
187 | | - context |
188 | | - .downcast::<PyDict>() |
189 | | - .expect("Failed to downcast PyDict") |
190 | | - }); |
| 189 | + let program = Program::compile(&src) |
| 190 | + .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Failed to compile expression '{}': {}", src, e)))?; |
191 | 191 |
|
192 | | - let program = Program::compile(src.as_str()); |
| 192 | + debug!("Compiled program: {:?}", program); |
193 | 193 |
|
194 | | - // Handle the result of the compilation |
195 | | - match program { |
196 | | - Err(compile_error) => { |
197 | | - debug!("An error occurred during compilation"); |
198 | | - debug!("compile_error: {:?}", compile_error); |
199 | | - // compile_error |
200 | | - // .into_iter() |
201 | | - // .for_each(|e| println!("Parse error: {:?}", e)); |
202 | | - return Err(PyValueError::new_err("Parse Error")); |
| 194 | + debug!("Preparing context"); |
| 195 | + let mut environment = cel_interpreter::Context::default(); |
| 196 | + let mut ctx = context::Context::new(None, None)?; |
| 197 | + |
| 198 | + // Custom Rust functions can also be added to the environment... |
| 199 | + //environment.add_function("add", |a: i64, b: i64| a + b); |
| 200 | + |
| 201 | + // Process the evaluation context if provided |
| 202 | + if let Some(evaluation_context) = evaluation_context { |
| 203 | + // Attempt to extract directly as a Context object |
| 204 | + if let Ok(py_context_ref) = evaluation_context.extract::<PyRef<context::Context>>() { |
| 205 | + // Clone variables and functions into our local Context |
| 206 | + ctx.variables = py_context_ref.variables.clone(); |
| 207 | + ctx.functions = py_context_ref.functions.clone(); |
| 208 | + |
| 209 | + } else if let Ok(py_dict) = evaluation_context.extract::<&PyDict>() { |
| 210 | + // User passed in a dict - let's process variables and functions from the dict |
| 211 | + ctx.update(&py_dict)?; |
| 212 | + } else { |
| 213 | + return Err(PyValueError::new_err("evaluation_context must be a Context object or a dict")) |
| 214 | + }; |
| 215 | + |
| 216 | + |
| 217 | + // Add any variables from the passed in Python context |
| 218 | + for (name, value) in &ctx.variables { |
| 219 | + environment |
| 220 | + .add_variable(name.clone(), value.clone()) |
| 221 | + .map_err(|e| PyValueError::new_err(format!("Failed to add variable '{}': {}", name, e)))?; |
203 | 222 | } |
204 | | - Ok(program) => { |
205 | | - let mut environment = Context::default(); |
206 | | - |
207 | | - // Custom functions can be added to the environment |
208 | | - //environment.add_function("add", |a: i64, b: i64| a + b); |
209 | | - |
210 | | - // Add any variables from the passed in Dict context |
211 | | - if let Some(context) = context { |
212 | | - for (key, value) in context { |
213 | | - debug!("Adding context {:?}", key); |
214 | | - let key = key.extract::<String>().unwrap(); |
215 | | - // Each value is of type PyAny, we need to try to extract into a Value |
216 | | - // and then add it to the CEL context |
217 | | - |
218 | | - let wrapped_value = RustyPyType(value); |
219 | | - match wrapped_value.try_into_value() { |
220 | | - Ok(value) => { |
221 | | - debug!("Converted value: {:?}", value); |
222 | | - environment |
223 | | - .add_variable(key, value) |
224 | | - .expect("Failed to add variable to context"); |
225 | | - } |
226 | | - Err(error) => { |
227 | | - debug!("An error occurred during context conversion"); |
228 | | - warn!("Conversion error: {:?}", error); |
229 | | - warn!("Key: {:?}", key); |
230 | 223 |
|
231 | | - return Err(PyValueError::new_err(error.to_string())); |
| 224 | + // Add functions |
| 225 | + let collected_functions: Vec<(String, Py<PyAny>)> = Python::with_gil(|py| { |
| 226 | + ctx.functions |
| 227 | + .iter() |
| 228 | + .map(|(name, py_function)| (name.clone(), py_function.clone_ref(py))) |
| 229 | + .collect() |
| 230 | + }); |
| 231 | + |
| 232 | + for (name, py_function) in collected_functions.into_iter() { |
| 233 | + environment.add_function( |
| 234 | + &name.clone(), |
| 235 | + move |ftx: &cel_interpreter::FunctionContext| -> cel_interpreter::ResolveResult { |
| 236 | + Python::with_gil(|py| { |
| 237 | + // Convert arguments from Expression in ftx.args to PyObjects |
| 238 | + let mut py_args = Vec::new(); |
| 239 | + for arg_expr in &ftx.args { |
| 240 | + let arg_value = ftx.ptx.resolve(arg_expr)?; |
| 241 | + let py_arg = RustyCelType(arg_value).into_py(py); |
| 242 | + py_args.push(py_arg); |
232 | 243 | } |
233 | | - } |
234 | | - } |
235 | | - } |
| 244 | + let py_args = PyTuple::new_bound(py, py_args); |
236 | 245 |
|
237 | | - let result = program.execute(&environment); |
238 | | - match result { |
239 | | - Err(error) => { |
240 | | - warn!("An error occurred during execution"); |
241 | | - warn!("Execution error: {:?}", error); |
242 | | - // errors |
243 | | - // .into_iter() |
244 | | - // .for_each(|e| println!("Execution error: {:?}", e)); |
245 | | - Err(PyValueError::new_err("Execution Error")) |
246 | | - } |
| 246 | + // Call the Python function |
| 247 | + let py_result = py_function.call1(py, py_args) |
| 248 | + .map_err(|e| ExecutionError::FunctionError { |
| 249 | + function: name.clone(), |
| 250 | + message: e.to_string(), |
| 251 | + })?; |
| 252 | + // Convert the PyObject to &PyAny |
| 253 | + let py_result_ref = py_result.as_ref(py); |
247 | 254 |
|
248 | | - Ok(value) => return Ok(RustyCelType(value)), |
249 | | - } |
| 255 | + // Convert the result back to Value |
| 256 | + let value = RustyPyType(py_result_ref).try_into_value().map_err(|e| { |
| 257 | + ExecutionError::FunctionError { |
| 258 | + function: name.clone(), |
| 259 | + message: format!("Error calling function '{}': {}", name, e), |
| 260 | + } |
| 261 | + })?; |
| 262 | + Ok(value) |
| 263 | + }) |
| 264 | + }, |
| 265 | + ); |
250 | 266 | } |
251 | 267 | } |
| 268 | + |
| 269 | + |
| 270 | + let result = program.execute(&environment); |
| 271 | + match result { |
| 272 | + Err(error) => { |
| 273 | + warn!("An error occurred during execution"); |
| 274 | + warn!("Execution error: {:?}", error); |
| 275 | + // errors |
| 276 | + // .into_iter() |
| 277 | + // .for_each(|e| println!("Execution error: {:?}", e)); |
| 278 | + Err(PyValueError::new_err("Execution Error")) |
| 279 | + } |
| 280 | + |
| 281 | + Ok(value) => return Ok(RustyCelType(value)), |
| 282 | + } |
252 | 283 | } |
253 | 284 |
|
254 | 285 | /// A Python module implemented in Rust. |
255 | 286 | #[pymodule] |
256 | | -fn cel(py: Python<'_>, m: &PyModule) -> PyResult<()> { |
| 287 | +fn cel<'py>(py: Python<'py>, m: &PyModule) -> PyResult<()> { |
257 | 288 | pyo3_log::init(); |
258 | 289 |
|
259 | 290 | m.add_function(wrap_pyfunction!(evaluate, m)?)?; |
| 291 | + |
| 292 | + m.add_class::<context::Context>()?; |
260 | 293 | Ok(()) |
261 | 294 | } |
0 commit comments