@@ -52,6 +52,7 @@ struct tls_decrypt_arg {
5252 struct_group (inargs ,
5353 bool zc ;
5454 bool async ;
55+ bool async_done ;
5556 u8 tail ;
5657 );
5758
@@ -274,22 +275,30 @@ static int tls_do_decryption(struct sock *sk,
274275 DEBUG_NET_WARN_ON_ONCE (atomic_read (& ctx -> decrypt_pending ) < 1 );
275276 atomic_inc (& ctx -> decrypt_pending );
276277 } else {
278+ DECLARE_CRYPTO_WAIT (wait );
279+
277280 aead_request_set_callback (aead_req ,
278281 CRYPTO_TFM_REQ_MAY_BACKLOG ,
279- crypto_req_done , & ctx -> async_wait );
282+ crypto_req_done , & wait );
283+ ret = crypto_aead_decrypt (aead_req );
284+ if (ret == - EINPROGRESS || ret == - EBUSY )
285+ ret = crypto_wait_req (ret , & wait );
286+ return ret ;
280287 }
281288
282289 ret = crypto_aead_decrypt (aead_req );
290+ if (ret == - EINPROGRESS )
291+ return 0 ;
292+
283293 if (ret == - EBUSY ) {
284294 ret = tls_decrypt_async_wait (ctx );
285- ret = ret ?: - EINPROGRESS ;
295+ darg -> async_done = true;
296+ /* all completions have run, we're not doing async anymore */
297+ darg -> async = false;
298+ return ret ;
286299 }
287- if (ret == - EINPROGRESS ) {
288- if (darg -> async )
289- return 0 ;
290300
291- ret = crypto_wait_req (ret , & ctx -> async_wait );
292- }
301+ atomic_dec (& ctx -> decrypt_pending );
293302 darg -> async = false;
294303
295304 return ret ;
@@ -1588,8 +1597,11 @@ static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov,
15881597 /* Prepare and submit AEAD request */
15891598 err = tls_do_decryption (sk , sgin , sgout , dctx -> iv ,
15901599 data_len + prot -> tail_size , aead_req , darg );
1591- if (err )
1600+ if (err ) {
1601+ if (darg -> async_done )
1602+ goto exit_free_skb ;
15921603 goto exit_free_pages ;
1604+ }
15931605
15941606 darg -> skb = clear_skb ?: tls_strp_msg (ctx );
15951607 clear_skb = NULL ;
@@ -1601,6 +1613,9 @@ static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov,
16011613 return err ;
16021614 }
16031615
1616+ if (unlikely (darg -> async_done ))
1617+ return 0 ;
1618+
16041619 if (prot -> tail_size )
16051620 darg -> tail = dctx -> tail ;
16061621
@@ -1948,6 +1963,7 @@ int tls_sw_recvmsg(struct sock *sk,
19481963 struct strp_msg * rxm ;
19491964 struct tls_msg * tlm ;
19501965 ssize_t copied = 0 ;
1966+ ssize_t peeked = 0 ;
19511967 bool async = false;
19521968 int target , err ;
19531969 bool is_kvec = iov_iter_is_kvec (& msg -> msg_iter );
@@ -2095,8 +2111,10 @@ int tls_sw_recvmsg(struct sock *sk,
20952111 if (err < 0 )
20962112 goto put_on_rx_list_err ;
20972113
2098- if (is_peek )
2114+ if (is_peek ) {
2115+ peeked += chunk ;
20992116 goto put_on_rx_list ;
2117+ }
21002118
21012119 if (partially_consumed ) {
21022120 rxm -> offset += chunk ;
@@ -2135,8 +2153,8 @@ int tls_sw_recvmsg(struct sock *sk,
21352153
21362154 /* Drain records from the rx_list & copy if required */
21372155 if (is_peek || is_kvec )
2138- err = process_rx_list (ctx , msg , & control , copied ,
2139- decrypted , is_peek , NULL );
2156+ err = process_rx_list (ctx , msg , & control , copied + peeked ,
2157+ decrypted - peeked , is_peek , NULL );
21402158 else
21412159 err = process_rx_list (ctx , msg , & control , 0 ,
21422160 async_copy_bytes , is_peek , NULL );
0 commit comments