77
88import  asyncio 
99import  collections 
10+ from  collections .abc  import  Callable 
1011import  enum 
1112import  functools 
1213import  getpass 
@@ -802,9 +803,21 @@ def connection_lost(self, exc: typing.Optional[Exception]) -> None:
802803                exc  =  ConnectionError ('unexpected connection_lost() call' )
803804            self .on_data .set_exception (exc )
804805
806+ _ProctolFactoryR  =  typing .TypeVar ("_ProctolFactoryR" , bound = asyncio .protocols .Protocol )
805807
806- async  def  _create_ssl_connection (protocol_factory , host , port , * ,
807-                                  loop , ssl_context , ssl_is_advisory = False ):
808+ 
809+ async  def  _create_ssl_connection (
810+     # TODO: The return type is a specific combination of subclasses of asyncio.protocols.Protocol 
811+     # that we can't express. For now, having the return type be dependent on signature of the 
812+     # factory is an improvement 
813+     protocol_factory : Callable [[], _ProctolFactoryR ],
814+     host : str ,
815+     port : int ,
816+     * ,
817+     loop : asyncio .AbstractEventLoop ,
818+     ssl_context : ssl_module .SSLContext ,
819+     ssl_is_advisory : bool  =  False ,
820+ ) ->  typing .Tuple [asyncio .Transport , _ProctolFactoryR ]:
808821
809822    tr , pr  =  await  loop .create_connection (
810823        lambda : TLSUpgradeProto (loop , host , port ,
@@ -824,6 +837,7 @@ async def _create_ssl_connection(protocol_factory, host, port, *,
824837            try :
825838                new_tr  =  await  loop .start_tls (
826839                    tr , pr , ssl_context , server_hostname = host )
840+                 assert  new_tr  is  not None 
827841            except  (Exception , asyncio .CancelledError ):
828842                tr .close ()
829843                raise 
0 commit comments