Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1917,6 +1917,68 @@ def gen() -> Iterator[int]:
)


class IterableSchema(TypedDict, total=False):
type: Required[Literal['iterable']]
items_schema: CoreSchema
min_length: int
max_length: int
lazy: bool
ref: str
metadata: dict[str, Any]
serialization: IncExSeqOrElseSerSchema


def iterable_schema(
items_schema: CoreSchema | None = None,
*,
min_length: int | None = None,
max_length: int | None = None,
lazy: bool | None = None,
ref: str | None = None,
metadata: dict[str, Any] | None = None,
serialization: IncExSeqOrElseSerSchema | None = None,
) -> IterableSchema:
"""
Returns a schema that matches an iterable value, e.g.:

```py
from typing import Iterator
from pydantic_core import SchemaValidator, core_schema

def gen() -> Iterator[int]:
yield 1

schema = core_schema.iterable_schema(items_schema=core_schema.int_schema())
v = SchemaValidator(schema)
v.validate_python(gen())
```

Lazy validation (the default) is equivalent to `generator_schema` for
backwards compatibility in Pydantic V2.

When not using lazy validation, validated iterables will be collected into a list.

Args:
items_schema: The value must be an iterable with items that match this schema
min_length: The value must be an iterable that yields at least this many items
max_length: The value must be an iterable that yields at most this many items
lazy: Whether to use lazy evaluation, defaults to True
ref: optional unique identifier of the schema, used to reference the schema in other places
metadata: Any other information you want to include with the schema, not used by pydantic-core
serialization: Custom serialization schema
"""
return _dict_not_none(
type='iterable',
items_schema=items_schema,
min_length=min_length,
max_length=max_length,
lazy=lazy,
ref=ref,
metadata=metadata,
serialization=serialization,
)


IncExDict = set[Union[int, str]]


Expand Down
3 changes: 2 additions & 1 deletion src/input/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ pub(crate) use input_python::{downcast_python_input, input_as_python_instance};
pub(crate) use input_string::StringMapping;
pub(crate) use return_enums::{
no_validator_iter_to_vec, py_string_str, validate_iter_to_set, validate_iter_to_vec, EitherBytes, EitherFloat,
EitherInt, EitherString, GenericIterator, Int, MaxLengthCheck, ValidationMatch,
EitherInt, EitherString, GenericIterator, GenericJsonIterator, GenericPyIterator, Int, MaxLengthCheck,
ValidationMatch,
};

// Defined here as it's not exported by pyo3
Expand Down
128 changes: 128 additions & 0 deletions src/validators/iterable.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
use std::sync::Arc;

use jiter::JsonValue;
use pyo3::types::PyDict;
use pyo3::{intern, prelude::*, IntoPyObjectExt};

use crate::errors::ValResult;
use crate::input::{
validate_iter_to_vec, GenericIterator, GenericJsonIterator, GenericPyIterator, Input, MaxLengthCheck,
};
use crate::tools::SchemaDict;
use crate::validators::any::AnyValidator;
use crate::validators::generator::GeneratorValidator;
use crate::validators::list::min_length_check;

use super::list::get_items_schema;
use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator};

#[derive(Debug, Clone)]
pub struct IterableValidator {
item_validator: Option<Arc<CombinedValidator>>,
min_length: Option<usize>,
max_length: Option<usize>,
name: String,
}

impl BuildValidator for IterableValidator {
const EXPECTED_TYPE: &'static str = "iterable";

fn build(
schema: &Bound<'_, PyDict>,
config: Option<&Bound<'_, PyDict>>,
definitions: &mut DefinitionsBuilder<CombinedValidator>,
) -> PyResult<CombinedValidator> {
// TODO: in Pydantic V3 default will be lazy=False
let lazy_iterable: bool = schema.get_as(intern!(schema.py(), "lazy"))?.unwrap_or(true);

if lazy_iterable {
// lazy iterable is equivalent to generator, for backwards compatibility
return GeneratorValidator::build(schema, config, definitions);
}

let item_validator = get_items_schema(schema, config, definitions)?.map(Arc::new);
let name = match item_validator {
Some(ref v) => format!("{}[{}]", Self::EXPECTED_TYPE, v.get_name()),
None => format!("{}[any]", Self::EXPECTED_TYPE),
};
Ok(Self {
item_validator,
name,
min_length: schema.get_as(pyo3::intern!(schema.py(), "min_length"))?,
max_length: schema.get_as(pyo3::intern!(schema.py(), "max_length"))?,
}
.into())
}
}

