diff --git a/library/std/src/os/unix/net/addr.rs b/library/std/src/os/unix/net/addr.rs index 25b95014e08b2..5ea77322d0d7c 100644 --- a/library/std/src/os/unix/net/addr.rs +++ b/library/std/src/os/unix/net/addr.rs @@ -6,6 +6,7 @@ use crate::os::unix::ffi::OsStrExt; use crate::path::Path; use crate::sealed::Sealed; use crate::sys::cvt; +use crate::sys::net::SockaddrLike; use crate::{fmt, io, mem, ptr}; // FIXME(#43348): Make libc adapt #[doc(cfg(...))] so we don't need these fake definitions here? @@ -253,6 +254,27 @@ impl SocketAddr { } } +impl SockaddrLike for SocketAddr { + unsafe fn from_storage( + storage: &libc::sockaddr_storage, + len: libc::socklen_t, + ) -> io::Result { + let p = (storage as *const libc::sockaddr_storage).cast(); + SocketAddr::from_parts(*p, len) + } + + fn to_storage(&self, storage_ret: &mut libc::sockaddr_storage) -> libc::socklen_t { + unsafe { + crate::ptr::copy_nonoverlapping( + &raw const self.addr, + (storage_ret as *mut libc::sockaddr_storage).cast(), + self.len as _, + ); + self.len + } + } +} + #[stable(feature = "unix_socket_abstract", since = "1.70.0")] impl Sealed for SocketAddr {} diff --git a/library/std/src/os/unix/net/ancillary.rs b/library/std/src/os/unix/net/ancillary.rs index d0984bdfb99d1..bfcef288d5beb 100644 --- a/library/std/src/os/unix/net/ancillary.rs +++ b/library/std/src/os/unix/net/ancillary.rs @@ -48,7 +48,7 @@ pub(super) fn recv_vectored_with_ancillary_from( msg.msg_control = ancillary.buffer.as_mut_ptr().cast(); } - let count = socket.recv_msg(&mut msg)?; + let count = socket.recv_msg_(&mut msg)?; ancillary.length = msg.msg_controllen as usize; ancillary.truncated = msg.msg_flags & libc::MSG_CTRUNC == libc::MSG_CTRUNC; @@ -83,7 +83,7 @@ pub(super) fn send_vectored_with_ancillary_to( ancillary.truncated = false; - socket.send_msg(&mut msg) + socket.send_msg_(&mut msg) } } diff --git a/library/std/src/sys/net/connection/socket/mod.rs b/library/std/src/sys/net/connection/socket/mod.rs index 1dd06e97bbabd..8f7d050a53e0a 100644 --- a/library/std/src/sys/net/connection/socket/mod.rs +++ b/library/std/src/sys/net/connection/socket/mod.rs @@ -198,6 +198,46 @@ unsafe fn socket_addr_from_c( } } +// Structs that have sockaddr header and can be marshalled to and from sockaddr_storage +pub(crate) trait SockaddrLike: Sized { + // used in recvmsg to parse the received addr + unsafe fn from_storage(storage: &c::sockaddr_storage, len: c::socklen_t) -> io::Result; + + // used in sendmsg to write to a suckaddr_storage buffer + fn to_storage(&self, storage_ret: &mut c::sockaddr_storage) -> c::socklen_t; +} + +impl SockaddrLike for SocketAddr { + unsafe fn from_storage(storage: &c::sockaddr_storage, len: c::socklen_t) -> io::Result { + socket_addr_from_c(storage as *const _, len as _) + } + + fn to_storage(&self, storage_ret: &mut c::sockaddr_storage) -> c::socklen_t { + let (crep, len) = socket_addr_to_c(self); + unsafe { + crate::ptr::copy_nonoverlapping( + &raw const crep, + (storage_ret as *mut c::sockaddr_storage).cast(), + len as _, + ); + } + len as _ + } +} + +impl SockaddrLike for () { + unsafe fn from_storage( + _storage: &libc::sockaddr_storage, + _len: libc::socklen_t, + ) -> io::Result { + Ok(()) + } + + fn to_storage(&self, _storage_ret: &mut libc::sockaddr_storage) -> libc::socklen_t { + 0 + } +} + //////////////////////////////////////////////////////////////////////////////// // sockaddr and misc bindings //////////////////////////////////////////////////////////////////////////////// diff --git a/library/std/src/sys/net/connection/socket/unix.rs b/library/std/src/sys/net/connection/socket/unix.rs index a191576d93b9d..3e6ec5b4a4f12 100644 --- a/library/std/src/sys/net/connection/socket/unix.rs +++ b/library/std/src/sys/net/connection/socket/unix.rs @@ -1,12 +1,17 @@ -use libc::{MSG_PEEK, c_int, c_void, size_t, sockaddr, socklen_t}; +use libc::{ + CMSG_DATA, CMSG_FIRSTHDR, CMSG_LEN, CMSG_NXTHDR, MSG_PEEK, c_int, c_uint, c_void, cmsghdr, + iovec, msghdr, size_t, sockaddr, sockaddr_storage, socklen_t, +}; #[cfg(not(any(target_os = "espidf", target_os = "nuttx")))] use crate::ffi::CStr; use crate::io::{self, BorrowedBuf, BorrowedCursor, IoSlice, IoSliceMut}; +use crate::mem::zeroed; use crate::net::{Shutdown, SocketAddr}; use crate::os::unix::io::{AsFd, AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, RawFd}; +use crate::ptr::copy_nonoverlapping; use crate::sys::fd::FileDesc; -use crate::sys::net::{getsockopt, setsockopt}; +use crate::sys::net::{SockaddrLike, getsockopt, setsockopt}; use crate::sys::pal::IsMinusOne; use crate::sys_common::{AsInner, FromInner, IntoInner}; use crate::time::{Duration, Instant}; @@ -62,6 +67,64 @@ pub fn cvt_gai(err: c_int) -> io::Result<()> { )) } +#[repr(C)] +pub union CmsgIter<'buf> { + _align: msghdr, + inner: CmsgIterInner<'buf>, +} + +#[derive(Clone, Copy)] +#[repr(C)] +struct CmsgIterInner<'buf> { + _padding: [u8; size_of::() + size_of::() + size_of::()], + curr_cmsg: *mut cmsghdr, + cmsg_buf: &'buf [u8], + cmsg_buf_len: usize, +} + +#[repr(transparent)] +pub struct CmsgBuf<'buf>(&'buf mut [u8]); + +impl<'buf> CmsgBuf<'buf> { + // fails if buf isn't aligned to alignof(cmsghdr) + pub fn new(buf: &'buf mut [u8]) -> io::Result { + if buf.as_ptr().align_offset(align_of::()) == 0 { + Ok(CmsgBuf(buf)) + } else { + Err(io::Error::new(io::ErrorKind::InvalidInput, "unaligned buffer")) + } + } + + pub unsafe fn new_unchecked(buf: &'buf mut [u8]) -> Self { + CmsgBuf(buf) + } +} + +impl<'buf> Iterator for CmsgIter<'buf> { + type Item = (size_t, c_int, c_int, &'buf [u8]); + + fn next(&mut self) -> Option { + unsafe { + if self.inner.curr_cmsg.is_null() { + None + } else { + let curr = *self.inner.curr_cmsg; + let data_ptr = CMSG_DATA(self.inner.curr_cmsg); + let ptrdiff = data_ptr.offset_from_unsigned(self.inner.curr_cmsg as *const u8); + let r = ( + curr.cmsg_len, + curr.cmsg_level, + curr.cmsg_type, + crate::slice::from_raw_parts(data_ptr, curr.cmsg_len - ptrdiff), + ); + self.inner.curr_cmsg = + CMSG_NXTHDR(self as *mut _ as *mut msghdr, self.inner.curr_cmsg); + Some(r) + } + } + } +} + impl Socket { pub fn new(addr: &SocketAddr, ty: c_int) -> io::Result { let fam = match *addr { @@ -362,11 +425,51 @@ impl Socket { } #[cfg(any(target_os = "android", target_os = "linux", target_os = "cygwin"))] - pub fn recv_msg(&self, msg: &mut libc::msghdr) -> io::Result { + pub fn recv_msg_(&self, msg: &mut libc::msghdr) -> io::Result { let n = cvt(unsafe { libc::recvmsg(self.as_raw_fd(), msg, libc::MSG_CMSG_CLOEXEC) })?; Ok(n as usize) } + pub fn recv_msg<'a, 'b, T>( + &self, + iov_buf: &mut [IoSliceMut<'_>], + cmsg_buf: CmsgBuf<'a>, + flags: c_int, + ) -> io::Result<(usize, T, c_int, CmsgIter<'b>)> + where + T: SockaddrLike, + 'a: 'b, + { + unsafe { + let mut msg: msghdr = zeroed(); + let mut addr: sockaddr_storage = zeroed(); + msg.msg_name = (&raw mut addr).cast(); + msg.msg_namelen = mem::size_of_val(&addr) as _; + + msg.msg_iovlen = iov_buf.len(); + msg.msg_iov = iov_buf.as_mut_ptr().cast(); + + msg.msg_controllen = cmsg_buf.0.len(); + if msg.msg_controllen != 0 { + msg.msg_control = cmsg_buf.0.as_mut_ptr().cast(); + } + + msg.msg_flags = 0; + + let bytes = cvt(libc::recvmsg(self.as_raw_fd(), &raw mut msg, flags))? as usize; + + let addr = SockaddrLike::from_storage(&addr, msg.msg_namelen)?; + + let mut iter: CmsgIter<'_> = zeroed(); + iter.inner.cmsg_buf = cmsg_buf.0; + iter.inner.cmsg_buf_len = msg.msg_controllen; + let fst_cmsg = CMSG_FIRSTHDR((&raw const iter).cast()); + iter.inner.curr_cmsg = fst_cmsg; + + Ok((bytes, addr, msg.msg_flags, iter)) + } + } + pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { self.recv_from_with_flags(buf, MSG_PEEK) } @@ -385,11 +488,63 @@ impl Socket { } #[cfg(any(target_os = "android", target_os = "linux", target_os = "cygwin"))] - pub fn send_msg(&self, msg: &mut libc::msghdr) -> io::Result { + pub fn send_msg_(&self, msg: &mut libc::msghdr) -> io::Result { let n = cvt(unsafe { libc::sendmsg(self.as_raw_fd(), msg, 0) })?; Ok(n as usize) } + pub fn send_msg( + &self, + addr: Option<&T>, + iov: &[IoSlice<'_>], + cmsgs: &[(c_int, c_int, &[u8])], + cmsg_buf: CmsgBuf<'_>, + flags: c_int, + ) -> io::Result + where + T: SockaddrLike, + { + unsafe { + let mut msg: msghdr = zeroed(); + let mut addr_s: sockaddr_storage = zeroed(); + + if let Some(addr_) = addr { + let len = addr_.to_storage(&mut addr_s); + msg.msg_namelen = len; + msg.msg_name = (&raw mut addr_s).cast(); + } + + msg.msg_iovlen = iov.len(); + msg.msg_iov = iov.as_ptr().cast::>() as *mut iovec; + + // cmsg + msg.msg_controllen = cmsg_buf.0.len(); + msg.msg_control = cmsg_buf.0.as_mut_ptr().cast(); + let mut curr_cmsg_hdr = CMSG_FIRSTHDR(&raw const msg); + for (cmsg_level, cmsg_type, cmsg_data) in cmsgs { + if curr_cmsg_hdr.is_null() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "cmsg_buf supplied is too small to hold all control messages", + )); + } + + (*curr_cmsg_hdr).cmsg_level = *cmsg_level; + (*curr_cmsg_hdr).cmsg_type = *cmsg_type; + (*curr_cmsg_hdr).cmsg_len = CMSG_LEN(cmsg_data.len() as c_uint) as usize; + + let cmsg_data_ptr = CMSG_DATA(curr_cmsg_hdr); + copy_nonoverlapping((*cmsg_data).as_ptr(), cmsg_data_ptr, cmsg_data.len()); + + curr_cmsg_hdr = CMSG_NXTHDR(&raw const msg, curr_cmsg_hdr as *const _); + } + + let bytes = cvt(libc::sendmsg(self.as_raw_fd(), &raw mut msg, flags))? as usize; + + Ok(bytes) + } + } + pub fn set_timeout(&self, dur: Option, kind: libc::c_int) -> io::Result<()> { let timeout = match dur { Some(dur) => {