Skip to content

Commit ea614af

Browse files
authored
Support user defined functions (#2)
1 parent a8c5b1b commit ea614af

File tree

7 files changed

+308
-75
lines changed

7 files changed

+308
-75
lines changed

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ name = "cel"
99
crate-type = ["cdylib"]
1010

1111
[dependencies]
12-
pyo3 = { version = "0.22.5", features = ["chrono", "gil-refs"]}
13-
cel-interpreter = "0.8.1"
12+
pyo3 = { version = "0.22.6", features = ["chrono", "gil-refs", "py-clone"]}
13+
cel-interpreter = { version = "0.9.0", features = ["chrono", "json", "regex"] }
1414
log = "0.4.22"
1515
pyo3-log = "0.11.0"
1616
chrono = { version = "0.4.38", features = ["serde"] }

README.md

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,22 +37,23 @@ evaluate(
3737
)
3838
True
3939
```
40-
## Future work
4140

41+
### Custom Python Functions
4242

43-
### Command line interface
43+
This Python library supports user defined Python functions
44+
in the context:
4445

45-
The package (plans to) provides a command line interface for evaluating CEL expressions:
46+
```python
47+
from cel import evaluate
4648

47-
```bash
48-
$ python -m cel '1 + 2'
49-
3
50-
```
49+
def is_adult(age):
50+
return age > 21
5151

52-
### Separate compilation and Execution steps
53-
### Custom Python Functions
52+
evaluate("is_adult(age)", {'is_adult': is_adult, 'age': 18})
53+
# False
54+
```
5455

55-
Ability to add Python functions to the Context object:
56+
You can also explicitly create a Context object:
5657

5758
```python
5859
from cel import evaluate, Context
@@ -62,5 +63,24 @@ def is_adult(age):
6263

6364
context = Context()
6465
context.add_function("is_adult", is_adult)
65-
print(evaluate("is_adult(age)", {"age": 18}, context)) # False
66+
context.update({"age": 18})
67+
68+
evaluate("is_adult(age)", context)
69+
# False
6670
```
71+
72+
73+
## Future work
74+
75+
76+
### Command line interface
77+
78+
The package (plans to) provides a command line interface for evaluating CEL expressions:
79+
80+
```bash
81+
$ python -m cel '1 + 2'
82+
3
83+
```
84+
85+
### Separate compilation and Execution steps
86+

src/context.rs

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
use cel_interpreter::objects::TryIntoValue;
2+
use cel_interpreter::Value;
3+
use pyo3::prelude::*;
4+
use pyo3::types::{PyDict, PyTuple};
5+
use std::collections::HashMap;
6+
use pyo3::exceptions::PyValueError;
7+
8+
#[pyo3::pyclass]
9+
pub struct Context {
10+
pub variables: HashMap<String, Value>,
11+
pub functions: HashMap<String, Py<PyAny>>,
12+
}
13+
14+
#[pyo3::pymethods]
15+
impl Context {
16+
17+
#[new]
18+
pub fn new(variables: Option<&PyDict>, functions: Option<&PyDict>) -> PyResult<Self> {
19+
let mut context = Context {
20+
variables: HashMap::new(),
21+
functions: HashMap::new(),
22+
};
23+
24+
if let Some(variables) = variables {
25+
//context.variables.extend(variables.clone());
26+
for (k, v) in variables {
27+
let key = k.extract::<String>().map_err(|_| {
28+
PyValueError::new_err("Keys must be strings")
29+
});
30+
key.map(|key| context.add_variable(key, v))??;
31+
32+
}
33+
};
34+
35+
if let Some(functions) = functions {
36+
context.update(functions)?;
37+
};
38+
39+
40+
41+
Ok(context)
42+
}
43+
44+
45+
fn add_function(&mut self, name: String, function: Py<PyAny>) {
46+
self.functions.insert(name, function);
47+
}
48+
49+
pub fn add_variable(&mut self, name: String, value: &PyAny) -> PyResult<()> {
50+
let value = crate::RustyPyType(value).try_into_value().map_err(|e| {
51+
pyo3::exceptions::PyValueError::new_err(format!(
52+
"Failed to convert variable '{}': {}",
53+
name, e
54+
))
55+
})?;
56+
self.variables.insert(name, value);
57+
Ok(())
58+
}
59+
60+
pub fn update(&mut self, variables: &PyDict) -> PyResult<()> {
61+
62+
for (key, value) in variables {
63+
// Attempt to extract the key as a String
64+
let key = key.extract::<String>().map_err(|_| {
65+
PyValueError::new_err("Keys must be strings")
66+
})?;
67+
68+
if value.is_callable() {
69+
// Value is a function, add it to the functions hashmap
70+
let py_function = value.to_object(value.py());
71+
self.functions.insert(key, py_function);
72+
} else {
73+
// Value is a variable, add it to the variables hashmap
74+
let value = crate::RustyPyType(value)
75+
.try_into_value()
76+
.map_err(|e| PyValueError::new_err(e.to_string()))?;
77+
78+
self.variables.insert(key, value);
79+
}
80+
81+
}
82+
83+
Ok(())
84+
}
85+
}

src/lib.rs

Lines changed: 93 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
1+
mod context;
2+
13
use cel_interpreter::objects::{Key, TryIntoValue};
2-
use cel_interpreter::{Context, Program, Value};
4+
use cel_interpreter::{ExecutionError, Program, Value};
35
use log::{debug, info, warn};
46
use pyo3::exceptions::PyValueError;
57
use pyo3::prelude::*;
68

79
use chrono::{DateTime, Duration as ChronoDuration, Offset, TimeZone, Utc};
8-
use pyo3::types::PyDelta;
910
use pyo3::types::{PyBytes, PyDateTime, PyDict, PyList, PyTuple};
11+
use pyo3::types::{PyDelta, PyFunction};
1012
use std::time::{Duration, SystemTime, UNIX_EPOCH};
1113

1214
use std::collections::HashMap;
1315
use std::error::Error;
1416
use std::fmt;
17+
use std::ops::Deref;
1518
use std::sync::Arc;
1619

1720
#[derive(Debug)]
@@ -183,79 +186,109 @@ impl TryIntoValue for RustyPyType<'_> {
183186
fn evaluate(src: String, evaluation_context: Option<&PyAny>) -> PyResult<RustyCelType> {
184187
debug!("Evaluating CEL expression: {}", src);
185188

186-
let context: Option<&PyDict> = evaluation_context.map(|context| {
187-
context
188-
.downcast::<PyDict>()
189-
.expect("Failed to downcast PyDict")
190-
});
189+
let program = Program::compile(&src)
190+
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Failed to compile expression '{}': {}", src, e)))?;
191191

192-
let program = Program::compile(src.as_str());
192+
debug!("Compiled program: {:?}", program);
193193

194-
// Handle the result of the compilation
195-
match program {
196-
Err(compile_error) => {
197-
debug!("An error occurred during compilation");
198-
debug!("compile_error: {:?}", compile_error);
199-
// compile_error
200-
// .into_iter()
201-
// .for_each(|e| println!("Parse error: {:?}", e));
202-
return Err(PyValueError::new_err("Parse Error"));
194+
debug!("Preparing context");
195+
let mut environment = cel_interpreter::Context::default();
196+
let mut ctx = context::Context::new(None, None)?;
197+
198+
// Custom Rust functions can also be added to the environment...
199+
//environment.add_function("add", |a: i64, b: i64| a + b);
200+
201+
// Process the evaluation context if provided
202+
if let Some(evaluation_context) = evaluation_context {
203+
// Attempt to extract directly as a Context object
204+
if let Ok(py_context_ref) = evaluation_context.extract::<PyRef<context::Context>>() {
205+
// Clone variables and functions into our local Context
206+
ctx.variables = py_context_ref.variables.clone();
207+
ctx.functions = py_context_ref.functions.clone();
208+
209+
} else if let Ok(py_dict) = evaluation_context.extract::<&PyDict>() {
210+
// User passed in a dict - let's process variables and functions from the dict
211+
ctx.update(&py_dict)?;
212+
} else {
213+
return Err(PyValueError::new_err("evaluation_context must be a Context object or a dict"))
214+
};
215+
216+
217+
// Add any variables from the passed in Python context
218+
for (name, value) in &ctx.variables {
219+
environment
220+
.add_variable(name.clone(), value.clone())
221+
.map_err(|e| PyValueError::new_err(format!("Failed to add variable '{}': {}", name, e)))?;
203222
}
204-
Ok(program) => {
205-
let mut environment = Context::default();
206-
207-
// Custom functions can be added to the environment
208-
//environment.add_function("add", |a: i64, b: i64| a + b);
209-
210-
// Add any variables from the passed in Dict context
211-
if let Some(context) = context {
212-
for (key, value) in context {
213-
debug!("Adding context {:?}", key);
214-
let key = key.extract::<String>().unwrap();
215-
// Each value is of type PyAny, we need to try to extract into a Value
216-
// and then add it to the CEL context
217-
218-
let wrapped_value = RustyPyType(value);
219-
match wrapped_value.try_into_value() {
220-
Ok(value) => {
221-
debug!("Converted value: {:?}", value);
222-
environment
223-
.add_variable(key, value)
224-
.expect("Failed to add variable to context");
225-
}
226-
Err(error) => {
227-
debug!("An error occurred during context conversion");
228-
warn!("Conversion error: {:?}", error);
229-
warn!("Key: {:?}", key);
230223

231-
return Err(PyValueError::new_err(error.to_string()));
224+
// Add functions
225+
let collected_functions: Vec<(String, Py<PyAny>)> = Python::with_gil(|py| {
226+
ctx.functions
227+
.iter()
228+
.map(|(name, py_function)| (name.clone(), py_function.clone_ref(py)))
229+
.collect()
230+
});
231+
232+
for (name, py_function) in collected_functions.into_iter() {
233+
environment.add_function(
234+
&name.clone(),
235+
move |ftx: &cel_interpreter::FunctionContext| -> cel_interpreter::ResolveResult {
236+
Python::with_gil(|py| {
237+
// Convert arguments from Expression in ftx.args to PyObjects
238+
let mut py_args = Vec::new();
239+
for arg_expr in &ftx.args {
240+
let arg_value = ftx.ptx.resolve(arg_expr)?;
241+
let py_arg = RustyCelType(arg_value).into_py(py);
242+
py_args.push(py_arg);
232243
}
233-
}
234-
}
235-
}
244+
let py_args = PyTuple::new_bound(py, py_args);
236245

237-
let result = program.execute(&environment);
238-
match result {
239-
Err(error) => {
240-
warn!("An error occurred during execution");
241-
warn!("Execution error: {:?}", error);
242-
// errors
243-
// .into_iter()
244-
// .for_each(|e| println!("Execution error: {:?}", e));
245-
Err(PyValueError::new_err("Execution Error"))
246-
}
246+
// Call the Python function
247+
let py_result = py_function.call1(py, py_args)
248+
.map_err(|e| ExecutionError::FunctionError {
249+
function: name.clone(),
250+
message: e.to_string(),
251+
})?;
252+
// Convert the PyObject to &PyAny
253+
let py_result_ref = py_result.as_ref(py);
247254

248-
Ok(value) => return Ok(RustyCelType(value)),
249-
}
255+
// Convert the result back to Value
256+
let value = RustyPyType(py_result_ref).try_into_value().map_err(|e| {
257+
ExecutionError::FunctionError {
258+
function: name.clone(),
259+
message: format!("Error calling function '{}': {}", name, e),
260+
}
261+
})?;
262+
Ok(value)
263+
})
264+
},
265+
);
250266
}
251267
}
268+
269+
270+
let result = program.execute(&environment);
271+
match result {
272+
Err(error) => {
273+
warn!("An error occurred during execution");
274+
warn!("Execution error: {:?}", error);
275+
// errors
276+
// .into_iter()
277+
// .for_each(|e| println!("Execution error: {:?}", e));
278+
Err(PyValueError::new_err("Execution Error"))
279+
}
280+
281+
Ok(value) => return Ok(RustyCelType(value)),
282+
}
252283
}
253284

254285
/// A Python module implemented in Rust.
255286
#[pymodule]
256-
fn cel(py: Python<'_>, m: &PyModule) -> PyResult<()> {
287+
fn cel<'py>(py: Python<'py>, m: &PyModule) -> PyResult<()> {
257288
pyo3_log::init();
258289

259290
m.add_function(wrap_pyfunction!(evaluate, m)?)?;
291+
292+
m.add_class::<context::Context>()?;
260293
Ok(())
261294
}

tests/test_basics.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,15 @@ def test_invalid_expression_raises_parse_value_error():
99
result = cel.evaluate("1 +")
1010

1111

12+
def test_readme_example():
13+
assert cel.evaluate(
14+
'resource.name.startsWith("/groups/" + claim.group)',
15+
{
16+
"resource": {"name": "/groups/hardbyte"},
17+
"claim": {"group": "hardbyte"}
18+
}
19+
)
20+
1221
def test_hello_world():
1322
assert cel.evaluate("'Hello ' + name", {'name': "World"}) == "Hello World"
1423

0 commit comments

Comments
 (0)