Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 0 additions & 28 deletions src/Renci.SshNet/Abstractions/SocketAbstraction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,6 @@ namespace Renci.SshNet.Abstractions
{
internal static partial class SocketAbstraction
{
public static bool CanRead(Socket socket)
{
if (socket.Connected)
{
return socket.Poll(-1, SelectMode.SelectRead) && socket.Available > 0;
}

return false;
}

/// <summary>
/// Returns a value indicating whether the specified <see cref="Socket"/> can be used
/// to send data.
/// </summary>
/// <param name="socket">The <see cref="Socket"/> to check.</param>
/// <returns>
/// <see langword="true"/> if <paramref name="socket"/> can be written to; otherwise, <see langword="false"/>.
/// </returns>
public static bool CanWrite(Socket socket)
{
if (socket != null && socket.Connected)
{
return socket.Poll(-1, SelectMode.SelectWrite);
}

return false;
}

public static Socket Connect(IPEndPoint remoteEndpoint, TimeSpan connectTimeout)
{
var socket = new Socket(remoteEndpoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp) { NoDelay = true };
Expand Down
11 changes: 0 additions & 11 deletions src/Renci.SshNet/Common/Extensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
using System.Runtime.CompilerServices;
using System.Threading;

using Renci.SshNet.Abstractions;
using Renci.SshNet.Messages;

namespace Renci.SshNet.Common
Expand Down Expand Up @@ -319,16 +318,6 @@ public static byte[] Concat(this byte[] first, byte[] second)
return concat;
}

internal static bool CanRead(this Socket socket)
{
return SocketAbstraction.CanRead(socket);
}

internal static bool CanWrite(this Socket socket)
{
return SocketAbstraction.CanWrite(socket);
}

internal static bool IsConnected(this Socket socket)
{
if (socket is null)
Expand Down
179 changes: 18 additions & 161 deletions src/Renci.SshNet/Session.cs
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,6 @@ public sealed class Session : ISession
private readonly ISocketFactory _socketFactory;
private readonly ILogger _logger;

/// <summary>
/// Holds an object that is used to ensure only a single thread can read from
/// <see cref="_socket"/> at any given time.
/// </summary>
private readonly Lock _socketReadLock = new Lock();
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no longer needed


/// <summary>
/// Holds an object that is used to ensure only a single thread can write to
/// <see cref="_socket"/> at any given time.
Expand All @@ -105,7 +99,7 @@ public sealed class Session : ISession
/// This is also used to ensure that <see cref="_socket"/> will not be disposed
/// while performing a given operation or set of operations on <see cref="_socket"/>.
/// </remarks>
private readonly SemaphoreSlim _socketDisposeLock = new SemaphoreSlim(1, 1);
private readonly Lock _socketDisposeLock = new Lock();

/// <summary>
/// Holds an object that is used to ensure only a single thread can connect
Expand Down Expand Up @@ -279,17 +273,11 @@ public bool IsConnected
{
get
{
if (_disposed || _isDisconnectMessageSent || !_isAuthenticated)
{
return false;
}

if (_messageListenerCompleted is null || _messageListenerCompleted.WaitOne(0))
{
return false;
}

return IsSocketConnected();
return !_disposed &&
!_isDisconnectMessageSent &&
_isAuthenticated &&
_messageListenerCompleted?.WaitOne(0) == false &&
_socket.IsConnected();
}
}

Expand Down Expand Up @@ -1046,7 +1034,7 @@ internal void WaitOnHandle(WaitHandle waitHandle, TimeSpan timeout)
/// <exception cref="InvalidOperationException">The size of the packet exceeds the maximum size defined by the protocol.</exception>
internal void SendMessage(Message message)
{
if (!_socket.CanWrite())
if (!_socket.IsConnected())
{
throw new SshConnectionException("Client not connected.");
}
Expand Down Expand Up @@ -1161,9 +1149,7 @@ internal void SendMessage(Message message)
/// </remarks>
private void SendPacket(byte[] packet, int offset, int length)
{
_socketDisposeLock.Wait();

try
lock (_socketDisposeLock)
{
if (!_socket.IsConnected())
{
Expand All @@ -1172,10 +1158,6 @@ private void SendPacket(byte[] packet, int offset, int length)

SocketAbstraction.Send(_socket, packet, offset, length);
}
finally
{
_ = _socketDisposeLock.Release();
}
}

/// <summary>
Expand Down Expand Up @@ -1259,11 +1241,6 @@ private Message ReceiveMessage(Socket socket)
byte[] data;
uint packetLength;

// avoid reading from socket while IsSocketConnected is attempting to determine whether the
// socket is still connected by invoking Socket.Poll(...) and subsequently verifying value of
// Socket.Available
lock (_socketReadLock)
{
// Read first block - which starts with the packet length
var firstBlock = new byte[blockSize];
if (TrySocketRead(socket, firstBlock, 0, blockSize) == 0)
Expand Down Expand Up @@ -1330,7 +1307,6 @@ private Message ReceiveMessage(Socket socket)
return null;
}
}
}

// validate encrypted message against MAC
if (_serverMac != null && _serverEtm)
Expand Down Expand Up @@ -1888,84 +1864,6 @@ private static string ToHex(byte[] bytes)
#endif
}

/// <summary>
/// Gets a value indicating whether the socket is connected.
/// </summary>
/// <returns>
/// <see langword="true"/> if the socket is connected; otherwise, <see langword="false"/>.
/// </returns>
/// <remarks>
/// <para>
/// As a first check we verify whether <see cref="Socket.Connected"/> is
/// <see langword="true"/>. However, this only returns the state of the socket as of
/// the last I/O operation.
/// </para>
/// <para>
/// Therefore we use the combination of <see cref="Socket.Poll(int, SelectMode)"/> with mode <see cref="SelectMode.SelectRead"/>
/// and <see cref="Socket.Available"/> to verify if the socket is still connected.
/// </para>
/// <para>
/// The MSDN doc mention the following on the return value of <see cref="Socket.Poll(int, SelectMode)"/>
/// with mode <see cref="SelectMode.SelectRead"/>:
/// <list type="bullet">
/// <item>
/// <description><see langword="true"/> if data is available for reading;</description>
/// </item>
/// <item>
/// <description><see langword="true"/> if the connection has been closed, reset, or terminated; otherwise, returns <see langword="false"/>.</description>
/// </item>
/// </list>
/// </para>
/// <para>
/// <c>Conclusion:</c> when the return value is <see langword="true"/> - but no data is available for reading - then
/// the socket is no longer connected.
/// </para>
/// <para>
/// When a <see cref="Socket"/> is used from multiple threads, there's a race condition
/// between the invocation of <see cref="Socket.Poll(int, SelectMode)"/> and the moment
/// when the value of <see cref="Socket.Available"/> is obtained. To workaround this issue
/// we synchronize reads from the <see cref="Socket"/>.
/// </para>
/// <para>
/// We assume the socket is still connected if the read lock cannot be acquired immediately.
/// In this case, we just return <see langword="true"/> without actually waiting to acquire
/// the lock. We don't want to wait for the read lock if another thread already has it because
/// there are cases where the other thread holding the lock can be waiting indefinitely for
/// a socket read operation to complete.
/// </para>
/// </remarks>
private bool IsSocketConnected()
{
_socketDisposeLock.Wait();

try
{
if (!_socket.IsConnected())
{
return false;
}

if (!_socketReadLock.TryEnter())
{
return true;
}

try
{
var connectionClosedOrDataAvailable = _socket.Poll(0, SelectMode.SelectRead);
return !(connectionClosedOrDataAvailable && _socket.Available == 0);
}
finally
{
_socketReadLock.Exit();
}
}
finally
{
_ = _socketDisposeLock.Release();
}
}

/// <summary>
/// Performs a blocking read on the socket until <paramref name="length"/> bytes are received.
/// </summary>
Expand All @@ -1988,16 +1886,13 @@ private static int TrySocketRead(Socket socket, byte[] buffer, int offset, int l
/// </summary>
private void SocketDisconnectAndDispose()
{
if (_socket != null)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a hot path, so simplify the locking and save some indentation

{
_socketDisposeLock.Wait();

try
lock (_socketDisposeLock)
{
#pragma warning disable CA1508 // Avoid dead conditional code; Value could have been changed by another thread.
if (_socket != null)
#pragma warning restore CA1508 // Avoid dead conditional code
if (_socket is null)
{
return;
}

if (_socket.Connected)
{
try
Expand All @@ -2010,7 +1905,7 @@ private void SocketDisconnectAndDispose()
// This may result in a SocketException (eg. An existing connection was forcibly
// closed by the remote host) which we'll log and ignore as it means the socket
// was already shut down.
_socket.Shutdown(SocketShutdown.Send);
_socket.Shutdown(SocketShutdown.Both);
}
catch (SocketException ex)
{
Expand All @@ -2024,12 +1919,6 @@ private void SocketDisconnectAndDispose()
_socket = null;
}
}
finally
{
_ = _socketDisposeLock.Release();
}
}
}

/// <summary>
/// Listens for incoming message from the server and handles them. This method run as a task on separate thread.
Expand All @@ -2048,25 +1937,6 @@ private void MessageListener()
break;
}

try
{
// Block until either data is available or the socket is closed
var connectionClosedOrDataAvailable = socket.Poll(-1, SelectMode.SelectRead);
if (connectionClosedOrDataAvailable && socket.Available == 0)
{
// connection with SSH server was closed or connection was reset
break;
}
}
catch (ObjectDisposedException)
{
// The socket was disposed by either:
// * a call to Disconnect()
// * a call to Dispose()
// * a SSH_MSG_DISCONNECT received from server
break;
}

var message = ReceiveMessage(socket);
if (message is null)
{
Expand Down Expand Up @@ -2102,37 +1972,24 @@ private void MessageListener()
/// <param name="exp">The <see cref="Exception"/>.</param>
private void RaiseError(Exception exp)
{
var connectionException = exp as SshConnectionException;

_logger.LogInformation(exp, "[{SessionId}] Raised exception", SessionIdHex);

if (_isDisconnecting)
{
// a connection exception which is raised while isDisconnecting is normal and
// should be ignored
if (connectionException != null)
{
return;
}

// any timeout while disconnecting can be caused by loss of connectivity
// altogether and should be ignored
if (exp is SocketException socketException && socketException.SocketErrorCode == SocketError.TimedOut)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dead code, we always wrap SocketException in SshConnection

if (_isDisconnecting && exp is SshConnectionException or ObjectDisposedException)
{
// Such an exception raised while isDisconnecting is expected and can be ignored.
return;
}
}

// "save" exception and set exception wait handle to ensure any waits are interrupted
_exception = exp;
_ = _exceptionWaitHandle.Set();

ErrorOccured?.Invoke(this, new ExceptionEventArgs(exp));

if (connectionException != null)
if (exp is SshConnectionException connectionException)
{
_logger.LogInformation(exp, "[{SessionId}] Disconnecting after exception", SessionIdHex);
Disconnect(connectionException.DisconnectReason, exp.ToString());
Disconnect(connectionException.DisconnectReason, exp.Message);
}
}

Expand Down
6 changes: 0 additions & 6 deletions test/Renci.SshNet.Tests/Classes/AbstractionsTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,6 @@ namespace Renci.SshNet.Tests.Classes
[TestClass]
public class AbstractionsTest
{
[TestMethod]
public void SocketAbstraction_CanWrite_ShouldReturnFalseWhenSocketIsNull()
{
Assert.IsFalse(SocketAbstraction.CanWrite(null));
}

[TestMethod]
public void CryptoAbstraction_GenerateRandom_ShouldPerformNoOpWhenDataIsZeroLength()
{
Expand Down
Loading