diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 4d6d2f2712..e971566e23 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -294,7 +294,7 @@ def get_config(name: str) -> Any: # If it was passed by kwargs, ensure it's also set in config new_cls.__config__.table = config_table for k, v in new_cls.__fields__.items(): - col = get_column_from_field(v) + col = get_column_from_field(v, cls=new_cls) setattr(new_cls, k, col) # Set a config flag to tell FastAPI that this should be read with a field # in orm_mode instead of preemptively converting it to a dict. @@ -330,7 +330,7 @@ def __init__( if getattr(cls.__config__, "table", False) and not base_is_table: dict_used = dict_.copy() for field_name, field_value in cls.__fields__.items(): - dict_used[field_name] = get_column_from_field(field_value) + dict_used[field_name] = get_column_from_field(field_value, cls=cls) for rel_name, rel_info in cls.__sqlmodel_relationships__.items(): if rel_info.sa_relationship: # There's a SQLAlchemy relationship declared, that takes precedence @@ -369,6 +369,8 @@ def __init__( relationship_to, *rel_args, **rel_kwargs ) dict_used[rel_name] = rel_value + # From https://github.com/tiangolo/sqlmodel/pull/322 + setattr(cls, rel_name, rel_value) # Fix #315 DeclarativeMeta.__init__(cls, classname, bases, dict_used, **kw) else: ModelMetaclass.__init__(cls, classname, bases, dict_, **kw) @@ -416,7 +418,12 @@ def get_sqlachemy_type(field: ModelField) -> Any: return GUID -def get_column_from_field(field: ModelField) -> Column: # type: ignore +_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel") + + +def get_column_from_field( + field: ModelField, cls: Union[_TSQLModel, SQLModelMetaclass] +) -> Column: # type: ignore sa_column = getattr(field.field_info, "sa_column", Undefined) if isinstance(sa_column, Column): return sa_column @@ -433,7 +440,12 @@ def get_column_from_field(field: ModelField) -> Column: # type: ignore args = [] foreign_key = getattr(field.field_info, "foreign_key", None) if foreign_key: - args.append(ForeignKey(foreign_key)) + tablename = getattr(cls, "__tablename__", None) + if tablename is not None: + fk_name = f"{tablename}_{field.name}_fkey" + args.append(ForeignKey(foreign_key, name=fk_name)) + else: + args.append(ForeignKey(foreign_key)) kwargs = { "primary_key": primary_key, "nullable": nullable, @@ -466,9 +478,6 @@ def _value_items_is_true(v: Any) -> bool: return v is True or v is ... -_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel") - - class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry): # SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values __slots__ = ("__weakref__",)