diff --git a/include/net/net_ip.h b/include/net/net_ip.h index 8fbeb3b12bb15..256e01a7e2800 100644 --- a/include/net/net_ip.h +++ b/include/net/net_ip.h @@ -45,6 +45,8 @@ extern "C" { #define PF_PACKET 3 /**< Packet family. */ #define PF_CAN 4 /**< Controller Area Network. */ #define PF_NET_MGMT 5 /**< Network management info. */ +#define PF_LOCAL 6 /**< Inter-process communication */ +#define PF_UNIX PF_LOCAL /**< Inter-process communication */ /* Address families. */ #define AF_UNSPEC PF_UNSPEC /**< Unspecified address family. */ @@ -53,6 +55,8 @@ extern "C" { #define AF_PACKET PF_PACKET /**< Packet family. */ #define AF_CAN PF_CAN /**< Controller Area Network. */ #define AF_NET_MGMT PF_NET_MGMT /**< Network management info. */ +#define AF_LOCAL PF_LOCAL /**< Inter-process communication */ +#define AF_UNIX PF_UNIX /**< Inter-process communication */ /** Protocol numbers from IANA/BSD */ enum net_ip_protocol { @@ -341,6 +345,12 @@ struct sockaddr_storage { char data[NET_SOCKADDR_MAX_SIZE - sizeof(sa_family_t)]; }; +/* Socket address struct for UNIX domain sockets */ +struct sockaddr_un { + sa_family_t sun_family; /* AF_UNIX */ + char sun_path[NET_SOCKADDR_MAX_SIZE - sizeof(sa_family_t)]; +}; + struct net_addr { sa_family_t family; union { diff --git a/include/net/socket.h b/include/net/socket.h index 6384ed667388a..4011f6fd2e0c8 100644 --- a/include/net/socket.h +++ b/include/net/socket.h @@ -159,6 +159,20 @@ struct zsock_addrinfo { */ __syscall int zsock_socket(int family, int type, int proto); +/** + * @brief Create an unnamed pair of connected sockets + * + * @details + * @rst + * See `POSIX.1-2017 article + * `__ + * for normative description. + * This function is also exposed as ``socketpair()`` + * if :option:`CONFIG_NET_SOCKETS_POSIX_NAMES` is defined. + * @endrst + */ +__syscall int zsock_socketpair(int family, int type, int proto, int *sv); + /** * @brief Close a network socket * @@ -566,6 +580,11 @@ static inline int socket(int family, int type, int proto) return zsock_socket(family, type, proto); } +static inline int socketpair(int family, int type, int proto, int sv[2]) +{ + return zsock_socketpair(family, type, proto, sv); +} + static inline int close(int sock) { return zsock_close(sock); diff --git a/include/posix/sys/socket.h b/include/posix/sys/socket.h index 11e6e2c5b59d8..faf659432a89c 100644 --- a/include/posix/sys/socket.h +++ b/include/posix/sys/socket.h @@ -1,5 +1,6 @@ /* * Copyright (c) 2019 Linaro Limited + * Copyright (c) 2020 Friedt Professional Engineering Services, Inc * * SPDX-License-Identifier: Apache-2.0 */ @@ -18,6 +19,11 @@ static inline int socket(int family, int type, int proto) return zsock_socket(family, type, proto); } +static inline int socketpair(int family, int type, int proto, int sv[2]) +{ + return zsock_socketpair(family, type, proto, sv); +} + #define SHUT_RD ZSOCK_SHUT_RD #define SHUT_WR ZSOCK_SHUT_WR #define SHUT_RDWR ZSOCK_SHUT_RDWR diff --git a/subsys/net/lib/sockets/CMakeLists.txt b/subsys/net/lib/sockets/CMakeLists.txt index c61620e496f93..42f4637d89bc5 100644 --- a/subsys/net/lib/sockets/CMakeLists.txt +++ b/subsys/net/lib/sockets/CMakeLists.txt @@ -28,4 +28,6 @@ if(CONFIG_SOCKS) zephyr_include_directories(${ZEPHYR_BASE}/subsys/net/lib/socks) endif() +zephyr_sources_ifdef(CONFIG_NET_SOCKETPAIR socketpair.c) + zephyr_link_libraries_ifdef(CONFIG_MBEDTLS mbedTLS) diff --git a/subsys/net/lib/sockets/Kconfig b/subsys/net/lib/sockets/Kconfig index e77080f8f9fbc..f1351c52917d1 100644 --- a/subsys/net/lib/sockets/Kconfig +++ b/subsys/net/lib/sockets/Kconfig @@ -139,6 +139,20 @@ config NET_SOCKETS_CAN_RECEIVERS The value tells how many sockets can receive data from same Socket-CAN interface. +config NET_SOCKETPAIR + bool "Support for the socketpair syscall [EXPERIMENTAL]" + help + Choose y here if you would like to use the socketpair(2) + system call. + +config NET_SOCKETPAIR_BUFFER_SIZE + int "Size of the intermediate buffer, in bytes" + default 64 + range 64 1048576 + depends on NET_SOCKETPAIR + help + Buffer size for socketpair(2) + config NET_SOCKETS_NET_MGMT bool "Enable network management socket support [EXPERIMENTAL]" depends on NET_MGMT_EVENT diff --git a/subsys/net/lib/sockets/socketpair.c b/subsys/net/lib/sockets/socketpair.c new file mode 100644 index 0000000000000..3779b571db77a --- /dev/null +++ b/subsys/net/lib/sockets/socketpair.c @@ -0,0 +1,414 @@ +/* + * Copyright (c) 2020 Friedt Professional Engineering Services, Inc + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include +#include + +/* Zephyr headers */ +#include +LOG_MODULE_REGISTER(net_sock, CONFIG_NET_SOCKETS_LOG_LEVEL); + +#include +#include +#include +#include + +#include "sockets_internal.h" + +#define SPAIR_MAGIC 0x9a12 + +#define D(fmt, args...) printk("D: %s(): %d: " fmt "\n", __func__, __LINE__, ##args) + +enum { + SPAIR_CANCEL, + SPAIR_CAN_READ, + SPAIR_CAN_WRITE, +}; + + + +struct spair { + u32_t magic; + int remote; + struct k_pipe recv_q; + struct k_poll_signal signal; + struct k_poll_event events[1]; + bool blocking :1; + u8_t __aligned(4) buf[CONFIG_NET_SOCKETPAIR_BUFFER_SIZE]; +}; + +static const struct socket_op_vtable spair_fd_op_vtable; + +static void spair_init(struct spair *spair) { + + spair->magic = SPAIR_MAGIC; + spair->remote = -1; + + k_pipe_init(&spair->recv_q, spair->buf, sizeof(spair->buf)); + + k_poll_signal_init(&spair->signal); + + k_poll_event_init(&spair->events[0], + K_POLL_TYPE_SIGNAL, + K_POLL_MODE_NOTIFY_ONLY, + &spair->signal); + + spair->blocking = true; +} + +static void spair_fini(struct spair *spair) +{ + memset(spair, 0, sizeof(*spair)); + k_free(spair); +} + +static size_t k_pipe_write_avail(struct k_pipe *pipe) +{ + if (pipe->write_index >= pipe->read_index) { + return pipe->size - (pipe->write_index - pipe->read_index); + } + + return pipe->read_index - pipe->write_index; +} + +static size_t k_pipe_read_avail(struct k_pipe *pipe) +{ + if (pipe->read_index >= pipe->write_index) { + return pipe->size - (pipe->read_index - pipe->write_index); + } + + return pipe->write_index - pipe->read_index; +} + +int z_impl_zsock_socketpair(int family, int type, int proto, int *sv) +{ + int res; + int tmp[2] = {-1, -1}; + struct spair *obj[2] = {}; + + if (family != AF_UNIX) { + errno = EAFNOSUPPORT; + res = -1; + goto out; + } + + if (!(type == SOCK_STREAM || type == SOCK_RAW)) { + errno = EPROTOTYPE; + res = -1; + goto out; + } + + if (proto != 0) { + errno = EPROTONOSUPPORT; + res = -1; + goto out; + } + + if (sv == NULL) { + /* not listed in the normative standard, but probably safe */ + errno = EINVAL; + res = -1; + goto out; + } + + for(size_t i = 0; i < 2; ++i) { + + tmp[i] = z_reserve_fd(); + if (tmp[i] == -1) { + errno = ENFILE; + res = -1; + goto cleanup; + } + D("reserved fd %d", tmp[i]); + + obj[i] = k_malloc(sizeof(*obj)); + if (obj[i] == NULL) { + errno = ENOMEM; + res = -1; + goto cleanup; + } + memset(obj[i], 0, sizeof(*(obj[i]))); + D("allocated spair[%u] at %p", i, obj[i]); + + spair_init(obj[i]); + D("initialized spair[%u] at %p", i, obj[i]); + + z_finalize_fd(tmp[i], obj[i], (const struct fd_op_vtable *) &spair_fd_op_vtable); + D("finalized fd: %d spair: %p vtable: %p &recv_q: %p", tmp[i], obj[i], &spair_fd_op_vtable, &obj[i]->recv_q); + } + + for(size_t i = 0; i < 2; ++i) { + obj[i]->remote = tmp[(!i) & 1]; + sv[i] = tmp[i]; + } + + res = 0; + + goto out; + +cleanup: + for(size_t i = 0; i < 2; ++i) { + if (obj[i] != NULL) { + spair_fini(obj[i]); + obj[i] = NULL; + } + } + + for(size_t i = 0; i < 2; ++i) { + if (tmp[i] != -1) { + z_free_fd(tmp[i]); + tmp[i] = -1; + } + } + +out: + return res; +} + +#ifdef CONFIG_USERSPACE +static inline int z_vrfy_zsock_socketpair(int family, int type, int proto, int *sv) +{ + int ret; + int tmp[2]; + + ret = z_impl_zsock_socketpair(family, type, proto, tmp); + Z_OOPS(z_user_to_copy(sv, tmp, sizeof(tmp)); + return ret; +} +#include +#endif /* CONFIG_USERSPACE */ + + +static ssize_t spair_read(void *obj, void *buffer, size_t count) +{ + struct spair *const spair = (struct spair *)obj; + + if ( spair->magic != SPAIR_MAGIC ) { + D( + "invalid magic for struct spair * at %p:\n" + "actual: %u\n" + "expected: %u", + spair, + spair->magic, + SPAIR_MAGIC + ); + } + __ASSERT(spair->magic == SPAIR_MAGIC, ""); + + int res; + + size_t bytes_read; + + if (count == 0) { + return 0; + } + + D("calling k_pipe_get(%p, %p, %u)", &spair->recv_q, buffer, count); + D("recv_q: buffer: %p size: %u bytes_used: %u read_index: %u write_index: %u flags: %u", + spair->recv_q.buffer, spair->recv_q.size, spair->recv_q.bytes_used, + spair->recv_q.read_index, spair->recv_q.write_index, spair->recv_q.flags); + res = k_pipe_get(&spair->recv_q, (void *)buffer, count, & bytes_read, 0, K_NO_WAIT); + D("k_pipe_get() returned %d", res); + if (res < 0) { + errno = -res; + return -1; + } + D("read %u bytes", bytes_read); + + return bytes_read; +} + +static ssize_t spair_write(void *obj, const void *buffer, size_t count) +{ + struct spair *const spair = (struct spair *)obj; + struct spair *const remote = z_get_fd_obj(spair->remote, (const struct fd_op_vtable *) &spair_fd_op_vtable, 0); + + int res; + size_t bytes_written; + + if ( spair->magic != SPAIR_MAGIC ) { + D( + "invalid magic for struct spair * at %p:\n" + "actual: %u\n" + "expected: %u", + spair, + spair->magic, + SPAIR_MAGIC + ); + } + __ASSERT(spair->magic == SPAIR_MAGIC, ""); + + if (remote == NULL) { + errno = EPIPE; + return -1; + } + + if (count == 0) { + return 0; + } + + D("calling k_pipe_put(%p, %p, %u)", &remote->recv_q, buffer, count); + D("recv_q: buffer: %p size: %u bytes_used: %u read_index: %u write_index: %u flags: %u", + remote->recv_q.buffer, remote->recv_q.size, remote->recv_q.bytes_used, + remote->recv_q.read_index, remote->recv_q.write_index, remote->recv_q.flags); + res = k_pipe_put(&remote->recv_q, (void *)buffer, count, & bytes_written, 0, K_NO_WAIT); + D("k_pipe_put() returned %d", res); + if (res < 0) { + errno = -res; + return -1; + } + D("wrote %u bytes", bytes_written); + + return bytes_written; +} + +static int spair_ioctl(void *obj, unsigned int request, va_list args) +{ + + struct spair *const spair = (struct spair *)obj; + + if ( spair->magic != SPAIR_MAGIC ) { + D( + "invalid magic for struct spair * at %p:\n" + "actual: %u\n" + "expected: %u", + spair, + spair->magic, + SPAIR_MAGIC + ); + } + __ASSERT(spair->magic == SPAIR_MAGIC, ""); + + switch (request) { + case ZFD_IOCTL_CLOSE: + spair_fini((struct spair *)obj); + return 0; + + default: + errno = EOPNOTSUPP; + return -1; + } +} + +static int spair_bind(void *obj, const struct sockaddr *addr, + socklen_t addrlen) +{ + (void) obj; + (void) addr; + (void) addrlen; + + errno = EISCONN; + return -1; +} + +static int spair_connect(void *obj, const struct sockaddr *addr, + socklen_t addrlen) +{ + (void) obj; + (void) addr; + (void) addrlen; + + errno = EISCONN; + return -1; +} + +static int spair_listen(void *obj, int backlog) +{ + (void) obj; + (void) backlog; + + errno = EINVAL; + return -1; +} + +static int spair_accept(void *obj, struct sockaddr *addr, + socklen_t *addrlen) +{ + (void) obj; + (void) addr; + (void) addrlen; + + errno = EOPNOTSUPP; + return -1; +} + +static ssize_t spair_sendto(void *obj, const void *buf, size_t len, + int flags, const struct sockaddr *dest_addr, + socklen_t addrlen) +{ + (void) flags; + (void) dest_addr; + (void) addrlen; + + return spair_write(obj, buf, len); +} + +static ssize_t spair_sendmsg(void *obj, const struct msghdr *msg, + int flags) +{ + (void) obj; + (void) msg; + (void) flags; + + errno = ENOSYS; + return -1; +} + +static ssize_t spair_recvfrom(void *obj, void *buf, size_t max_len, + int flags, struct sockaddr *src_addr, + socklen_t *addrlen) +{ + (void) flags; + (void) src_addr; + (void) addrlen; + + return spair_read(obj, buf, max_len); +} + +static int spair_getsockopt(void *obj, int level, int optname, + void *optval, socklen_t *optlen) +{ + (void) obj; + (void) level; + (void) optname; + (void) optval; + (void) optlen; + + errno = ENOSYS; + return -1; +} + +static int spair_setsockopt(void *obj, int level, int optname, + const void *optval, socklen_t optlen) +{ + (void) obj; + (void) level; + (void) optname; + (void) optval; + (void) optlen; + + errno = ENOSYS; + return -1; +} + +static const struct socket_op_vtable spair_fd_op_vtable = { + .fd_vtable = { + .read = spair_read, + .write = spair_write, + .ioctl = spair_ioctl, + }, + .bind = spair_bind, + .connect = spair_connect, + .listen = spair_listen, + .accept = spair_accept, + .sendto = spair_sendto, + .sendmsg = spair_sendmsg, + .recvfrom = spair_recvfrom, + .getsockopt = spair_getsockopt, + .setsockopt = spair_setsockopt, +}; diff --git a/tests/net/socket/socketpair/CMakeLists.txt b/tests/net/socket/socketpair/CMakeLists.txt new file mode 100644 index 0000000000000..27b977b095b55 --- /dev/null +++ b/tests/net/socket/socketpair/CMakeLists.txt @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: Apache-2.0 + +cmake_minimum_required(VERSION 3.13.1) +find_package(Zephyr HINTS $ENV{ZEPHYR_BASE}) +project(socketpair) + +FILE(GLOB app_sources src/*.c) +target_sources(app PRIVATE ${app_sources}) diff --git a/tests/net/socket/socketpair/prj.conf b/tests/net/socket/socketpair/prj.conf new file mode 100644 index 0000000000000..af079dd65558e --- /dev/null +++ b/tests/net/socket/socketpair/prj.conf @@ -0,0 +1,34 @@ +# General config +CONFIG_NEWLIB_LIBC=y + +# Networking config +CONFIG_NETWORKING=y +CONFIG_NET_TEST=y +CONFIG_NET_LOOPBACK=y +CONFIG_NET_IPV4=y +CONFIG_NET_IPV6=y +CONFIG_NET_SOCKETS=y +CONFIG_NET_SOCKETPAIR=y +CONFIG_NET_SOCKETS_POSIX_NAMES=y +# Defines fd_set size +CONFIG_POSIX_MAX_FDS=33 + +# Network driver config +CONFIG_TEST_RANDOM_GENERATOR=y + +# Network address config +CONFIG_NET_CONFIG_SETTINGS=y +CONFIG_NET_CONFIG_MY_IPV4_ADDR="192.0.2.1" +CONFIG_NET_CONFIG_MY_IPV6_ADDR="2001:db8::1" + +CONFIG_MAIN_STACK_SIZE=4096 +CONFIG_ZTEST=y + +# User mode requirements +CONFIG_TEST_USERSPACE=y +CONFIG_HEAP_MEM_POOL_SIZE=128 + +CONFIG_QEMU_TICKLESS_WORKAROUND=y + +CONFIG_NO_OPTIMIZATIONS=y +CONFIG_SPIN_VALIDATE=n diff --git a/tests/net/socket/socketpair/src/main.c b/tests/net/socket/socketpair/src/main.c new file mode 100644 index 0000000000000..6330238e7a0d8 --- /dev/null +++ b/tests/net/socket/socketpair/src/main.c @@ -0,0 +1,307 @@ +/* + * Copyright (c) 2019 Linaro Limited + * Copyright (c) 2020 Friedt Professional Engineering Services, Inc + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +LOG_MODULE_REGISTER(net_test, CONFIG_NET_SOCKETS_LOG_LEVEL); + +#include +#include +#include +#include +#include + +#include + +#undef read +#define read(fd, buf, len) zsock_recv(fd, buf, len, 0) + +#undef write +#define write(fd, buf, len) zsock_send(fd, buf, len, 0) + +static void happy_path_inner( + const unsigned family, const char *family_s, + const unsigned type, const char *type_s, + const unsigned proto, const char *proto_s +) +{ + int res; + int sv[2]; + + const char *expected_msg = "Hello, socketpair(2) world!"; + const size_t expected_msg_len = strlen(expected_msg); + char actual_msg[32]; + size_t actual_msg_len; + + printf("calling socketpair(%s, %s, %s, sv)\n", family_s, type_s, + proto_s); + res = socketpair(family, type, proto, sv); + zassert_true(res == -1 || res == 0, + "socketpair returned an unspecified value"); + zassert_equal(res, 0, "socketpair failed"); + + printf("sv: {%d, %d}\n", sv[0], sv[1]); + + socklen_t len; + + /* sockets are bidirectional. test functions from both ends */ + for(int i = 1; i > 0; --i) { + + printf("data direction: %d -> %d\n", sv[i], sv[(!i) & 1]); + + /* + * Test with write(2) / read(2) + */ + + printf("testing write(2)\n"); + res = write(sv[i], expected_msg, expected_msg_len); + + zassert_not_equal(res, -1, "write(2) failed: %d", errno); + actual_msg_len = res; + zassert_equal(actual_msg_len, expected_msg_len, + "did not write entire message"); + + memset(actual_msg, 0, sizeof(actual_msg)); + + printf("testing read(2)\n"); + res = read(sv[(!i) & 1], actual_msg, sizeof(actual_msg)); + + zassert_not_equal(res, -1, "read(2) failed: %d", errno); + actual_msg_len = res; + zassert_equal(actual_msg_len, expected_msg_len, + "wrong return value"); + + zassert_true(0 == strncmp(expected_msg, actual_msg, + actual_msg_len), + "the wrong message was passed through the socketpair"); + + /* + * Test with send(2) / recv(2) + */ + + printf("testing send(2)\n"); + res = send(sv[i], expected_msg, expected_msg_len, 0); + + zassert_not_equal(res, -1, "send(2) failed: %d", errno); + actual_msg_len = res; + zassert_equal(actual_msg_len, expected_msg_len, + "did not send entire message"); + + memset(actual_msg, 0, sizeof(actual_msg)); + + printf("testing recv(2)\n"); + res = recv(sv[(!i) & 1], actual_msg, sizeof(actual_msg), 0); + + zassert_not_equal(res, -1, "recv(2) failed: %d", errno); + actual_msg_len = res; + zassert_equal(actual_msg_len, expected_msg_len, + "wrong return value"); + + zassert_true(0 == strncmp(expected_msg, actual_msg, + actual_msg_len), + "the wrong message was passed through the socketpair"); + + /* + * Test with sendto(2) / recvfrom(2) + */ + + printf("testing sendto(2)\n"); + res = sendto(sv[i], expected_msg, expected_msg_len, 0, NULL, 0); + + zassert_not_equal(res, -1, "sendto(2) failed: %d", errno); + actual_msg_len = res; + zassert_equal(actual_msg_len, expected_msg_len, + "did not sendto entire message"); + + memset(actual_msg, 0, sizeof(actual_msg)); + + printf("testing recvfrom(2)\n"); + res = recvfrom(sv[(!i) & 1], actual_msg, sizeof(actual_msg), 0, NULL, &len); + + zassert_not_equal(res, -1, "recvfrom(2) failed: %d", errno); + actual_msg_len = res; + zassert_equal(actual_msg_len, expected_msg_len, + "wrong return value"); + + zassert_true(0 == strncmp(expected_msg, actual_msg, + actual_msg_len), + "the wrong message was passed through the socketpair"); + +#if 0 + /* + * Test with sendmsg(2) / recvmsg(2) + */ + + printf("testing sendmsg(2)\n"); + res = sendmsg(sv[i], expected_msg, expected_msg_len, 0); + + zassert_not_equal(res, -1, "sendmsg(2) failed: %d", errno); + actual_msg_len = res; + zassert_equal(actual_msg_len, expected_msg_len, + "did not sendmsg entire message"); + + memset(actual_msg, 0, sizeof(actual_msg)); + + printf("testing recvmsg(2)\n"); + res = recvmsg(sv[(!i) & 1], actual_msg, sizeof(actual_msg), 0); + + zassert_not_equal(res, -1, "recvmsg(2) failed: %d", errno); + actual_msg_len = res; + zassert_equal(actual_msg_len, expected_msg_len, + "wrong return value"); + + zassert_true(0 == strncmp(expected_msg, actual_msg, + actual_msg_len), + "the wrong message was passed through the socketpair"); +#endif + + } + + printf("closing sv[0]: fd %d\n", sv[0]); + res = close(sv[0]); + zassert_equal(res, 0, "close failed"); + + printf("closing sv[1]: fd %d\n", sv[1]); + res = close(sv[1]); + zassert_equal(res, 0, "close failed"); +} + +void test_socket_socketpair_happy_path(void) +{ + struct unstr { + unsigned u; + const char *s; + }; + static const struct unstr address_family[] = { + { AF_LOCAL, "AF_LOCAL" }, + { AF_UNIX, "AF_UNIX" }, + }; + static const struct unstr socket_type[] = { + { SOCK_STREAM, "SOCK_STREAM" }, + //{ SOCK_RAW, "SOCK_RAW" }, + }; + static const struct unstr protocol[] = { + { 0, "0" }, + }; + + for(size_t i = 0; i < ARRAY_SIZE(address_family); ++i) { + for(size_t j = 0; j < ARRAY_SIZE(socket_type); ++j) { + for(size_t k = 0; k < ARRAY_SIZE(protocol); ++k) { + happy_path_inner( + address_family[i].u, address_family[i].s, + socket_type[j].u, socket_type[j].s, + protocol[k].u, protocol[k].s + ); + } + } + } +} + +void test_socket_socketpair_expected_failures(void) +{ + int res; + int sv[2]; + + /* Use invalid values in fields starting from left to right */ + + res = socketpair(AF_INET, SOCK_STREAM, 0, sv); + zassert_equal(res, -1, "socketpair with fail with bad address family"); + zassert_equal(errno, EAFNOSUPPORT, + "errno should be EAFNOSUPPORT with bad adddress family"); + if (res != -1) { + close(sv[0]); + close(sv[1]); + } + + res = socketpair(AF_UNIX, 42, 0, sv); + zassert_equal(res, -1, + "socketpair should fail with unsupported socket type"); + zassert_equal(errno, EPROTOTYPE, + "errno should be EPROTOTYPE with bad socket type"); + if (res != -1) { + close(sv[0]); + close(sv[1]); + } + + res = socketpair(AF_UNIX, SOCK_STREAM, IPPROTO_TLS_1_0, sv); + zassert_equal(res, -1, + "socketpair should fail with unsupported protocol"); + zassert_equal(errno, EPROTONOSUPPORT, + "errno should be EPROTONOSUPPORT with bad protocol"); + if (res != -1) { + close(sv[0]); + close(sv[1]); + } + + /* This is not a POSIX requirement, but should be safe */ + res = socketpair(AF_UNIX, SOCK_STREAM, 0, NULL); + zassert_equal(res, -1, + "socketpair should fail with invalid socket vector"); + zassert_equal(errno, EINVAL, + "errno should be EINVAL with bad socket vector"); +} + +void test_socket_socketpair_unsupported_calls(void) +{ + int res; + int sv[2]; + struct sockaddr_un addr = { + .sun_family = AF_UNIX, + }; + socklen_t len = sizeof(addr); + + res = socketpair(AF_UNIX, SOCK_STREAM, 0, sv); + zassert_equal(res, 0, "socketpair(AF_UNIX, SOCK_STREAM, 0, sv) failed"); + + + for(size_t i = 0; i < 2; ++i) { + + res = bind(sv[i], (struct sockaddr *)&addr, len); + zassert_equal(res, -1, "bind should fail on a socketpair endpoint"); + zassert_equal(errno, EISCONN, "bind should set errno to EISCONN"); + + res = connect(sv[i], (struct sockaddr *)&addr, len); + zassert_equal(res, -1, "connect should fail on a socketpair endpoint"); + zassert_equal(errno, EISCONN, "connect should set errno to EISCONN"); + + res = listen(sv[i], 1); + zassert_equal(res, -1, "listen should fail on a socketpair endpoint"); + zassert_equal(errno, EINVAL, "listen should set errno to EINVAL"); + + res = accept(sv[i], (struct sockaddr *)&addr, &len); + zassert_equal(res, -1, "accept should fail on a socketpair endpoint"); + zassert_equal(errno, EOPNOTSUPP, "accept should set errno to EOPNOTSUPP"); + } + + res = close(sv[0]); + zassert_equal(res, 0, "close failed"); + + res = close(sv[1]); + zassert_equal(res, 0, "close failed"); +} + +void test_socket_socketpair_select(void) { + /* TODO */ +} + +void test_socket_socketpair_poll(void) { + /* TODO */ +} + +void test_main(void) +{ + k_thread_system_pool_assign(k_current_get()); + + ztest_test_suite( + socket_socketpair, + ztest_user_unit_test(test_socket_socketpair_happy_path), + ztest_user_unit_test(test_socket_socketpair_expected_failures), + ztest_user_unit_test(test_socket_socketpair_unsupported_calls), + ztest_user_unit_test(test_socket_socketpair_select), + ztest_user_unit_test(test_socket_socketpair_poll)); + + ztest_run_test_suite(socket_socketpair); +} diff --git a/tests/net/socket/socketpair/testcase.yaml b/tests/net/socket/socketpair/testcase.yaml new file mode 100644 index 0000000000000..5f327889d9c9c --- /dev/null +++ b/tests/net/socket/socketpair/testcase.yaml @@ -0,0 +1,6 @@ +common: + depends_on: +tests: + net.socket.socketpair: + min_ram: 21 + tags: net socket userspace