diff --git a/src/Renci.SshNet/Abstractions/SocketAbstraction.cs b/src/Renci.SshNet/Abstractions/SocketAbstraction.cs
index 69ec38b26..63dc2bf54 100644
--- a/src/Renci.SshNet/Abstractions/SocketAbstraction.cs
+++ b/src/Renci.SshNet/Abstractions/SocketAbstraction.cs
@@ -12,34 +12,6 @@ namespace Renci.SshNet.Abstractions
{
internal static partial class SocketAbstraction
{
- public static bool CanRead(Socket socket)
- {
- if (socket.Connected)
- {
- return socket.Poll(-1, SelectMode.SelectRead) && socket.Available > 0;
- }
-
- return false;
- }
-
- ///
- /// Returns a value indicating whether the specified can be used
- /// to send data.
- ///
- /// The to check.
- ///
- /// if can be written to; otherwise, .
- ///
- public static bool CanWrite(Socket socket)
- {
- if (socket != null && socket.Connected)
- {
- return socket.Poll(-1, SelectMode.SelectWrite);
- }
-
- return false;
- }
-
public static Socket Connect(IPEndPoint remoteEndpoint, TimeSpan connectTimeout)
{
var socket = new Socket(remoteEndpoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp) { NoDelay = true };
diff --git a/src/Renci.SshNet/Common/Extensions.cs b/src/Renci.SshNet/Common/Extensions.cs
index b7a97d067..6cc65a779 100644
--- a/src/Renci.SshNet/Common/Extensions.cs
+++ b/src/Renci.SshNet/Common/Extensions.cs
@@ -10,7 +10,6 @@
using System.Runtime.CompilerServices;
using System.Threading;
-using Renci.SshNet.Abstractions;
using Renci.SshNet.Messages;
namespace Renci.SshNet.Common
@@ -319,16 +318,6 @@ public static byte[] Concat(this byte[] first, byte[] second)
return concat;
}
- internal static bool CanRead(this Socket socket)
- {
- return SocketAbstraction.CanRead(socket);
- }
-
- internal static bool CanWrite(this Socket socket)
- {
- return SocketAbstraction.CanWrite(socket);
- }
-
internal static bool IsConnected(this Socket socket)
{
if (socket is null)
diff --git a/src/Renci.SshNet/Session.cs b/src/Renci.SshNet/Session.cs
index ec3eac878..77fe9d4c2 100644
--- a/src/Renci.SshNet/Session.cs
+++ b/src/Renci.SshNet/Session.cs
@@ -81,12 +81,6 @@ public sealed class Session : ISession
private readonly ISocketFactory _socketFactory;
private readonly ILogger _logger;
- ///
- /// Holds an object that is used to ensure only a single thread can read from
- /// at any given time.
- ///
- private readonly Lock _socketReadLock = new Lock();
-
///
/// Holds an object that is used to ensure only a single thread can write to
/// at any given time.
@@ -105,7 +99,7 @@ public sealed class Session : ISession
/// This is also used to ensure that will not be disposed
/// while performing a given operation or set of operations on .
///
- private readonly SemaphoreSlim _socketDisposeLock = new SemaphoreSlim(1, 1);
+ private readonly Lock _socketDisposeLock = new Lock();
///
/// Holds an object that is used to ensure only a single thread can connect
@@ -279,17 +273,11 @@ public bool IsConnected
{
get
{
- if (_disposed || _isDisconnectMessageSent || !_isAuthenticated)
- {
- return false;
- }
-
- if (_messageListenerCompleted is null || _messageListenerCompleted.WaitOne(0))
- {
- return false;
- }
-
- return IsSocketConnected();
+ return !_disposed &&
+ !_isDisconnectMessageSent &&
+ _isAuthenticated &&
+ _messageListenerCompleted?.WaitOne(0) == false &&
+ _socket.IsConnected();
}
}
@@ -1046,7 +1034,7 @@ internal void WaitOnHandle(WaitHandle waitHandle, TimeSpan timeout)
/// The size of the packet exceeds the maximum size defined by the protocol.
internal void SendMessage(Message message)
{
- if (!_socket.CanWrite())
+ if (!_socket.IsConnected())
{
throw new SshConnectionException("Client not connected.");
}
@@ -1161,9 +1149,7 @@ internal void SendMessage(Message message)
///
private void SendPacket(byte[] packet, int offset, int length)
{
- _socketDisposeLock.Wait();
-
- try
+ lock (_socketDisposeLock)
{
if (!_socket.IsConnected())
{
@@ -1172,10 +1158,6 @@ private void SendPacket(byte[] packet, int offset, int length)
SocketAbstraction.Send(_socket, packet, offset, length);
}
- finally
- {
- _ = _socketDisposeLock.Release();
- }
}
///
@@ -1259,76 +1241,70 @@ private Message ReceiveMessage(Socket socket)
byte[] data;
uint packetLength;
- // avoid reading from socket while IsSocketConnected is attempting to determine whether the
- // socket is still connected by invoking Socket.Poll(...) and subsequently verifying value of
- // Socket.Available
- lock (_socketReadLock)
+ // Read first block - which starts with the packet length
+ var firstBlock = new byte[blockSize];
+ if (TrySocketRead(socket, firstBlock, 0, blockSize) == 0)
{
- // Read first block - which starts with the packet length
- var firstBlock = new byte[blockSize];
- if (TrySocketRead(socket, firstBlock, 0, blockSize) == 0)
- {
- // connection with SSH server was closed
- return null;
- }
+ // connection with SSH server was closed
+ return null;
+ }
- var plainFirstBlock = firstBlock;
+ var plainFirstBlock = firstBlock;
- // First block is not encrypted in AES GCM mode.
- if (_serverCipher is not null and not Security.Cryptography.Ciphers.AesGcmCipher)
- {
- _serverCipher.SetSequenceNumber(_inboundPacketSequence);
+ // First block is not encrypted in AES GCM mode.
+ if (_serverCipher is not null and not Security.Cryptography.Ciphers.AesGcmCipher)
+ {
+ _serverCipher.SetSequenceNumber(_inboundPacketSequence);
- // First block is not encrypted in ETM mode.
- if (_serverMac == null || !_serverEtm)
- {
- plainFirstBlock = _serverCipher.Decrypt(firstBlock);
- }
+ // First block is not encrypted in ETM mode.
+ if (_serverMac == null || !_serverEtm)
+ {
+ plainFirstBlock = _serverCipher.Decrypt(firstBlock);
}
+ }
- packetLength = BinaryPrimitives.ReadUInt32BigEndian(plainFirstBlock);
+ packetLength = BinaryPrimitives.ReadUInt32BigEndian(plainFirstBlock);
- // Test packet minimum and maximum boundaries
- if (packetLength < Math.Max((byte)8, blockSize) - 4 || packetLength > MaximumSshPacketSize - 4)
- {
- throw new SshConnectionException(string.Format(CultureInfo.CurrentCulture, "Bad packet length: {0}.", packetLength),
- DisconnectReason.ProtocolError);
- }
+ // Test packet minimum and maximum boundaries
+ if (packetLength < Math.Max((byte)8, blockSize) - 4 || packetLength > MaximumSshPacketSize - 4)
+ {
+ throw new SshConnectionException(string.Format(CultureInfo.CurrentCulture, "Bad packet length: {0}.", packetLength),
+ DisconnectReason.ProtocolError);
+ }
- // Determine the number of bytes left to read; We've already read "blockSize" bytes, but the
- // "packet length" field itself - which is 4 bytes - is not included in the length of the packet
- var bytesToRead = (int)(packetLength - (blockSize - packetLengthFieldLength)) + serverMacLength;
-
- // Construct buffer for holding the payload and the inbound packet sequence as we need both in order
- // to generate the hash.
- //
- // The total length of the "data" buffer is an addition of:
- // - inboundPacketSequenceLength (4 bytes)
- // - packetLength
- // - serverMacLength
- //
- // We include the inbound packet sequence to allow us to have the the full SSH packet in a single
- // byte[] for the purpose of calculating the client hash. Room for the server MAC is foreseen
- // to read the packet including server MAC in a single pass (except for the initial block).
- data = new byte[bytesToRead + blockSize + inboundPacketSequenceLength];
- BinaryPrimitives.WriteUInt32BigEndian(data, _inboundPacketSequence);
-
- // Use raw packet length field to calculate the mac in AEAD mode.
- if (_serverAead)
- {
- Buffer.BlockCopy(firstBlock, 0, data, inboundPacketSequenceLength, blockSize);
- }
- else
- {
- Buffer.BlockCopy(plainFirstBlock, 0, data, inboundPacketSequenceLength, blockSize);
- }
+ // Determine the number of bytes left to read; We've already read "blockSize" bytes, but the
+ // "packet length" field itself - which is 4 bytes - is not included in the length of the packet
+ var bytesToRead = (int)(packetLength - (blockSize - packetLengthFieldLength)) + serverMacLength;
+
+ // Construct buffer for holding the payload and the inbound packet sequence as we need both in order
+ // to generate the hash.
+ //
+ // The total length of the "data" buffer is an addition of:
+ // - inboundPacketSequenceLength (4 bytes)
+ // - packetLength
+ // - serverMacLength
+ //
+ // We include the inbound packet sequence to allow us to have the the full SSH packet in a single
+ // byte[] for the purpose of calculating the client hash. Room for the server MAC is foreseen
+ // to read the packet including server MAC in a single pass (except for the initial block).
+ data = new byte[bytesToRead + blockSize + inboundPacketSequenceLength];
+ BinaryPrimitives.WriteUInt32BigEndian(data, _inboundPacketSequence);
+
+ // Use raw packet length field to calculate the mac in AEAD mode.
+ if (_serverAead)
+ {
+ Buffer.BlockCopy(firstBlock, 0, data, inboundPacketSequenceLength, blockSize);
+ }
+ else
+ {
+ Buffer.BlockCopy(plainFirstBlock, 0, data, inboundPacketSequenceLength, blockSize);
+ }
- if (bytesToRead > 0)
+ if (bytesToRead > 0)
+ {
+ if (TrySocketRead(socket, data, blockSize + inboundPacketSequenceLength, bytesToRead) == 0)
{
- if (TrySocketRead(socket, data, blockSize + inboundPacketSequenceLength, bytesToRead) == 0)
- {
- return null;
- }
+ return null;
}
}
@@ -1888,84 +1864,6 @@ private static string ToHex(byte[] bytes)
#endif
}
- ///
- /// Gets a value indicating whether the socket is connected.
- ///
- ///
- /// if the socket is connected; otherwise, .
- ///
- ///
- ///
- /// As a first check we verify whether is
- /// . However, this only returns the state of the socket as of
- /// the last I/O operation.
- ///
- ///
- /// Therefore we use the combination of with mode
- /// and to verify if the socket is still connected.
- ///
- ///
- /// The MSDN doc mention the following on the return value of
- /// with mode :
- ///
- /// -
- /// if data is available for reading;
- ///
- /// -
- /// if the connection has been closed, reset, or terminated; otherwise, returns .
- ///
- ///
- ///
- ///
- /// Conclusion: when the return value is - but no data is available for reading - then
- /// the socket is no longer connected.
- ///
- ///
- /// When a is used from multiple threads, there's a race condition
- /// between the invocation of and the moment
- /// when the value of is obtained. To workaround this issue
- /// we synchronize reads from the .
- ///
- ///
- /// We assume the socket is still connected if the read lock cannot be acquired immediately.
- /// In this case, we just return without actually waiting to acquire
- /// the lock. We don't want to wait for the read lock if another thread already has it because
- /// there are cases where the other thread holding the lock can be waiting indefinitely for
- /// a socket read operation to complete.
- ///
- ///
- private bool IsSocketConnected()
- {
- _socketDisposeLock.Wait();
-
- try
- {
- if (!_socket.IsConnected())
- {
- return false;
- }
-
- if (!_socketReadLock.TryEnter())
- {
- return true;
- }
-
- try
- {
- var connectionClosedOrDataAvailable = _socket.Poll(0, SelectMode.SelectRead);
- return !(connectionClosedOrDataAvailable && _socket.Available == 0);
- }
- finally
- {
- _socketReadLock.Exit();
- }
- }
- finally
- {
- _ = _socketDisposeLock.Release();
- }
- }
-
///
/// Performs a blocking read on the socket until bytes are received.
///
@@ -1988,46 +1886,37 @@ private static int TrySocketRead(Socket socket, byte[] buffer, int offset, int l
///
private void SocketDisconnectAndDispose()
{
- if (_socket != null)
+ lock (_socketDisposeLock)
{
- _socketDisposeLock.Wait();
+ if (_socket is null)
+ {
+ return;
+ }
- try
+ if (_socket.Connected)
{
-#pragma warning disable CA1508 // Avoid dead conditional code; Value could have been changed by another thread.
- if (_socket != null)
-#pragma warning restore CA1508 // Avoid dead conditional code
+ try
{
- if (_socket.Connected)
- {
- try
- {
- _logger.LogDebug("[{SessionId}] Shutting down socket.", SessionIdHex);
-
- // Interrupt any pending reads; should be done outside of socket read lock as we
- // actually want shutdown the socket to make sure blocking reads are interrupted.
- //
- // This may result in a SocketException (eg. An existing connection was forcibly
- // closed by the remote host) which we'll log and ignore as it means the socket
- // was already shut down.
- _socket.Shutdown(SocketShutdown.Send);
- }
- catch (SocketException ex)
- {
- _logger.LogInformation(ex, "Failure shutting down socket");
- }
- }
-
- _logger.LogDebug("[{SessionId}] Disposing socket.", SessionIdHex);
- _socket.Dispose();
- _logger.LogDebug("[{SessionId}] Disposed socket.", SessionIdHex);
- _socket = null;
+ _logger.LogDebug("[{SessionId}] Shutting down socket.", SessionIdHex);
+
+ // Interrupt any pending reads; should be done outside of socket read lock as we
+ // actually want shutdown the socket to make sure blocking reads are interrupted.
+ //
+ // This may result in a SocketException (eg. An existing connection was forcibly
+ // closed by the remote host) which we'll log and ignore as it means the socket
+ // was already shut down.
+ _socket.Shutdown(SocketShutdown.Both);
+ }
+ catch (SocketException ex)
+ {
+ _logger.LogInformation(ex, "Failure shutting down socket");
}
}
- finally
- {
- _ = _socketDisposeLock.Release();
- }
+
+ _logger.LogDebug("[{SessionId}] Disposing socket.", SessionIdHex);
+ _socket.Dispose();
+ _logger.LogDebug("[{SessionId}] Disposed socket.", SessionIdHex);
+ _socket = null;
}
}
@@ -2048,25 +1937,6 @@ private void MessageListener()
break;
}
- try
- {
- // Block until either data is available or the socket is closed
- var connectionClosedOrDataAvailable = socket.Poll(-1, SelectMode.SelectRead);
- if (connectionClosedOrDataAvailable && socket.Available == 0)
- {
- // connection with SSH server was closed or connection was reset
- break;
- }
- }
- catch (ObjectDisposedException)
- {
- // The socket was disposed by either:
- // * a call to Disconnect()
- // * a call to Dispose()
- // * a SSH_MSG_DISCONNECT received from server
- break;
- }
-
var message = ReceiveMessage(socket);
if (message is null)
{
@@ -2102,25 +1972,12 @@ private void MessageListener()
/// The .
private void RaiseError(Exception exp)
{
- var connectionException = exp as SshConnectionException;
-
_logger.LogInformation(exp, "[{SessionId}] Raised exception", SessionIdHex);
- if (_isDisconnecting)
+ if (_isDisconnecting && exp is SshConnectionException or ObjectDisposedException)
{
- // a connection exception which is raised while isDisconnecting is normal and
- // should be ignored
- if (connectionException != null)
- {
- return;
- }
-
- // any timeout while disconnecting can be caused by loss of connectivity
- // altogether and should be ignored
- if (exp is SocketException socketException && socketException.SocketErrorCode == SocketError.TimedOut)
- {
- return;
- }
+ // Such an exception raised while isDisconnecting is expected and can be ignored.
+ return;
}
// "save" exception and set exception wait handle to ensure any waits are interrupted
@@ -2129,10 +1986,10 @@ private void RaiseError(Exception exp)
ErrorOccured?.Invoke(this, new ExceptionEventArgs(exp));
- if (connectionException != null)
+ if (exp is SshConnectionException connectionException)
{
_logger.LogInformation(exp, "[{SessionId}] Disconnecting after exception", SessionIdHex);
- Disconnect(connectionException.DisconnectReason, exp.ToString());
+ Disconnect(connectionException.DisconnectReason, exp.Message);
}
}
diff --git a/test/Renci.SshNet.Tests/Classes/AbstractionsTest.cs b/test/Renci.SshNet.Tests/Classes/AbstractionsTest.cs
index e79eb8561..c37b3dcf9 100644
--- a/test/Renci.SshNet.Tests/Classes/AbstractionsTest.cs
+++ b/test/Renci.SshNet.Tests/Classes/AbstractionsTest.cs
@@ -8,12 +8,6 @@ namespace Renci.SshNet.Tests.Classes
[TestClass]
public class AbstractionsTest
{
- [TestMethod]
- public void SocketAbstraction_CanWrite_ShouldReturnFalseWhenSocketIsNull()
- {
- Assert.IsFalse(SocketAbstraction.CanWrite(null));
- }
-
[TestMethod]
public void CryptoAbstraction_GenerateRandom_ShouldPerformNoOpWhenDataIsZeroLength()
{
diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ConnectionReset.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ConnectionReset.cs
index 1b11e702d..a09999b36 100644
--- a/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ConnectionReset.cs
+++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ConnectionReset.cs
@@ -64,8 +64,6 @@ public void ErrorOccurredIsRaisedOnce()
var connectionException = (SshConnectionException)exception;
Assert.AreEqual(DisconnectReason.ConnectionLost, connectionException.DisconnectReason);
- Assert.IsNull(connectionException.InnerException);
- Assert.AreEqual("An established connection was aborted by the server.", connectionException.Message);
}
[TestMethod]
@@ -137,45 +135,29 @@ public void ISession_TrySendMessageShouldReturnFalse()
public void ISession_WaitOnHandle_WaitHandle_ShouldThrowSshConnectionException()
{
var session = (ISession)Session;
- var waitHandle = new ManualResetEvent(false);
+ using var waitHandle = new ManualResetEvent(false);
- try
- {
- session.WaitOnHandle(waitHandle);
- Assert.Fail();
- }
- catch (SshConnectionException ex)
- {
- Assert.AreEqual("An established connection was aborted by the server.", ex.Message);
- Assert.IsNull(ex.InnerException);
- Assert.AreEqual(DisconnectReason.ConnectionLost, ex.DisconnectReason);
- }
+ var ex = Assert.ThrowsExactly(() => session.WaitOnHandle(waitHandle));
+
+ Assert.AreEqual(DisconnectReason.ConnectionLost, ex.DisconnectReason);
}
[TestMethod]
public void ISession_WaitOnHandle_WaitHandleAndTimeout_ShouldThrowSshConnectionException()
{
var session = (ISession)Session;
- var waitHandle = new ManualResetEvent(false);
+ using var waitHandle = new ManualResetEvent(false);
- try
- {
- session.WaitOnHandle(waitHandle, Timeout.InfiniteTimeSpan);
- Assert.Fail();
- }
- catch (SshConnectionException ex)
- {
- Assert.AreEqual(DisconnectReason.ConnectionLost, ex.DisconnectReason);
- Assert.IsNull(ex.InnerException);
- Assert.AreEqual("An established connection was aborted by the server.", ex.Message);
- }
+ var ex = Assert.ThrowsExactly(() => session.WaitOnHandle(waitHandle));
+
+ Assert.AreEqual(DisconnectReason.ConnectionLost, ex.DisconnectReason);
}
[TestMethod]
public void ISession_TryWait_WaitHandleAndTimeout_ShouldReturnDisconnected()
{
var session = (ISession)Session;
- var waitHandle = new ManualResetEvent(false);
+ using var waitHandle = new ManualResetEvent(false);
var result = session.TryWait(waitHandle, Timeout.InfiniteTimeSpan);
@@ -186,7 +168,7 @@ public void ISession_TryWait_WaitHandleAndTimeout_ShouldReturnDisconnected()
public void ISession_TryWait_WaitHandleAndTimeoutAndException_ShouldReturnDisconnected()
{
var session = (ISession)Session;
- var waitHandle = new ManualResetEvent(false);
+ using var waitHandle = new ManualResetEvent(false);
var result = session.TryWait(waitHandle, Timeout.InfiniteTimeSpan, out var exception);
diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerSendsBadPacket.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerSendsBadPacket.cs
index 5dd53a292..61f1b494e 100644
--- a/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerSendsBadPacket.cs
+++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerSendsBadPacket.cs
@@ -91,13 +91,20 @@ public void DisposeShouldFinishImmediately()
}
[TestMethod]
- public void ReceiveOnServerSocketShouldReturnZero()
+ public void ServerShouldBeDisconnected()
{
- var buffer = new byte[1];
+ try
+ {
+ var buffer = new byte[1];
- var actual = ServerSocket.Receive(buffer, 0, buffer.Length, SocketFlags.None);
+ var actual = ServerSocket.Receive(buffer, 0, buffer.Length, SocketFlags.None);
- Assert.AreEqual(0, actual);
+ Assert.AreEqual(0, actual); // FIN
+ }
+ catch (SocketException sx)
+ {
+ Assert.AreEqual(SocketError.ConnectionReset, sx.SocketErrorCode); // RST
+ }
}
[TestMethod]