Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions library/std/src/os/unix/net/addr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::os::unix::ffi::OsStrExt;
use crate::path::Path;
use crate::sealed::Sealed;
use crate::sys::cvt;
use crate::sys::net::SockaddrLike;
use crate::{fmt, io, mem, ptr};

// FIXME(#43348): Make libc adapt #[doc(cfg(...))] so we don't need these fake definitions here?
Expand Down Expand Up @@ -253,6 +254,27 @@ impl SocketAddr {
}
}

impl SockaddrLike for SocketAddr {
unsafe fn from_storage(
storage: &libc::sockaddr_storage,
len: libc::socklen_t,
) -> io::Result<Self> {
let p = (storage as *const libc::sockaddr_storage).cast();
SocketAddr::from_parts(*p, len)
}

fn to_storage(&self, storage_ret: &mut libc::sockaddr_storage) -> libc::socklen_t {
unsafe {
crate::ptr::copy_nonoverlapping(
&raw const self.addr,
(storage_ret as *mut libc::sockaddr_storage).cast(),
self.len as _,
);
self.len
}
}
}

#[stable(feature = "unix_socket_abstract", since = "1.70.0")]
impl Sealed for SocketAddr {}

Expand Down
4 changes: 2 additions & 2 deletions library/std/src/os/unix/net/ancillary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ pub(super) fn recv_vectored_with_ancillary_from(
msg.msg_control = ancillary.buffer.as_mut_ptr().cast();
}

let count = socket.recv_msg(&mut msg)?;
let count = socket.recv_msg_(&mut msg)?;

ancillary.length = msg.msg_controllen as usize;
ancillary.truncated = msg.msg_flags & libc::MSG_CTRUNC == libc::MSG_CTRUNC;
Expand Down Expand Up @@ -83,7 +83,7 @@ pub(super) fn send_vectored_with_ancillary_to(

ancillary.truncated = false;

socket.send_msg(&mut msg)
socket.send_msg_(&mut msg)
}
}

Expand Down
40 changes: 40 additions & 0 deletions library/std/src/sys/net/connection/socket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,46 @@ unsafe fn socket_addr_from_c(
}
}

// Structs that have sockaddr header and can be marshalled to and from sockaddr_storage
pub(crate) trait SockaddrLike: Sized {
// used in recvmsg to parse the received addr
unsafe fn from_storage(storage: &c::sockaddr_storage, len: c::socklen_t) -> io::Result<Self>;

// used in sendmsg to write to a suckaddr_storage buffer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo in suckaddr

fn to_storage(&self, storage_ret: &mut c::sockaddr_storage) -> c::socklen_t;
}

impl SockaddrLike for SocketAddr {
unsafe fn from_storage(storage: &c::sockaddr_storage, len: c::socklen_t) -> io::Result<Self> {
socket_addr_from_c(storage as *const _, len as _)
}

fn to_storage(&self, storage_ret: &mut c::sockaddr_storage) -> c::socklen_t {
let (crep, len) = socket_addr_to_c(self);
unsafe {
crate::ptr::copy_nonoverlapping(
&raw const crep,
(storage_ret as *mut c::sockaddr_storage).cast(),
len as _,
);
}
len as _
}
}

impl SockaddrLike for () {
unsafe fn from_storage(
_storage: &libc::sockaddr_storage,
_len: libc::socklen_t,
) -> io::Result<Self> {
Ok(())
}

fn to_storage(&self, _storage_ret: &mut libc::sockaddr_storage) -> libc::socklen_t {
0
}
}

////////////////////////////////////////////////////////////////////////////////
// sockaddr and misc bindings
////////////////////////////////////////////////////////////////////////////////
Expand Down
163 changes: 159 additions & 4 deletions library/std/src/sys/net/connection/socket/unix.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
use libc::{MSG_PEEK, c_int, c_void, size_t, sockaddr, socklen_t};
use libc::{
CMSG_DATA, CMSG_FIRSTHDR, CMSG_LEN, CMSG_NXTHDR, MSG_PEEK, c_int, c_uint, c_void, cmsghdr,
iovec, msghdr, size_t, sockaddr, sockaddr_storage, socklen_t,
};

