Skip to content

Commit 8288816

Browse files
Make all tests but fastapi work
1 parent 43d5d41 commit 8288816

File tree

1 file changed

+27
-6
lines changed

1 file changed

+27
-6
lines changed

sqlmodel/main.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
ForwardRef,
1515
List,
1616
Mapping,
17-
NoneType,
1817
Optional,
1918
Sequence,
2019
Set,
@@ -48,6 +47,7 @@
4847

4948
_T = TypeVar("_T")
5049
NoArgAnyCallable = Callable[[], Any]
50+
NoneType = type(None)
5151

5252

5353
def __dataclass_transform__(
@@ -273,13 +273,17 @@ def __new__(
273273
key: pydantic_kwargs.pop(key)
274274
for key in pydantic_kwargs.keys() & allowed_config_kwargs
275275
}
276-
config_table = getattr(class_dict.get("Config", object()), "table", False)
276+
config_table = getattr(class_dict.get("Config", object()), "table", False) or kwargs.get("table", False)
277277
# If we have a table, we need to have defaults for all fields
278278
# Pydantic v2 sets a __pydantic_core_schema__ which is very hard to change. Changing the fields does not do anything
279279
if config_table is True:
280-
for key in original_annotations.keys():
281-
if dict_used.get(key, PydanticUndefined) is PydanticUndefined:
280+
for key in pydantic_annotations.keys():
281+
value = dict_used.get(key, PydanticUndefined)
282+
if value is PydanticUndefined:
282283
dict_used[key] = None
284+
elif isinstance(value, FieldInfo):
285+
if value.default is PydanticUndefined and value.default_factory is None:
286+
value.default = None
283287

284288
new_cls: Type["SQLModelMetaclass"] = super().__new__(
285289
cls, name, bases, dict_used, **config_kwargs
@@ -349,8 +353,11 @@ def __init__(
349353
continue
350354
ann = cls.__annotations__[rel_name]
351355
relationship_to = get_origin(ann)
352-
# If Union (Optional), get the real field
353-
if relationship_to is Union:
356+
# Direct relationships (e.g. 'Team' or Team) have None as an origin
357+
if relationship_to is None:
358+
relationship_to = ann
359+
# If Union (e.g. Optional), get the real field
360+
elif relationship_to is Union:
354361
relationship_to = get_args(ann)[0]
355362
# If a list, then also get the real field
356363
elif relationship_to is list:
@@ -501,6 +508,16 @@ class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry
501508
__allow_unmapped__ = True # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#migration-20-step-six
502509
model_config = SQLModelConfig(from_attributes=True)
503510

511+
def __new__(cls, *args: Any, **kwargs: Any) -> Any:
512+
new_object = super().__new__(cls)
513+
# SQLAlchemy doesn't call __init__ on the base class
514+
# Ref: https://docs.sqlalchemy.org/en/14/orm/constructors.html
515+
# Set __fields_set__ here, that would have been set when calling __init__
516+
# in the Pydantic model so that when SQLAlchemy sets attributes that are
517+
# added (e.g. when querying from DB) to the __fields_set__, this already exists
518+
object.__setattr__(new_object, "__pydantic_fields_set__", set())
519+
return new_object
520+
504521
def __init__(__pydantic_self__, **data: Any) -> None:
505522
old_dict = __pydantic_self__.__dict__.copy()
506523
super().__init__(**data)
@@ -531,6 +548,10 @@ def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]:
531548
def __tablename__(cls) -> str:
532549
return cls.__name__.lower()
533550

551+
@classmethod
552+
def model_validate(cls, *args, **kwargs):
553+
return super().model_validate(*args, **kwargs)
554+
534555

535556
def _is_field_noneable(field: FieldInfo) -> bool:
536557
if not field.is_required():

0 commit comments

Comments
 (0)