4747
4848model_registry = {}
4949_T = TypeVar ("_T" )
50+ Model = TypeVar ("Model" , bound = "RedisModel" )
5051log = logging .getLogger (__name__ )
5152escaper = TokenEscaper ()
5253
@@ -1310,16 +1311,16 @@ async def delete(
13101311 return await cls ._delete (db , cls .make_primary_key (pk ))
13111312
13121313 @classmethod
1313- async def get (cls , pk : Any ) -> "RedisModel " :
1314+ async def get (cls : Type [ "Model" ] , pk : Any ) -> "Model " :
13141315 raise NotImplementedError
13151316
13161317 async def update (self , ** field_values ):
13171318 """Update this model instance with the specified key-value pairs."""
13181319 raise NotImplementedError
13191320
13201321 async def save (
1321- self , pipeline : Optional [redis .client .Pipeline ] = None
1322- ) -> "RedisModel " :
1322+ self : "Model" , pipeline : Optional [redis .client .Pipeline ] = None
1323+ ) -> "Model " :
13231324 raise NotImplementedError
13241325
13251326 async def expire (
@@ -1423,11 +1424,11 @@ def get_annotations(cls):
14231424
14241425 @classmethod
14251426 async def add (
1426- cls ,
1427- models : Sequence ["RedisModel " ],
1427+ cls : Type [ "Model" ] ,
1428+ models : Sequence ["Model " ],
14281429 pipeline : Optional [redis .client .Pipeline ] = None ,
14291430 pipeline_verifier : Callable [..., Any ] = verify_pipeline_response ,
1430- ) -> Sequence ["RedisModel " ]:
1431+ ) -> Sequence ["Model " ]:
14311432 db = cls ._get_db (pipeline , bulk = True )
14321433
14331434 for model in models :
@@ -1502,8 +1503,8 @@ def __init_subclass__(cls, **kwargs):
15021503 )
15031504
15041505 async def save (
1505- self , pipeline : Optional [redis .client .Pipeline ] = None
1506- ) -> "HashModel " :
1506+ self : "Model" , pipeline : Optional [redis .client .Pipeline ] = None
1507+ ) -> "Model " :
15071508 self .check ()
15081509 db = self ._get_db (pipeline )
15091510
@@ -1525,7 +1526,7 @@ async def all_pks(cls): # type: ignore
15251526 )
15261527
15271528 @classmethod
1528- async def get (cls , pk : Any ) -> "HashModel " :
1529+ async def get (cls : Type [ "Model" ] , pk : Any ) -> "Model " :
15291530 document = await cls .db ().hgetall (cls .make_primary_key (pk ))
15301531 if not document :
15311532 raise NotFoundError
@@ -1676,8 +1677,8 @@ def __init__(self, *args, **kwargs):
16761677 super ().__init__ (* args , ** kwargs )
16771678
16781679 async def save (
1679- self , pipeline : Optional [redis .client .Pipeline ] = None
1680- ) -> "JsonModel " :
1680+ self : "Model" , pipeline : Optional [redis .client .Pipeline ] = None
1681+ ) -> "Model " :
16811682 self .check ()
16821683 db = self ._get_db (pipeline )
16831684
@@ -1722,7 +1723,7 @@ async def update(self, **field_values):
17221723 await self .save ()
17231724
17241725 @classmethod
1725- async def get (cls , pk : Any ) -> "JsonModel " :
1726+ async def get (cls : Type [ "Model" ] , pk : Any ) -> "Model " :
17261727 document = json .dumps (await cls .db ().json ().get (cls .make_key (pk )))
17271728 if document == "null" :
17281729 raise NotFoundError
0 commit comments