@@ -51,12 +51,9 @@ enum {
5151 TLSV6 ,
5252 TLS_NUM_PROTS ,
5353};
54-
5554enum {
5655 TLS_BASE ,
57- TLS_SW_TX ,
58- TLS_SW_RX ,
59- TLS_SW_RXTX ,
56+ TLS_SW ,
6057 TLS_HW_RECORD ,
6158 TLS_NUM_CONFIG ,
6259};
@@ -65,14 +62,14 @@ static struct proto *saved_tcpv6_prot;
6562static DEFINE_MUTEX (tcpv6_prot_mutex );
6663static LIST_HEAD (device_list );
6764static DEFINE_MUTEX (device_mutex );
68- static struct proto tls_prots [TLS_NUM_PROTS ][TLS_NUM_CONFIG ];
65+ static struct proto tls_prots [TLS_NUM_PROTS ][TLS_NUM_CONFIG ][ TLS_NUM_CONFIG ] ;
6966static struct proto_ops tls_sw_proto_ops ;
7067
71- static inline void update_sk_prot (struct sock * sk , struct tls_context * ctx )
68+ static void update_sk_prot (struct sock * sk , struct tls_context * ctx )
7269{
7370 int ip_ver = sk -> sk_family == AF_INET6 ? TLSV6 : TLSV4 ;
7471
75- sk -> sk_prot = & tls_prots [ip_ver ][ctx -> conf ];
72+ sk -> sk_prot = & tls_prots [ip_ver ][ctx -> tx_conf ][ ctx -> rx_conf ];
7673}
7774
7875int wait_on_pending_writer (struct sock * sk , long * timeo )
@@ -245,10 +242,10 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
245242 lock_sock (sk );
246243 sk_proto_close = ctx -> sk_proto_close ;
247244
248- if (ctx -> conf == TLS_HW_RECORD )
245+ if (ctx -> tx_conf == TLS_HW_RECORD && ctx -> rx_conf == TLS_HW_RECORD )
249246 goto skip_tx_cleanup ;
250247
251- if (ctx -> conf == TLS_BASE ) {
248+ if (ctx -> tx_conf == TLS_BASE && ctx -> rx_conf == TLS_BASE ) {
252249 kfree (ctx );
253250 ctx = NULL ;
254251 goto skip_tx_cleanup ;
@@ -270,15 +267,17 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
270267 }
271268 }
272269
273- kfree (ctx -> tx .rec_seq );
274- kfree (ctx -> tx .iv );
275- kfree (ctx -> rx .rec_seq );
276- kfree (ctx -> rx .iv );
270+ /* We need these for tls_sw_fallback handling of other packets */
271+ if (ctx -> tx_conf == TLS_SW ) {
272+ kfree (ctx -> tx .rec_seq );
273+ kfree (ctx -> tx .iv );
274+ tls_sw_free_resources_tx (sk );
275+ }
277276
278- if (ctx -> conf == TLS_SW_TX ||
279- ctx -> conf == TLS_SW_RX ||
280- ctx -> conf == TLS_SW_RXTX ) {
281- tls_sw_free_resources (sk );
277+ if (ctx -> rx_conf == TLS_SW ) {
278+ kfree ( ctx -> rx . rec_seq );
279+ kfree ( ctx -> rx . iv );
280+ tls_sw_free_resources_rx (sk );
282281 }
283282
284283skip_tx_cleanup :
@@ -287,7 +286,8 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
287286 /* free ctx for TLS_HW_RECORD, used by tcp_set_state
288287 * for sk->sk_prot->unhash [tls_hw_unhash]
289288 */
290- if (ctx && ctx -> conf == TLS_HW_RECORD )
289+ if (ctx && ctx -> tx_conf == TLS_HW_RECORD &&
290+ ctx -> rx_conf == TLS_HW_RECORD )
291291 kfree (ctx );
292292}
293293
@@ -441,25 +441,21 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
441441 goto err_crypto_info ;
442442 }
443443
444- /* currently SW is default, we will have ethtool in future */
445444 if (tx ) {
446445 rc = tls_set_sw_offload (sk , ctx , 1 );
447- if (ctx -> conf == TLS_SW_RX )
448- conf = TLS_SW_RXTX ;
449- else
450- conf = TLS_SW_TX ;
446+ conf = TLS_SW ;
451447 } else {
452448 rc = tls_set_sw_offload (sk , ctx , 0 );
453- if (ctx -> conf == TLS_SW_TX )
454- conf = TLS_SW_RXTX ;
455- else
456- conf = TLS_SW_RX ;
449+ conf = TLS_SW ;
457450 }
458451
459452 if (rc )
460453 goto err_crypto_info ;
461454
462- ctx -> conf = conf ;
455+ if (tx )
456+ ctx -> tx_conf = conf ;
457+ else
458+ ctx -> rx_conf = conf ;
463459 update_sk_prot (sk , ctx );
464460 if (tx ) {
465461 ctx -> sk_write_space = sk -> sk_write_space ;
@@ -535,7 +531,8 @@ static int tls_hw_prot(struct sock *sk)
535531 ctx -> hash = sk -> sk_prot -> hash ;
536532 ctx -> unhash = sk -> sk_prot -> unhash ;
537533 ctx -> sk_proto_close = sk -> sk_prot -> close ;
538- ctx -> conf = TLS_HW_RECORD ;
534+ ctx -> rx_conf = TLS_HW_RECORD ;
535+ ctx -> tx_conf = TLS_HW_RECORD ;
539536 update_sk_prot (sk , ctx );
540537 rc = 1 ;
541538 break ;
@@ -579,29 +576,30 @@ static int tls_hw_hash(struct sock *sk)
579576 return err ;
580577}
581578
582- static void build_protos (struct proto * prot , struct proto * base )
579+ static void build_protos (struct proto prot [TLS_NUM_CONFIG ][TLS_NUM_CONFIG ],
580+ struct proto * base )
583581{
584- prot [TLS_BASE ] = * base ;
585- prot [TLS_BASE ].setsockopt = tls_setsockopt ;
586- prot [TLS_BASE ].getsockopt = tls_getsockopt ;
587- prot [TLS_BASE ].close = tls_sk_proto_close ;
588-
589- prot [TLS_SW_TX ] = prot [TLS_BASE ];
590- prot [TLS_SW_TX ].sendmsg = tls_sw_sendmsg ;
591- prot [TLS_SW_TX ] .sendpage = tls_sw_sendpage ;
592-
593- prot [TLS_SW_RX ] = prot [TLS_BASE ];
594- prot [TLS_SW_RX ].recvmsg = tls_sw_recvmsg ;
595- prot [TLS_SW_RX ].close = tls_sk_proto_close ;
596-
597- prot [TLS_SW_RXTX ] = prot [TLS_SW_TX ];
598- prot [TLS_SW_RXTX ].recvmsg = tls_sw_recvmsg ;
599- prot [TLS_SW_RXTX ] .close = tls_sk_proto_close ;
600-
601- prot [TLS_HW_RECORD ] = * base ;
602- prot [TLS_HW_RECORD ].hash = tls_hw_hash ;
603- prot [TLS_HW_RECORD ].unhash = tls_hw_unhash ;
604- prot [TLS_HW_RECORD ].close = tls_sk_proto_close ;
582+ prot [TLS_BASE ][ TLS_BASE ] = * base ;
583+ prot [TLS_BASE ][ TLS_BASE ] .setsockopt = tls_setsockopt ;
584+ prot [TLS_BASE ][ TLS_BASE ] .getsockopt = tls_getsockopt ;
585+ prot [TLS_BASE ][ TLS_BASE ] .close = tls_sk_proto_close ;
586+
587+ prot [TLS_SW ][ TLS_BASE ] = prot [ TLS_BASE ] [TLS_BASE ];
588+ prot [TLS_SW ][ TLS_BASE ].sendmsg = tls_sw_sendmsg ;
589+ prot [TLS_SW ][ TLS_BASE ] .sendpage = tls_sw_sendpage ;
590+
591+ prot [TLS_BASE ][ TLS_SW ] = prot [ TLS_BASE ] [TLS_BASE ];
592+ prot [TLS_BASE ][ TLS_SW ].recvmsg = tls_sw_recvmsg ;
593+ prot [TLS_BASE ][ TLS_SW ].close = tls_sk_proto_close ;
594+
595+ prot [TLS_SW ][ TLS_SW ] = prot [TLS_SW ][ TLS_BASE ];
596+ prot [TLS_SW ][ TLS_SW ].recvmsg = tls_sw_recvmsg ;
597+ prot [TLS_SW ][ TLS_SW ] .close = tls_sk_proto_close ;
598+
599+ prot [TLS_HW_RECORD ][ TLS_HW_RECORD ] = * base ;
600+ prot [TLS_HW_RECORD ][ TLS_HW_RECORD ] .hash = tls_hw_hash ;
601+ prot [TLS_HW_RECORD ][ TLS_HW_RECORD ] .unhash = tls_hw_unhash ;
602+ prot [TLS_HW_RECORD ][ TLS_HW_RECORD ] .close = tls_sk_proto_close ;
605603}
606604
607605static int tls_init (struct sock * sk )
@@ -643,7 +641,8 @@ static int tls_init(struct sock *sk)
643641 mutex_unlock (& tcpv6_prot_mutex );
644642 }
645643
646- ctx -> conf = TLS_BASE ;
644+ ctx -> tx_conf = TLS_BASE ;
645+ ctx -> rx_conf = TLS_BASE ;
647646 update_sk_prot (sk , ctx );
648647out :
649648 return rc ;
0 commit comments