@@ -108,13 +108,14 @@ class SharedStateRenderer(SingleStateRenderer):
108108 """
109109
110110 _render_task : asyncio .Task
111- _join_event : asyncio .Event
112111
113112 def __init__ (self , layout : AbstractLayout ) -> None :
114113 super ().__init__ (layout )
115114 self ._models : Dict [str , Dict [str , Any ]] = {}
116115 self ._updates : Dict [str , asyncio .Queue [LayoutUpdate ]] = {}
117116 self ._task_group = create_task_group ()
117+ self ._join_event = asyncio .Event ()
118+ self ._active = False
118119 self ._joining = False
119120
120121 async def start (self ):
@@ -124,10 +125,9 @@ async def join(self):
124125 await self .__aexit__ (None , None , None )
125126
126127 async def __aenter__ (self ):
127- if hasattr ( self , "_join_event" ) :
128+ if self . _active :
128129 raise RuntimeError ("Renderer already active" )
129- self ._join_event = asyncio .Event ()
130-
130+ self ._active = True
131131 await self ._task_group .__aenter__ ()
132132 self ._render_task = asyncio .ensure_future (self ._render_loop (), loop = self .loop )
133133 return self
@@ -138,23 +138,28 @@ async def __aexit__(
138138 exc_val : Optional [BaseException ],
139139 exc_tb : Optional [TracebackType ],
140140 ) -> None :
141- if not self ._joining :
142- self ._joining = True
143- try :
144- await self ._task_group .__aexit__ (exc_type , exc_val , exc_tb )
145- finally :
146- self ._render_task .cancel ()
147- self ._join_event .set ()
148- else :
149- await self ._join_event .wait ()
141+ try :
142+ if not self ._joining :
143+ self ._joining = True
144+ try :
145+ await self ._task_group .__aexit__ (exc_type , exc_val , exc_tb )
146+ finally :
147+ self ._render_task .cancel ()
148+ self ._join_event .set ()
149+ self ._join_event .clear ()
150+ else :
151+ await self ._join_event .wait ()
152+ finally :
153+ self ._active = False
154+ self ._joining = False
150155
151156 async def run (
152157 self , send : SendCoroutine , recv : RecvCoroutine , context : str , join : bool = False
153158 ) -> None :
154159 self ._updates [context ] = asyncio .Queue ()
155160 await self ._task_group .spawn (super ().run , send , recv , context )
156161 if join :
157- self .join ()
162+ await self .join ()
158163
159164 async def _render_loop (self ) -> None :
160165 while True :
0 commit comments