diff --git a/src/utils/src/lib.rs b/src/utils/src/lib.rs index c16dddce0d4..b0940a18c89 100644 --- a/src/utils/src/lib.rs +++ b/src/utils/src/lib.rs @@ -14,6 +14,7 @@ pub use vmm_sys_util::{ pub mod arg_parser; pub mod byte_order; pub mod net; +pub mod ring_buffer; pub mod signal; pub mod sm; pub mod time; diff --git a/src/utils/src/ring_buffer.rs b/src/utils/src/ring_buffer.rs new file mode 100644 index 00000000000..a7be015d30e --- /dev/null +++ b/src/utils/src/ring_buffer.rs @@ -0,0 +1,193 @@ +// Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +// +#[derive(Debug, Default, Clone)] +pub struct RingBuffer { + pub items: Box<[T]>, + pub start: usize, + pub len: usize, +} + +impl RingBuffer { + /// New with zero size + pub fn new() -> Self { + Self { + items: Box::new([]), + start: 0, + len: 0, + } + } + + /// New with specified size + pub fn new_with_size(size: usize) -> Self { + Self { + items: vec![T::default(); size].into_boxed_slice(), + start: 0, + len: 0, + } + } + + /// Get number of items in the buffer + pub fn len(&self) -> usize { + self.len + } + + /// Check if ring is empty + pub fn is_empty(&self) -> bool { + self.len == 0 + } + + /// Check if ring is full + pub fn is_full(&self) -> bool { + self.len == self.items.len() + } + + /// Push new item to the end of the ring and increases + /// the length. + /// If there is no space for it, nothing will happen. + pub fn push(&mut self, item: T) { + if !self.is_full() { + let index = (self.start + self.len) % self.items.len(); + self.items[index] = item; + self.len += 1; + } + } + + /// Return next item that will be written to and increases + /// the length. + /// If ring is full returns None. + pub fn next_available(&mut self) -> Option<&mut T> { + if self.is_full() { + None + } else { + let index = (self.start + self.len) % self.items.len(); + self.len += 1; + Some(&mut self.items[index]) + } + } + + /// Pop item from the from of the ring. + /// If ring is empty returns None. + pub fn pop_front(&mut self) -> Option<&mut T> { + if self.is_empty() { + None + } else { + let index = self.start; + self.start += 1; + self.start %= self.items.len(); + self.len -= 1; + Some(&mut self.items[index]) + } + } +} + +#[cfg(test)] +pub mod tests { + use super::*; + + #[test] + fn test_new() { + let a = RingBuffer::::new(); + assert_eq!(a.items.len(), 0); + assert_eq!(a.start, 0); + assert_eq!(a.len, 0); + assert!(a.is_empty()); + assert!(a.is_full()); + + let a = RingBuffer::::new_with_size(69); + assert_eq!(a.items.len(), 69); + assert_eq!(a.start, 0); + assert_eq!(a.len, 0); + assert!(a.is_empty()); + assert!(!a.is_full()); + } + + #[test] + fn test_push() { + let mut a = RingBuffer::::new_with_size(4); + + a.push(0); + assert!(!a.is_empty()); + assert!(!a.is_full()); + + a.push(1); + assert!(!a.is_empty()); + assert!(!a.is_full()); + + a.push(2); + assert!(!a.is_empty()); + assert!(!a.is_full()); + + a.push(3); + assert!(!a.is_empty()); + assert!(a.is_full()); + + assert_eq!(a.items.as_ref(), &[0, 1, 2, 3]); + + a.push(4); + assert!(!a.is_empty()); + assert!(a.is_full()); + + assert_eq!(a.items.as_ref(), &[0, 1, 2, 3]); + } + + #[test] + fn test_next_available() { + let mut a = RingBuffer::::new_with_size(4); + assert!(a.is_empty()); + assert!(!a.is_full()); + + *a.next_available().unwrap() = 0; + assert!(!a.is_empty()); + assert!(!a.is_full()); + + *a.next_available().unwrap() = 1; + assert!(!a.is_empty()); + assert!(!a.is_full()); + + *a.next_available().unwrap() = 2; + assert!(!a.is_empty()); + assert!(!a.is_full()); + + *a.next_available().unwrap() = 3; + assert!(!a.is_empty()); + assert!(a.is_full()); + + assert_eq!(a.items.as_ref(), &[0, 1, 2, 3]); + + assert!(a.next_available().is_none()); + + assert_eq!(a.items.as_ref(), &[0, 1, 2, 3]); + } + + #[test] + fn test_pop_front() { + let mut a = RingBuffer::::new_with_size(4); + a.push(0); + a.push(1); + a.push(2); + a.push(3); + assert!(!a.is_empty()); + assert!(a.is_full()); + + assert_eq!(*a.pop_front().unwrap(), 0); + assert!(!a.is_empty()); + assert!(!a.is_full()); + + assert_eq!(*a.pop_front().unwrap(), 1); + assert!(!a.is_empty()); + assert!(!a.is_full()); + + assert_eq!(*a.pop_front().unwrap(), 2); + assert!(!a.is_empty()); + assert!(!a.is_full()); + + assert_eq!(*a.pop_front().unwrap(), 3); + assert!(a.is_empty()); + assert!(!a.is_full()); + + assert!(a.pop_front().is_none()); + assert!(a.is_empty()); + assert!(!a.is_full()); + } +} diff --git a/src/vmm/src/builder.rs b/src/vmm/src/builder.rs index 27d2ce72dad..bb903be9d19 100644 --- a/src/vmm/src/builder.rs +++ b/src/vmm/src/builder.rs @@ -1225,7 +1225,16 @@ pub mod tests { MmdsNetworkStack::default_ipv4_addr(), Arc::new(Mutex::new(mmds)), ); - + // Minimal setup for queues to be considered `valid` + for q in net.lock().unwrap().queues.iter_mut() { + q.ready = true; + q.size = 1; + // Need to explicitly set these addresses otherwise the aarch64 + // will error out as it's memory does not start at 0. + q.desc_table = GuestAddress(crate::arch::SYSTEM_MEM_START); + q.avail_ring = GuestAddress(crate::arch::SYSTEM_MEM_START); + q.used_ring = GuestAddress(crate::arch::SYSTEM_MEM_START); + } attach_net_devices(vmm, cmdline, net_builder.iter(), event_manager).unwrap(); } diff --git a/src/vmm/src/devices/virtio/iovec.rs b/src/vmm/src/devices/virtio/iovec.rs index fd48a94ca2c..9a05b854e41 100644 --- a/src/vmm/src/devices/virtio/iovec.rs +++ b/src/vmm/src/devices/virtio/iovec.rs @@ -97,10 +97,8 @@ impl IoVecBuffer { /// /// The descriptor chain cannot be referencing the same memory location as another chain pub unsafe fn from_descriptor_chain(head: DescriptorChain) -> Result { - let mut new_buffer: Self = Default::default(); - + let mut new_buffer = Self::default(); new_buffer.load_descriptor_chain(head)?; - Ok(new_buffer) } @@ -217,21 +215,31 @@ impl IoVecBuffer { /// It describes a write-only buffer passed to us by the guest that is scattered across multiple /// memory regions. Additionally, this wrapper provides methods that allow reading arbitrary ranges /// of data from that buffer. -#[derive(Debug)] +#[derive(Debug, Default, Clone)] pub struct IoVecBufferMut { + // Index of the head desciptor + pub head_index: u16, // container of the memory regions included in this IO vector - vecs: IoVecVec, + pub vecs: IoVecVec, // Total length of the IoVecBufferMut - len: u32, + pub len: u32, } impl IoVecBufferMut { - /// Create an `IoVecBufferMut` from a `DescriptorChain` - pub fn from_descriptor_chain(head: DescriptorChain) -> Result { - let mut vecs = IoVecVec::new(); - let mut len = 0u32; + /// Create an `IoVecBuffer` from a `DescriptorChain` + /// + /// # Safety + /// + /// The descriptor chain cannot be referencing the same memory location as another chain + pub unsafe fn load_descriptor_chain( + &mut self, + head: DescriptorChain, + ) -> Result<(), IoVecError> { + self.clear(); + self.head_index = head.index; - for desc in head { + let mut next_descriptor = Some(head); + while let Some(desc) = next_descriptor { if !desc.is_write_only() { return Err(IoVecError::ReadOnlyDescriptor); } @@ -247,16 +255,47 @@ impl IoVecBufferMut { slice.bitmap().mark_dirty(0, desc.len as usize); let iov_base = slice.ptr_guard_mut().as_ptr().cast::(); - vecs.push(iovec { + self.vecs.push(iovec { iov_base, iov_len: desc.len as size_t, }); - len = len + self.len = self + .len .checked_add(desc.len) .ok_or(IoVecError::OverflowedDescriptor)?; + + next_descriptor = desc.next_descriptor(); } - Ok(Self { vecs, len }) + Ok(()) + } + + /// Create an `IoVecBuffer` from a `DescriptorChain` + /// + /// # Safety + /// + /// The descriptor chain cannot be referencing the same memory location as another chain + pub unsafe fn from_descriptor_chain(head: DescriptorChain) -> Result { + let mut new_buffer = Self::default(); + new_buffer.load_descriptor_chain(head)?; + Ok(new_buffer) + } + + /// Get the index of the haed descriptor from which this IoVecBuffer + /// was built. + pub fn head_index(&self) -> u16 { + self.head_index + } + + /// Get the host pointer to the first buffer in the guest, + /// this buffer points to. + /// + /// # Safety + /// + /// It is assumed that IoVecBuffer will never have 0 elements + /// as it is build from at DescriptorChain with length of at least 1. + pub fn start_address(&self) -> *mut libc::c_void { + self.vecs[0].iov_base } /// Get the total length of the memory regions covered by this `IoVecBuffer` @@ -264,6 +303,12 @@ impl IoVecBufferMut { self.len } + /// Clears the `iovec` array + pub fn clear(&mut self) { + self.vecs.clear(); + self.len = 0u32; + } + /// Writes a number of bytes into the `IoVecBufferMut` starting at a given offset. /// /// This will try to fill `IoVecBufferMut` writing bytes from the `buf` starting from @@ -397,6 +442,7 @@ mod tests { impl From<&mut [u8]> for IoVecBufferMut { fn from(buf: &mut [u8]) -> Self { Self { + head_index: 0, vecs: vec![iovec { iov_base: buf.as_mut_ptr().cast::(), iov_len: buf.len(), @@ -468,11 +514,13 @@ mod tests { let (mut q, _) = read_only_chain(&mem); let head = q.pop(&mem).unwrap(); - IoVecBufferMut::from_descriptor_chain(head).unwrap_err(); + // SAFETY: This descriptor chain is only loaded into one buffer + unsafe { IoVecBufferMut::from_descriptor_chain(head).unwrap_err() }; let (mut q, _) = write_only_chain(&mem); let head = q.pop(&mem).unwrap(); - IoVecBufferMut::from_descriptor_chain(head).unwrap(); + // SAFETY: This descriptor chain is only loaded into one buffer + unsafe { IoVecBufferMut::from_descriptor_chain(head).unwrap() }; } #[test] @@ -493,7 +541,7 @@ mod tests { let head = q.pop(&mem).unwrap(); // SAFETY: This descriptor chain is only loaded once in this test - let iovec = IoVecBufferMut::from_descriptor_chain(head).unwrap(); + let iovec = unsafe { IoVecBufferMut::from_descriptor_chain(head).unwrap() }; assert_eq!(iovec.len(), 4 * 64); } @@ -558,7 +606,8 @@ mod tests { // This is a descriptor chain with 4 elements 64 bytes long each. let head = q.pop(&mem).unwrap(); - let mut iovec = IoVecBufferMut::from_descriptor_chain(head).unwrap(); + // SAFETY: This descriptor chain is only loaded into one buffer + let mut iovec = unsafe { IoVecBufferMut::from_descriptor_chain(head).unwrap() }; let buf = vec![0u8, 1, 2, 3, 4]; // One test vector for each part of the chain @@ -705,7 +754,12 @@ mod verification { }; let (vecs, len) = create_iovecs(mem, GUEST_MEMORY_SIZE); - Self { vecs, len } + let head_index = kani::any(); + Self { + head_index, + vecs, + len, + } } } diff --git a/src/vmm/src/devices/virtio/net/device.rs b/src/vmm/src/devices/virtio/net/device.rs index e34676b2c31..7b0801e3d68 100755 --- a/src/vmm/src/devices/virtio/net/device.rs +++ b/src/vmm/src/devices/virtio/net/device.rs @@ -9,30 +9,32 @@ use std::io::Read; use std::mem; use std::net::Ipv4Addr; +use std::num::Wrapping; use std::sync::{Arc, Mutex}; use libc::EAGAIN; -use log::{error, warn}; +use log::error; use utils::eventfd::EventFd; use utils::net::mac::MacAddr; -use utils::u64_to_usize; -use vm_memory::GuestMemoryError; +use utils::ring_buffer::RingBuffer; +use utils::{u64_to_usize, usize_to_u64}; +use vm_memory::{GuestAddress, GuestMemory, VolatileMemoryError}; use crate::devices::virtio::device::{DeviceState, IrqTrigger, IrqType, VirtioDevice}; use crate::devices::virtio::gen::virtio_blk::VIRTIO_F_VERSION_1; use crate::devices::virtio::gen::virtio_net::{ virtio_net_hdr_v1, VIRTIO_NET_F_CSUM, VIRTIO_NET_F_GUEST_CSUM, VIRTIO_NET_F_GUEST_TSO4, VIRTIO_NET_F_GUEST_TSO6, VIRTIO_NET_F_GUEST_UFO, VIRTIO_NET_F_HOST_TSO4, - VIRTIO_NET_F_HOST_TSO6, VIRTIO_NET_F_HOST_UFO, VIRTIO_NET_F_MAC, + VIRTIO_NET_F_HOST_TSO6, VIRTIO_NET_F_HOST_UFO, VIRTIO_NET_F_MAC, VIRTIO_NET_F_MRG_RXBUF, }; use crate::devices::virtio::gen::virtio_ring::VIRTIO_RING_F_EVENT_IDX; -use crate::devices::virtio::iovec::IoVecBuffer; +use crate::devices::virtio::iovec::{IoVecBuffer, IoVecBufferMut}; use crate::devices::virtio::net::metrics::{NetDeviceMetrics, NetMetricsPerDevice}; use crate::devices::virtio::net::tap::Tap; use crate::devices::virtio::net::{ gen, NetError, NetQueue, MAX_BUFFER_SIZE, NET_QUEUE_SIZES, RX_INDEX, TX_INDEX, }; -use crate::devices::virtio::queue::{DescriptorChain, Queue}; +use crate::devices::virtio::queue::{Queue, UsedElement}; use crate::devices::virtio::{ActivateError, TYPE_NET}; use crate::devices::{report_net_event_fail, DeviceError}; use crate::dumbo::pdu::arp::ETH_IPV4_FRAME_LEN; @@ -41,22 +43,22 @@ use crate::logger::{IncMetric, METRICS}; use crate::mmds::data_store::Mmds; use crate::mmds::ns::MmdsNetworkStack; use crate::rate_limiter::{BucketUpdate, RateLimiter, TokenType}; -use crate::vstate::memory::{ByteValued, Bytes, GuestMemoryMmap}; +use crate::vstate::memory::{ByteValued, GuestMemoryMmap}; const FRAME_HEADER_MAX_LEN: usize = PAYLOAD_OFFSET + ETH_IPV4_FRAME_LEN; #[derive(Debug, thiserror::Error, displaydoc::Display)] enum FrontendError { - /// Add user. - AddUsed, - /// Descriptor chain too mall. - DescriptorChainTooSmall, /// Empty queue. EmptyQueue, - /// Guest memory error: {0} - GuestMemory(GuestMemoryError), - /// Read only descriptor. - ReadOnlyDescriptor, + /// Attempt to write an empty packet. + AttemptToWriteEmptyPacket, + /// Attempt to use more descriptor chains(heads) than it is allowed. + MaxHeadsUsed, + /// Invalid descritor chain. + InvalidDescriptorChain, + /// Error during writing to the IoVecBuffer. + IoVecBufferWrite(VolatileMemoryError), } pub(crate) const fn vnet_hdr_len() -> usize { @@ -103,6 +105,153 @@ pub struct ConfigSpace { // SAFETY: `ConfigSpace` contains only PODs in `repr(C)` or `repr(transparent)`, without padding. unsafe impl ByteValued for ConfigSpace {} +// This struct contains information about partially +// written packet. +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct PartialWrite { + // Amount of bytes written so far. + pub bytes_written: usize, + // Amount of descriptor heads used for the packet. + pub used_heads: u16, + // Address of the first buffer used for the packet. + // This will be used to set number of descriptors heads used + // to store the whole packet. + pub packet_start_addr: *mut libc::c_void, +} + +// SAFETY: +// Value of this type is only used by one thread. +unsafe impl Send for PartialWrite {} + +#[derive(Debug, Default, Clone)] +pub struct AvailableDescriptors { + iov_ring: RingBuffer, + valid_ring: RingBuffer, +} + +impl AvailableDescriptors { + /// New with zero sized ring buffer. + pub fn new() -> Self { + Self { + iov_ring: RingBuffer::new(), + valid_ring: RingBuffer::new(), + } + } + + /// Check if there are no available items. + pub fn is_empty(&self) -> bool { + self.iov_ring.is_empty() + } + + // Get number of descriptors + pub fn len(&self) -> usize { + self.iov_ring.len() + } + + /// Update size of the ring buffer. + pub fn update_size(&mut self, size: usize) { + self.iov_ring = RingBuffer::new_with_size(size); + self.valid_ring = RingBuffer::new_with_size(size); + } + + /// Read new descriptor chains from the queue. + pub fn read_new_desc_chains(&mut self, queue: &mut Queue, mem: &GuestMemoryMmap) { + #[repr(C)] + struct AvailRing { + flags: u16, + idx: u16, + ring: [u16; 256], + used_event: u16, + } + #[repr(C)] + #[derive(Default, Clone, Copy)] + struct Descriptor { + addr: u64, + len: u32, + flags: u16, + next: u16, + } + + // SAFETY: + // avail_ring in the queue is a valid guest address + let avail_ring: &AvailRing = + unsafe { std::mem::transmute(mem.get_host_address(queue.avail_ring).unwrap()) }; + + // SAFETY: + // desc_table in the queue is a valid guest address + let desc_table: &[Descriptor; 256] = + unsafe { std::mem::transmute(mem.get_host_address(queue.desc_table).unwrap()) }; + + let avail_idx = queue.avail_idx(mem); + let actual_size = queue.actual_size(); + + while queue.next_avail.0 != avail_idx.0 { + let Some(next_iovec_buf) = self.iov_ring.next_available() else { + break; + }; + + let avail_index = queue.next_avail.0 % actual_size; + queue.next_avail += Wrapping(1); + + let desc_index = avail_ring.ring[avail_index as usize]; + let mut desc = &desc_table[desc_index as usize]; + + next_iovec_buf.clear(); + next_iovec_buf.head_index = desc_index; + + let iov = libc::iovec { + iov_base: mem.get_host_address(GuestAddress(desc.addr)).unwrap().cast(), + iov_len: desc.len as usize, + }; + next_iovec_buf.vecs.push(iov); + next_iovec_buf.len += desc.len; + self.valid_ring.push(true); + + while desc.flags & crate::devices::virtio::queue::VIRTQ_DESC_F_NEXT != 0 { + desc = &desc_table[desc.next as usize]; + let iov = libc::iovec { + iov_base: mem.get_host_address(GuestAddress(desc.addr)).unwrap().cast(), + iov_len: desc.len as usize, + }; + next_iovec_buf.vecs.push(iov); + next_iovec_buf.len += desc.len; + self.valid_ring.push(true); + } + } + + // for _ in 0..queue.len(mem) { + // let Some(next_iovec_buf) = self.iov_ring.next_available() else { + // break; + // }; + // + // match queue.do_pop_unchecked(mem) { + // Some(desc_chain) => { + // // SAFETY: + // // This descriptor chain is only processed once. + // let valid = unsafe { next_iovec_buf.load_descriptor_chain(desc_chain).is_ok() }; + // self.valid_ring.push(valid); + // } + // None => { + // self.valid_ring.push(false); + // } + // } + // } + } + + /// Pop first descriptor chain. + pub fn pop_desc_chain(&mut self) -> Option<(&mut IoVecBufferMut, bool)> { + let (Some(iov), Some(valid)) = (self.iov_ring.pop_front(), self.valid_ring.pop_front()) + else { + return None; + }; + Some((iov, *valid)) + } +} + +// SAFETY: +// Value of this type is only used by one thread. +unsafe impl Send for AvailableDescriptors {} + /// VirtIO network device. /// /// It emulates a network device able to exchange L2 frames between the guest @@ -127,6 +276,8 @@ pub struct Net { rx_bytes_read: usize, rx_frame_buf: [u8; MAX_BUFFER_SIZE], + pub rx_partial_write: Option, + pub rx_avail_desc: AvailableDescriptors, tx_frame_headers: [u8; frame_hdr_len()], @@ -163,6 +314,7 @@ impl Net { | 1 << VIRTIO_NET_F_HOST_TSO4 | 1 << VIRTIO_NET_F_HOST_TSO6 | 1 << VIRTIO_NET_F_HOST_UFO + | 1 << VIRTIO_NET_F_MRG_RXBUF | 1 << VIRTIO_F_VERSION_1 | 1 << VIRTIO_RING_F_EVENT_IDX; @@ -193,6 +345,8 @@ impl Net { rx_deferred_frame: false, rx_bytes_read: 0, rx_frame_buf: [0u8; MAX_BUFFER_SIZE], + rx_partial_write: None, + rx_avail_desc: AvailableDescriptors::new(), tx_frame_headers: [0u8; frame_hdr_len()], irq_trigger: IrqTrigger::new().map_err(NetError::EventFd)?, config_space, @@ -322,7 +476,17 @@ impl Net { } // Attempt frame delivery. - let success = self.write_frame_to_guest(); + let success = loop { + // We retry to write a frame if there were internal errors. + // Each new write will use new descriptor chains up to the + // point of consuming all available descriptors, if they are + // all bad. + match self.write_frame_to_guest() { + Ok(()) => break true, + Err(FrontendError::EmptyQueue) => break false, + _ => (), + }; + }; // Undo the tokens consumption if guest delivery failed. if !success { @@ -333,108 +497,182 @@ impl Net { success } - /// Write a slice in a descriptor chain - /// - /// # Errors - /// - /// Returns an error if the descriptor chain is too short or - /// an inappropriate (read only) descriptor is found in the chain - fn write_to_descriptor_chain( - mem: &GuestMemoryMmap, - data: &[u8], - head: DescriptorChain, - net_metrics: &NetDeviceMetrics, - ) -> Result<(), FrontendError> { - let mut chunk = data; - let mut next_descriptor = Some(head); + /// Write packet contained in the internal buffer into guest provided + /// descriptor chains. + fn write_frame_to_guest(&mut self) -> Result<(), FrontendError> { + // This is safe since we checked in the event handler that the device is activated. + let mem = self.device_state.mem().unwrap(); - while let Some(descriptor) = &next_descriptor { - if !descriptor.is_write_only() { - return Err(FrontendError::ReadOnlyDescriptor); + if self.rx_avail_desc.is_empty() && self.queues[RX_INDEX].is_empty(mem) { + self.metrics.no_rx_avail_buffer.inc(); + return Err(FrontendError::EmptyQueue); + } + + let (mut slice, mut packet_start_addr, mut used_heads) = + if let Some(pw) = &self.rx_partial_write { + ( + &self.rx_frame_buf[pw.bytes_written..self.rx_bytes_read], + Some(pw.packet_start_addr), + pw.used_heads, + ) + } else { + (&self.rx_frame_buf[..self.rx_bytes_read], None, 0) + }; + + let max_used_heads = if self.has_feature(u64::from(VIRTIO_NET_F_MRG_RXBUF)) { + // There is no real limit on how much heads we can use, but we will + // never use more than the queue has. + u16::MAX + } else { + // Without VIRTIO_NET_F_MRG_RXBUF only 1 head can be used for the packet. + 1 + }; + + let mut error = None; + while !slice.is_empty() && error.is_none() { + if used_heads == max_used_heads { + error = Some(FrontendError::MaxHeadsUsed); + break; } - let len = std::cmp::min(chunk.len(), descriptor.len as usize); - match mem.write_slice(&chunk[..len], descriptor.addr) { - Ok(()) => { - net_metrics.rx_count.inc(); - chunk = &chunk[len..]; + if self.rx_avail_desc.is_empty() { + self.rx_avail_desc + .read_new_desc_chains(&mut self.queues[RX_INDEX], mem); + } + let Some((iovec_buf, valid)) = self.rx_avail_desc.pop_desc_chain() else { + break; + }; + + let desc_len = if valid { + // If this is the first head of the packet, save it for later. + if packet_start_addr.is_none() { + packet_start_addr = Some(iovec_buf.start_address()); } - Err(err) => { - error!("Failed to write slice: {:?}", err); - if let GuestMemoryError::PartialBuffer { .. } = err { - net_metrics.rx_partial_writes.inc(); + + match iovec_buf.write_all_volatile_at(slice, 0) { + Ok(()) => { + let len = slice.len(); + slice = &[]; + len + } + Err(VolatileMemoryError::PartialBuffer { + expected: _, + completed, + }) => { + slice = &slice[completed..]; + completed + } + Err(e) => { + error = Some(FrontendError::IoVecBufferWrite(e)); + 0 } - return Err(FrontendError::GuestMemory(err)); } - } - - // If chunk is empty we are done here. - if chunk.is_empty() { - let len = data.len() as u64; - net_metrics.rx_bytes_count.add(len); - net_metrics.rx_packets_count.inc(); - return Ok(()); - } + } else { + error = Some(FrontendError::InvalidDescriptorChain); + 0 + }; - next_descriptor = descriptor.next_descriptor(); + // At this point descriptor chain was processed. + // We add it to the used_ring. + let next_used_index = self.queues[RX_INDEX].next_used + Wrapping(used_heads); + let used_element = UsedElement { + id: u32::from(iovec_buf.head_index()), + len: u32::try_from(desc_len).unwrap(), + }; + // SAFETY: + // This should never panic as we provide index in + // correct bounds. + self.queues[RX_INDEX] + .write_used_ring(mem, next_used_index.0, used_element) + .unwrap(); + + used_heads += 1; } - warn!("Receiving buffer is too small to hold frame of current size"); - Err(FrontendError::DescriptorChainTooSmall) - } + // The are 2 ways the packet_start_addr can be None: + // 1. the loop was never run because slice was initially empty. + // 2. the very first descriptor chain was invalid + // In second case the error will contain something, so + // we only care abount first case. + if packet_start_addr.is_none() && error.is_none() { + error = Some(FrontendError::AttemptToWriteEmptyPacket); + } - // Copies a single frame from `self.rx_frame_buf` into the guest. - fn do_write_frame_to_guest(&mut self) -> Result<(), FrontendError> { - // This is safe since we checked in the event handler that the device is activated. - let mem = self.device_state.mem().unwrap(); + let mut end_packet_processing = || { + // We only update queues internals when packet processing has + // finished. This is done to prevent giving information to the guest + // about descriptor heads used for partialy written packets. + // Otherwise guest will see that we used those descriptors and + // will try to process them. + self.queues[RX_INDEX].advance_used_ring(mem, used_heads); + let next_avail = self.queues[RX_INDEX].next_avail.0; + self.queues[RX_INDEX].set_used_ring_avail_event(next_avail, mem); + + // Clear partial write info if there was one + self.rx_partial_write = None; + }; - let queue = &mut self.queues[RX_INDEX]; - let head_descriptor = queue.pop_or_enable_notification(mem).ok_or_else(|| { - self.metrics.no_rx_avail_buffer.inc(); - FrontendError::EmptyQueue - })?; - let head_index = head_descriptor.index; + if let Some(err) = error { + // There was a error during writing. + end_packet_processing(); - let result = Self::write_to_descriptor_chain( - mem, - &self.rx_frame_buf[..self.rx_bytes_read], - head_descriptor, - &self.metrics, - ); - // Mark the descriptor chain as used. If an error occurred, skip the descriptor chain. - let used_len = if result.is_err() { self.metrics.rx_fails.inc(); - 0 - } else { - // Safe to unwrap because a frame must be smaller than 2^16 bytes. - u32::try_from(self.rx_bytes_read).unwrap() - }; - queue.add_used(mem, head_index, used_len).map_err(|err| { - error!("Failed to add available descriptor {}: {}", head_index, err); - FrontendError::AddUsed - })?; - result - } + // We used `used_heads` descriptors to process the packet. + // Because this is an error case, we discard those descriptors. + self.queues[RX_INDEX].discard_used(mem, used_heads); - // Copies a single frame from `self.rx_frame_buf` into the guest. In case of an error retries - // the operation if possible. Returns true if the operation was successfull. - fn write_frame_to_guest(&mut self) -> bool { - let max_iterations = self.queues[RX_INDEX].actual_size(); - for _ in 0..max_iterations { - match self.do_write_frame_to_guest() { - Ok(()) => return true, - Err(FrontendError::EmptyQueue) | Err(FrontendError::AddUsed) => { - return false; - } - Err(_) => { - // retry - continue; - } + Err(err) + } else if slice.is_empty() { + // Packet was fully written. + end_packet_processing(); + + self.metrics + .rx_bytes_count + .add(usize_to_u64(self.rx_bytes_read)); + self.metrics.rx_packets_count.inc(); + + // SAFETY: packet_start_addr cannot be None at + // this point. + let packet_start_addr = packet_start_addr.unwrap(); + + // Update number of descriptor heads used to store a packet. + // SAFETY: + // The packet_start_addr is valid address because it was + // obtained from a descriptor chain which was verified during + // its construction. + #[allow(clippy::transmute_ptr_to_ref)] + let header: &mut virtio_net_hdr_v1 = unsafe { + // SAFETY: + // The head address was checked during descriptor chain creation. + std::mem::transmute(packet_start_addr) + }; + header.num_buffers = used_heads; + + Ok(()) + } else { + // Packet could not be fully written to the guest + // Save necessary info to use it during next invocation. + self.metrics.rx_partial_writes.inc(); + + // SAFETY: packet_start_addr cannot be None at + // this point. + let packet_start_addr = packet_start_addr.unwrap(); + + if let Some(pw) = &mut self.rx_partial_write { + pw.bytes_written = self.rx_bytes_read - slice.len(); + pw.used_heads = used_heads; + } else { + let pw = PartialWrite { + bytes_written: self.rx_bytes_read - slice.len(), + used_heads, + packet_start_addr, + }; + self.rx_partial_write = Some(pw); } - } - false + Err(FrontendError::EmptyQueue) + } } // Tries to detour the frame to MMDS and if MMDS doesn't accept it, sends it on the host TAP. @@ -740,7 +978,15 @@ impl Net { // rate limiters present but with _very high_ allowed rate error!("Failed to get rx queue event: {:?}", err); self.metrics.event_fails.inc(); - } else if self.rx_rate_limiter.is_blocked() { + } + + // Guest tells us there are new descriptor chains, so + // we read them here. + let mem = self.device_state.mem().unwrap(); + self.rx_avail_desc + .read_new_desc_chains(&mut self.queues[RX_INDEX], mem); + + if self.rx_rate_limiter.is_blocked() { self.metrics.rx_rate_limiter_throttled.inc(); } else { // If the limiter is not blocked, resume the receiving of bytes. @@ -914,6 +1160,10 @@ impl VirtioDevice for Net { self.tap .set_offload(supported_flags) .map_err(super::super::ActivateError::TapSetOffload)?; + // Initially rx_avail_desc has size of 0. + // So we update it with correct size of the queue here. + self.rx_avail_desc + .update_size(self.queues[RX_INDEX].actual_size() as usize); if self.activate_evt.write(1).is_err() { self.metrics.activate_fails.inc(); @@ -938,6 +1188,7 @@ pub mod tests { use std::{io, mem, thread}; use utils::net::mac::{MacAddr, MAC_ADDR_LEN}; + use vm_memory::GuestAddress; use super::*; use crate::check_metric_after_block; @@ -948,11 +1199,12 @@ pub mod tests { }; use crate::devices::virtio::net::test_utils::test::TestHelper; use crate::devices::virtio::net::test_utils::{ - default_net, if_index, inject_tap_tx_frame, set_mac, NetEvent, NetQueue, ReadTapMock, - TapTrafficSimulator, WriteTapMock, + default_net, if_index, inject_tap_tx_frame, mock_frame_set_num_buffers, set_mac, NetEvent, + NetQueue, ReadTapMock, TapTrafficSimulator, WriteTapMock, }; use crate::devices::virtio::net::NET_QUEUE_SIZES; use crate::devices::virtio::queue::VIRTQ_DESC_F_WRITE; + use crate::devices::virtio::test_utils::{default_mem, VirtQueue}; use crate::dumbo::pdu::arp::{EthIPv4ArpFrame, ETH_IPV4_FRAME_LEN}; use crate::dumbo::pdu::ethernet::ETHERTYPE_ARP; use crate::dumbo::EthernetFrame; @@ -1035,6 +1287,7 @@ pub mod tests { | 1 << VIRTIO_NET_F_HOST_TSO4 | 1 << VIRTIO_NET_F_HOST_TSO6 | 1 << VIRTIO_NET_F_HOST_UFO + | 1 << VIRTIO_NET_F_MRG_RXBUF | 1 << VIRTIO_F_VERSION_1 | 1 << VIRTIO_RING_F_EVENT_IDX; @@ -1139,6 +1392,117 @@ pub mod tests { assert_eq!(new_config, new_config_read); } + #[test] + fn test_available_descriptors_new() { + let avail_desc = AvailableDescriptors::new(); + assert!(avail_desc.iov_ring.is_empty()); + assert_eq!(avail_desc.iov_ring.items.len(), 0); + assert!(avail_desc.valid_ring.is_empty()); + assert_eq!(avail_desc.valid_ring.items.len(), 0); + assert!(avail_desc.is_empty()); + } + + #[test] + fn test_available_descriptors_update_size() { + let mut avail_desc = AvailableDescriptors::new(); + avail_desc.update_size(69); + assert!(avail_desc.iov_ring.is_empty()); + assert_eq!(avail_desc.iov_ring.items.len(), 69); + assert!(avail_desc.valid_ring.is_empty()); + assert_eq!(avail_desc.valid_ring.items.len(), 69); + assert!(avail_desc.is_empty()); + } + + #[test] + fn test_available_descriptors_read_pop_desc() { + let m = &default_mem(); + let vq = VirtQueue::new(GuestAddress(0), m, 16); + vq.dtable[0].set(1, 1, 0x3, 1); + vq.dtable[1].set(2, 2, 0x3, 2); + vq.dtable[2].set(3, 3, 0x2, 0); + + vq.dtable[3].set(4, 4, 0x3, 4); + vq.dtable[4].set(5, 5, 0x3, 5); + vq.dtable[5].set(6, 6, 0x2, 0); + + // Invalid address + vq.dtable[6].set(7, 7, 0x3, 7); + vq.dtable[7].set(u64::MAX, 8, 0x3, 8); + vq.dtable[8].set(9, 9, 0x2, 0); + + // Not write only + vq.dtable[9].set(10, 10, 0x3, 10); + vq.dtable[10].set(11, 11, 0x0, 11); + vq.dtable[11].set(12, 12, 0x2, 0); + + vq.avail.ring[0].set(0); + vq.avail.ring[1].set(3); + vq.avail.ring[2].set(6); + vq.avail.ring[3].set(9); + + let mut rxq = vq.create_queue(); + + vq.avail.idx.set(4); + + // Test if too much descriptors are avalable + { + let mut avail_desc = AvailableDescriptors::new(); + avail_desc.update_size(1); + avail_desc.read_new_desc_chains(&mut rxq, m); + + assert_eq!(avail_desc.iov_ring.len, 1); + assert_eq!(avail_desc.valid_ring.len, 1); + + assert_eq!(avail_desc.iov_ring.start, 0); + assert_eq!(avail_desc.valid_ring.start, 0); + + let (iovec_buf_0, valid_0) = avail_desc.pop_desc_chain().unwrap(); + assert_eq!(iovec_buf_0.head_index(), 0); + assert_eq!(iovec_buf_0.len(), 6); + assert!(valid_0); + + assert!(avail_desc.is_empty()); + + assert!(avail_desc.pop_desc_chain().is_none()); + } + + // Reset queue + rxq.next_avail = Wrapping(0); + { + let mut avail_desc = AvailableDescriptors::new(); + avail_desc.update_size(4); + avail_desc.read_new_desc_chains(&mut rxq, m); + + assert_eq!(avail_desc.iov_ring.len, 4); + assert_eq!(avail_desc.valid_ring.len, 4); + + assert_eq!(avail_desc.iov_ring.start, 0); + assert_eq!(avail_desc.valid_ring.start, 0); + + let (iovec_buf_0, valid_0) = avail_desc.pop_desc_chain().unwrap(); + assert_eq!(iovec_buf_0.head_index(), 0); + assert_eq!(iovec_buf_0.len(), 6); + assert!(valid_0); + + let (iovec_buf_1, valid_1) = avail_desc.pop_desc_chain().unwrap(); + assert_eq!(iovec_buf_1.head_index(), 3); + assert_eq!(iovec_buf_1.len(), 15); + assert!(valid_1); + + let (iovec_buf_2, valid_2) = avail_desc.pop_desc_chain().unwrap(); + assert_eq!(iovec_buf_2.head_index(), 6); + assert!(!valid_2); + + let (iovec_buf_3, valid_3) = avail_desc.pop_desc_chain().unwrap(); + assert_eq!(iovec_buf_3.head_index(), 9); + assert!(!valid_3); + + assert!(avail_desc.is_empty()); + + assert!(avail_desc.pop_desc_chain().is_none()); + } + } + #[test] fn test_rx_missing_queue_signal() { let mut th = TestHelper::get_default(); @@ -1156,9 +1520,7 @@ pub mod tests { assert_eq!(th.rxq.used.idx.get(), 0); } - #[test] - fn test_rx_read_only_descriptor() { - let mut th = TestHelper::get_default(); + fn rx_read_only_descriptor(mut th: TestHelper) { th.activate_net(); th.add_desc_chain( @@ -1176,6 +1538,20 @@ pub mod tests { th.check_rx_queue_resume(&frame); } + #[test] + fn test_rx_read_only_descriptor() { + let th = TestHelper::get_default(); + rx_read_only_descriptor(th); + } + + #[test] + fn test_rx_mrg_buf_read_only_descriptor() { + let mut th = TestHelper::get_default(); + // VIRTIO_NET_F_MRG_RXBUF is not enabled by default + th.net().acked_features = 1 << VIRTIO_NET_F_MRG_RXBUF; + rx_read_only_descriptor(th); + } + #[test] fn test_rx_short_writable_descriptor() { let mut th = TestHelper::get_default(); @@ -1189,9 +1565,65 @@ pub mod tests { } #[test] - fn test_rx_partial_write() { + fn test_rx_mrg_buf_short_writable_descriptor() { + let mut th = TestHelper::get_default(); + th.activate_net(); + th.net().tap.mocks.set_read_tap(ReadTapMock::TapFrame); + + // VIRTIO_NET_F_MRG_RXBUF is not enabled by default + th.net().acked_features = 1 << VIRTIO_NET_F_MRG_RXBUF; + + th.add_desc_chain(NetQueue::Rx, 0, &[(0, 100, VIRTQ_DESC_F_WRITE)]); + + _ = inject_tap_tx_frame(&th.net(), 1000); + // For now only 1 descriptor chain is used, + // but the packet is not fully written yet. + check_metric_after_block!( + th.net().metrics.rx_packets_count, + 0, + th.event_manager.run_with_timeout(100).unwrap() + ); + th.rxq.check_used_elem(0, 0, 100); + + // The write was converted to partial write + assert!(th.net().rx_partial_write.is_some()); + } + + #[test] + fn test_rx_mrg_buf_multiple_short_writable_descriptors() { let mut th = TestHelper::get_default(); th.activate_net(); + th.net().tap.mocks.set_read_tap(ReadTapMock::TapFrame); + + // VIRTIO_NET_F_MRG_RXBUF is not enabled by default + th.net().acked_features = 1 << VIRTIO_NET_F_MRG_RXBUF; + + th.add_desc_chain(NetQueue::Rx, 0, &[(0, 500, VIRTQ_DESC_F_WRITE)]); + th.add_desc_chain(NetQueue::Rx, 500, &[(1, 500, VIRTQ_DESC_F_WRITE)]); + + // There will be 2 heads used. + let mut frame = inject_tap_tx_frame(&th.net(), 1000); + mock_frame_set_num_buffers(&mut frame, 2); + + check_metric_after_block!( + th.net().metrics.rx_packets_count, + 1, + th.event_manager.run_with_timeout(100).unwrap() + ); + + assert_eq!(th.rxq.used.idx.get(), 2); + assert!(th.net().irq_trigger.has_pending_irq(IrqType::Vring)); + assert!(!th.net().rx_deferred_frame); + + th.rxq.check_used_elem(0, 0, 500); + th.rxq.check_used_elem(1, 1, 500); + + th.rxq.dtable[0].check_data(&frame[..500]); + th.rxq.dtable[1].check_data(&frame[500..]); + } + + fn rx_invalid_desc_chain(mut th: TestHelper) { + th.activate_net(); // The descriptor chain is created so that the last descriptor doesn't fit in the // guest memory. @@ -1212,11 +1644,85 @@ pub mod tests { } #[test] - fn test_rx_retry() { + fn test_rx_invalid_desc_chain() { + let th = TestHelper::get_default(); + rx_invalid_desc_chain(th); + } + + #[test] + fn test_rx_mrg_buf_invalid_desc_chain() { + let mut th = TestHelper::get_default(); + // VIRTIO_NET_F_MRG_RXBUF is not enabled by default + th.net().acked_features = 1 << VIRTIO_NET_F_MRG_RXBUF; + rx_invalid_desc_chain(th); + } + + #[test] + fn test_rx_mrg_buf_partial_write() { let mut th = TestHelper::get_default(); th.activate_net(); th.net().tap.mocks.set_read_tap(ReadTapMock::TapFrame); + // VIRTIO_NET_F_MRG_RXBUF is not enabled by default + th.net().acked_features = 1 << VIRTIO_NET_F_MRG_RXBUF; + + // Add descriptor that is not big enough to store the + // whole packet. + th.add_desc_chain(NetQueue::Rx, 0, &[(0, 250, VIRTQ_DESC_F_WRITE)]); + + // There will be 3 heads used. + let mut frame = inject_tap_tx_frame(&th.net(), 1000); + mock_frame_set_num_buffers(&mut frame, 3); + + // For now only 1 descriptor chain is used, + // but the packet is not fully written yet. + check_metric_after_block!( + th.net().metrics.rx_partial_writes, + 1, + th.event_manager.run_with_timeout(100).unwrap() + ); + th.rxq.check_used_elem(0, 0, 250); + + // The write was converted to partial write + assert!(th.net().rx_partial_write.is_some()); + + // Continue writing. + th.add_desc_chain(NetQueue::Rx, 250, &[(1, 250, VIRTQ_DESC_F_WRITE)]); + // Only 500 bytes of 1000 should be written now. + check_metric_after_block!( + th.net().metrics.rx_partial_writes, + 1, + th.event_manager.run_with_timeout(100).unwrap() + ); + th.rxq.check_used_elem(1, 1, 250); + + // The write is still a partial write + assert!(th.net().rx_partial_write.is_some()); + + // Finish writing. + th.add_desc_chain(NetQueue::Rx, 500, &[(2, 500, VIRTQ_DESC_F_WRITE)]); + check_metric_after_block!( + th.net().metrics.rx_packets_count, + 1, + th.event_manager.run_with_timeout(100).unwrap() + ); + assert!(th.net().rx_partial_write.is_none()); + assert_eq!(th.rxq.used.idx.get(), 3); + assert!(th.net().irq_trigger.has_pending_irq(IrqType::Vring)); + + th.rxq.check_used_elem(0, 0, 250); + th.rxq.check_used_elem(1, 1, 250); + th.rxq.check_used_elem(2, 2, 500); + + th.rxq.dtable[0].check_data(&frame[..250]); + th.rxq.dtable[1].check_data(&frame[250..500]); + th.rxq.dtable[2].check_data(&frame[500..]); + } + + fn rx_retry(mut th: TestHelper) { + th.activate_net(); + th.net().tap.mocks.set_read_tap(ReadTapMock::TapFrame); + // Add invalid descriptor chain - read only descriptor. th.add_desc_chain( NetQueue::Rx, @@ -1227,9 +1733,15 @@ pub mod tests { (2, 1000, VIRTQ_DESC_F_WRITE), ], ); - // Add invalid descriptor chain - too short. + + // Without VIRTIO_NET_F_MRG_RXBUF this descriptor is invalid as it is too short. + // With VIRTIO_NET_F_MRG_RXBUF this descriptor is valid, write will be converted into + // partial write. th.add_desc_chain(NetQueue::Rx, 1200, &[(3, 100, VIRTQ_DESC_F_WRITE)]); // Add invalid descriptor chain - invalid memory offset. + // Without VIRTIO_NET_F_MRG_RXBUF this descriptor is invalid. + // With VIRTIO_NET_F_MRG_RXBUF the partial write stated with previous descriptor should halt + // here. th.add_desc_chain( NetQueue::Rx, th.mem.last_addr().raw_value(), @@ -1263,8 +1775,20 @@ pub mod tests { } #[test] - fn test_rx_complex_desc_chain() { + fn test_rx_retry() { + let th = TestHelper::get_default(); + rx_retry(th); + } + + #[test] + fn test_rx_mrg_buf_retry() { let mut th = TestHelper::get_default(); + // VIRTIO_NET_F_MRG_RXBUF is not enabled by default + th.net().acked_features = 1 << VIRTIO_NET_F_MRG_RXBUF; + rx_retry(th); + } + + fn rx_complex_desc_chain(mut th: TestHelper) { th.activate_net(); th.net().tap.mocks.set_read_tap(ReadTapMock::TapFrame); @@ -1302,8 +1826,20 @@ pub mod tests { } #[test] - fn test_rx_multiple_frames() { + fn test_rx_complex_desc_chain() { + let th = TestHelper::get_default(); + rx_complex_desc_chain(th); + } + + #[test] + fn test_rx_mrg_buf_complex_desc_chain() { let mut th = TestHelper::get_default(); + // VIRTIO_NET_F_MRG_RXBUF is not enabled by default + th.net().acked_features = 1 << VIRTIO_NET_F_MRG_RXBUF; + rx_complex_desc_chain(th); + } + + fn rx_multiple_frames(mut th: TestHelper) { th.activate_net(); th.net().tap.mocks.set_read_tap(ReadTapMock::TapFrame); @@ -1345,6 +1881,20 @@ pub mod tests { th.rxq.dtable[3].check_data(&[0; 500]); } + #[test] + fn test_rx_multiple_frames() { + let th = TestHelper::get_default(); + rx_multiple_frames(th); + } + + #[test] + fn test_rx_mrg_buf_multiple_frames() { + let mut th = TestHelper::get_default(); + // VIRTIO_NET_F_MRG_RXBUF is not enabled by default + th.net().acked_features = 1 << VIRTIO_NET_F_MRG_RXBUF; + rx_multiple_frames(th); + } + #[test] fn test_tx_missing_queue_signal() { let mut th = TestHelper::get_default(); diff --git a/src/vmm/src/devices/virtio/net/persist.rs b/src/vmm/src/devices/virtio/net/persist.rs index 271977a4792..972e36c931c 100644 --- a/src/vmm/src/devices/virtio/net/persist.rs +++ b/src/vmm/src/devices/virtio/net/persist.rs @@ -11,7 +11,7 @@ use serde::{Deserialize, Serialize}; use utils::net::mac::MacAddr; use super::device::Net; -use super::NET_NUM_QUEUES; +use super::{NET_NUM_QUEUES, RX_INDEX}; use crate::devices::virtio::device::DeviceState; use crate::devices::virtio::persist::{PersistError as VirtioStateError, VirtioDeviceState}; use crate::devices::virtio::queue::FIRECRACKER_MAX_QUEUE_SIZE; @@ -37,6 +37,8 @@ pub struct NetConfigSpaceState { pub struct NetState { id: String, tap_if_name: String, + rx_avail_desc_len: u16, + rx_partial_write_used_heads: u16, rx_rate_limiter_state: RateLimiterState, tx_rate_limiter_state: RateLimiterState, /// The associated MMDS network stack. @@ -73,9 +75,15 @@ impl Persist<'_> for Net { type Error = NetPersistError; fn save(&self) -> Self::State { + let rx_partial_write_used_heads = + self.rx_partial_write.as_ref().map_or(0, |pw| pw.used_heads); + let rx_avail_desc_len = u16::try_from(self.rx_avail_desc.len()).unwrap(); + NetState { id: self.id().clone(), tap_if_name: self.iface_name(), + rx_avail_desc_len, + rx_partial_write_used_heads, rx_rate_limiter_state: self.rx_rate_limiter.save(), tx_rate_limiter_state: self.tx_rate_limiter.save(), mmds_ns: self.mmds_ns.as_ref().map(|mmds| mmds.save()), @@ -129,6 +137,25 @@ impl Persist<'_> for Net { net.acked_features = state.virtio_state.acked_features; if state.virtio_state.activated { + // Discard descriptors used for partial write. + // We do it here, because it requires modification of queues, but during + // `save` we are only given an immutable reference. + net.queues[RX_INDEX] + .advance_used_ring(&constructor_args.mem, state.rx_partial_write_used_heads); + let next_avail = net.queues[RX_INDEX].next_avail.0; + net.queues[RX_INDEX].set_used_ring_avail_event(next_avail, &constructor_args.mem); + net.queues[RX_INDEX] + .discard_used(&constructor_args.mem, state.rx_partial_write_used_heads); + + // Recreate `avail_desc`. We do it by temporarily + // rollback `next_avail` in the RX queue. The `next_avail` + // will be rolled forward in the `read_new_desc_chains` method. + net.rx_avail_desc + .update_size(net.queues[RX_INDEX].actual_size() as usize); + net.queues[RX_INDEX].next_avail -= state.rx_avail_desc_len; + net.rx_avail_desc + .read_new_desc_chains(&mut net.queues[RX_INDEX], &constructor_args.mem); + net.device_state = DeviceState::Activated(constructor_args.mem); } @@ -146,8 +173,11 @@ mod tests { use crate::devices::virtio::test_utils::default_mem; use crate::snapshot::Snapshot; - fn validate_save_and_restore(net: Net, mmds_ds: Option>>) { - let guest_mem = default_mem(); + fn validate_save_and_restore( + guest_mem: GuestMemoryMmap, + net: Net, + mmds_ds: Option>>, + ) { let mut mem = vec![0; 4096]; let id; @@ -155,6 +185,8 @@ mod tests { let has_mmds_ns; let allow_mmds_requests; let virtio_state; + let rx_partial_write_used_heads; + let rx_avail_desc_len; // Create and save the net device. { @@ -166,6 +198,8 @@ mod tests { has_mmds_ns = net.mmds_ns.is_some(); allow_mmds_requests = has_mmds_ns && mmds_ds.is_some(); virtio_state = VirtioDeviceState::from_device(&net); + rx_partial_write_used_heads = net.rx_partial_write; + rx_avail_desc_len = net.rx_avail_desc.len(); } // Drop the initial net device so that we don't get an error when trying to recreate the @@ -197,6 +231,8 @@ mod tests { assert_eq!(restored_net.mmds_ns.is_some(), allow_mmds_requests); assert_eq!(restored_net.rx_rate_limiter, RateLimiter::default()); assert_eq!(restored_net.tx_rate_limiter, RateLimiter::default()); + assert_eq!(restored_net.rx_partial_write, rx_partial_write_used_heads); + assert_eq!(restored_net.rx_avail_desc.len(), rx_avail_desc_len); } Err(NetPersistError::NoMmdsDataStore) => { assert!(has_mmds_ns && !allow_mmds_requests) @@ -208,17 +244,24 @@ mod tests { #[test] fn test_persistence() { + let guest_mem = default_mem(); + let mmds = Some(Arc::new(Mutex::new(Mmds::default()))); - validate_save_and_restore(default_net(), mmds.as_ref().cloned()); - validate_save_and_restore(default_net_no_mmds(), None); + validate_save_and_restore(guest_mem.clone(), default_net(), mmds.as_ref().cloned()); + validate_save_and_restore(guest_mem.clone(), default_net_no_mmds(), None); + + // Test activated device + let mut net = default_net_no_mmds(); + net.activate(guest_mem.clone()).unwrap(); + validate_save_and_restore(guest_mem.clone(), net, None); // Check what happens if the MMIODeviceManager gives us the reference to the MMDS // data store even if this device does not have mmds ns configured. // The restore should be conservative and not configure the mmds ns. - validate_save_and_restore(default_net_no_mmds(), mmds); + validate_save_and_restore(guest_mem.clone(), default_net_no_mmds(), mmds); // Check what happens if the MMIODeviceManager does not give us the reference to the MMDS // data store. This will return an error. - validate_save_and_restore(default_net(), None); + validate_save_and_restore(guest_mem, default_net(), None); } } diff --git a/src/vmm/src/devices/virtio/net/test_utils.rs b/src/vmm/src/devices/virtio/net/test_utils.rs index 216db273859..958bd141e56 100644 --- a/src/vmm/src/devices/virtio/net/test_utils.rs +++ b/src/vmm/src/devices/virtio/net/test_utils.rs @@ -15,6 +15,7 @@ use std::sync::{Arc, Mutex}; use utils::net::mac::MacAddr; +use crate::devices::virtio::gen::virtio_net::virtio_net_hdr_v1; #[cfg(test)] use crate::devices::virtio::net::device::vnet_hdr_len; use crate::devices::virtio::net::tap::{IfReqBuilder, Tap}; @@ -53,6 +54,11 @@ pub fn default_net() -> Net { MmdsNetworkStack::default_ipv4_addr(), Arc::new(Mutex::new(Mmds::default())), ); + // Minimal setup for queues to be considered `valid` + for q in net.queues.iter_mut() { + q.ready = true; + q.size = 1; + } enable(&net.tap); net @@ -64,7 +70,7 @@ pub fn default_net_no_mmds() -> Net { let guest_mac = default_guest_mac(); - let net = Net::new( + let mut net = Net::new( tap_device_id, "net-device%d", Some(guest_mac), @@ -72,11 +78,41 @@ pub fn default_net_no_mmds() -> Net { RateLimiter::default(), ) .unwrap(); + // Minimal setup for queues to be considered `valid` + for q in net.queues.iter_mut() { + q.ready = true; + q.size = 1; + } enable(&net.tap); net } +pub fn mock_frame(len: usize) -> Vec { + assert!(std::mem::size_of::() <= len); + let mut mock_frame = utils::rand::rand_alphanumerics(len).as_bytes().to_vec(); + // SAFETY: + // Frame is bigger than the header. + unsafe { + let hdr = &mut *mock_frame.as_mut_ptr().cast::(); + let zeroed = std::mem::zeroed::(); + *hdr = zeroed; + // We need to test num_buffers to 1 as the spec says. + hdr.num_buffers = 1; + } + mock_frame +} + +pub fn mock_frame_set_num_buffers(frame: &mut [u8], num_buffers: u16) { + assert!(std::mem::size_of::() <= frame.len()); + // SAFETY: + // Frame is bigger than the header. + unsafe { + let hdr = &mut *frame.as_mut_ptr().cast::(); + hdr.num_buffers = num_buffers; + } +} + #[derive(Debug)] pub enum ReadTapMock { Failure, @@ -119,9 +155,7 @@ impl Mocks { impl Default for Mocks { fn default() -> Mocks { Mocks { - read_tap: ReadTapMock::MockFrame( - utils::rand::rand_alphanumerics(1234).as_bytes().to_vec(), - ), + read_tap: ReadTapMock::MockFrame(mock_frame(1234)), write_tap: WriteTapMock::Success, } } @@ -300,11 +334,8 @@ pub fn enable(tap: &Tap) { pub(crate) fn inject_tap_tx_frame(net: &Net, len: usize) -> Vec { assert!(len >= vnet_hdr_len()); let tap_traffic_simulator = TapTrafficSimulator::new(if_index(&net.tap)); - let mut frame = utils::rand::rand_alphanumerics(len - vnet_hdr_len()) - .as_bytes() - .to_vec(); - tap_traffic_simulator.push_tx_packet(&frame); - frame.splice(0..0, vec![b'\0'; vnet_hdr_len()]); + let frame = mock_frame(len); + tap_traffic_simulator.push_tx_packet(&frame[vnet_hdr_len()..]); frame } diff --git a/src/vmm/src/devices/virtio/queue.rs b/src/vmm/src/devices/virtio/queue.rs index 0fd6882d201..94aac70416c 100644 --- a/src/vmm/src/devices/virtio/queue.rs +++ b/src/vmm/src/devices/virtio/queue.rs @@ -59,10 +59,10 @@ unsafe impl ByteValued for Descriptor {} /// https://docs.oasis-open.org/virtio/virtio/v1.1/csprd01/virtio-v1.1-csprd01.html#x1-430008 /// 2.6.8 The Virtqueue Used Ring #[repr(C)] -#[derive(Default, Clone, Copy)] -struct UsedElement { - id: u32, - len: u32, +#[derive(Default, Debug, Clone, Copy)] +pub struct UsedElement { + pub id: u32, + pub len: u32, } // SAFETY: `UsedElement` is a POD and contains no padding. @@ -113,17 +113,16 @@ impl<'a, M: GuestMemory> DescriptorChain<'a, M> { // bounds. let desc_head = desc_table.unchecked_add(u64::from(index) * 16); - // These reads can't fail unless Guest memory is hopelessly broken. - let desc = match mem.read_obj::(desc_head) { - Ok(ret) => ret, - Err(err) => { - error!( - "Failed to read virtio descriptor from memory at address {:#x}: {}", - desc_head.0, err - ); - return None; - } - }; + // SAFETY: + // This can't fail as we checked the `desc_head` + let ptr = mem.get_host_address(desc_head).unwrap(); + + // SAFETY: + // Safe as we know that `ptr` is inside guest memory and + // following `std::mem::size_of::` bytes belong + // to the descriptor table + let desc: &Descriptor = unsafe { &*ptr.cast::() }; + let chain = DescriptorChain { mem, desc_table, @@ -391,7 +390,7 @@ impl Queue { /// # Important /// This is an internal method that ASSUMES THAT THERE ARE AVAILABLE DESCRIPTORS. Otherwise it /// will retrieve a descriptor that contains garbage data (obsolete/empty). - fn do_pop_unchecked<'b, M: GuestMemory>( + pub fn do_pop_unchecked<'b, M: GuestMemory>( &mut self, mem: &'b M, ) -> Option> { @@ -402,33 +401,33 @@ impl Queue { // In a naive notation, that would be: // `descriptor_table[avail_ring[next_avail]]`. // - // First, we compute the byte-offset (into `self.avail_ring`) of the index of the next - // available descriptor. `self.avail_ring` stores the address of a `struct - // virtq_avail`, as defined by the VirtIO spec: - // - // ```C - // struct virtq_avail { - // le16 flags; - // le16 idx; - // le16 ring[QUEUE_SIZE]; - // le16 used_event + // Avail ring has layout: + // struct AvailRing { + // flags: u16, + // idx: u16, + // ring: [u16; ], + // used_event: u16, // } - // ``` - // - // We use `self.next_avail` to store the position, in `ring`, of the next available - // descriptor index, with a twist: we always only increment `self.next_avail`, so the - // actual position will be `self.next_avail % self.actual_size()`. - // We are now looking for the offset of `ring[self.next_avail % self.actual_size()]`. - // `ring` starts after `flags` and `idx` (4 bytes into `struct virtq_avail`), and holds - // 2-byte items, so the offset will be: - let index_offset = 4 + 2 * (self.next_avail.0 % self.actual_size()); + // We calculate offset into `ring` field. + // We use `self.next_avail` to store the position, of the next available descriptor + // index in the `ring` field. Because `self.next_avail` is only incremented, the actual + // index into `AvailRing` is `self.next_avail % self.actual_size()`. + let desc_index_offset = std::mem::size_of::() + + std::mem::size_of::() + + std::mem::size_of::() * usize::from(self.next_avail.0 % self.actual_size()); + let desc_index_address = self + .avail_ring + .unchecked_add(usize_to_u64(desc_index_offset)); // `self.is_valid()` already performed all the bound checks on the descriptor table // and virtq rings, so it's safe to unwrap guest memory reads and to use unchecked // offsets. - let desc_index: u16 = mem - .read_obj(self.avail_ring.unchecked_add(u64::from(index_offset))) + let slice = mem + .get_slice(desc_index_address, std::mem::size_of::()) .unwrap(); + // SAFETY: + // We transforming valid memory slice + let desc_index = unsafe { *slice.ptr_guard().as_ptr().cast::() }; DescriptorChain::checked_new(mem, self.desc_table, self.actual_size(), desc_index).map( |dc| { @@ -453,24 +452,44 @@ impl Queue { ) -> Result<(), QueueError> { debug_assert!(self.is_layout_valid(mem)); - let next_used = self.next_used.0 % self.actual_size(); let used_element = UsedElement { id: u32::from(desc_index), len, }; - self.write_used_ring(mem, next_used, used_element)?; + self.write_used_ring(mem, self.next_used.0, used_element)?; + self.advance_used_ring(mem, 1); + Ok(()) + } - self.num_added += Wrapping(1); - self.next_used += Wrapping(1); + /// Advance number of used descriptor heads by `n`. + pub fn advance_used_ring(&mut self, mem: &M, n: u16) { + self.num_added += Wrapping(n); + self.next_used += Wrapping(n); // This fence ensures all descriptor writes are visible before the index update is. fence(Ordering::Release); self.set_used_ring_idx(self.next_used.0, mem); - Ok(()) } - fn write_used_ring( + /// Discards last `n` descriptors by setting their len to 0. + pub fn discard_used(&mut self, mem: &M, n: u16) { + // `next_used` is pointing to the next descriptor index. + // So we use range 1..n + 1 to get indexes of last n descriptors. + for i in 1..n + 1 { + let next_used_index = self.next_used - Wrapping(i); + let mut used_element = self.read_used_ring(mem, next_used_index.0); + used_element.len = 0; + // SAFETY: + // This should never panic as we only update len of the used_element. + self.write_used_ring(mem, next_used_index.0, used_element) + .unwrap(); + } + } + + /// Read used element to the used ring at specified index. + #[inline(always)] + pub fn write_used_ring( &self, mem: &M, index: u16, @@ -494,11 +513,42 @@ impl Queue { // We calculate offset into `ring` field. let used_ring_offset = std::mem::size_of::() + std::mem::size_of::() - + std::mem::size_of::() * usize::from(index); + + std::mem::size_of::() * usize::from(index % self.actual_size()); + let used_element_address = self.used_ring.unchecked_add(usize_to_u64(used_ring_offset)); + + // SAFETY: + // `used_element_address` param is bounded by size of the queue as `index` is + // modded by `actual_size()`. + // `self.is_valid()` already performed all the bound checks on the descriptor table + // and virtq rings, so it's safe to unwrap guest memory reads and to use unchecked + // offsets. + mem.write_obj(used_element, used_element_address).unwrap(); + Ok(()) + } + + /// Read used element from a used ring at specified index. + #[inline(always)] + fn read_used_ring(&self, mem: &M, index: u16) -> UsedElement { + // Used ring has layout: + // struct UsedRing { + // flags: u16, + // idx: u16, + // ring: [UsedElement; ], + // avail_event: u16, + // } + // We calculate offset into `ring` field. + let used_ring_offset = std::mem::size_of::() + + std::mem::size_of::() + + std::mem::size_of::() * usize::from(index % self.actual_size()); let used_element_address = self.used_ring.unchecked_add(usize_to_u64(used_ring_offset)); - mem.write_obj(used_element, used_element_address) - .map_err(QueueError::UsedRing) + // SAFETY: + // `used_element_address` param is bounded by size of the queue as `index` is + // modded by `actual_size()`. + // `self.is_valid()` already performed all the bound checks on the descriptor table + // and virtq rings, so it's safe to unwrap guest memory reads and to use unchecked + // offsets. + mem.read_obj(used_element_address).unwrap() } /// Fetch the available ring index (`virtq_avail->idx`) from guest memory. @@ -529,7 +579,7 @@ impl Queue { /// Helper method that writes to the `avail_event` field of the used ring. #[inline(always)] - fn set_used_ring_avail_event(&mut self, avail_event: u16, mem: &M) { + pub fn set_used_ring_avail_event(&mut self, avail_event: u16, mem: &M) { debug_assert!(self.is_layout_valid(mem)); // Used ring has layout: @@ -552,7 +602,7 @@ impl Queue { /// Helper method that writes to the `idx` field of the used ring. #[inline(always)] - fn set_used_ring_idx(&mut self, next_used: u16, mem: &M) { + pub fn set_used_ring_idx(&mut self, next_used: u16, mem: &M) { debug_assert!(self.is_layout_valid(mem)); // Used ring has layout: @@ -1175,9 +1225,6 @@ mod tests { // index >= queue_size assert!(DescriptorChain::checked_new(m, vq.dtable_start(), 16, 16).is_none()); - // desc_table address is way off - assert!(DescriptorChain::checked_new(m, GuestAddress(0x00ff_ffff_ffff), 16, 0).is_none()); - // Let's create an invalid chain. { // The first desc has a normal len, and the next_descriptor flag is set. @@ -1211,6 +1258,16 @@ mod tests { } } + #[test] + #[should_panic] + fn test_checked_new_descriptor_chain_panic() { + let m = &multi_region_mem(&[(GuestAddress(0), 0x10000)]); + + // `checked_new` does assume that `desc_table` is valid. + // When desc_table address is way off, it should panic. + DescriptorChain::checked_new(m, GuestAddress(0x00ff_ffff_ffff), 16, 0); + } + #[test] fn test_queue_validation() { let m = &default_mem(); diff --git a/src/vmm/src/devices/virtio/rng/device.rs b/src/vmm/src/devices/virtio/rng/device.rs index bb01ce5e44e..f671f00e554 100644 --- a/src/vmm/src/devices/virtio/rng/device.rs +++ b/src/vmm/src/devices/virtio/rng/device.rs @@ -132,7 +132,10 @@ impl Entropy { let index = desc.index; METRICS.entropy_event_count.inc(); - let bytes = match IoVecBufferMut::from_descriptor_chain(desc) { + // SAFETY: This descriptor chain is only loaded once + // virtio requests are handled sequentially so no two IoVecBuffers + // are live at the same time, meaning this has exclusive ownership over the memory + let bytes = match unsafe { IoVecBufferMut::from_descriptor_chain(desc) } { Ok(mut iovec) => { debug!( "entropy: guest request for {} bytes of entropy", @@ -428,13 +431,15 @@ mod tests { // This should succeed, we just added two descriptors let desc = entropy_dev.queues_mut()[RNG_QUEUE].pop(&mem).unwrap(); assert!(matches!( - IoVecBufferMut::from_descriptor_chain(desc), + // SAFETY: This descriptor chain is only loaded into one buffer + unsafe { IoVecBufferMut::from_descriptor_chain(desc) }, Err(crate::devices::virtio::iovec::IoVecError::ReadOnlyDescriptor) )); // This should succeed, we should have one more descriptor let desc = entropy_dev.queues_mut()[RNG_QUEUE].pop(&mem).unwrap(); - let mut iovec = IoVecBufferMut::from_descriptor_chain(desc).unwrap(); + // SAFETY: This descriptor chain is only loaded into one buffer + let mut iovec = unsafe { IoVecBufferMut::from_descriptor_chain(desc).unwrap() }; entropy_dev.handle_one(&mut iovec).unwrap(); } diff --git a/src/vmm/src/devices/virtio/vsock/packet.rs b/src/vmm/src/devices/virtio/vsock/packet.rs index 952f8b1511e..c18b45b9a94 100644 --- a/src/vmm/src/devices/virtio/vsock/packet.rs +++ b/src/vmm/src/devices/virtio/vsock/packet.rs @@ -161,7 +161,10 @@ impl VsockPacket { /// Returns [`VsockError::DescChainTooShortForHeader`] if the descriptor chain's total buffer /// length is insufficient to hold the 44 byte vsock header pub fn from_rx_virtq_head(chain: DescriptorChain) -> Result { - let buffer = IoVecBufferMut::from_descriptor_chain(chain)?; + // SAFETY: This descriptor chain is only loaded once + // virtio requests are handled sequentially so no two IoVecBuffers + // are live at the same time, meaning this has exclusive ownership over the memory + let buffer = unsafe { IoVecBufferMut::from_descriptor_chain(chain)? }; if buffer.len() < VSOCK_PKT_HDR_SIZE { return Err(VsockError::DescChainTooShortForHeader(buffer.len() as usize));