Skip to content

Commit 94531cf

Browse files
Jiang-BDanakryiko
authored andcommitted
af_unix: Add unix_stream_proto for sockmap
Previously, sockmap for AF_UNIX protocol only supports dgram type. This patch add unix stream type support, which is similar to unix_dgram_proto. To support sockmap, dgram and stream cannot share the same unix_proto anymore, because they have different implementations, such as unhash for stream type (which will remove closed or disconnected sockets from the map), so rename unix_proto to unix_dgram_proto and add a new unix_stream_proto. Also implement stream related sockmap functions. And add dgram key words to those dgram specific functions. Signed-off-by: Jiang Wang <[email protected]> Signed-off-by: Andrii Nakryiko <[email protected]> Reviewed-by: Cong Wang <[email protected]> Acked-by: Jakub Sitnicki <[email protected]> Acked-by: John Fastabend <[email protected]> Link: https://lore.kernel.org/bpf/[email protected]
1 parent 77462de commit 94531cf

File tree

4 files changed

+148
-37
lines changed

4 files changed

+148
-37
lines changed

include/net/af_unix.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ long unix_outq_len(struct sock *sk);
8787

8888
int __unix_dgram_recvmsg(struct sock *sk, struct msghdr *msg, size_t size,
8989
int flags);
90+
int __unix_stream_recvmsg(struct sock *sk, struct msghdr *msg, size_t size,
91+
int flags);
9092
#ifdef CONFIG_SYSCTL
9193
int unix_sysctl_register(struct net *net);
9294
void unix_sysctl_unregister(struct net *net);
@@ -96,9 +98,11 @@ static inline void unix_sysctl_unregister(struct net *net) {}
9698
#endif
9799

98100
#ifdef CONFIG_BPF_SYSCALL
99-
extern struct proto unix_proto;
101+
extern struct proto unix_dgram_proto;
102+
extern struct proto unix_stream_proto;
100103

101-
int unix_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore);
104+
int unix_dgram_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore);
105+
int unix_stream_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore);
102106
void __init unix_bpf_build_proto(void);
103107
#else
104108
static inline void __init unix_bpf_build_proto(void)

net/core/sock_map.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1494,6 +1494,7 @@ void sock_map_unhash(struct sock *sk)
14941494
rcu_read_unlock();
14951495
saved_unhash(sk);
14961496
}
1497+
EXPORT_SYMBOL_GPL(sock_map_unhash);
14971498

14981499
void sock_map_close(struct sock *sk, long timeout)
14991500
{

net/unix/af_unix.c

Lines changed: 70 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -798,17 +798,35 @@ static void unix_close(struct sock *sk, long timeout)
798798
*/
799799
}
800800