#[cfg(not(any(target_os = "espidf", target_os = "nuttx")))]
use crate::ffi::CStr;
use crate::io::{self, BorrowedBuf, BorrowedCursor, IoSlice, IoSliceMut};
use crate::mem::zeroed;
use crate::net::{Shutdown, SocketAddr};
use crate::os::unix::io::{AsFd, AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, RawFd};
use crate::ptr::copy_nonoverlapping;
use crate::sys::fd::FileDesc;
use crate::sys::net::{getsockopt, setsockopt};
use crate::sys::net::{SockaddrLike, getsockopt, setsockopt};
use crate::sys::pal::IsMinusOne;
use crate::sys_common::{AsInner, FromInner, IntoInner};
use crate::time::{Duration, Instant};
Expand Down Expand Up @@ -62,6 +67,64 @@ pub fn cvt_gai(err: c_int) -> io::Result<()> {
))
}

#[repr(C)]
pub union CmsgIter<'buf> {
_align: msghdr,
inner: CmsgIterInner<'buf>,
}

#[derive(Clone, Copy)]
#[repr(C)]
struct CmsgIterInner<'buf> {
_padding: [u8; size_of::<usize>() + size_of::<socklen_t>() + size_of::<size_t>()],
curr_cmsg: *mut cmsghdr,
cmsg_buf: &'buf [u8],
cmsg_buf_len: usize,
}

#[repr(transparent)]
pub struct CmsgBuf<'buf>(&'buf mut [u8]);

impl<'buf> CmsgBuf<'buf> {
// fails if buf isn't aligned to alignof(cmsghdr)
pub fn new(buf: &'buf mut [u8]) -> io::Result<Self> {
if buf.as_ptr().align_offset(align_of::<cmsghdr>()) == 0 {
Ok(CmsgBuf(buf))
} else {
Err(io::Error::new(io::ErrorKind::InvalidInput, "unaligned buffer"))
}
}

pub unsafe fn new_unchecked(buf: &'buf mut [u8]) -> Self {
CmsgBuf(buf)
}
}

