From 1cf2965acc7ada720878253ff62359f267aea666 Mon Sep 17 00:00:00 2001 From: Kevin Lane Date: Sun, 26 Jun 2022 16:04:54 -0700 Subject: [PATCH] Add argument for SQLAlchemy type kwargs --- sqlmodel/main.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 4d6d2f2712..9feb5fb7e0 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -35,6 +35,7 @@ Column, Date, DateTime, + Enum as SAEnum, Float, ForeignKey, Integer, @@ -73,6 +74,7 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: sa_column = kwargs.pop("sa_column", Undefined) sa_column_args = kwargs.pop("sa_column_args", Undefined) sa_column_kwargs = kwargs.pop("sa_column_kwargs", Undefined) + sa_type_kwargs = kwargs.pop("sa_type_kwargs", Undefined) if sa_column is not Undefined: if sa_column_args is not Undefined: raise RuntimeError( @@ -92,6 +94,7 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: self.sa_column = sa_column self.sa_column_args = sa_column_args self.sa_column_kwargs = sa_column_kwargs + self.sa_type_kwargs = sa_type_kwargs class RelationshipInfo(Representation): @@ -154,6 +157,7 @@ def Field( sa_column: Union[Column, UndefinedType] = Undefined, # type: ignore sa_column_args: Union[Sequence[Any], UndefinedType] = Undefined, sa_column_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, + sa_type_kwargs: Union[Mapping[str, Any], UndefinedType] = Undefined, schema_extra: Optional[Dict[str, Any]] = None, ) -> Any: current_schema_extra = schema_extra or {} @@ -184,6 +188,7 @@ def Field( sa_column=sa_column, sa_column_args=sa_column_args, sa_column_kwargs=sa_column_kwargs, + sa_type_kwargs=sa_type_kwargs, **current_schema_extra, ) field_info._validate() @@ -375,28 +380,31 @@ def __init__( def get_sqlachemy_type(field: ModelField) -> Any: + sa_type_kwargs = getattr(field.field_info, "sa_type_kwargs", Undefined) + if sa_type_kwargs is Undefined: + sa_type_kwargs = {} + if issubclass(field.type_, Enum): + return SAEnum(field.type_, **sa_type_kwargs) if issubclass(field.type_, str): if field.field_info.max_length: return AutoString(length=field.field_info.max_length) return AutoString if issubclass(field.type_, float): - return Float + return Float(**sa_type_kwargs) if issubclass(field.type_, bool): - return Boolean + return Boolean(**sa_type_kwargs) if issubclass(field.type_, int): return Integer if issubclass(field.type_, datetime): - return DateTime + return DateTime(**sa_type_kwargs) if issubclass(field.type_, date): return Date if issubclass(field.type_, timedelta): - return Interval + return Interval(**sa_type_kwargs) if issubclass(field.type_, time): - return Time - if issubclass(field.type_, Enum): - return Enum + return Time(**sa_type_kwargs) if issubclass(field.type_, bytes): - return LargeBinary + return LargeBinary(**sa_type_kwargs) if issubclass(field.type_, Decimal): return Numeric( precision=getattr(field.type_, "max_digits", None),