|
| 1 | +from typing import ( |
| 2 | + TYPE_CHECKING, |
| 3 | + Any, |
| 4 | + Callable, |
| 5 | + Dict, |
| 6 | + ForwardRef, |
| 7 | + Optional, |
| 8 | + Type, |
| 9 | + TypeVar, |
| 10 | + Union, |
| 11 | + get_args, |
| 12 | + get_origin, |
| 13 | +) |
| 14 | + |
| 15 | +from pydantic import VERSION as PYDANTIC_VERSION |
| 16 | + |
| 17 | +IS_PYDANTIC_V2 = int(PYDANTIC_VERSION.split(".")[0]) >= 2 |
| 18 | + |
| 19 | + |
| 20 | +if IS_PYDANTIC_V2: |
| 21 | + from pydantic import ConfigDict |
| 22 | + from pydantic_core import PydanticUndefined as PydanticUndefined, PydanticUndefinedType as PydanticUndefinedType # noqa |
| 23 | +else: |
| 24 | + from pydantic import BaseConfig # noqa |
| 25 | + from pydantic.fields import ModelField # noqa |
| 26 | + from pydantic.fields import Undefined as PydanticUndefined, UndefinedType as PydanticUndefinedType # noqa |
| 27 | + |
| 28 | +if TYPE_CHECKING: |
| 29 | + from .main import FieldInfo, RelationshipInfo, SQLModel, SQLModelMetaclass |
| 30 | + |
| 31 | + |
| 32 | +NoArgAnyCallable = Callable[[], Any] |
| 33 | +T = TypeVar("T") |
| 34 | +InstanceOrType = Union[T, Type[T]] |
| 35 | + |
| 36 | +if IS_PYDANTIC_V2: |
| 37 | + |
| 38 | + class SQLModelConfig(ConfigDict, total=False): |
| 39 | + table: Optional[bool] |
| 40 | + read_from_attributes: Optional[bool] |
| 41 | + registry: Optional[Any] |
| 42 | + |
| 43 | +else: |
| 44 | + |
| 45 | + class SQLModelConfig(BaseConfig): |
| 46 | + table: Optional[bool] = None |
| 47 | + read_from_attributes: Optional[bool] = None |
| 48 | + registry: Optional[Any] = None |
| 49 | + |
| 50 | + def __getitem__(self, item: str) -> Any: |
| 51 | + return self.__getattr__(item) |
| 52 | + |
| 53 | + def __setitem__(self, item: str, value: Any) -> None: |
| 54 | + return self.__setattr__(item, value) |
| 55 | + |
| 56 | + |
| 57 | +# Inspired from https://github.com/roman-right/beanie/blob/main/beanie/odm/utils/pydantic.py |
| 58 | +def get_model_config(model: type) -> Optional[SQLModelConfig]: |
| 59 | + if IS_PYDANTIC_V2: |
| 60 | + return getattr(model, "model_config", None) |
| 61 | + else: |
| 62 | + return getattr(model, "Config", None) |
| 63 | + |
| 64 | + |
| 65 | +def get_config_value( |
| 66 | + model: InstanceOrType["SQLModel"], parameter: str, default: Any = None |
| 67 | +) -> Any: |
| 68 | + if IS_PYDANTIC_V2: |
| 69 | + return model.model_config.get(parameter, default) |
| 70 | + else: |
| 71 | + return getattr(model.Config, parameter, default) |
| 72 | + |
| 73 | + |
| 74 | +def set_config_value( |
| 75 | + model: InstanceOrType["SQLModel"], parameter: str, value: Any, v1_parameter: str = None |
| 76 | +) -> None: |
| 77 | + if IS_PYDANTIC_V2: |
| 78 | + model.model_config[parameter] = value # type: ignore |
| 79 | + else: |
| 80 | + model.Config[v1_parameter or parameter] = value # type: ignore |
| 81 | + |
| 82 | + |
| 83 | +def get_model_fields(model: InstanceOrType["SQLModel"]) -> Dict[str, "FieldInfo"]: |
| 84 | + if IS_PYDANTIC_V2: |
| 85 | + return model.model_fields # type: ignore |
| 86 | + else: |
| 87 | + return model.__fields__ # type: ignore |
| 88 | + |
| 89 | + |
| 90 | +def get_fields_set(model: InstanceOrType["SQLModel"]) -> set[str]: |
| 91 | + if IS_PYDANTIC_V2: |
| 92 | + return model.__pydantic_fields_set__ |
| 93 | + else: |
| 94 | + return model.__fields_set__ # type: ignore |
| 95 | + |
| 96 | + |
| 97 | +def set_fields_set( |
| 98 | + new_object: InstanceOrType["SQLModel"], fields: set["FieldInfo"] |
| 99 | +) -> None: |
| 100 | + if IS_PYDANTIC_V2: |
| 101 | + object.__setattr__(new_object, "__pydantic_fields_set__", fields) |
| 102 | + else: |
| 103 | + object.__setattr__(new_object, "__fields_set__", fields) |
| 104 | + |
| 105 | + |
| 106 | +def set_attribute_mode(cls: Type["SQLModelMetaclass"]) -> None: |
| 107 | + if IS_PYDANTIC_V2: |
| 108 | + cls.model_config["read_from_attributes"] = True |
| 109 | + else: |
| 110 | + cls.__config__.read_with_orm_mode = True # type: ignore |
| 111 | + |
| 112 | + |
| 113 | +def get_relationship_to( |
| 114 | + name: str, |
| 115 | + rel_info: "RelationshipInfo", |
| 116 | + annotation: Any, |
| 117 | +) -> Any: |
| 118 | + if IS_PYDANTIC_V2: |
| 119 | + relationship_to = get_origin(annotation) |
| 120 | + # Direct relationships (e.g. 'Team' or Team) have None as an origin |
| 121 | + if relationship_to is None: |
| 122 | + relationship_to = annotation |
| 123 | + # If Union (e.g. Optional), get the real field |
| 124 | + elif relationship_to is Union: |
| 125 | + relationship_to = get_args(annotation)[0] |
| 126 | + # If a list, then also get the real field |
| 127 | + elif relationship_to is list: |
| 128 | + relationship_to = get_args(annotation)[0] |
| 129 | + if isinstance(relationship_to, ForwardRef): |
| 130 | + relationship_to = relationship_to.__forward_arg__ |
| 131 | + return relationship_to |
| 132 | + else: |
| 133 | + temp_field = ModelField.infer( |
| 134 | + name=name, |
| 135 | + value=rel_info, |
| 136 | + annotation=annotation, |
| 137 | + class_validators=None, |
| 138 | + config=SQLModelConfig, |
| 139 | + ) |
| 140 | + relationship_to = temp_field.type_ |
| 141 | + if isinstance(temp_field.type_, ForwardRef): |
| 142 | + relationship_to = temp_field.type_.__forward_arg__ |
| 143 | + return relationship_to |
| 144 | + |
| 145 | + |
| 146 | +def set_empty_defaults(annotations: Dict[str, Any], class_dict: Dict[str, Any]) -> None: |
| 147 | + """ |
| 148 | + Pydantic v2 without required fields with no optionals cannot do empty initialisations. |
| 149 | + This means we cannot do Table() and set fields later. |
| 150 | + We go around this by adding a default to everything, being None |
| 151 | +
|
| 152 | + Args: |
| 153 | + annotations: Dict[str, Any]: The annotations to provide to pydantic |
| 154 | + class_dict: Dict[str, Any]: The class dict for the defaults |
| 155 | + """ |
| 156 | + if IS_PYDANTIC_V2: |
| 157 | + from .main import FieldInfo |
| 158 | + # Pydantic v2 sets a __pydantic_core_schema__ which is very hard to change. Changing the fields does not do anything |
| 159 | + for key in annotations.keys(): |
| 160 | + value = class_dict.get(key, PydanticUndefined) |
| 161 | + if value is PydanticUndefined: |
| 162 | + class_dict[key] = None |
| 163 | + elif isinstance(value, FieldInfo): |
| 164 | + if ( |
| 165 | + value.default in (PydanticUndefined, Ellipsis) |
| 166 | + ) and value.default_factory is None: |
| 167 | + # So we can check for nullable |
| 168 | + value.original_default = value.default |
| 169 | + value.default = None |
0 commit comments