4747
4848model_registry = {}
4949_T = TypeVar ("_T" )
50+ Model = TypeVar ("Model" , bound = "RedisModel" )
5051log = logging .getLogger (__name__ )
5152escaper = TokenEscaper ()
5253
@@ -1152,16 +1153,16 @@ async def delete(
11521153 return await cls ._delete (db , cls .make_primary_key (pk ))
11531154
11541155 @classmethod
1155- async def get (cls , pk : Any ) -> "RedisModel " :
1156+ async def get (cls : Type [ "Model" ] , pk : Any ) -> "Model " :
11561157 raise NotImplementedError
11571158
11581159 async def update (self , ** field_values ):
11591160 """Update this model instance with the specified key-value pairs."""
11601161 raise NotImplementedError
11611162
11621163 async def save (
1163- self , pipeline : Optional [redis .client .Pipeline ] = None
1164- ) -> "RedisModel " :
1164+ self : "Model" , pipeline : Optional [redis .client .Pipeline ] = None
1165+ ) -> "Model " :
11651166 raise NotImplementedError
11661167
11671168 async def expire (
@@ -1258,11 +1259,11 @@ def get_annotations(cls):
12581259
12591260 @classmethod
12601261 async def add (
1261- cls ,
1262- models : Sequence ["RedisModel " ],
1262+ cls : Type [ "Model" ] ,
1263+ models : Sequence ["Model " ],
12631264 pipeline : Optional [redis .client .Pipeline ] = None ,
12641265 pipeline_verifier : Callable [..., Any ] = verify_pipeline_response ,
1265- ) -> Sequence ["RedisModel " ]:
1266+ ) -> Sequence ["Model " ]:
12661267 db = cls ._get_db (pipeline , bulk = True )
12671268
12681269 for model in models :
@@ -1337,8 +1338,8 @@ def __init_subclass__(cls, **kwargs):
13371338 )
13381339
13391340 async def save (
1340- self , pipeline : Optional [redis .client .Pipeline ] = None
1341- ) -> "HashModel " :
1341+ self : "Model" , pipeline : Optional [redis .client .Pipeline ] = None
1342+ ) -> "Model " :
13421343 self .check ()
13431344 db = self ._get_db (pipeline )
13441345
@@ -1364,7 +1365,7 @@ async def all_pks(cls): # type: ignore
13641365 )
13651366
13661367 @classmethod
1367- async def get (cls , pk : Any ) -> "HashModel " :
1368+ async def get (cls : Type [ "Model" ] , pk : Any ) -> "Model " :
13681369 document = await cls .db ().hgetall (cls .make_primary_key (pk ))
13691370 if not document :
13701371 raise NotFoundError
@@ -1509,8 +1510,8 @@ def __init__(self, *args, **kwargs):
15091510 super ().__init__ (* args , ** kwargs )
15101511
15111512 async def save (
1512- self , pipeline : Optional [redis .client .Pipeline ] = None
1513- ) -> "JsonModel " :
1513+ self : "Model" , pipeline : Optional [redis .client .Pipeline ] = None
1514+ ) -> "Model " :
15141515 self .check ()
15151516 db = self ._get_db (pipeline )
15161517
@@ -1559,7 +1560,7 @@ async def update(self, **field_values):
15591560 await self .save ()
15601561
15611562 @classmethod
1562- async def get (cls , pk : Any ) -> "JsonModel " :
1563+ async def get (cls : Type [ "Model" ] , pk : Any ) -> "Model " :
15631564 document = json .dumps (await cls .db ().json ().get (cls .make_key (pk )))
15641565 if document == "null" :
15651566 raise NotFoundError
0 commit comments