801-
struct proto unix_proto = {
802-
.name = "UNIX",
801+
static void unix_unhash(struct sock *sk)
802+
{
803+
/* Nothing to do here, unix socket does not need a ->unhash().
804+
* This is merely for sockmap.
805+
*/
806+
}
807+
808+
struct proto unix_dgram_proto = {
809+
.name = "UNIX-DGRAM",
810+
.owner = THIS_MODULE,
811+
.obj_size = sizeof(struct unix_sock),
812+
.close = unix_close,
813+
#ifdef CONFIG_BPF_SYSCALL
814+
.psock_update_sk_prot = unix_dgram_bpf_update_proto,
815+
#endif
816+
};
817+
818+
struct proto unix_stream_proto = {
819+
.name = "UNIX-STREAM",
803820
.owner = THIS_MODULE,
804821
.obj_size = sizeof(struct unix_sock),
805822
.close = unix_close,
823+
.unhash = unix_unhash,
806824
#ifdef CONFIG_BPF_SYSCALL
807-
.psock_update_sk_prot = unix_bpf_update_proto,
825+
.psock_update_sk_prot = unix_stream_bpf_update_proto,
808826
#endif
809827
};
810828

811-
static struct sock *unix_create1(struct net *net, struct socket *sock, int kern)
829+
static struct sock *unix_create1(struct net *net, struct socket *sock, int kern, int type)
812830
{
813831
struct sock *sk = NULL;
814832
struct unix_sock *u;
@@ -817,7 +835,11 @@ static struct sock *unix_create1(struct net *net, struct socket *sock, int kern)
817835
if (atomic_long_read(&unix_nr_socks) > 2 * get_max_files())
818836
goto out;
819837

820-
sk = sk_alloc(net, PF_UNIX, GFP_KERNEL, &unix_proto, kern);
838+
if (type == SOCK_STREAM)
839+
sk = sk_alloc(net, PF_UNIX, GFP_KERNEL, &unix_stream_proto, kern);
840+
else /*dgram and seqpacket */
841+
sk = sk_alloc(net, PF_UNIX, GFP_KERNEL, &unix_dgram_proto, kern);
842+
821843
if (!sk)
822844
goto out;
823845

@@ -879,7 +901,7 @@ static int unix_create(struct net *net, struct socket *sock, int protocol,
879901
return -ESOCKTNOSUPPORT;
880902
}
881903

882-
return unix_create1(net, sock, kern) ? 0 : -ENOMEM;
904+
return unix_create1(net, sock, kern, sock->type) ? 0 : -ENOMEM;
883905
}
884906

885907
static int unix_release(struct socket *sock)
@@ -1293,7 +1315,7 @@ static int unix_stream_connect(struct socket *sock, struct sockaddr *uaddr,
12931315
err = -ENOMEM;
12941316

12951317
/* create new sock for complete connection */
1296-
newsk = unix_create1(sock_net(sk), NULL, 0);
1318+
newsk = unix_create1(sock_net(sk), NULL, 0, sock->type);
12971319
if (newsk == NULL)
12981320
goto out;
12991321

@@ -2323,8 +2345,10 @@ static int unix_dgram_recvmsg(struct socket *sock, struct msghdr *msg, size_t si
23232345
struct sock *sk = sock->sk;
23242346

23252347
#ifdef CONFIG_BPF_SYSCALL
2326-
if (sk->sk_prot != &unix_proto)
2327-
return sk->sk_prot->recvmsg(sk, msg, size, flags & MSG_DONTWAIT,
2348+
const struct proto *prot = READ_ONCE(sk->sk_prot);
2349+
2350+
if (prot != &unix_dgram_proto)
2351+
return prot->recvmsg(sk, msg, size, flags & MSG_DONTWAIT,
23282352
flags & ~MSG_DONTWAIT, NULL);
23292353
#endif
23302354
return __unix_dgram_recvmsg(sk, msg, size, flags);
@@ -2728,6 +2752,20 @@ static int unix_stream_read_actor(struct sk_buff *skb,
27282752
return ret ?: chunk;
27292753
}
27302754

2755+
int __unix_stream_recvmsg(struct sock *sk, struct msghdr *msg,
2756+
size_t size, int flags)
2757+
{
2758+
struct unix_stream_read_state state = {
2759+
.recv_actor = unix_stream_read_actor,
2760+
.socket = sk->sk_socket,
2761+
.msg = msg,
2762+
.size = size,
2763+
.flags = flags
2764+
};
2765+
2766+
return unix_stream_read_generic(&state, true);
2767+
}
2768+
27312769
static int unix_stream_recvmsg(struct socket *sock, struct msghdr *msg,
27322770
size_t size, int flags)
27332771
{
@@ -2739,6 +2777,14 @@ static int unix_stream_recvmsg(struct socket *sock, struct msghdr *msg,
27392777
.flags = flags
27402778
};
27412779

2780+
#ifdef CONFIG_BPF_SYSCALL
2781+
struct sock *sk = sock->sk;
2782+
const struct proto *prot = READ_ONCE(sk->sk_prot);
2783+
2784+
if (prot != &unix_stream_proto)
2785+
return prot->recvmsg(sk, msg, size, flags & MSG_DONTWAIT,
2786+
flags & ~MSG_DONTWAIT, NULL);
2787+
#endif
27422788
return unix_stream_read_generic(&state, true);
27432789
}
27442790

@@ -2799,7 +2845,9 @@ static int unix_shutdown(struct socket *sock, int mode)
27992845
(sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET)) {
28002846

28012847
int peer_mode = 0;
2848+
const struct proto *prot = READ_ONCE(other->sk_prot);
28022849

2850+
prot->unhash(other);
28032851
if (mode&RCV_SHUTDOWN)
28042852
peer_mode |= SEND_SHUTDOWN;
28052853
if (mode&SEND_SHUTDOWN)
@@ -2808,10 +2856,12 @@ static int unix_shutdown(struct socket *sock, int mode)
28082856
other->sk_shutdown |= peer_mode;
28092857
unix_state_unlock(other);
28102858
other->sk_state_change(other);
2811-
if (peer_mode == SHUTDOWN_MASK)
2859+
if (peer_mode == SHUTDOWN_MASK) {
28122860
sk_wake_async(other, SOCK_WAKE_WAITD, POLL_HUP);
2813-
else if (peer_mode & RCV_SHUTDOWN)
2861+
other->sk_state = TCP_CLOSE;
2862+
} else if (peer_mode & RCV_SHUTDOWN) {
28142863
sk_wake_async(other, SOCK_WAKE_WAITD, POLL_IN);
2864+
}
28152865
}
28162866
if (other)
28172867
sock_put(other);
@@ -3289,7 +3339,13 @@ static int __init af_unix_init(void)
32893339

32903340
BUILD_BUG_ON(sizeof(struct unix_skb_parms) > sizeof_field(struct sk_buff, cb));
32913341

3292-
rc = proto_register(&unix_proto, 1);
3342+
rc = proto_register(&unix_dgram_proto, 1);
3343+
if (rc != 0) {
3344+
pr_crit("%s: Cannot create unix_sock SLAB cache!\n", __func__);
3345+
goto out;
3346+
}
3347+
3348+
rc = proto_register(&unix_stream_proto, 1);
32933349
if (rc != 0) {
32943350
pr_crit("%s: Cannot create unix_sock SLAB cache!\n", __func__);
32953351
goto out;
@@ -3310,7 +3366,8 @@ static int __init af_unix_init(void)
33103366
static void __exit af_unix_exit(void)
33113367
{
33123368
sock_unregister(PF_UNIX);
3313-
proto_unregister(&unix_proto);
3369+
proto_unregister(&unix_dgram_proto);
3370+
proto_unregister(&unix_stream_proto);
33143371
unregister_pernet_subsys(&unix_net_ops);
33153372
}
33163373

net/unix/unix_bpf.c

Lines changed: 71 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -38,24 +38,33 @@ static int unix_msg_wait_data(struct sock *sk, struct sk_psock *psock,
3838
return ret;
3939
}
4040

41-
static int unix_dgram_bpf_recvmsg(struct sock *sk, struct msghdr *msg,
42-
size_t len, int nonblock, int flags,
43-
int *addr_len)
41+
static int __unix_recvmsg(struct sock *sk, struct msghdr *msg,
42+
size_t len, int flags)
43+
{
44+
if (sk->sk_type == SOCK_DGRAM)
45+
return __unix_dgram_recvmsg(sk, msg, len, flags);
46+
else
47+
return __unix_stream_recvmsg(sk, msg, len, flags);
48+
}
49+
50+
static int unix_bpf_recvmsg(struct sock *sk, struct msghdr *msg,
51+
size_t len, int nonblock, int flags,
52+
int *addr_len)
4453
{
4554
struct unix_sock *u = unix_sk(sk);
4655
struct sk_psock *psock;
4756
int copied;
4857

4958
psock = sk_psock_get(sk);
5059
if (unlikely(!psock))
51-
return __unix_dgram_recvmsg(sk, msg, len, flags);
60+
return __unix_recvmsg(sk, msg, len, flags);
5261

5362
mutex_lock(&u->iolock);
5463
if (!skb_queue_empty(&sk->sk_receive_queue) &&
5564
sk_psock_queue_empty(psock)) {
5665
mutex_unlock(&u->iolock);
5766
sk_psock_put(sk, psock);
58-
return __unix_dgram_recvmsg(sk, msg, len, flags);
67+
return __unix_recvmsg(sk, msg, len, flags);
5968
}
6069

6170
msg_bytes_ready:
@@ -71,7 +80,7 @@ static int unix_dgram_bpf_recvmsg(struct sock *sk, struct msghdr *msg,
7180
goto msg_bytes_ready;
7281
mutex_unlock(&u->iolock);
7382
sk_psock_put(sk, psock);
74-
return __unix_dgram_recvmsg(sk, msg, len, flags);
83+
return __unix_recvmsg(sk, msg, len, flags);
7584
}
7685
copied = -EAGAIN;
7786
}
@@ -80,30 +89,55 @@ static int unix_dgram_bpf_recvmsg(struct sock *sk, struct msghdr *msg,
8089
return copied;
8190
}
8291

83-
static struct proto *unix_prot_saved __read_mostly;
84-
static DEFINE_SPINLOCK(unix_prot_lock);
85-
static struct proto unix_bpf_prot;
92+
static struct proto *unix_dgram_prot_saved __read_mostly;
93+
static DEFINE_SPINLOCK(unix_dgram_prot_lock);
94+
static struct proto unix_dgram_bpf_prot;
95+
96+
static struct proto *unix_stream_prot_saved __read_mostly;
97+
static DEFINE_SPINLOCK(unix_stream_prot_lock);
98+
static struct proto unix_stream_bpf_prot;
8699

87-
static void unix_bpf_rebuild_protos(struct proto *prot, const struct proto *base)
100+
static void unix_dgram_bpf_rebuild_protos(struct proto *prot, const struct proto *base)
88101
{
89102
*prot = *base;
90103
prot->close = sock_map_close;
91-
prot->recvmsg = unix_dgram_bpf_recvmsg;
104+
prot->recvmsg = unix_bpf_recvmsg;
105+
}
106+
107+
static void unix_stream_bpf_rebuild_protos(struct proto *prot,
108+
const struct proto *base)
109+
{
110+
*prot = *base;
111+
prot->close = sock_map_close;
112+
prot->recvmsg = unix_bpf_recvmsg;
113+
prot->unhash = sock_map_unhash;
114+
}
115+
116+
static void unix_dgram_bpf_check_needs_rebuild(struct proto *ops)
117+
{
118+
if (unlikely(ops != smp_load_acquire(&unix_dgram_prot_saved))) {
119+
spin_lock_bh(&unix_dgram_prot_lock);
120+
if (likely(ops != unix_dgram_prot_saved)) {
121+
unix_dgram_bpf_rebuild_protos(&unix_dgram_bpf_prot, ops);
122+
smp_store_release(&unix_dgram_prot_saved, ops);
123+
}
124+
spin_unlock_bh(&unix_dgram_prot_lock);
125+
}
92126
}
93127

94-
static void unix_bpf_check_needs_rebuild(struct proto *ops)
128+
static void unix_stream_bpf_check_needs_rebuild(struct proto *ops)
95129
{
96-
if (unlikely(ops != smp_load_acquire(&unix_prot_saved))) {
97-
spin_lock_bh(&unix_prot_lock);
98-
if (likely(ops != unix_prot_saved)) {
99-
unix_bpf_rebuild_protos(&unix_bpf_prot, ops);
100-
smp_store_release(&unix_prot_saved, ops);
130+
if (unlikely(ops != smp_load_acquire(&unix_stream_prot_saved))) {
131+
spin_lock_bh(&unix_stream_prot_lock);
132+
if (likely(ops != unix_stream_prot_saved)) {
133+
unix_stream_bpf_rebuild_protos(&unix_stream_bpf_prot, ops);
134+
smp_store_release(&unix_stream_prot_saved, ops);
101135
}
102-
spin_unlock_bh(&unix_prot_lock);
136+
spin_unlock_bh(&unix_stream_prot_lock);
103137
}
104138
}
105139

106-
int unix_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
140+
int unix_dgram_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
107141
{
108142
if (sk->sk_type != SOCK_DGRAM)
109143
return -EOPNOTSUPP;
@@ -114,12 +148,27 @@ int unix_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
114148
return 0;
115149
}
116150

117-
unix_bpf_check_needs_rebuild(psock->sk_proto);
118-
WRITE_ONCE(sk->sk_prot, &unix_bpf_prot);
151+
unix_dgram_bpf_check_needs_rebuild(psock->sk_proto);
152+
WRITE_ONCE(sk->sk_prot, &unix_dgram_bpf_prot);
153+
return 0;
154+
}
155+
156+
int unix_stream_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
157+
{
158+
if (restore) {
159+
sk->sk_write_space = psock->saved_write_space;
160+
WRITE_ONCE(sk->sk_prot, psock->sk_proto);
161+
return 0;
162+
}
163+
164+
unix_stream_bpf_check_needs_rebuild(psock->sk_proto);
165+
WRITE_ONCE(sk->sk_prot, &unix_stream_bpf_prot);
119166
return 0;
120167
}
121168

122169
void __init unix_bpf_build_proto(void)
123170
{
124-
unix_bpf_rebuild_protos(&unix_bpf_prot, &unix_proto);
171+
unix_dgram_bpf_rebuild_protos(&unix_dgram_bpf_prot, &unix_dgram_proto);
172+
unix_stream_bpf_rebuild_protos(&unix_stream_bpf_prot, &unix_stream_proto);
173+
125174
}

0 commit comments

Comments
 (0)