impl_py_gc_traverse!(IterableValidator { item_validator });

impl Validator for IterableValidator {
fn validate<'py>(
&self,
py: Python<'py>,
input: &(impl Input<'py> + ?Sized),
state: &mut ValidationState<'_, 'py>,
) -> ValResult<Py<PyAny>> {
// this validator does not yet support partial validation, disable it to avoid incorrect results
state.allow_partial = false.into();

let iterator = input.validate_iter()?;

let item_validator = self
.item_validator
.as_deref()
.unwrap_or(&CombinedValidator::Any(AnyValidator));

let max_length_check = MaxLengthCheck::new(self.max_length, "Iterable", input, None);
let vec = match iterator {
GenericIterator::PyIterator(iter) => validate_iter_to_vec(
py,
IterWithPy { py, iter },
0,
max_length_check,
item_validator,
state,
false,
)?,
GenericIterator::JsonArray(iter) => validate_iter_to_vec(
py,
IterWithPy { py, iter },
0,
max_length_check,
item_validator,
state,
false,
)?,
};

min_length_check!(input, "Iterable", self.min_length, vec);

vec.into_py_any(py).map_err(Into::into)
}

fn get_name(&self) -> &str {
&self.name
}
}

struct IterWithPy<'py, I> {
py: Python<'py>,
iter: I,
}

impl<'py> Iterator for IterWithPy<'py, GenericPyIterator> {
type Item = PyResult<Bound<'py, PyAny>>;

fn next(&mut self) -> Option<Self::Item> {
Some(self.iter.next(self.py).transpose()?.map(|(v, _)| v))
}
}

impl<'j> Iterator for IterWithPy<'_, GenericJsonIterator<'j>> {
type Item = PyResult<JsonValue<'j>>;

fn next(&mut self) -> Option<Self::Item> {
Some(self.iter.next(self.py).transpose()?.map(|(v, _)| v.clone()))
}
}
5 changes: 5 additions & 0 deletions src/validators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ mod generator;
mod int;
mod is_instance;
mod is_subclass;
mod iterable;
mod json;
mod json_or_python;
mod lax_or_strict;
Expand Down Expand Up @@ -645,6 +646,8 @@ fn build_validator_inner(
json_or_python::JsonOrPython,
// generator validators
generator::GeneratorValidator,
// iterables
iterable::IterableValidator,
// custom error
custom_error::CustomErrorValidator,
// json data
Expand Down Expand Up @@ -822,6 +825,8 @@ pub enum CombinedValidator {
LaxOrStrict(lax_or_strict::LaxOrStrictValidator),
// generator validators
Generator(generator::GeneratorValidator),
// iterables
Iterable(iterable::IterableValidator),
// custom error
CustomError(custom_error::CustomErrorValidator),
// json data
Expand Down
44 changes: 28 additions & 16 deletions tests/validators/test_generator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
from typing import Callable

import pytest
from dirty_equals import HasRepr, IsStr
Expand All @@ -9,6 +10,17 @@
from ..conftest import Err, PyAndJson


@pytest.fixture(params=['generator', 'iterable'])
def schema_type(request):
# both generator and (lazy) iterable should behave the same
return request.param


@pytest.fixture(params=[cs.generator_schema, cs.iterable_schema])
def schema_func(request):
return request.param


@pytest.mark.parametrize(
'input_value,expected',
[
Expand All @@ -21,8 +33,8 @@
],
ids=repr,
)
def test_generator_json_int(py_and_json: PyAndJson, input_value, expected):
v = py_and_json({'type': 'generator', 'items_schema': {'type': 'int'}})
def test_generator_json_int(schema_type: str, py_and_json: PyAndJson, input_value, expected):
v = py_and_json({'type': schema_type, 'items_schema': {'type': 'int'}})
if isinstance(expected, Err):
with pytest.raises(ValidationError, match=re.escape(expected.message)):
list(v.validate_test(input_value))
Expand All @@ -39,8 +51,8 @@ def test_generator_json_int(py_and_json: PyAndJson, input_value, expected):
(CoreConfig(hide_input_in_errors=True), 'type=iterable_type'),
),
)
def test_generator_json_hide_input(py_and_json: PyAndJson, config, input_str):
v = py_and_json({'type': 'generator', 'items_schema': {'type': 'int'}}, config)
def test_generator_json_hide_input(schema_type: str, py_and_json: PyAndJson, config, input_str):
v = py_and_json({'type': schema_type, 'items_schema': {'type': 'int'}}, config)
with pytest.raises(ValidationError, match=re.escape(f'[{input_str}]')):
list(v.validate_test(5))

