77
88import asyncio
99import collections
10+ from collections .abc import Callable
1011import enum
1112import functools
1213import getpass
@@ -802,9 +803,21 @@ def connection_lost(self, exc: typing.Optional[Exception]) -> None:
802803 exc = ConnectionError ('unexpected connection_lost() call' )
803804 self .on_data .set_exception (exc )
804805
806+ _ProctolFactoryR = typing .TypeVar ("_ProctolFactoryR" , bound = asyncio .protocols .Protocol )
805807
806- async def _create_ssl_connection (protocol_factory , host , port , * ,
807- loop , ssl_context , ssl_is_advisory = False ):
808+
809+ async def _create_ssl_connection (
810+ # TODO: The return type is a specific combination of subclasses of asyncio.protocols.Protocol
811+ # that we can't express. For now, having the return type be dependent on signature of the
812+ # factory is an improvement
813+ protocol_factory : Callable [[], _ProctolFactoryR ],
814+ host : str ,
815+ port : int ,
816+ * ,
817+ loop : asyncio .AbstractEventLoop ,
818+ ssl_context : ssl_module .SSLContext ,
819+ ssl_is_advisory : bool = False ,
820+ ) -> typing .Tuple [asyncio .Transport , _ProctolFactoryR ]:
808821
809822 tr , pr = await loop .create_connection (
810823 lambda : TLSUpgradeProto (loop , host , port ,
@@ -824,6 +837,7 @@ async def _create_ssl_connection(protocol_factory, host, port, *,
824837 try :
825838 new_tr = await loop .start_tls (
826839 tr , pr , ssl_context , server_hostname = host )
840+ assert new_tr is not None
827841 except (Exception , asyncio .CancelledError ):
828842 tr .close ()
829843 raise
0 commit comments