@@ -335,8 +335,31 @@ def connection_made(prot, transport):
335335 def datagram_received (prot , data , addr ):
336336 asyncio .ensure_future (datagram_handler (prot .transport , data , addr , ** vars (self ), ** args ))
337337 return asyncio .get_event_loop ().create_datagram_endpoint (Protocol , local_addr = (self .host_name , self .port ))
338+ async def make_ssh_connect (self , ** kwargs ):
339+ if self .streams is None :
340+ self .streams = asyncio .get_event_loop ().create_future ()
341+ else :
342+ if not self .streams .done ():
343+ await self .streams
344+ return self .streams .result ()
345+ try :
346+ import asyncssh
347+ for s in ('read_' , 'read_n' , 'read_until' ):
348+ setattr (asyncssh .SSHReader , s , getattr (asyncio .StreamReader , s ))
349+ except Exception :
350+ raise Exception ('Missing library: "pip3 install asyncssh"' )
351+ username , password = self .auth .decode ().split (':' , 1 )
352+ if password .startswith (':' ):
353+ client_keys = [password [1 :]]
354+ password = None
355+ else :
356+ client_keys = None
357+ conn = await asyncssh .connect (host = self .host_name , port = self .port , x509_trusted_certs = None , known_hosts = None , username = username , password = password , client_keys = client_keys , keepalive_interval = 60 , ** kwargs )
358+ if not self .streams .done ():
359+ self .streams .set_result ((conn , None ))
360+ return conn , None
338361 async def open_connection (self , host , port , local_addr , lbind , timeout = SOCKET_TIMEOUT ):
339- if self .reuse or self . ssh :
362+ if self .reuse :
340363 if self .streams is None or self .streams .done () and (self .reuse and not self .handler ):
341364 self .streams = asyncio .get_event_loop ().create_future ()
342365 else :
@@ -352,22 +375,7 @@ async def open_connection(self, host, port, local_addr, lbind, timeout=SOCKET_TI
352375 raise Exception ('Unknown tunnel endpoint' )
353376 wait = asyncio .open_connection (host = host , port = port , local_addr = local_addr , family = family )
354377 elif self .ssh :
355- try :
356- import asyncssh
357- for s in ('read_' , 'read_n' , 'read_until' ):
358- setattr (asyncssh .SSHReader , s , getattr (asyncio .StreamReader , s ))
359- except Exception :
360- raise Exception ('Missing library: "pip3 install asyncssh"' )
361- username , password = self .auth .decode ().split (':' , 1 )
362- if password .startswith (':' ):
363- client_keys = [password [1 :]]
364- password = None
365- else :
366- client_keys = None
367- conn = await asyncssh .connect (host = self .host_name , port = self .port , local_addr = local_addr , family = family , x509_trusted_certs = None , known_hosts = None , username = username , password = password , client_keys = client_keys , keepalive_interval = 60 )
368- if not self .streams .done ():
369- self .streams .set_result ((conn , None ))
370- return conn , None
378+ wait = self .make_ssh_connect (local_addr = local_addr , family = family )
371379 elif self .backward :
372380 wait = self .backward .open_connection ()
373381 elif self .unix :
@@ -399,14 +407,7 @@ async def prepare_ciphers_and_headers(self, reader_remote, writer_remote, host,
399407 reader_remote , writer_remote = handler .connect (whost , wport )
400408 elif self .ssh :
401409 if self .relay .ssh :
402- import asyncssh
403- username , password = self .relay .auth .decode ().split (':' , 1 )
404- if password .startswith (':' ):
405- client_keys = [password [1 :]]
406- password = None
407- else :
408- client_keys = None
409- reader_remote , writer_remote = await asyncssh .connect (tunnel = reader_remote , host = self .host_name , port = self .port , x509_trusted_certs = None , known_hosts = None , username = username , password = password , client_keys = client_keys , keepalive_interval = 60 ), None
410+ reader_remote , writer_remote = await self .relay .make_ssh_connect (tunnel = reader_remote )
410411 else :
411412 reader_remote , writer_remote = await reader_remote .open_connection (whost , wport )
412413 else :
0 commit comments