4747
4848model_registry = {}
4949_T = TypeVar ("_T" )
50+ Model = TypeVar ("Model" , bound = "RedisModel" )
5051log = logging .getLogger (__name__ )
5152escaper = TokenEscaper ()
5253
@@ -1160,16 +1161,16 @@ async def delete(
11601161 return await cls ._delete (db , cls .make_primary_key (pk ))
11611162
11621163 @classmethod
1163- async def get (cls , pk : Any ) -> "RedisModel " :
1164+ async def get (cls : Type [ "Model" ] , pk : Any ) -> "Model " :
11641165 raise NotImplementedError
11651166
11661167 async def update (self , ** field_values ):
11671168 """Update this model instance with the specified key-value pairs."""
11681169 raise NotImplementedError
11691170
11701171 async def save (
1171- self , pipeline : Optional [redis .client .Pipeline ] = None
1172- ) -> "RedisModel " :
1172+ self : "Model" , pipeline : Optional [redis .client .Pipeline ] = None
1173+ ) -> "Model " :
11731174 raise NotImplementedError
11741175
11751176 async def expire (
@@ -1266,11 +1267,11 @@ def get_annotations(cls):
12661267
12671268 @classmethod
12681269 async def add (
1269- cls ,
1270- models : Sequence ["RedisModel " ],
1270+ cls : Type [ "Model" ] ,
1271+ models : Sequence ["Model " ],
12711272 pipeline : Optional [redis .client .Pipeline ] = None ,
12721273 pipeline_verifier : Callable [..., Any ] = verify_pipeline_response ,
1273- ) -> Sequence ["RedisModel " ]:
1274+ ) -> Sequence ["Model " ]:
12741275 db = cls ._get_db (pipeline , bulk = True )
12751276
12761277 for model in models :
@@ -1345,8 +1346,8 @@ def __init_subclass__(cls, **kwargs):
13451346 )
13461347
13471348 async def save (
1348- self , pipeline : Optional [redis .client .Pipeline ] = None
1349- ) -> "HashModel " :
1349+ self : "Model" , pipeline : Optional [redis .client .Pipeline ] = None
1350+ ) -> "Model " :
13501351 self .check ()
13511352 db = self ._get_db (pipeline )
13521353
@@ -1368,7 +1369,7 @@ async def all_pks(cls): # type: ignore
13681369 )
13691370
13701371 @classmethod
1371- async def get (cls , pk : Any ) -> "HashModel " :
1372+ async def get (cls : Type [ "Model" ] , pk : Any ) -> "Model " :
13721373 document = await cls .db ().hgetall (cls .make_primary_key (pk ))
13731374 if not document :
13741375 raise NotFoundError
@@ -1513,8 +1514,8 @@ def __init__(self, *args, **kwargs):
15131514 super ().__init__ (* args , ** kwargs )
15141515
15151516 async def save (
1516- self , pipeline : Optional [redis .client .Pipeline ] = None
1517- ) -> "JsonModel " :
1517+ self : "Model" , pipeline : Optional [redis .client .Pipeline ] = None
1518+ ) -> "Model " :
15181519 self .check ()
15191520 db = self ._get_db (pipeline )
15201521
@@ -1559,7 +1560,7 @@ async def update(self, **field_values):
15591560 await self .save ()
15601561
15611562 @classmethod
1562- async def get (cls , pk : Any ) -> "JsonModel " :
1563+ async def get (cls : Type [ "Model" ] , pk : Any ) -> "Model " :
15631564 document = json .dumps (await cls .db ().json ().get (cls .make_key (pk )))
15641565 if document == "null" :
15651566 raise NotFoundError
0 commit comments