Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 100 additions & 39 deletions src/serializers/type_serializers/function.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use std::borrow::Cow;
use std::sync::Arc;
use std::f32::consts::E;
use std::marker::PhantomData;
use std::ptr::{self, NonNull};
use std::sync::{Arc, Mutex, MutexGuard, RwLock, RwLockReadGuard};

use pyo3::exceptions::{PyAttributeError, PyRecursionError, PyRuntimeError};
use pyo3::gc::PyVisit;
Expand All @@ -11,6 +14,7 @@ use pyo3::PyTraverseError;
use pyo3::types::PyString;

use crate::definitions::DefinitionsBuilder;
use crate::serializers::extra;
use crate::tools::SchemaDict;
use crate::tools::{function_name, py_err, py_error_type};
use crate::{PydanticOmit, PydanticSerializationUnexpectedValue};
Expand Down Expand Up @@ -393,7 +397,9 @@ impl FunctionWrapSerializer {
) -> PyResult<(bool, PyObject)> {
let py = value.py();
if self.when_used.should_use(value, extra) {
let serialize = SerializationCallable::new(&self.serializer, include, exclude, extra);
let extra_ref_guard = ExtraRef::new(extra);
let serialize =
SerializationCallable::new(&self.serializer, include, exclude, extra_ref_guard.inner().clone());
let v = if self.is_field_serializer {
if let Some(model) = extra.model {
if self.info_arg {
Expand Down Expand Up @@ -434,11 +440,56 @@ impl_py_gc_traverse!(FunctionWrapSerializer {

function_type_serializer!(FunctionWrapSerializer);

/// A wrapper around `&Extra` which drops the lifetime, in order to be stored inside a Python object.
#[derive(Clone)]
struct ExtraRef {
value: Arc<RwLock<Option<*const Extra<'static>>>>,
}

// Safety: `&Extra` is `Send + Sync`
unsafe impl Send for ExtraRef {}
unsafe impl Sync for ExtraRef {}

impl ExtraRef {
fn new<'a>(extra: &'a Extra<'a>) -> ExtraRefGuard<'a> {
ExtraRefGuard(
ExtraRef {
value: Arc::new(RwLock::new(Some(ptr::from_ref(extra).cast()))),
},
PhantomData,
)
}

fn map<R>(&self, f: impl FnOnce(&Extra<'_>) -> R) -> Option<R> {
// FIXME: deal with lock poisoning?, use try_read
let guard = self.value.read().unwrap();
guard.as_ref().map(|ptr| {
// Safety: we ensure that the pointer is valid while `ExtraRef` is alive
let extra: &Extra = unsafe { &**ptr };
f(extra)
})
}
}

struct ExtraRefGuard<'a>(ExtraRef, PhantomData<&'a Extra<'a>>);

impl ExtraRefGuard<'_> {
fn inner(&self) -> &ExtraRef {
&self.0
}
}

impl Drop for ExtraRefGuard<'_> {
fn drop(&mut self) {
let mut guard = self.0.value.write().unwrap();
*guard = None;
}
}

#[pyclass(module = "pydantic_core._pydantic_core")]
#[cfg_attr(debug_assertions, derive(Debug))]
pub(crate) struct SerializationCallable {
serializer: Arc<CombinedSerializer>,
extra_owned: ExtraOwned,
extra: ExtraRef,
filter: AnyFilter,
include: Option<PyObject>,
exclude: Option<PyObject>,
Expand All @@ -449,11 +500,11 @@ impl SerializationCallable {
serializer: &Arc<CombinedSerializer>,
include: Option<&Bound<'_, PyAny>>,
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
extra: ExtraRef,
) -> Self {
Self {
serializer: serializer.clone(),
extra_owned: ExtraOwned::new(extra),
extra: extra,
filter: AnyFilter::new(),
include: include.map(|v| v.clone().unbind()),
exclude: exclude.map(|v| v.clone().unbind()),
Expand All @@ -467,24 +518,22 @@ impl SerializationCallable {
if let Some(exclude) = &self.exclude {
visit.call(exclude)?;
}
if let Some(model) = &self.extra_owned.model {
visit.call(model)?;
}
if let Some(fallback) = &self.extra_owned.fallback {
visit.call(fallback)?;
}
if let Some(context) = &self.extra_owned.context {
visit.call(context)?;
}
self.extra
.map(|extra| {
// FIXME: not sound to get .read() of extra inside GC, probably need to make `Extra` not
// have the `'py` lifetime
visit.call(extra.model.map(Bound::as_unbound))?;
visit.call(extra.fallback.map(Bound::as_unbound))?;
visit.call(extra.context.map(Bound::as_unbound))?;
Ok(())
})
.transpose()?;
Ok(())
}

fn __clear__(&mut self) {
self.include = None;
self.exclude = None;
self.extra_owned.model = None;
self.extra_owned.fallback = None;
self.extra_owned.context = None;
}
}

Expand All @@ -503,28 +552,40 @@ impl SerializationCallable {

let include = self.include.as_ref().map(|o| o.bind(py));
let exclude = self.exclude.as_ref().map(|o| o.bind(py));
let extra = self.extra_owned.to_extra(py);

if let Some(index_key) = index_key {
let filter = if let Ok(index) = index_key.extract::<usize>() {
self.filter.index_filter(index, include, exclude, None)?
} else {
self.filter.key_filter(index_key, include, exclude)?
};
if let Some((next_include, next_exclude)) = filter {
let v =
self.serializer
.to_python_no_infer(value, next_include.as_ref(), next_exclude.as_ref(), &extra)?;
extra.warnings.final_check(py)?;
Ok(Some(v))
} else {
Err(PydanticOmit::new_err())
}
} else {
let v = self.serializer.to_python_no_infer(value, include, exclude, &extra)?;
extra.warnings.final_check(py)?;
Ok(Some(v))
}
// FIXME: the &T is not sound here, since the guard is dropped at the end of this statement.
// Probably need to have a .map() method to avoid scope leak?
self.extra
.map(|extra| {
if let Some(index_key) = index_key {
let filter = if let Ok(index) = index_key.extract::<usize>() {
self.filter.index_filter(index, include, exclude, None)?
} else {
self.filter.key_filter(index_key, include, exclude)?
};
if let Some((next_include, next_exclude)) = filter {
let v = self.serializer.to_python_no_infer(
value,
next_include.as_ref(),
next_exclude.as_ref(),
&extra,
)?;
extra.warnings.final_check(py)?;
Ok(Some(v))
} else {
Err(PydanticOmit::new_err())
}
} else {
let v = self.serializer.to_python_no_infer(value, include, exclude, &extra)?;
extra.warnings.final_check(py)?;
Ok(Some(v))
}
})
.unwrap_or_else(|| {
Err(PyRuntimeError::new_err(
"Attempted to use SerializationCallable after its wrap validation context was exited",
))
})
}

fn __repr__(&self) -> PyResult<String> {
Expand Down
Loading