@@ -211,22 +211,20 @@ private SslStreamSettings GetTlsStreamSettings(string kmsProvider)
211
211
212
212
private void ProcessNeedKmsState ( CryptContext context , CancellationToken cancellationToken )
213
213
{
214
- var requests = context . GetKmsMessageRequests ( ) ;
215
- foreach ( var request in requests )
214
+ while ( context . GetNextKmsMessageRequest ( ) is { } request )
216
215
{
217
216
SendKmsRequest ( request , cancellationToken ) ;
218
217
}
219
- requests . MarkDone ( ) ;
218
+ context . MarkKmsDone ( ) ;
220
219
}
221
220
222
221
private async Task ProcessNeedKmsStateAsync ( CryptContext context , CancellationToken cancellationToken )
223
222
{
224
- var requests = context . GetKmsMessageRequests ( ) ;
225
- foreach ( var request in requests )
223
+ while ( context . GetNextKmsMessageRequest ( ) is { } request )
226
224
{
227
225
await SendKmsRequestAsync ( request , cancellationToken ) . ConfigureAwait ( false ) ;
228
226
}
229
- requests . MarkDone ( ) ;
227
+ context . MarkKmsDone ( ) ;
230
228
}
231
229
232
230
private void ProcessNeedMongoKeysState ( CryptContext context , CancellationToken cancellationToken )
@@ -278,48 +276,90 @@ private static byte[] ProcessReadyState(CryptContext context)
278
276
279
277
private void SendKmsRequest ( KmsRequest request , CancellationToken cancellation )
280
278
{
281
- var endpoint = CreateKmsEndPoint ( request . Endpoint ) ;
282
-
283
- var tlsStreamSettings = GetTlsStreamSettings ( request . KmsProvider ) ;
284
- var sslStreamFactory = new SslStreamFactory ( tlsStreamSettings , _networkStreamFactory ) ;
285
- using ( var sslStream = sslStreamFactory . CreateStream ( endpoint , cancellation ) )
286
- using ( var binary = request . GetMessage ( ) )
279
+ try
287
280
{
281
+ var endpoint = CreateKmsEndPoint ( request . Endpoint ) ;
282
+
283
+ var tlsStreamSettings = GetTlsStreamSettings ( request . KmsProvider ) ;
284
+ var sslStreamFactory = new SslStreamFactory ( tlsStreamSettings , _networkStreamFactory ) ;
285
+ using var sslStream = sslStreamFactory . CreateStream ( endpoint , cancellation ) ;
286
+
287
+ var sleepMs = request . Sleep ;
288
+ if ( sleepMs > 0 )
289
+ {
290
+ Thread . Sleep ( sleepMs ) ;
291
+ }
292
+
293
+ using var binary = request . GetMessage ( ) ;
288
294
var requestBytes = binary . ToArray ( ) ;
289
295
sslStream . Write ( requestBytes , 0 , requestBytes . Length ) ;
290
296
291
297
while ( request . BytesNeeded > 0 )
292
298
{
293
299
var buffer = new byte [ request . BytesNeeded ] ; // BytesNeeded is the maximum number of bytes that libmongocrypt wants to receive.
294
300
var count = sslStream . Read ( buffer , 0 , buffer . Length ) ;
301
+
302
+ if ( count == 0 )
303
+ {
304
+ throw new IOException ( "Unexpected end of stream. No data was read from the SSL stream." ) ;
305
+ }
306
+
295
307
var responseBytes = new byte [ count ] ;
296
308
Buffer . BlockCopy ( buffer , 0 , responseBytes , 0 , count ) ;
297
309
request . Feed ( responseBytes ) ;
298
310
}
299
311
}
312
+ catch ( Exception ex ) when ( ex is IOException or SocketException )
313
+ {
314
+ if ( ! request . Fail ( ) )
315
+ {
316
+ throw ;
317
+ }
318
+ }
300
319
}
301
320
302
321
private async Task SendKmsRequestAsync ( KmsRequest request , CancellationToken cancellation )
303
322
{
304
- var endpoint = CreateKmsEndPoint ( request . Endpoint ) ;
305
-
306
- var tlsStreamSettings = GetTlsStreamSettings ( request . KmsProvider ) ;
307
- var sslStreamFactory = new SslStreamFactory ( tlsStreamSettings , _networkStreamFactory ) ;
308
- using ( var sslStream = await sslStreamFactory . CreateStreamAsync ( endpoint , cancellation ) . ConfigureAwait ( false ) )
309
- using ( var binary = request . GetMessage ( ) )
323
+ try
310
324
{
325
+ var endpoint = CreateKmsEndPoint ( request . Endpoint ) ;
326
+
327
+ var tlsStreamSettings = GetTlsStreamSettings ( request . KmsProvider ) ;
328
+ var sslStreamFactory = new SslStreamFactory ( tlsStreamSettings , _networkStreamFactory ) ;
329
+ using var sslStream = await sslStreamFactory . CreateStreamAsync ( endpoint , cancellation ) . ConfigureAwait ( false ) ;
330
+
331
+ var sleepMs = request . Sleep ;
332
+ if ( sleepMs > 0 )
333
+ {
334
+ await Task . Delay ( sleepMs , cancellation ) . ConfigureAwait ( false ) ;
335
+ }
336
+
337
+ using var binary = request . GetMessage ( ) ;
311
338
var requestBytes = binary . ToArray ( ) ;
312
339
await sslStream . WriteAsync ( requestBytes , 0 , requestBytes . Length ) . ConfigureAwait ( false ) ;
313
340
314
341
while ( request . BytesNeeded > 0 )
315
342
{
316
343
var buffer = new byte [ request . BytesNeeded ] ; // BytesNeeded is the maximum number of bytes that libmongocrypt wants to receive.
317
344
var count = await sslStream . ReadAsync ( buffer , 0 , buffer . Length ) . ConfigureAwait ( false ) ;
345
+
346
+ if ( count == 0 )
347
+ {
348
+ throw new IOException ( "Unexpected end of stream. No data was read from the SSL stream." ) ;
349
+ }
350
+
318
351
var responseBytes = new byte [ count ] ;
319
352
Buffer . BlockCopy ( buffer , 0 , responseBytes , 0 , count ) ;
320
353
request . Feed ( responseBytes ) ;
321
354
}
322
355
}
356
+ catch ( Exception ex ) when ( ex is IOException or SocketException )
357
+ {
358
+ if ( ! request . Fail ( ) )
359
+ {
360
+ throw ;
361
+ }
362
+ }
323
363
}
324
364
325
365
// nested type
0 commit comments