diff --git a/src/Renci.SshNet/Abstractions/SocketAbstraction.cs b/src/Renci.SshNet/Abstractions/SocketAbstraction.cs index 69ec38b26..63dc2bf54 100644 --- a/src/Renci.SshNet/Abstractions/SocketAbstraction.cs +++ b/src/Renci.SshNet/Abstractions/SocketAbstraction.cs @@ -12,34 +12,6 @@ namespace Renci.SshNet.Abstractions { internal static partial class SocketAbstraction { - public static bool CanRead(Socket socket) - { - if (socket.Connected) - { - return socket.Poll(-1, SelectMode.SelectRead) && socket.Available > 0; - } - - return false; - } - - /// - /// Returns a value indicating whether the specified can be used - /// to send data. - /// - /// The to check. - /// - /// if can be written to; otherwise, . - /// - public static bool CanWrite(Socket socket) - { - if (socket != null && socket.Connected) - { - return socket.Poll(-1, SelectMode.SelectWrite); - } - - return false; - } - public static Socket Connect(IPEndPoint remoteEndpoint, TimeSpan connectTimeout) { var socket = new Socket(remoteEndpoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp) { NoDelay = true }; diff --git a/src/Renci.SshNet/Common/Extensions.cs b/src/Renci.SshNet/Common/Extensions.cs index b7a97d067..6cc65a779 100644 --- a/src/Renci.SshNet/Common/Extensions.cs +++ b/src/Renci.SshNet/Common/Extensions.cs @@ -10,7 +10,6 @@ using System.Runtime.CompilerServices; using System.Threading; -using Renci.SshNet.Abstractions; using Renci.SshNet.Messages; namespace Renci.SshNet.Common @@ -319,16 +318,6 @@ public static byte[] Concat(this byte[] first, byte[] second) return concat; } - internal static bool CanRead(this Socket socket) - { - return SocketAbstraction.CanRead(socket); - } - - internal static bool CanWrite(this Socket socket) - { - return SocketAbstraction.CanWrite(socket); - } - internal static bool IsConnected(this Socket socket) { if (socket is null) diff --git a/src/Renci.SshNet/Session.cs b/src/Renci.SshNet/Session.cs index ec3eac878..77fe9d4c2 100644 --- a/src/Renci.SshNet/Session.cs +++ b/src/Renci.SshNet/Session.cs @@ -81,12 +81,6 @@ public sealed class Session : ISession private readonly ISocketFactory _socketFactory; private readonly ILogger _logger; - /// - /// Holds an object that is used to ensure only a single thread can read from - /// at any given time. - /// - private readonly Lock _socketReadLock = new Lock(); - /// /// Holds an object that is used to ensure only a single thread can write to /// at any given time. @@ -105,7 +99,7 @@ public sealed class Session : ISession /// This is also used to ensure that will not be disposed /// while performing a given operation or set of operations on . /// - private readonly SemaphoreSlim _socketDisposeLock = new SemaphoreSlim(1, 1); + private readonly Lock _socketDisposeLock = new Lock(); /// /// Holds an object that is used to ensure only a single thread can connect @@ -279,17 +273,11 @@ public bool IsConnected { get { - if (_disposed || _isDisconnectMessageSent || !_isAuthenticated) - { - return false; - } - - if (_messageListenerCompleted is null || _messageListenerCompleted.WaitOne(0)) - { - return false; - } - - return IsSocketConnected(); + return !_disposed && + !_isDisconnectMessageSent && + _isAuthenticated && + _messageListenerCompleted?.WaitOne(0) == false && + _socket.IsConnected(); } } @@ -1046,7 +1034,7 @@ internal void WaitOnHandle(WaitHandle waitHandle, TimeSpan timeout) /// The size of the packet exceeds the maximum size defined by the protocol. internal void SendMessage(Message message) { - if (!_socket.CanWrite()) + if (!_socket.IsConnected()) { throw new SshConnectionException("Client not connected."); } @@ -1161,9 +1149,7 @@ internal void SendMessage(Message message) /// private void SendPacket(byte[] packet, int offset, int length) { - _socketDisposeLock.Wait(); - - try + lock (_socketDisposeLock) { if (!_socket.IsConnected()) { @@ -1172,10 +1158,6 @@ private void SendPacket(byte[] packet, int offset, int length) SocketAbstraction.Send(_socket, packet, offset, length); } - finally - { - _ = _socketDisposeLock.Release(); - } } /// @@ -1259,76 +1241,70 @@ private Message ReceiveMessage(Socket socket) byte[] data; uint packetLength; - // avoid reading from socket while IsSocketConnected is attempting to determine whether the - // socket is still connected by invoking Socket.Poll(...) and subsequently verifying value of - // Socket.Available - lock (_socketReadLock) + // Read first block - which starts with the packet length + var firstBlock = new byte[blockSize]; + if (TrySocketRead(socket, firstBlock, 0, blockSize) == 0) { - // Read first block - which starts with the packet length - var firstBlock = new byte[blockSize]; - if (TrySocketRead(socket, firstBlock, 0, blockSize) == 0) - { - // connection with SSH server was closed - return null; - } + // connection with SSH server was closed + return null; + } - var plainFirstBlock = firstBlock; + var plainFirstBlock = firstBlock; - // First block is not encrypted in AES GCM mode. - if (_serverCipher is not null and not Security.Cryptography.Ciphers.AesGcmCipher) - { - _serverCipher.SetSequenceNumber(_inboundPacketSequence); + // First block is not encrypted in AES GCM mode. + if (_serverCipher is not null and not Security.Cryptography.Ciphers.AesGcmCipher) + { + _serverCipher.SetSequenceNumber(_inboundPacketSequence); - // First block is not encrypted in ETM mode. - if (_serverMac == null || !_serverEtm) - { - plainFirstBlock = _serverCipher.Decrypt(firstBlock); - } + // First block is not encrypted in ETM mode. + if (_serverMac == null || !_serverEtm) + { + plainFirstBlock = _serverCipher.Decrypt(firstBlock); } + } - packetLength = BinaryPrimitives.ReadUInt32BigEndian(plainFirstBlock); + packetLength = BinaryPrimitives.ReadUInt32BigEndian(plainFirstBlock); - // Test packet minimum and maximum boundaries - if (packetLength < Math.Max((byte)8, blockSize) - 4 || packetLength > MaximumSshPacketSize - 4) - { - throw new SshConnectionException(string.Format(CultureInfo.CurrentCulture, "Bad packet length: {0}.", packetLength), - DisconnectReason.ProtocolError); - } + // Test packet minimum and maximum boundaries + if (packetLength < Math.Max((byte)8, blockSize) - 4 || packetLength > MaximumSshPacketSize - 4) + { + throw new SshConnectionException(string.Format(CultureInfo.CurrentCulture, "Bad packet length: {0}.", packetLength), + DisconnectReason.ProtocolError); + } - // Determine the number of bytes left to read; We've already read "blockSize" bytes, but the - // "packet length" field itself - which is 4 bytes - is not included in the length of the packet - var bytesToRead = (int)(packetLength - (blockSize - packetLengthFieldLength)) + serverMacLength; - - // Construct buffer for holding the payload and the inbound packet sequence as we need both in order - // to generate the hash. - // - // The total length of the "data" buffer is an addition of: - // - inboundPacketSequenceLength (4 bytes) - // - packetLength - // - serverMacLength - // - // We include the inbound packet sequence to allow us to have the the full SSH packet in a single - // byte[] for the purpose of calculating the client hash. Room for the server MAC is foreseen - // to read the packet including server MAC in a single pass (except for the initial block). - data = new byte[bytesToRead + blockSize + inboundPacketSequenceLength]; - BinaryPrimitives.WriteUInt32BigEndian(data, _inboundPacketSequence); - - // Use raw packet length field to calculate the mac in AEAD mode. - if (_serverAead) - { - Buffer.BlockCopy(firstBlock, 0, data, inboundPacketSequenceLength, blockSize); - } - else - { - Buffer.BlockCopy(plainFirstBlock, 0, data, inboundPacketSequenceLength, blockSize); - } + // Determine the number of bytes left to read; We've already read "blockSize" bytes, but the + // "packet length" field itself - which is 4 bytes - is not included in the length of the packet + var bytesToRead = (int)(packetLength - (blockSize - packetLengthFieldLength)) + serverMacLength; + + // Construct buffer for holding the payload and the inbound packet sequence as we need both in order + // to generate the hash. + // + // The total length of the "data" buffer is an addition of: + // - inboundPacketSequenceLength (4 bytes) + // - packetLength + // - serverMacLength + // + // We include the inbound packet sequence to allow us to have the the full SSH packet in a single + // byte[] for the purpose of calculating the client hash. Room for the server MAC is foreseen + // to read the packet including server MAC in a single pass (except for the initial block). + data = new byte[bytesToRead + blockSize + inboundPacketSequenceLength]; + BinaryPrimitives.WriteUInt32BigEndian(data, _inboundPacketSequence); + + // Use raw packet length field to calculate the mac in AEAD mode. + if (_serverAead) + { + Buffer.BlockCopy(firstBlock, 0, data, inboundPacketSequenceLength, blockSize); + } + else + { + Buffer.BlockCopy(plainFirstBlock, 0, data, inboundPacketSequenceLength, blockSize); + } - if (bytesToRead > 0) + if (bytesToRead > 0) + { + if (TrySocketRead(socket, data, blockSize + inboundPacketSequenceLength, bytesToRead) == 0) { - if (TrySocketRead(socket, data, blockSize + inboundPacketSequenceLength, bytesToRead) == 0) - { - return null; - } + return null; } } @@ -1888,84 +1864,6 @@ private static string ToHex(byte[] bytes) #endif } - /// - /// Gets a value indicating whether the socket is connected. - /// - /// - /// if the socket is connected; otherwise, . - /// - /// - /// - /// As a first check we verify whether is - /// . However, this only returns the state of the socket as of - /// the last I/O operation. - /// - /// - /// Therefore we use the combination of with mode - /// and to verify if the socket is still connected. - /// - /// - /// The MSDN doc mention the following on the return value of - /// with mode : - /// - /// - /// if data is available for reading; - /// - /// - /// if the connection has been closed, reset, or terminated; otherwise, returns . - /// - /// - /// - /// - /// Conclusion: when the return value is - but no data is available for reading - then - /// the socket is no longer connected. - /// - /// - /// When a is used from multiple threads, there's a race condition - /// between the invocation of and the moment - /// when the value of is obtained. To workaround this issue - /// we synchronize reads from the . - /// - /// - /// We assume the socket is still connected if the read lock cannot be acquired immediately. - /// In this case, we just return without actually waiting to acquire - /// the lock. We don't want to wait for the read lock if another thread already has it because - /// there are cases where the other thread holding the lock can be waiting indefinitely for - /// a socket read operation to complete. - /// - /// - private bool IsSocketConnected() - { - _socketDisposeLock.Wait(); - - try - { - if (!_socket.IsConnected()) - { - return false; - } - - if (!_socketReadLock.TryEnter()) - { - return true; - } - - try - { - var connectionClosedOrDataAvailable = _socket.Poll(0, SelectMode.SelectRead); - return !(connectionClosedOrDataAvailable && _socket.Available == 0); - } - finally - { - _socketReadLock.Exit(); - } - } - finally - { - _ = _socketDisposeLock.Release(); - } - } - /// /// Performs a blocking read on the socket until bytes are received. /// @@ -1988,46 +1886,37 @@ private static int TrySocketRead(Socket socket, byte[] buffer, int offset, int l /// private void SocketDisconnectAndDispose() { - if (_socket != null) + lock (_socketDisposeLock) { - _socketDisposeLock.Wait(); + if (_socket is null) + { + return; + } - try + if (_socket.Connected) { -#pragma warning disable CA1508 // Avoid dead conditional code; Value could have been changed by another thread. - if (_socket != null) -#pragma warning restore CA1508 // Avoid dead conditional code + try { - if (_socket.Connected) - { - try - { - _logger.LogDebug("[{SessionId}] Shutting down socket.", SessionIdHex); - - // Interrupt any pending reads; should be done outside of socket read lock as we - // actually want shutdown the socket to make sure blocking reads are interrupted. - // - // This may result in a SocketException (eg. An existing connection was forcibly - // closed by the remote host) which we'll log and ignore as it means the socket - // was already shut down. - _socket.Shutdown(SocketShutdown.Send); - } - catch (SocketException ex) - { - _logger.LogInformation(ex, "Failure shutting down socket"); - } - } - - _logger.LogDebug("[{SessionId}] Disposing socket.", SessionIdHex); - _socket.Dispose(); - _logger.LogDebug("[{SessionId}] Disposed socket.", SessionIdHex); - _socket = null; + _logger.LogDebug("[{SessionId}] Shutting down socket.", SessionIdHex); + + // Interrupt any pending reads; should be done outside of socket read lock as we + // actually want shutdown the socket to make sure blocking reads are interrupted. + // + // This may result in a SocketException (eg. An existing connection was forcibly + // closed by the remote host) which we'll log and ignore as it means the socket + // was already shut down. + _socket.Shutdown(SocketShutdown.Both); + } + catch (SocketException ex) + { + _logger.LogInformation(ex, "Failure shutting down socket"); } } - finally - { - _ = _socketDisposeLock.Release(); - } + + _logger.LogDebug("[{SessionId}] Disposing socket.", SessionIdHex); + _socket.Dispose(); + _logger.LogDebug("[{SessionId}] Disposed socket.", SessionIdHex); + _socket = null; } } @@ -2048,25 +1937,6 @@ private void MessageListener() break; } - try - { - // Block until either data is available or the socket is closed - var connectionClosedOrDataAvailable = socket.Poll(-1, SelectMode.SelectRead); - if (connectionClosedOrDataAvailable && socket.Available == 0) - { - // connection with SSH server was closed or connection was reset - break; - } - } - catch (ObjectDisposedException) - { - // The socket was disposed by either: - // * a call to Disconnect() - // * a call to Dispose() - // * a SSH_MSG_DISCONNECT received from server - break; - } - var message = ReceiveMessage(socket); if (message is null) { @@ -2102,25 +1972,12 @@ private void MessageListener() /// The . private void RaiseError(Exception exp) { - var connectionException = exp as SshConnectionException; - _logger.LogInformation(exp, "[{SessionId}] Raised exception", SessionIdHex); - if (_isDisconnecting) + if (_isDisconnecting && exp is SshConnectionException or ObjectDisposedException) { - // a connection exception which is raised while isDisconnecting is normal and - // should be ignored - if (connectionException != null) - { - return; - } - - // any timeout while disconnecting can be caused by loss of connectivity - // altogether and should be ignored - if (exp is SocketException socketException && socketException.SocketErrorCode == SocketError.TimedOut) - { - return; - } + // Such an exception raised while isDisconnecting is expected and can be ignored. + return; } // "save" exception and set exception wait handle to ensure any waits are interrupted @@ -2129,10 +1986,10 @@ private void RaiseError(Exception exp) ErrorOccured?.Invoke(this, new ExceptionEventArgs(exp)); - if (connectionException != null) + if (exp is SshConnectionException connectionException) { _logger.LogInformation(exp, "[{SessionId}] Disconnecting after exception", SessionIdHex); - Disconnect(connectionException.DisconnectReason, exp.ToString()); + Disconnect(connectionException.DisconnectReason, exp.Message); } } diff --git a/test/Renci.SshNet.Tests/Classes/AbstractionsTest.cs b/test/Renci.SshNet.Tests/Classes/AbstractionsTest.cs index e79eb8561..c37b3dcf9 100644 --- a/test/Renci.SshNet.Tests/Classes/AbstractionsTest.cs +++ b/test/Renci.SshNet.Tests/Classes/AbstractionsTest.cs @@ -8,12 +8,6 @@ namespace Renci.SshNet.Tests.Classes [TestClass] public class AbstractionsTest { - [TestMethod] - public void SocketAbstraction_CanWrite_ShouldReturnFalseWhenSocketIsNull() - { - Assert.IsFalse(SocketAbstraction.CanWrite(null)); - } - [TestMethod] public void CryptoAbstraction_GenerateRandom_ShouldPerformNoOpWhenDataIsZeroLength() { diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ConnectionReset.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ConnectionReset.cs index 1b11e702d..a09999b36 100644 --- a/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ConnectionReset.cs +++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ConnectionReset.cs @@ -64,8 +64,6 @@ public void ErrorOccurredIsRaisedOnce() var connectionException = (SshConnectionException)exception; Assert.AreEqual(DisconnectReason.ConnectionLost, connectionException.DisconnectReason); - Assert.IsNull(connectionException.InnerException); - Assert.AreEqual("An established connection was aborted by the server.", connectionException.Message); } [TestMethod] @@ -137,45 +135,29 @@ public void ISession_TrySendMessageShouldReturnFalse() public void ISession_WaitOnHandle_WaitHandle_ShouldThrowSshConnectionException() { var session = (ISession)Session; - var waitHandle = new ManualResetEvent(false); + using var waitHandle = new ManualResetEvent(false); - try - { - session.WaitOnHandle(waitHandle); - Assert.Fail(); - } - catch (SshConnectionException ex) - { - Assert.AreEqual("An established connection was aborted by the server.", ex.Message); - Assert.IsNull(ex.InnerException); - Assert.AreEqual(DisconnectReason.ConnectionLost, ex.DisconnectReason); - } + var ex = Assert.ThrowsExactly(() => session.WaitOnHandle(waitHandle)); + + Assert.AreEqual(DisconnectReason.ConnectionLost, ex.DisconnectReason); } [TestMethod] public void ISession_WaitOnHandle_WaitHandleAndTimeout_ShouldThrowSshConnectionException() { var session = (ISession)Session; - var waitHandle = new ManualResetEvent(false); + using var waitHandle = new ManualResetEvent(false); - try - { - session.WaitOnHandle(waitHandle, Timeout.InfiniteTimeSpan); - Assert.Fail(); - } - catch (SshConnectionException ex) - { - Assert.AreEqual(DisconnectReason.ConnectionLost, ex.DisconnectReason); - Assert.IsNull(ex.InnerException); - Assert.AreEqual("An established connection was aborted by the server.", ex.Message); - } + var ex = Assert.ThrowsExactly(() => session.WaitOnHandle(waitHandle)); + + Assert.AreEqual(DisconnectReason.ConnectionLost, ex.DisconnectReason); } [TestMethod] public void ISession_TryWait_WaitHandleAndTimeout_ShouldReturnDisconnected() { var session = (ISession)Session; - var waitHandle = new ManualResetEvent(false); + using var waitHandle = new ManualResetEvent(false); var result = session.TryWait(waitHandle, Timeout.InfiniteTimeSpan); @@ -186,7 +168,7 @@ public void ISession_TryWait_WaitHandleAndTimeout_ShouldReturnDisconnected() public void ISession_TryWait_WaitHandleAndTimeoutAndException_ShouldReturnDisconnected() { var session = (ISession)Session; - var waitHandle = new ManualResetEvent(false); + using var waitHandle = new ManualResetEvent(false); var result = session.TryWait(waitHandle, Timeout.InfiniteTimeSpan, out var exception); diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerSendsBadPacket.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerSendsBadPacket.cs index 5dd53a292..61f1b494e 100644 --- a/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerSendsBadPacket.cs +++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerSendsBadPacket.cs @@ -91,13 +91,20 @@ public void DisposeShouldFinishImmediately() } [TestMethod] - public void ReceiveOnServerSocketShouldReturnZero() + public void ServerShouldBeDisconnected() { - var buffer = new byte[1]; + try + { + var buffer = new byte[1]; - var actual = ServerSocket.Receive(buffer, 0, buffer.Length, SocketFlags.None); + var actual = ServerSocket.Receive(buffer, 0, buffer.Length, SocketFlags.None); - Assert.AreEqual(0, actual); + Assert.AreEqual(0, actual); // FIN + } + catch (SocketException sx) + { + Assert.AreEqual(SocketError.ConnectionReset, sx.SocketErrorCode); // RST + } } [TestMethod]