Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def get_config(name: str) -> Any:
# If it was passed by kwargs, ensure it's also set in config
new_cls.__config__.table = config_table
for k, v in new_cls.__fields__.items():
col = get_column_from_field(v)
col = get_column_from_field(v, cls=new_cls)
setattr(new_cls, k, col)
# Set a config flag to tell FastAPI that this should be read with a field
# in orm_mode instead of preemptively converting it to a dict.
Expand Down Expand Up @@ -330,7 +330,7 @@ def __init__(
if getattr(cls.__config__, "table", False) and not base_is_table:
dict_used = dict_.copy()
for field_name, field_value in cls.__fields__.items():
dict_used[field_name] = get_column_from_field(field_value)
dict_used[field_name] = get_column_from_field(field_value, cls=cls)
for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
if rel_info.sa_relationship:
# There's a SQLAlchemy relationship declared, that takes precedence
Expand Down Expand Up @@ -369,6 +369,8 @@ def __init__(
relationship_to, *rel_args, **rel_kwargs
)
dict_used[rel_name] = rel_value
# From https://github.com/tiangolo/sqlmodel/pull/322
setattr(cls, rel_name, rel_value) # Fix #315
DeclarativeMeta.__init__(cls, classname, bases, dict_used, **kw)
else:
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
Expand Down Expand Up @@ -416,7 +418,12 @@ def get_sqlachemy_type(field: ModelField) -> Any:
return GUID


def get_column_from_field(field: ModelField) -> Column: # type: ignore
_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel")


def get_column_from_field(
field: ModelField, cls: Union[_TSQLModel, SQLModelMetaclass]
) -> Column: # type: ignore
sa_column = getattr(field.field_info, "sa_column", Undefined)
if isinstance(sa_column, Column):
return sa_column
Expand All @@ -433,7 +440,12 @@ def get_column_from_field(field: ModelField) -> Column: # type: ignore
args = []
foreign_key = getattr(field.field_info, "foreign_key", None)
if foreign_key:
args.append(ForeignKey(foreign_key))
tablename = getattr(cls, "__tablename__", None)
if tablename is not None:
fk_name = f"{tablename}_{field.name}_fkey"
args.append(ForeignKey(foreign_key, name=fk_name))
else:
args.append(ForeignKey(foreign_key))
kwargs = {
"primary_key": primary_key,
"nullable": nullable,
Expand Down Expand Up @@ -466,9 +478,6 @@ def _value_items_is_true(v: Any) -> bool:
return v is True or v is ...


_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel")


class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry):
# SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
__slots__ = ("__weakref__",)
Expand Down