impl<'buf> Iterator for CmsgIter<'buf> {
type Item = (size_t, c_int, c_int, &'buf [u8]);

fn next(&mut self) -> Option<Self::Item> {
unsafe {
if self.inner.curr_cmsg.is_null() {
None
} else {
let curr = *self.inner.curr_cmsg;
let data_ptr = CMSG_DATA(self.inner.curr_cmsg);
let ptrdiff = data_ptr.offset_from_unsigned(self.inner.curr_cmsg as *const u8);
let r = (
curr.cmsg_len,
curr.cmsg_level,
curr.cmsg_type,
crate::slice::from_raw_parts(data_ptr, curr.cmsg_len - ptrdiff),
);
self.inner.curr_cmsg =
CMSG_NXTHDR(self as *mut _ as *mut msghdr, self.inner.curr_cmsg);
Some(r)
}
}
}
}

impl Socket {
pub fn new(addr: &SocketAddr, ty: c_int) -> io::Result<Socket> {
let fam = match *addr {
Expand Down Expand Up @@ -362,11 +425,51 @@ impl Socket {
}

#[cfg(any(target_os = "android", target_os = "linux", target_os = "cygwin"))]
pub fn recv_msg(&self, msg: &mut libc::msghdr) -> io::Result<usize> {
pub fn recv_msg_(&self, msg: &mut libc::msghdr) -> io::Result<usize> {
let n = cvt(unsafe { libc::recvmsg(self.as_raw_fd(), msg, libc::MSG_CMSG_CLOEXEC) })?;
Ok(n as usize)
}

pub fn recv_msg<'a, 'b, T>(
&self,
iov_buf: &mut [IoSliceMut<'_>],
cmsg_buf: CmsgBuf<'a>,
flags: c_int,
) -> io::Result<(usize, T, c_int, CmsgIter<'b>)>
where
T: SockaddrLike,
'a: 'b,
{
unsafe {
let mut msg: msghdr = zeroed();
let mut addr: sockaddr_storage = zeroed();
msg.msg_name = (&raw mut addr).cast();
msg.msg_namelen = mem::size_of_val(&addr) as _;

msg.msg_iovlen = iov_buf.len();
msg.msg_iov = iov_buf.as_mut_ptr().cast();

msg.msg_controllen = cmsg_buf.0.len();
if msg.msg_controllen != 0 {
msg.msg_control = cmsg_buf.0.as_mut_ptr().cast();
}

msg.msg_flags = 0;

let bytes = cvt(libc::recvmsg(self.as_raw_fd(), &raw mut msg, flags))? as usize;

let addr = SockaddrLike::from_storage(&addr, msg.msg_namelen)?;

let mut iter: CmsgIter<'_> = zeroed();
iter.inner.cmsg_buf = cmsg_buf.0;
iter.inner.cmsg_buf_len = msg.msg_controllen;
let fst_cmsg = CMSG_FIRSTHDR((&raw const iter).cast());
iter.inner.curr_cmsg = fst_cmsg;

Ok((bytes, addr, msg.msg_flags, iter))
}
}

pub fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.recv_from_with_flags(buf, MSG_PEEK)
}
Expand All @@ -385,11 +488,63 @@ impl Socket {
}

#[cfg(any(target_os = "android", target_os = "linux", target_os = "cygwin"))]
pub fn send_msg(&self, msg: &mut libc::msghdr) -> io::Result<usize> {
pub fn send_msg_(&self, msg: &mut libc::msghdr) -> io::Result<usize> {
let n = cvt(unsafe { libc::sendmsg(self.as_raw_fd(), msg, 0) })?;
Ok(n as usize)
}

pub fn send_msg<T>(
&self,
addr: Option<&T>,
iov: &[IoSlice<'_>],
cmsgs: &[(c_int, c_int, &[u8])],
cmsg_buf: CmsgBuf<'_>,
flags: c_int,
) -> io::Result<usize>
where
T: SockaddrLike,
{
unsafe {
let mut msg: msghdr = zeroed();
let mut addr_s: sockaddr_storage = zeroed();

if let Some(addr_) = addr {
let len = addr_.to_storage(&mut addr_s);
msg.msg_namelen = len;
msg.msg_name = (&raw mut addr_s).cast();
}

msg.msg_iovlen = iov.len();
msg.msg_iov = iov.as_ptr().cast::<IoSlice<'_>>() as *mut iovec;

// cmsg
msg.msg_controllen = cmsg_buf.0.len();
msg.msg_control = cmsg_buf.0.as_mut_ptr().cast();
let mut curr_cmsg_hdr = CMSG_FIRSTHDR(&raw const msg);
for (cmsg_level, cmsg_type, cmsg_data) in cmsgs {
if curr_cmsg_hdr.is_null() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"cmsg_buf supplied is too small to hold all control messages",
));
}

(*curr_cmsg_hdr).cmsg_level = *cmsg_level;
(*curr_cmsg_hdr).cmsg_type = *cmsg_type;
(*curr_cmsg_hdr).cmsg_len = CMSG_LEN(cmsg_data.len() as c_uint) as usize;

let cmsg_data_ptr = CMSG_DATA(curr_cmsg_hdr);
copy_nonoverlapping((*cmsg_data).as_ptr(), cmsg_data_ptr, cmsg_data.len());

curr_cmsg_hdr = CMSG_NXTHDR(&raw const msg, curr_cmsg_hdr as *const _);
}

let bytes = cvt(libc::sendmsg(self.as_raw_fd(), &raw mut msg, flags))? as usize;

Ok(bytes)
}
}

pub fn set_timeout(&self, dur: Option<Duration>, kind: libc::c_int) -> io::Result<()> {
let timeout = match dur {
Some(dur) => {
Expand Down
Loading