Skip to content

Commit aeb26d9

Browse files
author
Sabrina Dubroca
committed
net/tls: Use RCU API to access tls_ctx->netdev
Tested: selftests Bugzilla: https://bugzilla.redhat.com/show_bug.cgi?id=2143700 Conflicts: skipped the mlx5e_ktls_handle_tx_skb bits, we don't have that code yet, it will come through the driver rebase commit 94ce3b6 Author: Maxim Mikityanskiy <[email protected]> Date: Wed Aug 10 11:16:02 2022 +0300 net/tls: Use RCU API to access tls_ctx->netdev Currently, tls_device_down synchronizes with tls_device_resync_rx using RCU, however, the pointer to netdev is stored using WRITE_ONCE and loaded using READ_ONCE. Although such approach is technically correct (rcu_dereference is essentially a READ_ONCE, and rcu_assign_pointer uses WRITE_ONCE to store NULL), using special RCU helpers for pointers is more valid, as it includes additional checks and might change the implementation transparently to the callers. Mark the netdev pointer as __rcu and use the correct RCU helpers to access it. For non-concurrent access pass the right conditions that guarantee safe access (locks taken, refcount value). Also use the correct helper in mlx5e, where even READ_ONCE was missing. The transition to RCU exposes existing issues, fixed by this commit: 1. bond_tls_device_xmit could read netdev twice, and it could become NULL the second time, after the NULL check passed. 2. Drivers shouldn't stop processing the last packet if tls_device_down just set netdev to NULL, before tls_dev_del was called. This prevents a possible packet drop when transitioning to the fallback software mode. Fixes: 89df6a8 ("net/bonding: Implement TLS TX device offload") Fixes: c55dcdd ("net/tls: Fix use-after-free after the TLS device goes down and up") Signed-off-by: Maxim Mikityanskiy <[email protected]> Link: https://lore.kernel.org/r/[email protected] Signed-off-by: Jakub Kicinski <[email protected]> Signed-off-by: Sabrina Dubroca <[email protected]>
1 parent 498c582 commit aeb26d9

File tree

5 files changed

+47
-14
lines changed

5 files changed

+47
-14
lines changed