Expand All @@ -57,8 +69,8 @@ def test_generator_json_hide_input(py_and_json: PyAndJson, config, input_str):
],
ids=repr,
)
def test_generator_json_any(py_and_json: PyAndJson, input_value, expected):
v = py_and_json({'type': 'generator'})
def test_generator_json_any(schema_type: str, py_and_json: PyAndJson, input_value, expected):
v = py_and_json({'type': schema_type})
if isinstance(expected, Err):
with pytest.raises(ValidationError, match=re.escape(expected.message)):
list(v.validate_test(input_value))
Expand All @@ -67,8 +79,8 @@ def test_generator_json_any(py_and_json: PyAndJson, input_value, expected):
assert list(v.validate_test(input_value)) == expected


def test_error_index(py_and_json: PyAndJson):
v = py_and_json({'type': 'generator', 'items_schema': {'type': 'int'}})
def test_error_index(schema_type: str, py_and_json: PyAndJson):
v = py_and_json({'type': schema_type, 'items_schema': {'type': 'int'}})
gen = v.validate_test(['wrong'])
assert gen.index == 0
with pytest.raises(ValidationError) as exc_info:
Expand Down Expand Up @@ -108,8 +120,8 @@ def test_error_index(py_and_json: PyAndJson):
assert gen.index == 5


def test_too_long(py_and_json: PyAndJson):
v = py_and_json({'type': 'generator', 'items_schema': {'type': 'int'}, 'max_length': 2})
def test_too_long(schema_type: str, py_and_json: PyAndJson):
v = py_and_json({'type': schema_type, 'items_schema': {'type': 'int'}, 'max_length': 2})
assert list(v.validate_test([1])) == [1]
assert list(v.validate_test([1, 2])) == [1, 2]
with pytest.raises(ValidationError) as exc_info:
Expand All @@ -126,8 +138,8 @@ def test_too_long(py_and_json: PyAndJson):
]


def test_too_short(py_and_json: PyAndJson):
v = py_and_json({'type': 'generator', 'items_schema': {'type': 'int'}, 'min_length': 2})
def test_too_short(schema_type: str, py_and_json: PyAndJson):
v = py_and_json({'type': schema_type, 'items_schema': {'type': 'int'}, 'min_length': 2})
assert list(v.validate_test([1, 2, 3])) == [1, 2, 3]
assert list(v.validate_test([1, 2])) == [1, 2]
with pytest.raises(ValidationError) as exc_info:
Expand All @@ -150,8 +162,8 @@ def gen():
yield 3


def test_generator_too_long():
v = SchemaValidator(cs.generator_schema(items_schema=cs.int_schema(), max_length=2))
def test_generator_too_long(schema_func: Callable):
v = SchemaValidator(schema_func(items_schema=cs.int_schema(), max_length=2))

validating_iterator = v.validate_python(gen())

Expand All @@ -174,8 +186,8 @@ def test_generator_too_long():
]


def test_generator_too_short():
v = SchemaValidator(cs.generator_schema(items_schema=cs.int_schema(), min_length=4))
def test_generator_too_short(schema_func: Callable):
v = SchemaValidator(schema_func(items_schema=cs.int_schema(), min_length=4))

validating_iterator = v.validate_python(gen())

Expand Down
Loading
Loading