diff --git a/include/net/net_pkt.h b/include/net/net_pkt.h index 73347c213d9f4..7803be9aab33a 100644 --- a/include/net/net_pkt.h +++ b/include/net/net_pkt.h @@ -1359,6 +1359,19 @@ void net_pkt_get_info(struct k_mem_slab **rx, struct net_buf_pool **rx_data, struct net_buf_pool **tx_data); +/** + * @brief Get source socket address. + * + * @param pkt Nework packet + * @param addr Source socket address + * @param addrlen The length of source socket address + * @return 0 on success, <0 otherwise. + */ + +int net_pkt_get_src_addr(struct net_pkt *pkt, + struct sockaddr *addr, + socklen_t addrlen); + #if defined(CONFIG_NET_DEBUG_NET_PKT) /** * @brief Debug helper to print out the buffer allocations diff --git a/include/net/socket.h b/include/net/socket.h index 9d344acf528c1..8335fdb386593 100644 --- a/include/net/socket.h +++ b/include/net/socket.h @@ -61,6 +61,10 @@ int zsock_listen(int sock, int backlog); int zsock_accept(int sock, struct sockaddr *addr, socklen_t *addrlen); ssize_t zsock_send(int sock, const void *buf, size_t len, int flags); ssize_t zsock_recv(int sock, void *buf, size_t max_len, int flags); +ssize_t zsock_sendto(int sock, const void *buf, size_t len, int flags, + const struct sockaddr *dest_addr, socklen_t addrlen); +ssize_t zsock_recvfrom(int sock, void *buf, size_t max_len, int flags, + struct sockaddr *src_addr, socklen_t *addrlen); int zsock_fcntl(int sock, int cmd, int flags); int zsock_poll(struct zsock_pollfd *fds, int nfds, int timeout); int zsock_inet_pton(sa_family_t family, const char *src, void *dst); @@ -78,6 +82,8 @@ int zsock_getaddrinfo(const char *host, const char *service, #define send zsock_send #define recv zsock_recv #define fcntl zsock_fcntl +#define sendto zsock_sendto +#define recvfrom zsock_recvfrom #define poll zsock_poll #define pollfd zsock_pollfd diff --git a/subsys/net/ip/net_context.c b/subsys/net/ip/net_context.c index 5ebccb67f893d..f33ddf6902abc 100644 --- a/subsys/net/ip/net_context.c +++ b/subsys/net/ip/net_context.c @@ -2073,6 +2073,16 @@ static int sendto(struct net_pkt *pkt, } #endif /* CONFIG_NET_TCP */ +#if defined(CONFIG_NET_UDP) + /* Bind default address and port only if UDP */ + if (net_context_get_ip_proto(context) == IPPROTO_UDP) { + ret = bind_default(context); + if (ret) { + return ret; + } + } +#endif /* CONFIG_NET_UDP */ + if (!dst_addr) { return -EDESTADDRREQ; } @@ -2312,6 +2322,11 @@ static int recv_udp(struct net_context *context, context->conn_handler = NULL; } + ret = bind_default(context); + if (ret) { + return ret; + } + #if defined(CONFIG_NET_IPV6) if (net_context_get_family(context) == AF_INET6) { if (net_sin6_ptr(&context->local)->sin6_addr) { diff --git a/subsys/net/ip/net_pkt.c b/subsys/net/ip/net_pkt.c index 1bb8d549bb316..f1905e44702bc 100644 --- a/subsys/net/ip/net_pkt.c +++ b/subsys/net/ip/net_pkt.c @@ -1654,6 +1654,61 @@ void net_pkt_get_info(struct k_mem_slab **rx, } } +int net_pkt_get_src_addr(struct net_pkt *pkt, struct sockaddr *addr, + socklen_t addrlen) +{ + enum net_ip_protocol proto; + sa_family_t family; + + if (!addr || !pkt) { + return -EINVAL; + } + + family = net_pkt_family(pkt); + + if (IS_ENABLED(CONFIG_NET_IPV6) && family == AF_INET6) { + struct sockaddr_in6 *addr6 = net_sin6(addr); + + if (addrlen < sizeof(struct sockaddr_in6)) { + return -EINVAL; + } + + net_ipaddr_copy(&addr6->sin6_addr, &NET_IPV6_HDR(pkt)->src); + proto = NET_IPV6_HDR(pkt)->nexthdr; + + if (IS_ENABLED(CONFIG_NET_TCP) && proto == IPPROTO_TCP) { + addr6->sin6_port = net_pkt_tcp_data(pkt)->src_port; + } else if (IS_ENABLED(CONFIG_NET_UDP) && proto == IPPROTO_UDP) { + addr6->sin6_port = net_pkt_udp_data(pkt)->src_port; + } else { + return -ENOTSUP; + } + + } else if (IS_ENABLED(CONFIG_NET_IPV4) && family == AF_INET) { + struct sockaddr_in *addr4 = net_sin(addr); + + if (addrlen < sizeof(struct sockaddr_in)) { + return -EINVAL; + } + + net_ipaddr_copy(&addr4->sin_addr, &NET_IPV4_HDR(pkt)->src); + proto = NET_IPV4_HDR(pkt)->proto; + + if (IS_ENABLED(CONFIG_NET_TCP) && proto == IPPROTO_TCP) { + addr4->sin_port = net_pkt_tcp_data(pkt)->src_port; + } else if (IS_ENABLED(CONFIG_NET_UDP) && proto == IPPROTO_UDP) { + addr4->sin_port = net_pkt_udp_data(pkt)->src_port; + } else { + return -ENOTSUP; + } + + } else { + return -ENOTSUP; + } + + return 0; +} + #if defined(CONFIG_NET_DEBUG_NET_PKT) void net_pkt_print(void) { diff --git a/subsys/net/lib/sockets/sockets.c b/subsys/net/lib/sockets/sockets.c index b0909ef0c93e9..4d57a3a8eb8ae 100644 --- a/subsys/net/lib/sockets/sockets.c +++ b/subsys/net/lib/sockets/sockets.c @@ -145,11 +145,12 @@ static void zsock_received_cb(struct net_context *ctx, struct net_pkt *pkt, /* Normal packet */ net_pkt_set_eof(pkt, false); - /* We don't care about packet header, so get rid of it asap */ - header_len = net_pkt_appdata(pkt) - pkt->frags->data; - net_buf_pull(pkt->frags, header_len); - if (net_context_get_type(ctx) == SOCK_STREAM) { + /* TCP: we don't care about packet header, get rid of it asap. + * UDP: keep packet header to support recvfrom(). + */ + header_len = net_pkt_appdata(pkt) - pkt->frags->data; + net_buf_pull(pkt->frags, header_len); net_context_update_recv_wnd(ctx, -net_pkt_appdatalen(pkt)); } @@ -213,12 +214,20 @@ int zsock_accept(int sock, struct sockaddr *addr, socklen_t *addrlen) ssize_t zsock_send(int sock, const void *buf, size_t len, int flags) { - ARG_UNUSED(flags); + return zsock_sendto(sock, buf, len, flags, NULL, 0); +} + +ssize_t zsock_sendto(int sock, const void *buf, size_t len, int flags, + const struct sockaddr *dest_addr, socklen_t addrlen) +{ int err; struct net_pkt *send_pkt; s32_t timeout = K_FOREVER; struct net_context *ctx = INT_TO_POINTER(sock); size_t max_len = net_if_get_mtu(net_context_get_iface(ctx)); + enum net_sock_type sock_type = net_context_get_type(ctx); + + ARG_UNUSED(flags); if (sock_is_nonblock(ctx)) { timeout = K_NO_WAIT; @@ -249,7 +258,20 @@ ssize_t zsock_send(int sock, const void *buf, size_t len, int flags) return -1; } - err = net_context_send(send_pkt, /*cb*/NULL, timeout, NULL, NULL); + /* Register the callback before sending in order to receive the response + * from the peer. + */ + SET_ERRNO(net_context_recv(ctx, zsock_received_cb, K_NO_WAIT, NULL)); + + if (sock_type == SOCK_DGRAM) { + err = net_context_sendto(send_pkt, dest_addr, addrlen, NULL, + timeout, NULL, NULL); + } else if (sock_type == SOCK_STREAM) { + err = net_context_send(send_pkt, NULL, timeout, NULL, NULL); + } else { + __ASSERT(0, "Unknown socket type"); + } + if (err < 0) { net_pkt_unref(send_pkt); errno = -err; @@ -259,7 +281,9 @@ ssize_t zsock_send(int sock, const void *buf, size_t len, int flags) return len; } -static inline ssize_t zsock_recv_stream(struct net_context *ctx, void *buf, size_t max_len) +static inline ssize_t zsock_recv_stream(struct net_context *ctx, + void *buf, + size_t max_len) { size_t recv_len = 0; s32_t timeout = K_FOREVER; @@ -333,11 +357,18 @@ static inline ssize_t zsock_recv_stream(struct net_context *ctx, void *buf, size } ssize_t zsock_recv(int sock, void *buf, size_t max_len, int flags) +{ + return zsock_recvfrom(sock, buf, max_len, flags, NULL, NULL); +} + +ssize_t zsock_recvfrom(int sock, void *buf, size_t max_len, int flags, + struct sockaddr *src_addr, socklen_t *addrlen) { ARG_UNUSED(flags); struct net_context *ctx = INT_TO_POINTER(sock); enum net_sock_type sock_type = net_context_get_type(ctx); size_t recv_len = 0; + unsigned int header_len; if (sock_type == SOCK_DGRAM) { @@ -354,6 +385,19 @@ ssize_t zsock_recv(int sock, void *buf, size_t max_len, int flags) return -1; } + if (src_addr && addrlen) { + int rv; + + rv = net_pkt_get_src_addr(pkt, src_addr, *addrlen); + if (rv < 0) { + errno = rv; + return -1; + } + } + /* Remove packet header since we've handled src addr and port */ + header_len = net_pkt_appdata(pkt) - pkt->frags->data; + net_buf_pull(pkt->frags, header_len); + recv_len = net_pkt_appdatalen(pkt); if (recv_len > max_len) { recv_len = max_len; diff --git a/tests/net/socket/udp/prj.conf b/tests/net/socket/udp/prj.conf index 65cd9435eb507..4b29cd5f9260d 100644 --- a/tests/net/socket/udp/prj.conf +++ b/tests/net/socket/udp/prj.conf @@ -4,7 +4,7 @@ CONFIG_NEWLIB_LIBC=y # Networking config CONFIG_NETWORKING=y CONFIG_NET_IPV4=y -CONFIG_NET_IPV6=n +CONFIG_NET_IPV6=y CONFIG_NET_UDP=y CONFIG_NET_SOCKETS=y CONFIG_NET_SOCKETS_POSIX_NAMES=y @@ -13,13 +13,14 @@ CONFIG_NET_SOCKETS_POSIX_NAMES=y CONFIG_TEST_RANDOM_GENERATOR=y # Network address config -#CONFIG_NET_APP_SETTINGS=y -#CONFIG_NET_APP_MY_IPV4_ADDR="192.0.2.1" -#CONFIG_NET_APP_PEER_IPV4_ADDR="192.0.2.2" +CONFIG_NET_APP_SETTINGS=y +CONFIG_NET_APP_MY_IPV4_ADDR="192.0.2.1" +CONFIG_NET_APP_MY_IPV6_ADDR="2001:db8::1" # Network debug config #CONFIG_NET_LOG=y #CONFIG_NET_DEBUG_SOCKETS=y #CONFIG_SYS_LOG_NET_LEVEL=4 +CONFIG_MAIN_STACK_SIZE=2048 CONFIG_ZTEST=y diff --git a/tests/net/socket/udp/src/main.c b/tests/net/socket/udp/src/main.c index bbd821fd8a969..6bdda57247c1a 100644 --- a/tests/net/socket/udp/src/main.c +++ b/tests/net/socket/udp/src/main.c @@ -8,13 +8,207 @@ #include #include -#include #define BUF_AND_SIZE(buf) buf, sizeof(buf) - 1 #define STRLEN(buf) (sizeof(buf) - 1) #define TEST_STR_SMALL "test" +#define LOCAL_PORT 9898 +#define REMOTE_PORT 4242 + +#define V4_ANY_ADDR "0.0.0.0" +#define V6_ANY_ADDR "0:0:0:0:0:0:0:0" + +#define V4_REMOTE_ADDR "192.0.2.2" +#define V6_REMOTE_ADDR "2001:db8::2" + +static void test_v4_sendto_recvfrom(void) +{ + int rv; + int sock; + ssize_t sent = 0; + ssize_t recved = 0; + char rx_buf[30] = {0}; + struct sockaddr_in addr; + socklen_t socklen; + + sock = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP); + zassert_true(sock >= 0, "socket open failed"); + + addr.sin_family = AF_INET; + addr.sin_port = htons(REMOTE_PORT); + rv = inet_pton(AF_INET, V4_REMOTE_ADDR, &(addr.sin_addr)); + zassert_equal(rv, 1, "inet_pton failed"); + + sent = sendto(sock, + TEST_STR_SMALL, + strlen(TEST_STR_SMALL), + 0, + (struct sockaddr *)&addr, + sizeof(addr)); + zassert_equal(sent, strlen(TEST_STR_SMALL), "sendto failed"); + + socklen = sizeof(addr); + recved = recvfrom(sock, + rx_buf, + sizeof(rx_buf), + 0, + (struct sockaddr *)&addr, + &socklen); + zassert_true(recved > 0, "recvfrom fail"); + zassert_equal(recved, + strlen(TEST_STR_SMALL), + "unexpected received bytes"); + zassert_equal(strncmp(rx_buf, TEST_STR_SMALL, strlen(TEST_STR_SMALL)), + 0, + "unexpected data"); +} + +static void test_v6_sendto_recvfrom(void) +{ + int rv; + int sock; + ssize_t sent = 0; + ssize_t recved = 0; + char rx_buf[30] = {0}; + struct sockaddr_in6 addr; + socklen_t socklen; + + sock = socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP); + zassert_true(sock >= 0, "socket open failed"); + + addr.sin6_family = AF_INET6; + addr.sin6_port = htons(REMOTE_PORT); + rv = inet_pton(AF_INET6, V6_REMOTE_ADDR, &(addr.sin6_addr)); + zassert_equal(rv, 1, "inet_pton failed"); + + sent = sendto(sock, + TEST_STR_SMALL, + strlen(TEST_STR_SMALL), + 0, + (struct sockaddr *)&addr, + sizeof(addr)); + zassert_equal(sent, strlen(TEST_STR_SMALL), "sendto failed"); + + socklen = sizeof(addr); + recved = recvfrom(sock, + rx_buf, + sizeof(rx_buf), + 0, + (struct sockaddr *)&addr, + &socklen); + zassert_true(recved > 0, "recvfrom fail"); + zassert_equal(recved, + strlen(TEST_STR_SMALL), + "unexpected received bytes"); + zassert_equal(strncmp(rx_buf, TEST_STR_SMALL, strlen(TEST_STR_SMALL)), + 0, + "unexpected data"); +} + +static void test_v4_bind_sendto(void) +{ + int rv; + int sock; + ssize_t sent = 0; + ssize_t recved = 0; + char rx_buf[30] = {0}; + struct sockaddr_in remote_addr; + struct sockaddr_in local_addr; + socklen_t socklen; + + sock = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP); + zassert_true(sock >= 0, "socket open failed"); + + local_addr.sin_family = AF_INET; + local_addr.sin_port = htons(LOCAL_PORT); + rv = inet_pton(AF_INET, V4_ANY_ADDR, &(local_addr.sin_addr)); + zassert_equal(rv, 1, "inet_pton failed"); + + rv = bind(sock, (struct sockaddr *)&local_addr, sizeof(local_addr)); + zassert_equal(rv, 0, "bind failed"); + + remote_addr.sin_family = AF_INET; + remote_addr.sin_port = htons(REMOTE_PORT); + rv = inet_pton(AF_INET, V4_REMOTE_ADDR, &(remote_addr.sin_addr)); + zassert_equal(rv, 1, "inet_pton failed"); + + sent = sendto(sock, + TEST_STR_SMALL, + strlen(TEST_STR_SMALL), + 0, + (struct sockaddr *)&remote_addr, + sizeof(remote_addr)); + zassert_equal(sent, strlen(TEST_STR_SMALL), "sendto failed"); + + socklen = sizeof(remote_addr); + recved = recvfrom(sock, + rx_buf, + sizeof(rx_buf), + 0, + (struct sockaddr *)&remote_addr, + &socklen); + zassert_true(recved > 0, "recvfrom fail"); + zassert_equal(recved, + strlen(TEST_STR_SMALL), + "unexpected received bytes"); + zassert_equal(strncmp(rx_buf, TEST_STR_SMALL, strlen(TEST_STR_SMALL)), + 0, + "unexpected data"); +} + +static void test_v6_bind_sendto(void) +{ + int rv; + int sock; + ssize_t sent = 0; + ssize_t recved = 0; + char rx_buf[30] = {0}; + struct sockaddr_in6 remote_addr; + struct sockaddr_in6 local_addr; + socklen_t socklen; + + sock = socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP); + zassert_true(sock >= 0, "socket open failed"); + + local_addr.sin6_family = AF_INET6; + local_addr.sin6_port = htons(LOCAL_PORT); + rv = inet_pton(AF_INET6, V6_ANY_ADDR, &(local_addr.sin6_addr)); + zassert_equal(rv, 1, "inet_pton failed"); + + rv = bind(sock, (struct sockaddr *)&local_addr, sizeof(local_addr)); + zassert_equal(rv, 0, "bind failed"); + + remote_addr.sin6_family = AF_INET6; + remote_addr.sin6_port = htons(REMOTE_PORT); + rv = inet_pton(AF_INET6, V6_REMOTE_ADDR, &(remote_addr.sin6_addr)); + zassert_equal(rv, 1, "inet_pton failed"); + + sent = sendto(sock, + TEST_STR_SMALL, + strlen(TEST_STR_SMALL), + 0, + (struct sockaddr *)&remote_addr, + sizeof(remote_addr)); + zassert_equal(sent, strlen(TEST_STR_SMALL), "sendto failed"); + + socklen = sizeof(remote_addr); + recved = recvfrom(sock, + rx_buf, + sizeof(rx_buf), + 0, + (struct sockaddr *)&remote_addr, + &socklen); + zassert_true(recved > 0, "recvfrom fail"); + zassert_equal(recved, + strlen(TEST_STR_SMALL), + "unexpected received bytes"); + zassert_equal(strncmp(rx_buf, TEST_STR_SMALL, strlen(TEST_STR_SMALL)), + 0, + "unexpected data"); +} + void test_send_recv_2_sock(void) { int sock1, sock2; @@ -45,15 +239,12 @@ void test_send_recv_2_sock(void) void test_main(void) { - zassert_not_null(net_if_get_default(), "No default netif"); - static struct in_addr in4addr_my = { { {192, 0, 2, 1} } }; - - net_if_ipv4_addr_add(net_if_get_default(), &in4addr_my, - NET_ADDR_MANUAL, 0); - ztest_test_suite(socket_udp, - ztest_unit_test(test_send_recv_2_sock) - ); + ztest_unit_test(test_send_recv_2_sock), + ztest_unit_test(test_v4_sendto_recvfrom), + ztest_unit_test(test_v6_sendto_recvfrom), + ztest_unit_test(test_v4_bind_sendto), + ztest_unit_test(test_v6_bind_sendto)); ztest_run_test_suite(socket_udp); }