drivers/net/bonding/bond_main.c

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5327,8 +5327,14 @@ static struct net_device *bond_sk_get_lower_dev(struct net_device *dev,
53275327
static netdev_tx_t bond_tls_device_xmit(struct bonding *bond, struct sk_buff *skb,
53285328
struct net_device *dev)
53295329
{
5330-
if (likely(bond_get_slave_by_dev(bond, tls_get_ctx(skb->sk)->netdev)))
5331-
return bond_dev_queue_xmit(bond, skb, tls_get_ctx(skb->sk)->netdev);
5330+
struct net_device *tls_netdev = rcu_dereference(tls_get_ctx(skb->sk)->netdev);
5331+
5332+
/* tls_netdev might become NULL, even if tls_is_sk_tx_device_offloaded
5333+
* was true, if tls_device_down is running in parallel, but it's OK,
5334+
* because bond_get_slave_by_dev has a NULL check.
5335+
*/
5336+
if (likely(bond_get_slave_by_dev(bond, tls_netdev)))
5337+
return bond_dev_queue_xmit(bond, skb, tls_netdev);
53325338
return bond_tx_drop(dev, skb);
53335339
}
53345340
#endif

drivers/net/ethernet/chelsio/inline_crypto/ch_ktls/chcr_ktls.c

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1932,6 +1932,7 @@ static int chcr_ktls_xmit(struct sk_buff *skb, struct net_device *dev)
19321932
int data_len, qidx, ret = 0, mss;
19331933
struct tls_record_info *record;
19341934
struct chcr_ktls_info *tx_info;
1935+
struct net_device *tls_netdev;
19351936
struct tls_context *tls_ctx;
19361937
struct sge_eth_txq *q;
19371938
struct adapter *adap;
@@ -1945,7 +1946,12 @@ static int chcr_ktls_xmit(struct sk_buff *skb, struct net_device *dev)
19451946
mss = skb_is_gso(skb) ? skb_shinfo(skb)->gso_size : data_len;
19461947

19471948
tls_ctx = tls_get_ctx(skb->sk);
1948-
if (unlikely(tls_ctx->netdev != dev))
1949+
tls_netdev = rcu_dereference_bh(tls_ctx->netdev);
1950+
/* Don't quit on NULL: if tls_device_down is running in parallel,
1951+
* netdev might become NULL, even if tls_is_sk_tx_device_offloaded was
1952+
* true. Rather continue processing this packet.
1953+
*/
1954+
if (unlikely(tls_netdev && tls_netdev != dev))
19491955
goto out;
19501956

19511957
tx_ctx = chcr_get_ktls_tx_context(tls_ctx);

include/net/tls.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ struct tls_context {
237237
void *priv_ctx_tx;
238238
void *priv_ctx_rx;
239239

240-
struct net_device *netdev;
240+
struct net_device __rcu *netdev;
241241

242242
/* rw cache line */
243243
struct cipher_context tx;

net/tls/tls_device.c

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,13 @@ static void tls_device_tx_del_task(struct work_struct *work)
7171
struct tls_offload_context_tx *offload_ctx =
7272
container_of(work, struct tls_offload_context_tx, destruct_work);
7373
struct tls_context *ctx = offload_ctx->ctx;
74-
struct net_device *netdev = ctx->netdev;
74+
struct net_device *netdev;
75+
76+
/* Safe, because this is the destroy flow, refcount is 0, so
77+
* tls_device_down can't store this field in parallel.
78+
*/
79+
netdev = rcu_dereference_protected(ctx->netdev,
80+
!refcount_read(&ctx->refcount));
7581

7682
netdev->tlsdev_ops->tls_dev_del(netdev, ctx, TLS_OFFLOAD_CTX_DIR_TX);
7783
dev_put(netdev);
@@ -81,6 +87,7 @@ static void tls_device_tx_del_task(struct work_struct *work)
8187

8288
static void tls_device_queue_ctx_destruction(struct tls_context *ctx)
8389
{
90+
struct net_device *netdev;
8491
unsigned long flags;
8592
bool async_cleanup;
8693

@@ -91,7 +98,14 @@ static void tls_device_queue_ctx_destruction(struct tls_context *ctx)
9198
}
9299

93100
list_del(&ctx->list); /* Remove from tls_device_list / tls_device_down_list */
94-
async_cleanup = ctx->netdev && ctx->tx_conf == TLS_HW;
101+
102+
/* Safe, because this is the destroy flow, refcount is 0, so
103+
* tls_device_down can't store this field in parallel.
104+
*/
105+
netdev = rcu_dereference_protected(ctx->netdev,
106+
!refcount_read(&ctx->refcount));
107+
108+
async_cleanup = netdev && ctx->tx_conf == TLS_HW;
95109
if (async_cleanup) {
96110
struct tls_offload_context_tx *offload_ctx = tls_offload_ctx_tx(ctx);
97111

@@ -229,7 +243,8 @@ static void tls_device_resync_tx(struct sock *sk, struct tls_context *tls_ctx,
229243

230244
trace_tls_device_tx_resync_send(sk, seq, rcd_sn);
231245
down_read(&device_offload_lock);
232-
netdev = tls_ctx->netdev;
246+
netdev = rcu_dereference_protected(tls_ctx->netdev,
247+
lockdep_is_held(&device_offload_lock));
233248
if (netdev)
234249
err = netdev->tlsdev_ops->tls_dev_resync(netdev, sk, seq,
235250
rcd_sn,
@@ -710,7 +725,7 @@ static void tls_device_resync_rx(struct tls_context *tls_ctx,
710725

711726
trace_tls_device_rx_resync_send(sk, seq, rcd_sn, rx_ctx->resync_type);
712727
rcu_read_lock();
713-
netdev = READ_ONCE(tls_ctx->netdev);
728+
netdev = rcu_dereference(tls_ctx->netdev);
714729
if (netdev)
715730
netdev->tlsdev_ops->tls_dev_resync(netdev, sk, seq, rcd_sn,
716731
TLS_OFFLOAD_CTX_DIR_RX);
@@ -1035,7 +1050,7 @@ static void tls_device_attach(struct tls_context *ctx, struct sock *sk,
10351050
if (sk->sk_destruct != tls_device_sk_destruct) {
10361051
refcount_set(&ctx->refcount, 1);
10371052
dev_hold(netdev);
1038-
ctx->netdev = netdev;
1053+
RCU_INIT_POINTER(ctx->netdev, netdev);
10391054
spin_lock_irq(&tls_device_lock);
10401055
list_add_tail(&ctx->list, &tls_device_list);
10411056
spin_unlock_irq(&tls_device_lock);
@@ -1306,7 +1321,8 @@ void tls_device_offload_cleanup_rx(struct sock *sk)
13061321
struct net_device *netdev;
13071322

13081323
down_read(&device_offload_lock);
1309-
netdev = tls_ctx->netdev;
1324+
netdev = rcu_dereference_protected(tls_ctx->netdev,
1325+
lockdep_is_held(&device_offload_lock));
13101326
if (!netdev)
13111327
goto out;
13121328

@@ -1315,7 +1331,7 @@ void tls_device_offload_cleanup_rx(struct sock *sk)
13151331

13161332
if (tls_ctx->tx_conf != TLS_HW) {
13171333
dev_put(netdev);
1318-
tls_ctx->netdev = NULL;
1334+
rcu_assign_pointer(tls_ctx->netdev, NULL);
13191335
} else {
13201336
set_bit(TLS_RX_DEV_CLOSED, &tls_ctx->flags);
13211337
}
@@ -1335,7 +1351,11 @@ static int tls_device_down(struct net_device *netdev)
13351351

13361352
spin_lock_irqsave(&tls_device_lock, flags);
13371353
list_for_each_entry_safe(ctx, tmp, &tls_device_list, list) {
1338-
if (ctx->netdev != netdev ||
1354+
struct net_device *ctx_netdev =
1355+
rcu_dereference_protected(ctx->netdev,
1356+
lockdep_is_held(&device_offload_lock));
1357+
1358+
if (ctx_netdev != netdev ||
13391359
!refcount_inc_not_zero(&ctx->refcount))
13401360
continue;
13411361

@@ -1352,7 +1372,7 @@ static int tls_device_down(struct net_device *netdev)
13521372
/* Stop the RX and TX resync.
13531373
* tls_dev_resync must not be called after tls_dev_del.
13541374
*/
1355-
WRITE_ONCE(ctx->netdev, NULL);
1375+
rcu_assign_pointer(ctx->netdev, NULL);
13561376

13571377
/* Start skipping the RX resync logic completely. */
13581378
set_bit(TLS_RX_DEV_DEGRADED, &ctx->flags);

net/tls/tls_device_fallback.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,8 @@ struct sk_buff *tls_validate_xmit_skb(struct sock *sk,
426426
struct net_device *dev,
427427
struct sk_buff *skb)
428428
{
429-
if (dev == tls_get_ctx(sk)->netdev || netif_is_bond_master(dev))
429+
if (dev == rcu_dereference_bh(tls_get_ctx(sk)->netdev) ||
430+
netif_is_bond_master(dev))
430431
return skb;
431432

432433
return tls_sw_fallback(sk, skb);

0 commit comments

Comments
 (0)