From 00aa3e1382224d2f1c26f4f6326c73ca4c75ad55 Mon Sep 17 00:00:00 2001 From: Robert Hague Date: Fri, 15 Dec 2023 11:57:53 +0100 Subject: [PATCH 1/2] Send the client key exchange init in Connect --- src/Renci.SshNet/Security/IKeyExchange.cs | 5 ++- src/Renci.SshNet/Security/KeyExchange.cs | 13 +++--- .../Security/KeyExchangeDiffieHellman.cs | 10 ++--- ...changeDiffieHellmanGroupExchangeShaBase.cs | 10 ++--- .../KeyExchangeDiffieHellmanGroupShaBase.cs | 10 ++--- src/Renci.SshNet/Security/KeyExchangeEC.cs | 10 ++--- .../Security/KeyExchangeECCurve25519.cs | 10 ++--- src/Renci.SshNet/Security/KeyExchangeECDH.cs | 10 ++--- src/Renci.SshNet/Session.cs | 44 +++++++++++-------- .../Classes/SessionTest_ConnectedBase.cs | 20 ++++----- ...Connected_ServerAndClientDisconnectRace.cs | 16 +++---- 11 files changed, 71 insertions(+), 87 deletions(-) diff --git a/src/Renci.SshNet/Security/IKeyExchange.cs b/src/Renci.SshNet/Security/IKeyExchange.cs index f12a18322..7ffd2f465 100644 --- a/src/Renci.SshNet/Security/IKeyExchange.cs +++ b/src/Renci.SshNet/Security/IKeyExchange.cs @@ -38,8 +38,9 @@ public interface IKeyExchange : IDisposable /// Starts the key exchange algorithm. /// /// The session. - /// Key exchange init message. - void Start(Session session, KeyExchangeInitMessage message); + /// The key exchange init message received from the server. + /// Whether to send a key exchange init message in response. + void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage); /// /// Finishes the key exchange algorithm. diff --git a/src/Renci.SshNet/Security/KeyExchange.cs b/src/Renci.SshNet/Security/KeyExchange.cs index 44684a92e..f01a4b117 100644 --- a/src/Renci.SshNet/Security/KeyExchange.cs +++ b/src/Renci.SshNet/Security/KeyExchange.cs @@ -61,16 +61,15 @@ public byte[] ExchangeHash /// public event EventHandler HostKeyReceived; - /// - /// Starts key exchange algorithm. - /// - /// The session. - /// Key exchange init message. - public virtual void Start(Session session, KeyExchangeInitMessage message) + /// + public virtual void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage) { Session = session; - SendMessage(session.ClientInitMessage); + if (sendClientInitMessage) + { + SendMessage(session.ClientInitMessage); + } // Determine encryption algorithm var clientEncryptionAlgorithmName = (from b in session.ConnectionInfo.Encryptions.Keys diff --git a/src/Renci.SshNet/Security/KeyExchangeDiffieHellman.cs b/src/Renci.SshNet/Security/KeyExchangeDiffieHellman.cs index 4f31514a7..7dfc51e34 100644 --- a/src/Renci.SshNet/Security/KeyExchangeDiffieHellman.cs +++ b/src/Renci.SshNet/Security/KeyExchangeDiffieHellman.cs @@ -76,14 +76,10 @@ protected override bool ValidateExchangeHash() return ValidateExchangeHash(_hostKey, _signature); } - /// - /// Starts key exchange algorithm. - /// - /// The session. - /// Key exchange init message. - public override void Start(Session session, KeyExchangeInitMessage message) + /// + public override void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage) { - base.Start(session, message); + base.Start(session, message, sendClientInitMessage); _serverPayload = message.GetBytes(); _clientPayload = Session.ClientInitMessage.GetBytes(); diff --git a/src/Renci.SshNet/Security/KeyExchangeDiffieHellmanGroupExchangeShaBase.cs b/src/Renci.SshNet/Security/KeyExchangeDiffieHellmanGroupExchangeShaBase.cs index 93703ee8f..5774f2c34 100644 --- a/src/Renci.SshNet/Security/KeyExchangeDiffieHellmanGroupExchangeShaBase.cs +++ b/src/Renci.SshNet/Security/KeyExchangeDiffieHellmanGroupExchangeShaBase.cs @@ -39,14 +39,10 @@ protected override byte[] CalculateHash() return Hash(groupExchangeHashData.GetBytes()); } - /// - /// Starts key exchange algorithm. - /// - /// The session. - /// Key exchange init message. - public override void Start(Session session, KeyExchangeInitMessage message) + /// + public override void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage) { - base.Start(session, message); + base.Start(session, message, sendClientInitMessage); // Register SSH_MSG_KEX_DH_GEX_GROUP message Session.RegisterMessage("SSH_MSG_KEX_DH_GEX_GROUP"); diff --git a/src/Renci.SshNet/Security/KeyExchangeDiffieHellmanGroupShaBase.cs b/src/Renci.SshNet/Security/KeyExchangeDiffieHellmanGroupShaBase.cs index 63c2bba40..b0db30eaa 100644 --- a/src/Renci.SshNet/Security/KeyExchangeDiffieHellmanGroupShaBase.cs +++ b/src/Renci.SshNet/Security/KeyExchangeDiffieHellmanGroupShaBase.cs @@ -13,14 +13,10 @@ internal abstract class KeyExchangeDiffieHellmanGroupShaBase : KeyExchangeDiffie /// public abstract BigInteger GroupPrime { get; } - /// - /// Starts key exchange algorithm. - /// - /// The session. - /// Key exchange init message. - public override void Start(Session session, KeyExchangeInitMessage message) + /// + public override void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage) { - base.Start(session, message); + base.Start(session, message, sendClientInitMessage); Session.RegisterMessage("SSH_MSG_KEXDH_REPLY"); diff --git a/src/Renci.SshNet/Security/KeyExchangeEC.cs b/src/Renci.SshNet/Security/KeyExchangeEC.cs index 4368affbf..8bc61e7fc 100644 --- a/src/Renci.SshNet/Security/KeyExchangeEC.cs +++ b/src/Renci.SshNet/Security/KeyExchangeEC.cs @@ -78,14 +78,10 @@ protected override bool ValidateExchangeHash() return ValidateExchangeHash(_hostKey, _signature); } - /// - /// Starts key exchange algorithm. - /// - /// The session. - /// Key exchange init message. - public override void Start(Session session, KeyExchangeInitMessage message) + /// + public override void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage) { - base.Start(session, message); + base.Start(session, message, sendClientInitMessage); _serverPayload = message.GetBytes(); _clientPayload = Session.ClientInitMessage.GetBytes(); diff --git a/src/Renci.SshNet/Security/KeyExchangeECCurve25519.cs b/src/Renci.SshNet/Security/KeyExchangeECCurve25519.cs index 18443fe73..c6c060bab 100644 --- a/src/Renci.SshNet/Security/KeyExchangeECCurve25519.cs +++ b/src/Renci.SshNet/Security/KeyExchangeECCurve25519.cs @@ -29,14 +29,10 @@ protected override int HashSize get { return 256; } } - /// - /// Starts key exchange algorithm. - /// - /// The session. - /// Key exchange init message. - public override void Start(Session session, KeyExchangeInitMessage message) + /// + public override void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage) { - base.Start(session, message); + base.Start(session, message, sendClientInitMessage); Session.RegisterMessage("SSH_MSG_KEX_ECDH_REPLY"); diff --git a/src/Renci.SshNet/Security/KeyExchangeECDH.cs b/src/Renci.SshNet/Security/KeyExchangeECDH.cs index c3fc7bfe4..c756fb6cb 100644 --- a/src/Renci.SshNet/Security/KeyExchangeECDH.cs +++ b/src/Renci.SshNet/Security/KeyExchangeECDH.cs @@ -24,14 +24,10 @@ internal abstract class KeyExchangeECDH : KeyExchangeEC private ECDHCBasicAgreement _keyAgreement; private ECDomainParameters _domainParameters; - /// - /// Starts key exchange algorithm. - /// - /// The session. - /// Key exchange init message. - public override void Start(Session session, KeyExchangeInitMessage message) + /// + public override void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage) { - base.Start(session, message); + base.Start(session, message, sendClientInitMessage); Session.RegisterMessage("SSH_MSG_KEX_ECDH_REPLY"); diff --git a/src/Renci.SshNet/Session.cs b/src/Renci.SshNet/Session.cs index 5bf6d8eef..a57d5a1e7 100644 --- a/src/Renci.SshNet/Session.cs +++ b/src/Renci.SshNet/Session.cs @@ -160,12 +160,7 @@ public class Session : ISession /// /// WaitHandle to signal that key exchange was completed. /// - private EventWaitHandle _keyExchangeCompletedWaitHandle = new ManualResetEvent(initialState: false); - - /// - /// WaitHandle to signal that key exchange is in progress. - /// - private bool _keyExchangeInProgress; + private ManualResetEventSlim _keyExchangeCompletedWaitHandle = new ManualResetEventSlim(initialState: false); /// /// Exception that need to be thrown by waiting thread. @@ -643,6 +638,11 @@ public void Connect() // Some server implementations might sent this message first, prior to establishing encryption algorithm RegisterMessage("SSH_MSG_USERAUTH_BANNER"); + // Send our key exchange init. + // We need to do this before starting the message listener to avoid the case where we receive the server + // key exchange init and we continue the key exchange before having sent our own init. + SendMessage(ClientInitMessage); + // Mark the message listener threads as started _ = _messageListenerCompleted.Reset(); @@ -651,7 +651,7 @@ public void Connect() _ = ThreadAbstraction.ExecuteThreadLongRunning(MessageListener); // Wait for key exchange to be completed - WaitOnHandle(_keyExchangeCompletedWaitHandle); + WaitOnHandle(_keyExchangeCompletedWaitHandle.WaitHandle); // If sessionId is not set then its not connected if (SessionId is null) @@ -757,6 +757,11 @@ public async Task ConnectAsync(CancellationToken cancellationToken) // Some server implementations might sent this message first, prior to establishing encryption algorithm RegisterMessage("SSH_MSG_USERAUTH_BANNER"); + // Send our key exchange init. + // We need to do this before starting the message listener to avoid the case where we receive the server + // key exchange init and we continue the key exchange before having sent our own init. + SendMessage(ClientInitMessage); + // Mark the message listener threads as started _ = _messageListenerCompleted.Reset(); @@ -765,7 +770,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken) _ = ThreadAbstraction.ExecuteThreadLongRunning(MessageListener); // Wait for key exchange to be completed - WaitOnHandle(_keyExchangeCompletedWaitHandle); + WaitOnHandle(_keyExchangeCompletedWaitHandle.WaitHandle); // If sessionId is not set then its not connected if (SessionId is null) @@ -1046,10 +1051,10 @@ internal void SendMessage(Message message) throw new SshConnectionException("Client not connected."); } - if (_keyExchangeInProgress && message is not IKeyExchangedAllowed) + if (!_keyExchangeCompletedWaitHandle.IsSet && message is not IKeyExchangedAllowed) { // Wait for key exchange to be completed - WaitOnHandle(_keyExchangeCompletedWaitHandle); + WaitOnHandle(_keyExchangeCompletedWaitHandle.WaitHandle); } DiagnosticAbstraction.Log(string.Format("[{0}] Sending message '{1}' to server: '{2}'.", ToHex(SessionId), message.GetType().Name, message)); @@ -1394,9 +1399,15 @@ internal void OnKeyExchangeDhGroupExchangeReplyReceived(KeyExchangeDhGroupExchan /// message. internal void OnKeyExchangeInitReceived(KeyExchangeInitMessage message) { - _keyExchangeInProgress = true; + // If _keyExchangeCompletedWaitHandle is already set, then this is a key + // re-exchange initiated by the server, and we need to send our own init + // message. + // Otherwise, the wait handle is not set and this received init is part of the + // initial connection for which we have already sent our init, so we shouldn't + // send another one. + var sendClientInitMessage = _keyExchangeCompletedWaitHandle.IsSet; - _ = _keyExchangeCompletedWaitHandle.Reset(); + _keyExchangeCompletedWaitHandle.Reset(); // Disable messages that are not key exchange related _sshMessageFactory.DisableNonKeyExchangeMessages(); @@ -1411,7 +1422,7 @@ internal void OnKeyExchangeInitReceived(KeyExchangeInitMessage message) _keyExchange.HostKeyReceived += KeyExchange_HostKeyReceived; // Start the algorithm implementation - _keyExchange.Start(this, message); + _keyExchange.Start(this, message, sendClientInitMessage); KeyExchangeInitReceived?.Invoke(this, new MessageEventArgs(message)); } @@ -1477,9 +1488,7 @@ internal void OnNewKeysReceived(NewKeysMessage message) NewKeysReceived?.Invoke(this, new MessageEventArgs(message)); // Signal that key exchange completed - _ = _keyExchangeCompletedWaitHandle.Set(); - - _keyExchangeInProgress = false; + _keyExchangeCompletedWaitHandle.Set(); } /// @@ -1967,7 +1976,7 @@ private void RaiseError(Exception exp) private void Reset() { _ = _exceptionWaitHandle?.Reset(); - _ = _keyExchangeCompletedWaitHandle?.Reset(); + _keyExchangeCompletedWaitHandle?.Reset(); _ = _messageListenerCompleted?.Set(); SessionId = null; @@ -1975,7 +1984,6 @@ private void Reset() _isDisconnecting = false; _isAuthenticated = false; _exception = null; - _keyExchangeInProgress = false; } private static SshConnectionException CreateConnectionAbortedByServerException() diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs index 7fa1ac24e..cf4066b93 100644 --- a/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs +++ b/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs @@ -115,6 +115,15 @@ protected virtual void SetupData() var newKeysMessage = new NewKeysMessage(); var newKeys = newKeysMessage.GetPacket(8, null); _ = ServerSocket.Send(newKeys, 4, newKeys.Length - 4, SocketFlags.None); + + if (!_authenticationStarted) + { + var serviceAcceptMessage = ServiceAcceptMessageBuilder.Create(ServiceName.UserAuthentication) + .Build(); + _ = ServerSocket.Send(serviceAcceptMessage, 0, serviceAcceptMessage.Length, SocketFlags.None); + + _authenticationStarted = true; + } }; ServerListener = new AsyncSocketListener(_serverEndPoint) @@ -147,15 +156,6 @@ protected virtual void SetupData() ServerListener.BytesReceived += (received, socket) => { ServerBytesReceivedRegister.Add(received); - - if (!_authenticationStarted) - { - var serviceAcceptMessage = ServiceAcceptMessageBuilder.Create(ServiceName.UserAuthentication) - .Build(); - _ = ServerSocket.Send(serviceAcceptMessage, 0, serviceAcceptMessage.Length, SocketFlags.None); - - _authenticationStarted = true; - } }; ServerListener.Start(); @@ -187,7 +187,7 @@ private void SetupMocks() _ = ServiceFactoryMock.Setup(p => p.CreateKeyExchange(ConnectionInfo.KeyExchangeAlgorithms, new[] { _keyExchangeAlgorithm })).Returns(_keyExchangeMock.Object); _ = _keyExchangeMock.Setup(p => p.Name) .Returns(_keyExchangeAlgorithm); - _ = _keyExchangeMock.Setup(p => p.Start(Session, It.IsAny())); + _ = _keyExchangeMock.Setup(p => p.Start(Session, It.IsAny(), false)); _ = _keyExchangeMock.Setup(p => p.ExchangeHash) .Returns(SessionId); _ = _keyExchangeMock.Setup(p => p.CreateServerCipher()) diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerAndClientDisconnectRace.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerAndClientDisconnectRace.cs index 96797d727..11cda2d90 100644 --- a/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerAndClientDisconnectRace.cs +++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerAndClientDisconnectRace.cs @@ -89,6 +89,13 @@ protected virtual void SetupData() var newKeysMessage = new NewKeysMessage(); var newKeys = newKeysMessage.GetPacket(8, null); _ = ServerSocket.Send(newKeys, 4, newKeys.Length - 4, SocketFlags.None); + + if (!_authenticationStarted) + { + var serviceAcceptMessage = ServiceAcceptMessageBuilder.Create(ServiceName.UserAuthentication).Build(); + _ = ServerSocket.Send(serviceAcceptMessage, 0, serviceAcceptMessage.Length, SocketFlags.None); + _authenticationStarted = true; + } }; ServerListener = new AsyncSocketListener(_serverEndPoint); @@ -118,13 +125,6 @@ protected virtual void SetupData() ServerListener.BytesReceived += (received, socket) => { ServerBytesReceivedRegister.Add(received); - - if (!_authenticationStarted) - { - var serviceAcceptMessage =ServiceAcceptMessageBuilder.Create(ServiceName.UserAuthentication).Build(); - _ = ServerSocket.Send(serviceAcceptMessage, 0, serviceAcceptMessage.Length, SocketFlags.None); - _authenticationStarted = true; - } }; ServerListener.Start(); @@ -156,7 +156,7 @@ private void SetupMocks() .Returns(_keyExchangeMock.Object); _ = _keyExchangeMock.Setup(p => p.Name) .Returns(_keyExchangeAlgorithm); - _ = _keyExchangeMock.Setup(p => p.Start(Session, It.IsAny())); + _ = _keyExchangeMock.Setup(p => p.Start(Session, It.IsAny(), false)); _ = _keyExchangeMock.Setup(p => p.ExchangeHash) .Returns(SessionId); _ = _keyExchangeMock.Setup(p => p.CreateServerCipher()) From 89391f7516d834cb3dff9e91ff160887587f41c4 Mon Sep 17 00:00:00 2001 From: Robert Hague Date: Wed, 20 Dec 2023 17:32:53 +0100 Subject: [PATCH 2/2] Add a test --- .../Classes/SessionTest_ConnectedBase.cs | 65 ++++++++++++------- ...Test_Connected_ServerDoesNotSendKexInit.cs | 24 +++++++ 2 files changed, 67 insertions(+), 22 deletions(-) create mode 100644 test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerDoesNotSendKexInit.cs diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs index cf4066b93..6331f7b9c 100644 --- a/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs +++ b/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs @@ -49,6 +49,12 @@ public abstract class SessionTest_ConnectedBase internal SshIdentification ServerIdentification { get; set; } protected bool CallSessionConnectWhenArrange { get; set; } + /// + /// Should the "server" wait for the client kexinit before sending its own. + /// A regression test simulating e.g. cisco devices. + /// + protected bool WaitForClientKeyExchangeInit { get; set; } + [TestInitialize] public void Setup() { @@ -59,18 +65,18 @@ public void Setup() [TestCleanup] public void TearDown() { - if (ServerSocket != null) - { - ServerSocket.Dispose(); - ServerSocket = null; - } - if (ServerListener != null) { ServerListener.Dispose(); ServerListener = null; } + if (ServerSocket != null) + { + ServerSocket.Dispose(); + ServerSocket = null; + } + if (Session != null) { Session.Dispose(); @@ -134,34 +140,49 @@ protected virtual void SetupData() { ServerSocket = socket; - // Since we're mocking the protocol version exchange, we'll immediately stat KEX upon + // Since we're mocking the protocol version exchange, we'll immediately start KEX upon // having established the connection instead of when the client has been identified - var keyExchangeInitMessage = new KeyExchangeInitMessage - { - CompressionAlgorithmsClientToServer = new string[0], - CompressionAlgorithmsServerToClient = new string[0], - EncryptionAlgorithmsClientToServer = new string[0], - EncryptionAlgorithmsServerToClient = new string[0], - KeyExchangeAlgorithms = new[] { _keyExchangeAlgorithm }, - LanguagesClientToServer = new string[0], - LanguagesServerToClient = new string[0], - MacAlgorithmsClientToServer = new string[0], - MacAlgorithmsServerToClient = new string[0], - ServerHostKeyAlgorithms = new string[0] - }; - var keyExchangeInit = keyExchangeInitMessage.GetPacket(8, null); - _ = ServerSocket.Send(keyExchangeInit, 4, keyExchangeInit.Length - 4, SocketFlags.None); + if (!WaitForClientKeyExchangeInit) + { + SendKeyExchangeInit(); + } }; ServerListener.BytesReceived += (received, socket) => { ServerBytesReceivedRegister.Add(received); + + if (WaitForClientKeyExchangeInit && received.Length > 5 && received[5] == 20) + { + // This is the KEXINIT. Send one back. + SendKeyExchangeInit(); + WaitForClientKeyExchangeInit = false; + } }; ServerListener.Start(); ClientSocket = new DirectConnector(_socketFactory).Connect(ConnectionInfo); CallSessionConnectWhenArrange = true; + + void SendKeyExchangeInit() + { + var keyExchangeInitMessage = new KeyExchangeInitMessage + { + CompressionAlgorithmsClientToServer = new string[0], + CompressionAlgorithmsServerToClient = new string[0], + EncryptionAlgorithmsClientToServer = new string[0], + EncryptionAlgorithmsServerToClient = new string[0], + KeyExchangeAlgorithms = new[] { _keyExchangeAlgorithm }, + LanguagesClientToServer = new string[0], + LanguagesServerToClient = new string[0], + MacAlgorithmsClientToServer = new string[0], + MacAlgorithmsServerToClient = new string[0], + ServerHostKeyAlgorithms = new string[0] + }; + var keyExchangeInit = keyExchangeInitMessage.GetPacket(8, null); + _ = ServerSocket.Send(keyExchangeInit, 4, keyExchangeInit.Length - 4, SocketFlags.None); + } } private void CreateMocks() diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerDoesNotSendKexInit.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerDoesNotSendKexInit.cs new file mode 100644 index 000000000..44bfa74fd --- /dev/null +++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerDoesNotSendKexInit.cs @@ -0,0 +1,24 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Renci.SshNet.Tests.Classes +{ + [TestClass] + public class SessionTest_Connected_ServerDoesNotSendKexInit : SessionTest_ConnectedBase + { + protected override void SetupData() + { + WaitForClientKeyExchangeInit = true; + + base.SetupData(); + } + + protected override void Act() + { + } + + [TestMethod] + public void ConnectShouldSucceed() + { + } + } +}