diff --git a/src/vmm/benches/block_request.rs b/src/vmm/benches/block_request.rs index c8c16b029b5..17937dd0242 100644 --- a/src/vmm/benches/block_request.rs +++ b/src/vmm/benches/block_request.rs @@ -24,7 +24,7 @@ pub fn block_request_benchmark(c: &mut Criterion) { chain.set_header(request_header); let mut queue = virt_queue.create_queue(); - let desc = queue.pop(&mem).unwrap(); + let desc = queue.pop().unwrap(); c.bench_function("request_parse", |b| { b.iter(|| { diff --git a/src/vmm/benches/queue.rs b/src/vmm/benches/queue.rs index 49a62b82edd..9724e8227e4 100644 --- a/src/vmm/benches/queue.rs +++ b/src/vmm/benches/queue.rs @@ -61,7 +61,7 @@ pub fn queue_benchmark(c: &mut Criterion) { set_dtable_one_chain(&rxq, 1); queue.next_avail = Wrapping(0); - let desc = queue.pop(&mem).unwrap(); + let desc = queue.pop().unwrap(); c.bench_function("next_descriptor_1", |b| { b.iter(|| { let mut head = Some(desc.clone()); @@ -73,7 +73,7 @@ pub fn queue_benchmark(c: &mut Criterion) { set_dtable_one_chain(&rxq, 2); queue.next_avail = Wrapping(0); - let desc = queue.pop(&mem).unwrap(); + let desc = queue.pop().unwrap(); c.bench_function("next_descriptor_2", |b| { b.iter(|| { let mut head = Some(desc.clone()); @@ -85,7 +85,7 @@ pub fn queue_benchmark(c: &mut Criterion) { set_dtable_one_chain(&rxq, 4); queue.next_avail = Wrapping(0); - let desc = queue.pop(&mem).unwrap(); + let desc = queue.pop().unwrap(); c.bench_function("next_descriptor_4", |b| { b.iter(|| { let mut head = Some(desc.clone()); @@ -97,7 +97,7 @@ pub fn queue_benchmark(c: &mut Criterion) { set_dtable_one_chain(&rxq, 16); queue.next_avail = Wrapping(0); - let desc = queue.pop(&mem).unwrap(); + let desc = queue.pop().unwrap(); c.bench_function("next_descriptor_16", |b| { b.iter(|| { let mut head = Some(desc.clone()); @@ -113,7 +113,7 @@ pub fn queue_benchmark(c: &mut Criterion) { c.bench_function("queue_pop_1", |b| { b.iter(|| { queue.next_avail = Wrapping(0); - while let Some(desc) = queue.pop(&mem) { + while let Some(desc) = queue.pop() { std::hint::black_box(desc); } }) @@ -123,7 +123,7 @@ pub fn queue_benchmark(c: &mut Criterion) { c.bench_function("queue_pop_4", |b| { b.iter(|| { queue.next_avail = Wrapping(0); - while let Some(desc) = queue.pop(&mem) { + while let Some(desc) = queue.pop() { std::hint::black_box(desc); } }) @@ -133,7 +133,7 @@ pub fn queue_benchmark(c: &mut Criterion) { c.bench_function("queue_pop_16", |b| { b.iter(|| { queue.next_avail = Wrapping(0); - while let Some(desc) = queue.pop(&mem) { + while let Some(desc) = queue.pop() { std::hint::black_box(desc); } }) @@ -146,7 +146,7 @@ pub fn queue_benchmark(c: &mut Criterion) { for i in 0_u16..1_u16 { let index = std::hint::black_box(i); let len = std::hint::black_box(i + 1); - _ = queue.add_used(&mem, index as u16, len as u32); + _ = queue.add_used(index as u16, len as u32); } }) }); @@ -158,7 +158,7 @@ pub fn queue_benchmark(c: &mut Criterion) { for i in 0_u16..16_u16 { let index = std::hint::black_box(i); let len = std::hint::black_box(i + 1); - _ = queue.add_used(&mem, index as u16, len as u32); + _ = queue.add_used(index as u16, len as u32); } }) }); @@ -170,7 +170,7 @@ pub fn queue_benchmark(c: &mut Criterion) { for i in 0_u16..256_u16 { let index = std::hint::black_box(i); let len = std::hint::black_box(i + 1); - _ = queue.add_used(&mem, index as u16, len as u32); + _ = queue.add_used(index as u16, len as u32); } }) }); diff --git a/src/vmm/src/devices/virtio/balloon/device.rs b/src/vmm/src/devices/virtio/balloon/device.rs index b294da9b9c0..98ed8332ac8 100644 --- a/src/vmm/src/devices/virtio/balloon/device.rs +++ b/src/vmm/src/devices/virtio/balloon/device.rs @@ -297,7 +297,7 @@ impl Balloon { // Internal loop processes descriptors and acummulates the pfns in `pfn_buffer`. // Breaks out when there is not enough space in `pfn_buffer` to completely process // the next descriptor. - while let Some(head) = queue.pop(mem) { + while let Some(head) = queue.pop() { let len = head.len as usize; let max_len = MAX_PAGES_IN_DESC * SIZE_OF_U32; valid_descs_found = true; @@ -339,9 +339,7 @@ impl Balloon { // Acknowledge the receipt of the descriptor. // 0 is number of bytes the device has written to memory. - queue - .add_used(mem, head.index, 0) - .map_err(BalloonError::Queue)?; + queue.add_used(head.index, 0).map_err(BalloonError::Queue)?; needs_interrupt = true; } @@ -372,17 +370,13 @@ impl Balloon { } pub(crate) fn process_deflate_queue(&mut self) -> Result<(), BalloonError> { - // This is safe since we checked in the event handler that the device is activated. - let mem = self.device_state.mem().unwrap(); METRICS.deflate_count.inc(); let queue = &mut self.queues[DEFLATE_INDEX]; let mut needs_interrupt = false; - while let Some(head) = queue.pop(mem) { - queue - .add_used(mem, head.index, 0) - .map_err(BalloonError::Queue)?; + while let Some(head) = queue.pop() { + queue.add_used(head.index, 0).map_err(BalloonError::Queue)?; needs_interrupt = true; } @@ -398,13 +392,13 @@ impl Balloon { let mem = self.device_state.mem().unwrap(); METRICS.stats_updates_count.inc(); - while let Some(head) = self.queues[STATS_INDEX].pop(mem) { + while let Some(head) = self.queues[STATS_INDEX].pop() { if let Some(prev_stats_desc) = self.stats_desc_index { // We shouldn't ever have an extra buffer if the driver follows // the protocol, but return it if we find one. error!("balloon: driver is not compliant, more than one stats buffer received"); self.queues[STATS_INDEX] - .add_used(mem, prev_stats_desc, 0) + .add_used(prev_stats_desc, 0) .map_err(BalloonError::Queue)?; } for index in (0..head.len).step_by(SIZE_OF_STAT) { @@ -450,14 +444,11 @@ impl Balloon { } fn trigger_stats_update(&mut self) -> Result<(), BalloonError> { - // This is safe since we checked in the event handler that the device is activated. - let mem = self.device_state.mem().unwrap(); - // The communication is driven by the device by using the buffer // and sending a used buffer notification if let Some(index) = self.stats_desc_index.take() { self.queues[STATS_INDEX] - .add_used(mem, index, 0) + .add_used(index, 0) .map_err(BalloonError::Queue)?; self.signal_used_queue() } else { @@ -611,6 +602,11 @@ impl VirtioDevice for Balloon { } fn activate(&mut self, mem: GuestMemoryMmap) -> Result<(), ActivateError> { + for q in self.queues.iter_mut() { + q.initialize(&mem) + .map_err(ActivateError::QueueMemoryError)?; + } + self.device_state = DeviceState::Activated(mem); if self.activate_evt.write(1).is_err() { METRICS.activate_fails.inc(); diff --git a/src/vmm/src/devices/virtio/block/vhost_user/device.rs b/src/vmm/src/devices/virtio/block/vhost_user/device.rs index 99b27915598..627984d1bb5 100644 --- a/src/vmm/src/devices/virtio/block/vhost_user/device.rs +++ b/src/vmm/src/devices/virtio/block/vhost_user/device.rs @@ -331,6 +331,11 @@ impl VirtioDevice for VhostUserBlock } fn activate(&mut self, mem: GuestMemoryMmap) -> Result<(), ActivateError> { + for q in self.queues.iter_mut() { + q.initialize(&mem) + .map_err(ActivateError::QueueMemoryError)?; + } + let start_time = utils::time::get_time_us(utils::time::ClockType::Monotonic); // Setting features again, because now we negotiated them // with guest driver as well. diff --git a/src/vmm/src/devices/virtio/block/virtio/device.rs b/src/vmm/src/devices/virtio/block/virtio/device.rs index a0d01413b76..5c8fb95886e 100644 --- a/src/vmm/src/devices/virtio/block/virtio/device.rs +++ b/src/vmm/src/devices/virtio/block/virtio/device.rs @@ -388,15 +388,14 @@ impl VirtioBlock { queue: &mut Queue, index: u16, len: u32, - mem: &GuestMemoryMmap, irq_trigger: &IrqTrigger, block_metrics: &BlockDeviceMetrics, ) { - queue.add_used(mem, index, len).unwrap_or_else(|err| { + queue.add_used(index, len).unwrap_or_else(|err| { error!("Failed to add available descriptor head {}: {}", index, err) }); - if queue.prepare_kick(mem) { + if queue.prepare_kick() { irq_trigger.trigger_irq(IrqType::Vring).unwrap_or_else(|_| { block_metrics.event_fails.inc(); }); @@ -411,8 +410,8 @@ impl VirtioBlock { let queue = &mut self.queues[queue_index]; let mut used_any = false; - while let Some(head) = queue.pop_or_enable_notification(mem) { - self.metrics.remaining_reqs_count.add(queue.len(mem).into()); + while let Some(head) = queue.pop_or_enable_notification() { + self.metrics.remaining_reqs_count.add(queue.len().into()); let processing_result = match Request::parse(&head, mem, self.disk.nsectors) { Ok(request) => { if request.rate_limit(&mut self.rate_limiter) { @@ -448,7 +447,6 @@ impl VirtioBlock { queue, head.index, finished.num_bytes_to_mem, - mem, &self.irq_trigger, &self.metrics, ); @@ -500,7 +498,6 @@ impl VirtioBlock { queue, finished.desc_idx, finished.num_bytes_to_mem, - mem, &self.irq_trigger, &self.metrics, ); @@ -633,6 +630,11 @@ impl VirtioDevice for VirtioBlock { } fn activate(&mut self, mem: GuestMemoryMmap) -> Result<(), ActivateError> { + for q in self.queues.iter_mut() { + q.initialize(&mem) + .map_err(ActivateError::QueueMemoryError)?; + } + let event_idx = self.has_feature(u64::from(VIRTIO_RING_F_EVENT_IDX)); if event_idx { for queue in &mut self.queues { diff --git a/src/vmm/src/devices/virtio/block/virtio/request.rs b/src/vmm/src/devices/virtio/block/virtio/request.rs index 419124a81eb..c9c3cf112d1 100644 --- a/src/vmm/src/devices/virtio/block/virtio/request.rs +++ b/src/vmm/src/devices/virtio/block/virtio/request.rs @@ -473,7 +473,7 @@ mod tests { let memory = self.driver_queue.memory(); assert!(matches!( - Request::parse(&q.pop(memory).unwrap(), memory, NUM_DISK_SECTORS), + Request::parse(&q.pop().unwrap(), memory, NUM_DISK_SECTORS), Err(_e) )); } @@ -481,8 +481,7 @@ mod tests { fn check_parse(&self, check_data: bool) { let mut q = self.driver_queue.create_queue(); let memory = self.driver_queue.memory(); - let request = - Request::parse(&q.pop(memory).unwrap(), memory, NUM_DISK_SECTORS).unwrap(); + let request = Request::parse(&q.pop().unwrap(), memory, NUM_DISK_SECTORS).unwrap(); let expected_header = self.header(); assert_eq!( @@ -949,7 +948,7 @@ mod tests { fn parse_random_requests() { let cfg = ProptestConfig::with_cases(1000); proptest!(cfg, |(mut request in random_request_parse())| { - let result = Request::parse(&request.2.pop(&request.1).unwrap(), &request.1, NUM_DISK_SECTORS); + let result = Request::parse(&request.2.pop().unwrap(), &request.1, NUM_DISK_SECTORS); match result { Ok(r) => prop_assert!(r == request.0.unwrap()), Err(err) => { diff --git a/src/vmm/src/devices/virtio/device.rs b/src/vmm/src/devices/virtio/device.rs index 86c8327d8c8..7ae93b46a40 100644 --- a/src/vmm/src/devices/virtio/device.rs +++ b/src/vmm/src/devices/virtio/device.rs @@ -12,7 +12,7 @@ use std::sync::Arc; use utils::eventfd::EventFd; use super::mmio::{VIRTIO_MMIO_INT_CONFIG, VIRTIO_MMIO_INT_VRING}; -use super::queue::Queue; +use super::queue::{Queue, QueueError}; use super::ActivateError; use crate::devices::virtio::AsAny; use crate::logger::{error, warn}; @@ -180,6 +180,14 @@ pub trait VirtioDevice: AsAny + Send { fn reset(&mut self) -> Option<(EventFd, Vec)> { None } + + /// Mark pages used by queues as dirty. + fn mark_queue_memory_dirty(&self, mem: &GuestMemoryMmap) -> Result<(), QueueError> { + for queue in self.queues() { + queue.mark_memory_dirty(mem)? + } + Ok(()) + } } impl fmt::Debug for dyn VirtioDevice { diff --git a/src/vmm/src/devices/virtio/iovec.rs b/src/vmm/src/devices/virtio/iovec.rs index 29a45fb5462..688f27a0a67 100644 --- a/src/vmm/src/devices/virtio/iovec.rs +++ b/src/vmm/src/devices/virtio/iovec.rs @@ -5,12 +5,13 @@ use std::io::ErrorKind; use libc::{c_void, iovec, size_t}; use smallvec::SmallVec; +use vm_memory::bitmap::Bitmap; use vm_memory::{ - GuestMemoryError, ReadVolatile, VolatileMemoryError, VolatileSlice, WriteVolatile, + GuestMemory, GuestMemoryError, ReadVolatile, VolatileMemoryError, VolatileSlice, WriteVolatile, }; use crate::devices::virtio::queue::DescriptorChain; -use crate::vstate::memory::{Bitmap, GuestMemory}; +use crate::vstate::memory::GuestMemoryMmap; #[derive(Debug, thiserror::Error, displaydoc::Display)] pub enum IoVecError { @@ -57,6 +58,7 @@ impl IoVecBuffer { /// The descriptor chain cannot be referencing the same memory location as another chain pub unsafe fn load_descriptor_chain( &mut self, + mem: &GuestMemoryMmap, head: DescriptorChain, ) -> Result<(), IoVecError> { self.clear(); @@ -70,8 +72,7 @@ impl IoVecBuffer { // We use get_slice instead of `get_host_address` here in order to have the whole // range of the descriptor chain checked, i.e. [addr, addr + len) is a valid memory // region in the GuestMemoryMmap. - let iov_base = desc - .mem + let iov_base = mem .get_slice(desc.addr, desc.len as usize)? .ptr_guard_mut() .as_ptr() @@ -96,9 +97,12 @@ impl IoVecBuffer { /// # Safety /// /// The descriptor chain cannot be referencing the same memory location as another chain - pub unsafe fn from_descriptor_chain(head: DescriptorChain) -> Result { + pub unsafe fn from_descriptor_chain( + mem: &GuestMemoryMmap, + head: DescriptorChain, + ) -> Result { let mut new_buffer = Self::default(); - new_buffer.load_descriptor_chain(head)?; + new_buffer.load_descriptor_chain(mem, head)?; Ok(new_buffer) } @@ -231,6 +235,7 @@ impl IoVecBufferMut { /// The descriptor chain cannot be referencing the same memory location as another chain pub unsafe fn load_descriptor_chain( &mut self, + mem: &GuestMemoryMmap, head: DescriptorChain, ) -> Result<(), IoVecError> { self.clear(); @@ -244,7 +249,7 @@ impl IoVecBufferMut { // We use get_slice instead of `get_host_address` here in order to have the whole // range of the descriptor chain checked, i.e. [addr, addr + len) is a valid memory // region in the GuestMemoryMmap. - let slice = desc.mem.get_slice(desc.addr, desc.len as usize)?; + let slice = mem.get_slice(desc.addr, desc.len as usize)?; // We need to mark the area of guest memory that will be mutated through this // IoVecBufferMut as dirty ahead of time, as we loose access to all @@ -272,9 +277,12 @@ impl IoVecBufferMut { /// # Safety /// /// The descriptor chain cannot be referencing the same memory location as another chain - pub unsafe fn from_descriptor_chain(head: DescriptorChain) -> Result { + pub unsafe fn from_descriptor_chain( + mem: &GuestMemoryMmap, + head: DescriptorChain, + ) -> Result { let mut new_buffer = Self::default(); - new_buffer.load_descriptor_chain(head)?; + new_buffer.load_descriptor_chain(mem, head)?; Ok(new_buffer) } @@ -482,34 +490,34 @@ mod tests { fn test_access_mode() { let mem = default_mem(); let (mut q, _) = read_only_chain(&mem); - let head = q.pop(&mem).unwrap(); + let head = q.pop().unwrap(); // SAFETY: This descriptor chain is only loaded into one buffer - unsafe { IoVecBuffer::from_descriptor_chain(head).unwrap() }; + unsafe { IoVecBuffer::from_descriptor_chain(&mem, head).unwrap() }; let (mut q, _) = write_only_chain(&mem); - let head = q.pop(&mem).unwrap(); + let head = q.pop().unwrap(); // SAFETY: This descriptor chain is only loaded into one buffer - unsafe { IoVecBuffer::from_descriptor_chain(head).unwrap_err() }; + unsafe { IoVecBuffer::from_descriptor_chain(&mem, head).unwrap_err() }; let (mut q, _) = read_only_chain(&mem); - let head = q.pop(&mem).unwrap(); + let head = q.pop().unwrap(); // SAFETY: This descriptor chain is only loaded into one buffer - unsafe { IoVecBufferMut::from_descriptor_chain(head).unwrap_err() }; + unsafe { IoVecBufferMut::from_descriptor_chain(&mem, head).unwrap_err() }; let (mut q, _) = write_only_chain(&mem); - let head = q.pop(&mem).unwrap(); + let head = q.pop().unwrap(); // SAFETY: This descriptor chain is only loaded into one buffer - unsafe { IoVecBufferMut::from_descriptor_chain(head).unwrap() }; + unsafe { IoVecBufferMut::from_descriptor_chain(&mem, head).unwrap() }; } #[test] fn test_iovec_length() { let mem = default_mem(); let (mut q, _) = read_only_chain(&mem); - let head = q.pop(&mem).unwrap(); + let head = q.pop().unwrap(); // SAFETY: This descriptor chain is only loaded once in this test - let iovec = unsafe { IoVecBuffer::from_descriptor_chain(head).unwrap() }; + let iovec = unsafe { IoVecBuffer::from_descriptor_chain(&mem, head).unwrap() }; assert_eq!(iovec.len(), 4 * 64); } @@ -517,10 +525,10 @@ mod tests { fn test_iovec_mut_length() { let mem = default_mem(); let (mut q, _) = write_only_chain(&mem); - let head = q.pop(&mem).unwrap(); + let head = q.pop().unwrap(); // SAFETY: This descriptor chain is only loaded once in this test - let iovec = unsafe { IoVecBufferMut::from_descriptor_chain(head).unwrap() }; + let iovec = unsafe { IoVecBufferMut::from_descriptor_chain(&mem, head).unwrap() }; assert_eq!(iovec.len(), 4 * 64); } @@ -528,10 +536,10 @@ mod tests { fn test_iovec_read_at() { let mem = default_mem(); let (mut q, _) = read_only_chain(&mem); - let head = q.pop(&mem).unwrap(); + let head = q.pop().unwrap(); // SAFETY: This descriptor chain is only loaded once in this test - let iovec = unsafe { IoVecBuffer::from_descriptor_chain(head).unwrap() }; + let iovec = unsafe { IoVecBuffer::from_descriptor_chain(&mem, head).unwrap() }; let mut buf = vec![0u8; 257]; assert_eq!( @@ -583,10 +591,10 @@ mod tests { let (mut q, vq) = write_only_chain(&mem); // This is a descriptor chain with 4 elements 64 bytes long each. - let head = q.pop(&mem).unwrap(); + let head = q.pop().unwrap(); // SAFETY: This descriptor chain is only loaded into one buffer - let mut iovec = unsafe { IoVecBufferMut::from_descriptor_chain(head).unwrap() }; + let mut iovec = unsafe { IoVecBufferMut::from_descriptor_chain(&mem, head).unwrap() }; let buf = vec![0u8, 1, 2, 3, 4]; // One test vector for each part of the chain diff --git a/src/vmm/src/devices/virtio/mmio.rs b/src/vmm/src/devices/virtio/mmio.rs index 7cf09f723a5..b4aa9c88875 100644 --- a/src/vmm/src/devices/virtio/mmio.rs +++ b/src/vmm/src/devices/virtio/mmio.rs @@ -340,12 +340,12 @@ impl MmioTransport { } } 0x70 => self.set_device_status(v), - 0x80 => self.update_queue_field(|q| lo(&mut q.desc_table, v)), - 0x84 => self.update_queue_field(|q| hi(&mut q.desc_table, v)), - 0x90 => self.update_queue_field(|q| lo(&mut q.avail_ring, v)), - 0x94 => self.update_queue_field(|q| hi(&mut q.avail_ring, v)), - 0xa0 => self.update_queue_field(|q| lo(&mut q.used_ring, v)), - 0xa4 => self.update_queue_field(|q| hi(&mut q.used_ring, v)), + 0x80 => self.update_queue_field(|q| lo(&mut q.desc_table_address, v)), + 0x84 => self.update_queue_field(|q| hi(&mut q.desc_table_address, v)), + 0x90 => self.update_queue_field(|q| lo(&mut q.avail_ring_address, v)), + 0x94 => self.update_queue_field(|q| hi(&mut q.avail_ring_address, v)), + 0xa0 => self.update_queue_field(|q| lo(&mut q.used_ring_address, v)), + 0xa4 => self.update_queue_field(|q| hi(&mut q.used_ring_address, v)), _ => { warn!("unknown virtio mmio register write: 0x{:x}", offset); } @@ -696,32 +696,35 @@ pub(crate) mod tests { d.bus_write(0x44, &buf[..]); assert!(d.locked_device().queues()[0].ready); - assert_eq!(d.locked_device().queues()[0].desc_table.0, 0); + assert_eq!(d.locked_device().queues()[0].desc_table_address.0, 0); write_le_u32(&mut buf[..], 123); d.bus_write(0x80, &buf[..]); - assert_eq!(d.locked_device().queues()[0].desc_table.0, 123); + assert_eq!(d.locked_device().queues()[0].desc_table_address.0, 123); d.bus_write(0x84, &buf[..]); assert_eq!( - d.locked_device().queues()[0].desc_table.0, + d.locked_device().queues()[0].desc_table_address.0, 123 + (123 << 32) ); - assert_eq!(d.locked_device().queues()[0].avail_ring.0, 0); + assert_eq!(d.locked_device().queues()[0].avail_ring_address.0, 0); write_le_u32(&mut buf[..], 124); d.bus_write(0x90, &buf[..]); - assert_eq!(d.locked_device().queues()[0].avail_ring.0, 124); + assert_eq!(d.locked_device().queues()[0].avail_ring_address.0, 124); d.bus_write(0x94, &buf[..]); assert_eq!( - d.locked_device().queues()[0].avail_ring.0, + d.locked_device().queues()[0].avail_ring_address.0, 124 + (124 << 32) ); - assert_eq!(d.locked_device().queues()[0].used_ring.0, 0); + assert_eq!(d.locked_device().queues()[0].used_ring_address.0, 0); write_le_u32(&mut buf[..], 125); d.bus_write(0xa0, &buf[..]); - assert_eq!(d.locked_device().queues()[0].used_ring.0, 125); + assert_eq!(d.locked_device().queues()[0].used_ring_address.0, 125); d.bus_write(0xa4, &buf[..]); - assert_eq!(d.locked_device().queues()[0].used_ring.0, 125 + (125 << 32)); + assert_eq!( + d.locked_device().queues()[0].used_ring_address.0, + 125 + (125 << 32) + ); set_device_status( &mut d, diff --git a/src/vmm/src/devices/virtio/mod.rs b/src/vmm/src/devices/virtio/mod.rs index 9edf96514b0..f68c2a123c9 100644 --- a/src/vmm/src/devices/virtio/mod.rs +++ b/src/vmm/src/devices/virtio/mod.rs @@ -9,6 +9,7 @@ use std::any::Any; +use self::queue::QueueError; use crate::devices::virtio::net::TapError; pub mod balloon; @@ -70,6 +71,8 @@ pub enum ActivateError { VhostUser(vhost_user::VhostUserError), /// Setting tap interface offload flags failed: {0} TapSetOffload(TapError), + /// Error setting pointers in the queue: (0) + QueueMemoryError(QueueError), } /// Trait that helps in upcasting an object to Any diff --git a/src/vmm/src/devices/virtio/net/device.rs b/src/vmm/src/devices/virtio/net/device.rs index e34676b2c31..29b99cfb775 100755 --- a/src/vmm/src/devices/virtio/net/device.rs +++ b/src/vmm/src/devices/virtio/net/device.rs @@ -272,15 +272,12 @@ impl Net { /// https://docs.oasis-open.org/virtio/virtio/v1.1/csprd01/virtio-v1.1-csprd01.html#x1-320005 /// 2.6.7.1 Driver Requirements: Used Buffer Notification Suppression fn try_signal_queue(&mut self, queue_type: NetQueue) -> Result<(), DeviceError> { - // This is safe since we checked in the event handler that the device is activated. - let mem = self.device_state.mem().unwrap(); - let queue = match queue_type { NetQueue::Rx => &mut self.queues[RX_INDEX], NetQueue::Tx => &mut self.queues[TX_INDEX], }; - if queue.prepare_kick(mem) { + if queue.prepare_kick() { self.irq_trigger .trigger_irq(IrqType::Vring) .map_err(|err| { @@ -389,7 +386,7 @@ impl Net { let mem = self.device_state.mem().unwrap(); let queue = &mut self.queues[RX_INDEX]; - let head_descriptor = queue.pop_or_enable_notification(mem).ok_or_else(|| { + let head_descriptor = queue.pop_or_enable_notification().ok_or_else(|| { self.metrics.no_rx_avail_buffer.inc(); FrontendError::EmptyQueue })?; @@ -409,7 +406,7 @@ impl Net { // Safe to unwrap because a frame must be smaller than 2^16 bytes. u32::try_from(self.rx_bytes_read).unwrap() }; - queue.add_used(mem, head_index, used_len).map_err(|err| { + queue.add_used(head_index, used_len).map_err(|err| { error!("Failed to add available descriptor {}: {}", head_index, err); FrontendError::AddUsed })?; @@ -594,19 +591,19 @@ impl Net { let mut used_any = false; let tx_queue = &mut self.queues[TX_INDEX]; - while let Some(head) = tx_queue.pop_or_enable_notification(mem) { + while let Some(head) = tx_queue.pop_or_enable_notification() { self.metrics .tx_remaining_reqs_count - .add(tx_queue.len(mem).into()); + .add(tx_queue.len().into()); let head_index = head.index; // Parse IoVecBuffer from descriptor head // SAFETY: This descriptor chain is only loaded once // virtio requests are handled sequentially so no two IoVecBuffers // are live at the same time, meaning this has exclusive ownership over the memory - if unsafe { self.tx_buffer.load_descriptor_chain(head).is_err() } { + if unsafe { self.tx_buffer.load_descriptor_chain(mem, head).is_err() } { self.metrics.tx_fails.inc(); tx_queue - .add_used(mem, head_index, 0) + .add_used(head_index, 0) .map_err(DeviceError::QueueError)?; continue; }; @@ -616,7 +613,7 @@ impl Net { error!("net: received too big frame from driver"); self.metrics.tx_malformed_frames.inc(); tx_queue - .add_used(mem, head_index, 0) + .add_used(head_index, 0) .map_err(DeviceError::QueueError)?; continue; } @@ -646,7 +643,7 @@ impl Net { } tx_queue - .add_used(mem, head_index, 0) + .add_used(head_index, 0) .map_err(DeviceError::QueueError)?; used_any = true; } @@ -751,14 +748,13 @@ impl Net { pub fn process_tap_rx_event(&mut self) { // This is safe since we checked in the event handler that the device is activated. - let mem = self.device_state.mem().unwrap(); self.metrics.rx_tap_event_count.inc(); // While there are no available RX queue buffers and there's a deferred_frame // don't process any more incoming. Otherwise start processing a frame. In the // process the deferred_frame flag will be set in order to avoid freezing the // RX queue. - if self.queues[RX_INDEX].is_empty(mem) && self.rx_deferred_frame { + if self.queues[RX_INDEX].is_empty() && self.rx_deferred_frame { self.metrics.no_rx_avail_buffer.inc(); return; } @@ -903,6 +899,11 @@ impl VirtioDevice for Net { } fn activate(&mut self, mem: GuestMemoryMmap) -> Result<(), ActivateError> { + for q in self.queues.iter_mut() { + q.initialize(&mem) + .map_err(ActivateError::QueueMemoryError)?; + } + let event_idx = self.has_feature(u64::from(VIRTIO_RING_F_EVENT_IDX)); if event_idx { for queue in &mut self.queues { diff --git a/src/vmm/src/devices/virtio/net/test_utils.rs b/src/vmm/src/devices/virtio/net/test_utils.rs index 216db273859..c2bffa1d2a9 100644 --- a/src/vmm/src/devices/virtio/net/test_utils.rs +++ b/src/vmm/src/devices/virtio/net/test_utils.rs @@ -312,7 +312,7 @@ pub(crate) fn inject_tap_tx_frame(net: &Net, len: usize) -> Vec { pub fn write_element_in_queue(net: &Net, idx: u16, val: u64) -> Result<(), DeviceError> { if idx as usize > net.queue_evts.len() { return Err(DeviceError::QueueError(QueueError::DescIndexOutOfBounds( - u32::from(idx), + idx, ))); } net.queue_evts[idx as usize].write(val).unwrap(); @@ -322,7 +322,7 @@ pub fn write_element_in_queue(net: &Net, idx: u16, val: u64) -> Result<(), Devic pub fn get_element_from_queue(net: &Net, idx: u16) -> Result { if idx as usize > net.queue_evts.len() { return Err(DeviceError::QueueError(QueueError::DescIndexOutOfBounds( - u32::from(idx), + idx, ))); } Ok(u64::try_from(net.queue_evts[idx as usize].as_raw_fd()).unwrap()) diff --git a/src/vmm/src/devices/virtio/persist.rs b/src/vmm/src/devices/virtio/persist.rs index 3497bc58c34..f80d5382104 100644 --- a/src/vmm/src/devices/virtio/persist.rs +++ b/src/vmm/src/devices/virtio/persist.rs @@ -9,6 +9,7 @@ use std::sync::{Arc, Mutex}; use serde::{Deserialize, Serialize}; +use super::queue::QueueError; use crate::devices::virtio::device::VirtioDevice; use crate::devices::virtio::gen::virtio_ring::VIRTIO_RING_F_EVENT_IDX; use crate::devices::virtio::mmio::MmioTransport; @@ -21,6 +22,8 @@ use crate::vstate::memory::{GuestAddress, GuestMemoryMmap}; pub enum PersistError { /// Snapshot state contains invalid queue info. InvalidInput, + /// Could not restore queue. + QueueConstruction(QueueError), } /// Queue information saved in snapshot. @@ -51,38 +54,59 @@ pub struct QueueState { num_added: Wrapping, } +/// Auxiliary structure for restoring queues. +#[derive(Debug, Clone)] +pub struct QueueConstructorArgs { + /// Pointer to guest memory. + pub mem: GuestMemoryMmap, + /// Is device this queue belong to activated + pub is_activated: bool, +} + impl Persist<'_> for Queue { type State = QueueState; - type ConstructorArgs = (); - type Error = (); + type ConstructorArgs = QueueConstructorArgs; + type Error = QueueError; fn save(&self) -> Self::State { QueueState { max_size: self.max_size, size: self.size, ready: self.ready, - desc_table: self.desc_table.0, - avail_ring: self.avail_ring.0, - used_ring: self.used_ring.0, + desc_table: self.desc_table_address.0, + avail_ring: self.avail_ring_address.0, + used_ring: self.used_ring_address.0, next_avail: self.next_avail, next_used: self.next_used, num_added: self.num_added, } } - fn restore(_: Self::ConstructorArgs, state: &Self::State) -> Result { - Ok(Queue { + fn restore( + constructor_args: Self::ConstructorArgs, + state: &Self::State, + ) -> Result { + let mut queue = Queue { max_size: state.max_size, size: state.size, ready: state.ready, - desc_table: GuestAddress(state.desc_table), - avail_ring: GuestAddress(state.avail_ring), - used_ring: GuestAddress(state.used_ring), + desc_table_address: GuestAddress(state.desc_table), + avail_ring_address: GuestAddress(state.avail_ring), + used_ring_address: GuestAddress(state.used_ring), + + desc_table_ptr: std::ptr::null(), + avail_ring_ptr: std::ptr::null_mut(), + used_ring_ptr: std::ptr::null_mut(), + next_avail: state.next_avail, next_used: state.next_used, uses_notif_suppression: false, num_added: state.num_added, - }) + }; + if constructor_args.is_activated { + queue.initialize(&constructor_args.mem)?; + } + Ok(queue) } } @@ -137,18 +161,24 @@ impl VirtioDeviceState { } let uses_notif_suppression = (self.acked_features & 1u64 << VIRTIO_RING_F_EVENT_IDX) != 0; + let queue_construction_args = QueueConstructorArgs { + mem: mem.clone(), + is_activated: self.activated, + }; let queues: Vec = self .queues .iter() .map(|queue_state| { - // Safe to unwrap, `Queue::restore` has no error case. - let mut queue = Queue::restore((), queue_state).unwrap(); - if uses_notif_suppression { - queue.enable_notif_suppression(); - } - queue + Queue::restore(queue_construction_args.clone(), queue_state) + .map(|mut queue| { + if uses_notif_suppression { + queue.enable_notif_suppression(); + } + queue + }) + .map_err(PersistError::QueueConstruction) }) - .collect(); + .collect::>()?; for q in &queues { // Sanity check queue size and queue max size. @@ -316,14 +346,21 @@ mod tests { #[test] fn test_queue_persistence() { - let queue = Queue::new(128); + let mem = default_mem(); - let mut mem = vec![0; 4096]; + let mut queue = Queue::new(128); + queue.initialize(&mem).unwrap(); + + let mut bytes = vec![0; 4096]; - Snapshot::serialize(&mut mem.as_mut_slice(), &queue.save()).unwrap(); + Snapshot::serialize(&mut bytes.as_mut_slice(), &queue.save()).unwrap(); + let ca = QueueConstructorArgs { + mem, + is_activated: true, + }; let restored_queue = - Queue::restore((), &Snapshot::deserialize(&mut mem.as_slice()).unwrap()).unwrap(); + Queue::restore(ca, &Snapshot::deserialize(&mut bytes.as_slice()).unwrap()).unwrap(); assert_eq!(restored_queue, queue); } diff --git a/src/vmm/src/devices/virtio/queue.rs b/src/vmm/src/devices/virtio/queue.rs index 8bf144a2018..9dc3f2042ae 100644 --- a/src/vmm/src/devices/virtio/queue.rs +++ b/src/vmm/src/devices/virtio/queue.rs @@ -9,12 +9,8 @@ use std::cmp::min; use std::num::Wrapping; use std::sync::atomic::{fence, Ordering}; -use utils::usize_to_u64; - use crate::logger::error; -use crate::vstate::memory::{ - Address, ByteValued, Bytes, GuestAddress, GuestMemory, GuestMemoryMmap, -}; +use crate::vstate::memory::{Address, Bitmap, ByteValued, GuestAddress, GuestMemory}; pub const VIRTQ_DESC_F_NEXT: u16 = 0x1; pub const VIRTQ_DESC_F_WRITE: u16 = 0x2; @@ -32,10 +28,12 @@ pub(super) const FIRECRACKER_MAX_QUEUE_SIZE: u16 = 256; #[derive(Debug, thiserror::Error, displaydoc::Display)] pub enum QueueError { + /// Virtio queue number of available descriptors {0} is greater than queue size {1}. + InvalidQueueSize(u16, u16), /// Descriptor index out of bounds: {0}. - DescIndexOutOfBounds(u32), + DescIndexOutOfBounds(u16), /// Failed to write value into the virtio queue used ring: {0} - UsedRing(#[from] vm_memory::GuestMemoryError), + MemoryError(#[from] vm_memory::GuestMemoryError), } /// A virtio descriptor constraints with C representative. @@ -43,12 +41,12 @@ pub enum QueueError { /// https://docs.oasis-open.org/virtio/virtio/v1.1/csprd01/virtio-v1.1-csprd01.html#x1-430008 /// 2.6.5 The Virtqueue Descriptor Table #[repr(C)] -#[derive(Default, Clone, Copy)] -struct Descriptor { - addr: u64, - len: u32, - flags: u16, - next: u16, +#[derive(Debug, Default, Clone, Copy)] +pub struct Descriptor { + pub addr: u64, + pub len: u32, + pub flags: u16, + pub next: u16, } // SAFETY: `Descriptor` is a POD and contains no padding. @@ -59,10 +57,10 @@ unsafe impl ByteValued for Descriptor {} /// https://docs.oasis-open.org/virtio/virtio/v1.1/csprd01/virtio-v1.1-csprd01.html#x1-430008 /// 2.6.8 The Virtqueue Used Ring #[repr(C)] -#[derive(Default, Clone, Copy)] -struct UsedElement { - id: u32, - len: u32, +#[derive(Debug, Default, Clone, Copy)] +pub struct UsedElement { + pub id: u32, + pub len: u32, } // SAFETY: `UsedElement` is a POD and contains no padding. @@ -70,14 +68,12 @@ unsafe impl ByteValued for UsedElement {} /// A virtio descriptor chain. #[derive(Debug, Copy, Clone)] -pub struct DescriptorChain<'a, M: GuestMemory = GuestMemoryMmap> { - desc_table: GuestAddress, +pub struct DescriptorChain { + desc_table_ptr: *const Descriptor, + queue_size: u16, ttl: u16, // used to prevent infinite chain cycles - /// Reference to guest memory - pub mem: &'a M, - /// Index into the descriptor table pub index: u16, @@ -95,38 +91,20 @@ pub struct DescriptorChain<'a, M: GuestMemory = GuestMemoryMmap> { pub next: u16, } -impl<'a, M: GuestMemory> DescriptorChain<'a, M> { +impl DescriptorChain { /// Creates a new `DescriptorChain` from the given memory and descriptor table. /// /// Note that the desc_table and queue_size are assumed to be validated by the caller. - fn checked_new( - mem: &'a M, - desc_table: GuestAddress, - queue_size: u16, - index: u16, - ) -> Option { - if index >= queue_size { + fn checked_new(desc_table_ptr: *const Descriptor, queue_size: u16, index: u16) -> Option { + if queue_size <= index { return None; } - // There's no need for checking as we already validated the descriptor table and index - // bounds. - let desc_head = desc_table.unchecked_add(u64::from(index) * 16); - - // These reads can't fail unless Guest memory is hopelessly broken. - let desc = match mem.read_obj::(desc_head) { - Ok(ret) => ret, - Err(err) => { - error!( - "Failed to read virtio descriptor from memory at address {:#x}: {}", - desc_head.0, err - ); - return None; - } - }; + // SAFETY: + // index is in 0..queue_size bounds + let desc = unsafe { desc_table_ptr.add(usize::from(index)).read_volatile() }; let chain = DescriptorChain { - mem, - desc_table, + desc_table_ptr, queue_size, ttl: queue_size, index, @@ -166,7 +144,7 @@ impl<'a, M: GuestMemory> DescriptorChain<'a, M> { /// the head of the next _available_ descriptor chain. pub fn next_descriptor(&self) -> Option { if self.has_next() { - DescriptorChain::checked_new(self.mem, self.desc_table, self.queue_size, self.next).map( + DescriptorChain::checked_new(self.desc_table_ptr, self.queue_size, self.next).map( |mut c| { c.ttl = self.ttl - 1; c @@ -179,19 +157,19 @@ impl<'a, M: GuestMemory> DescriptorChain<'a, M> { } #[derive(Debug)] -pub struct DescriptorIterator<'a>(Option>); +pub struct DescriptorIterator(Option); -impl<'a> IntoIterator for DescriptorChain<'a> { - type Item = DescriptorChain<'a>; - type IntoIter = DescriptorIterator<'a>; +impl IntoIterator for DescriptorChain { + type Item = DescriptorChain; + type IntoIter = DescriptorIterator; fn into_iter(self) -> Self::IntoIter { DescriptorIterator(Some(self)) } } -impl<'a> Iterator for DescriptorIterator<'a> { - type Item = DescriptorChain<'a>; +impl Iterator for DescriptorIterator { + type Item = DescriptorChain; fn next(&mut self) -> Option { self.0.take().map(|desc| { @@ -214,13 +192,57 @@ pub struct Queue { pub ready: bool, /// Guest physical address of the descriptor table - pub desc_table: GuestAddress, + pub desc_table_address: GuestAddress, /// Guest physical address of the available ring - pub avail_ring: GuestAddress, + pub avail_ring_address: GuestAddress, /// Guest physical address of the used ring - pub used_ring: GuestAddress, + pub used_ring_address: GuestAddress, + + /// Host virtual address pointer to the descriptor table + /// in the guest memory . + /// Getting access to the underling + /// data structure should only occur after the + /// struct is initialized with `new`. + /// Representation of in memory struct layout. + /// struct DescriptorTable = [Descriptor; ] + pub desc_table_ptr: *const Descriptor, + + /// Host virtual address pointer to the available ring + /// in the guest memory . + /// Getting access to the underling + /// data structure should only occur after the + /// struct is initialized with `new`. + /// + /// Representation of in memory struct layout. + /// struct AvailRing { + /// flags: u16, + /// idx: u16, + /// ring: [u16; ], + /// used_event: u16, + /// } + /// + /// Because all types in the AvailRing are u16, + /// we store pointer as *mut u16 for simplicity. + pub avail_ring_ptr: *mut u16, + + /// Host virtual address pointer to the used ring + /// in the guest memory . + /// Getting access to the underling + /// data structure should only occur after the + /// struct is initialized with `new`. + /// + /// Representation of in memory struct layout. + // struct UsedRing { + // flags: u16, + // idx: u16, + // ring: [UsedElement; ], + // avail_event: u16, + // } + /// Because types in the UsedRing are different (u16 and u32) + /// store pointer as *mut u8. + pub used_ring_ptr: *mut u8, pub next_avail: Wrapping, pub next_used: Wrapping, @@ -231,6 +253,12 @@ pub struct Queue { pub num_added: Wrapping, } +/// SAFETY: Queue is Send, because we use volatile memory accesses when +/// working with pointers. These pointers are not copied or store anywhere +/// else. We assume guest will not give different queues same guest memory +/// addresses. +unsafe impl Send for Queue {} + #[allow(clippy::len_without_is_empty)] impl Queue { /// Constructs an empty virtio queue with the given `max_size`. @@ -239,9 +267,14 @@ impl Queue { max_size, size: 0, ready: false, - desc_table: GuestAddress(0), - avail_ring: GuestAddress(0), - used_ring: GuestAddress(0), + desc_table_address: GuestAddress(0), + avail_ring_address: GuestAddress(0), + used_ring_address: GuestAddress(0), + + desc_table_ptr: std::ptr::null(), + avail_ring_ptr: std::ptr::null_mut(), + used_ring_ptr: std::ptr::null_mut(), + next_avail: Wrapping(0), next_used: Wrapping(0), uses_notif_suppression: false, @@ -249,6 +282,150 @@ impl Queue { } } + fn desc_table_size(&self) -> usize { + std::mem::size_of::() * usize::from(self.size) + } + + fn avail_ring_size(&self) -> usize { + std::mem::size_of::() + + std::mem::size_of::() + + std::mem::size_of::() * usize::from(self.size) + + std::mem::size_of::() + } + + fn used_ring_size(&self) -> usize { + std::mem::size_of::() + + std::mem::size_of::() + + std::mem::size_of::() * usize::from(self.size) + + std::mem::size_of::() + } + + fn get_slice_ptr( + &self, + mem: &M, + addr: GuestAddress, + len: usize, + ) -> Result<*mut u8, QueueError> { + let slice = mem.get_slice(addr, len).map_err(QueueError::MemoryError)?; + slice.bitmap().mark_dirty(0, len); + Ok(slice.ptr_guard_mut().as_ptr()) + } + + /// Set up pointers to the queue objects in the guest memory + /// and mark memory dirty for those objects + pub fn initialize(&mut self, mem: &M) -> Result<(), QueueError> { + self.desc_table_ptr = self + .get_slice_ptr(mem, self.desc_table_address, self.desc_table_size())? + .cast(); + self.avail_ring_ptr = self + .get_slice_ptr(mem, self.avail_ring_address, self.avail_ring_size())? + .cast(); + self.used_ring_ptr = self + .get_slice_ptr(mem, self.used_ring_address, self.used_ring_size())? + .cast(); + + // Disable it for kani tests, otherwise it will hit this assertion + // and fail. + #[cfg(not(kani))] + if self.actual_size() < self.len() { + return Err(QueueError::InvalidQueueSize(self.len(), self.actual_size())); + } + + Ok(()) + } + + /// Mark memory used for queue objects as dirty. + pub fn mark_memory_dirty(&self, mem: &M) -> Result<(), QueueError> { + _ = self.get_slice_ptr(mem, self.desc_table_address, self.desc_table_size())?; + _ = self.get_slice_ptr(mem, self.avail_ring_address, self.avail_ring_size())?; + _ = self.get_slice_ptr(mem, self.used_ring_address, self.used_ring_size())?; + Ok(()) + } + + /// Get AvailRing.idx + #[inline(always)] + pub fn avail_ring_idx_get(&self) -> u16 { + // SAFETY: `idx` is 1 u16 away from the start + unsafe { self.avail_ring_ptr.add(1).read_volatile() } + } + + /// Get element from AvailRing.ring at index + /// # Safety + /// The `index` parameter should be in 0..queue_size bounds + #[inline(always)] + unsafe fn avail_ring_ring_get(&self, index: usize) -> u16 { + // SAFETY: `ring` is 2 u16 away from the start + unsafe { self.avail_ring_ptr.add(2).add(index).read_volatile() } + } + + /// Get AvailRing.used_event + #[inline(always)] + pub fn avail_ring_used_event_get(&self) -> u16 { + // SAFETY: `used_event` is 2 + self.len u16 away from the start + unsafe { + self.avail_ring_ptr + .add(2_usize.unchecked_add(usize::from(self.size))) + .read_volatile() + } + } + + /// Set UsedRing.idx + #[inline(always)] + pub fn used_ring_idx_set(&mut self, val: u16) { + // SAFETY: `idx` is 1 u16 away from the start + unsafe { + self.used_ring_ptr + .add(std::mem::size_of::()) + .cast::() + .write_volatile(val) + } + } + + /// Get element from UsedRing.ring at index + /// # Safety + /// The `index` parameter should be in 0..queue_size bounds + #[inline(always)] + unsafe fn used_ring_ring_set(&mut self, index: usize, val: UsedElement) { + // SAFETY: `ring` is 2 u16 away from the start + unsafe { + self.used_ring_ptr + .add(std::mem::size_of::().unchecked_mul(2)) + .cast::() + .add(index) + .write_volatile(val) + } + } + + #[cfg(any(test, kani))] + #[inline(always)] + pub fn used_ring_avail_event_get(&mut self) -> u16 { + // SAFETY: `avail_event` is 2 * u16 and self.len * UsedElement away from the start + unsafe { + self.used_ring_ptr + .add( + std::mem::size_of::().unchecked_mul(2) + + std::mem::size_of::().unchecked_mul(usize::from(self.size)), + ) + .cast::() + .read_volatile() + } + } + + /// Set UsedRing.avail_event + #[inline(always)] + pub fn used_ring_avail_event_set(&mut self, val: u16) { + // SAFETY: `avail_event` is 2 * u16 and self.len * UsedElement away from the start + unsafe { + self.used_ring_ptr + .add( + std::mem::size_of::().unchecked_mul(2) + + std::mem::size_of::().unchecked_mul(usize::from(self.size)), + ) + .cast::() + .write_volatile(val) + } + } + /// Maximum size of the queue. pub fn get_max_size(&self) -> u16 { self.max_size @@ -261,14 +438,13 @@ impl Queue { } /// Validates the queue's in-memory layout is correct. - pub fn is_layout_valid(&self, mem: &M) -> bool { - let queue_size = usize::from(self.actual_size()); - let desc_table = self.desc_table; - let desc_table_size = 16 * queue_size; - let avail_ring = self.avail_ring; - let avail_ring_size = 6 + 2 * queue_size; - let used_ring = self.used_ring; - let used_ring_size = 6 + 8 * queue_size; + pub fn is_valid(&self, mem: &M) -> bool { + let desc_table = self.desc_table_address; + let desc_table_size = self.desc_table_size(); + let avail_ring = self.avail_ring_address; + let avail_ring_size = self.avail_ring_size(); + let used_ring = self.used_ring_address; + let used_ring_size = self.used_ring_size(); if !self.ready { error!("attempt to use virtio queue that is not marked ready"); @@ -313,39 +489,19 @@ impl Queue { } } - /// Validates that the queue's representation is correct. - pub fn is_valid(&self, mem: &M) -> bool { - if !self.is_layout_valid(mem) { - false - } else if self.len(mem) > self.max_size { - error!( - "virtio queue number of available descriptors {} is greater than queue max size {}", - self.len(mem), - self.max_size - ); - false - } else { - true - } - } - /// Returns the number of yet-to-be-popped descriptor chains in the avail ring. - pub fn len(&self, mem: &M) -> u16 { - debug_assert!(self.is_layout_valid(mem)); - - (self.avail_idx(mem) - self.next_avail).0 + pub fn len(&self) -> u16 { + (Wrapping(self.avail_ring_idx_get()) - self.next_avail).0 } /// Checks if the driver has made any descriptor chains available in the avail ring. - pub fn is_empty(&self, mem: &M) -> bool { - self.len(mem) == 0 + pub fn is_empty(&self) -> bool { + self.len() == 0 } /// Pop the first available descriptor chain from the avail ring. - pub fn pop<'b, M: GuestMemory>(&mut self, mem: &'b M) -> Option> { - debug_assert!(self.is_layout_valid(mem)); - - let len = self.len(mem); + pub fn pop(&mut self) -> Option { + let len = self.len(); // The number of descriptor chain heads to process should always // be smaller or equal to the queue size, as the driver should // never ask the VMM to process a available ring entry more than @@ -366,24 +522,21 @@ impl Queue { return None; } - self.do_pop_unchecked(mem) + self.pop_unchecked() } /// Try to pop the first available descriptor chain from the avail ring. /// If no descriptor is available, enable notifications. - pub fn pop_or_enable_notification<'b, M: GuestMemory>( - &mut self, - mem: &'b M, - ) -> Option> { + pub fn pop_or_enable_notification(&mut self) -> Option { if !self.uses_notif_suppression { - return self.pop(mem); + return self.pop(); } - if self.try_enable_notification(mem) { + if self.try_enable_notification() { return None; } - self.do_pop_unchecked(mem) + self.pop_unchecked() } /// Pop the first available descriptor chain from the avail ring. @@ -391,10 +544,7 @@ impl Queue { /// # Important /// This is an internal method that ASSUMES THAT THERE ARE AVAILABLE DESCRIPTORS. Otherwise it /// will retrieve a descriptor that contains garbage data (obsolete/empty). - fn do_pop_unchecked<'b, M: GuestMemory>( - &mut self, - mem: &'b M, - ) -> Option> { + fn pop_unchecked(&mut self) -> Option { // This fence ensures all subsequent reads see the updated driver writes. fence(Ordering::Acquire); @@ -402,30 +552,15 @@ impl Queue { // In a naive notation, that would be: // `descriptor_table[avail_ring[next_avail]]`. // - // Avail ring has layout: - // struct AvailRing { - // flags: u16, - // idx: u16, - // ring: [u16; ], - // used_event: u16, - // } - // We calculate offset into `ring` field. - // We use `self.next_avail` to store the position, of the next available descriptor - // index in the `ring` field. Because `self.next_avail` is only incremented, the actual - // index into `AvailRing` is `self.next_avail % self.actual_size()`. - let desc_index_offset = std::mem::size_of::() - + std::mem::size_of::() - + std::mem::size_of::() * usize::from(self.next_avail.0 % self.actual_size()); - let desc_index_address = self - .avail_ring - .unchecked_add(usize_to_u64(desc_index_offset)); - - // `self.is_valid()` already performed all the bound checks on the descriptor table - // and virtq rings, so it's safe to unwrap guest memory reads and to use unchecked - // offsets. - let desc_index: u16 = mem.read_obj(desc_index_address).unwrap(); - - DescriptorChain::checked_new(mem, self.desc_table, self.actual_size(), desc_index).map( + // We use `self.next_avail` to store the position, in `ring`, of the next available + // descriptor index, with a twist: we always only increment `self.next_avail`, so the + // actual position will be `self.next_avail % self.actual_size()`. + let idx = self.next_avail.0 % self.actual_size(); + // SAFETY: + // index is bound by the queue size + let desc_index = unsafe { self.avail_ring_ring_get(usize::from(idx)) }; + + DescriptorChain::checked_new(self.desc_table_ptr, self.actual_size(), desc_index).map( |dc| { self.next_avail += Wrapping(1); dc @@ -440,20 +575,25 @@ impl Queue { } /// Puts an available descriptor head into the used ring for use by the guest. - pub fn add_used( - &mut self, - mem: &M, - desc_index: u16, - len: u32, - ) -> Result<(), QueueError> { - debug_assert!(self.is_layout_valid(mem)); + pub fn add_used(&mut self, desc_index: u16, len: u32) -> Result<(), QueueError> { + if self.actual_size() <= desc_index { + error!( + "attempted to add out of bounds descriptor to used ring: {}", + desc_index + ); + return Err(QueueError::DescIndexOutOfBounds(desc_index)); + } let next_used = self.next_used.0 % self.actual_size(); let used_element = UsedElement { id: u32::from(desc_index), len, }; - self.write_used_ring(mem, next_used, used_element)?; + // SAFETY: + // index is bound by the queue size + unsafe { + self.used_ring_ring_set(usize::from(next_used), used_element); + } self.num_added += Wrapping(1); self.next_used += Wrapping(1); @@ -461,122 +601,22 @@ impl Queue { // This fence ensures all descriptor writes are visible before the index update is. fence(Ordering::Release); - self.set_used_ring_idx(self.next_used.0, mem); + self.used_ring_idx_set(self.next_used.0); Ok(()) } - fn write_used_ring( - &self, - mem: &M, - index: u16, - used_element: UsedElement, - ) -> Result<(), QueueError> { - if used_element.id >= u32::from(self.actual_size()) { - error!( - "attempted to add out of bounds descriptor to used ring: {}", - used_element.id - ); - return Err(QueueError::DescIndexOutOfBounds(used_element.id)); - } - - // Used ring has layout: - // struct UsedRing { - // flags: u16, - // idx: u16, - // ring: [UsedElement; ], - // avail_event: u16, - // } - // We calculate offset into `ring` field. - let used_ring_offset = std::mem::size_of::() - + std::mem::size_of::() - + std::mem::size_of::() * usize::from(index); - let used_element_address = self.used_ring.unchecked_add(usize_to_u64(used_ring_offset)); - - mem.write_obj(used_element, used_element_address) - .map_err(QueueError::UsedRing) - } - - /// Fetch the available ring index (`virtq_avail->idx`) from guest memory. - /// This is written by the driver, to indicate the next slot that will be filled in the avail - /// ring. - pub fn avail_idx(&self, mem: &M) -> Wrapping { - // Bound checks for queue inner data have already been performed, at device activation time, - // via `self.is_valid()`, so it's safe to unwrap and use unchecked offsets here. - // Note: the `MmioTransport` code ensures that queue addresses cannot be changed by the - // guest after device activation, so we can be certain that no change has - // occurred since the last `self.is_valid()` check. - let addr = self.avail_ring.unchecked_add(2); - Wrapping(mem.read_obj::(addr).unwrap()) - } - - /// Get the value of the used event field of the avail ring. - #[inline(always)] - pub fn used_event(&self, mem: &M) -> Wrapping { - debug_assert!(self.is_layout_valid(mem)); - - // We need to find the `used_event` field from the avail ring. - let used_event_addr = self - .avail_ring - .unchecked_add(u64::from(4 + 2 * self.actual_size())); - - Wrapping(mem.read_obj::(used_event_addr).unwrap()) - } - - /// Helper method that writes to the `avail_event` field of the used ring. - #[inline(always)] - fn set_used_ring_avail_event(&mut self, avail_event: u16, mem: &M) { - debug_assert!(self.is_layout_valid(mem)); - - // Used ring has layout: - // struct UsedRing { - // flags: u16, - // idx: u16, - // ring: [UsedElement; ], - // avail_event: u16, - // } - // We calculate offset into `avail_event` field. - let avail_event_offset = std::mem::size_of::() - + std::mem::size_of::() - + std::mem::size_of::() * usize::from(self.actual_size()); - let avail_event_addr = self - .used_ring - .unchecked_add(usize_to_u64(avail_event_offset)); - - mem.write_obj(avail_event, avail_event_addr).unwrap(); - } - - /// Helper method that writes to the `idx` field of the used ring. - #[inline(always)] - fn set_used_ring_idx(&mut self, next_used: u16, mem: &M) { - debug_assert!(self.is_layout_valid(mem)); - - // Used ring has layout: - // struct UsedRing { - // flags: u16, - // idx: u16, - // ring: [UsedElement; ], - // avail_event: u16, - // } - // We calculate offset into `idx` field. - let idx_offset = std::mem::size_of::(); - let next_used_addr = self.used_ring.unchecked_add(usize_to_u64(idx_offset)); - mem.write_obj(next_used, next_used_addr).unwrap(); - } - /// Try to enable notification events from the guest driver. Returns true if notifications were /// successfully enabled. Otherwise it means that one or more descriptors can still be consumed /// from the available ring and we can't guarantee that there will be a notification. In this /// case the caller might want to consume the mentioned descriptors and call this method again. - pub fn try_enable_notification(&mut self, mem: &M) -> bool { - debug_assert!(self.is_layout_valid(mem)); - + pub fn try_enable_notification(&mut self) -> bool { // If the device doesn't use notification suppression, we'll continue to get notifications // no matter what. if !self.uses_notif_suppression { return true; } - let len = self.len(mem); + let len = self.len(); if len != 0 { // The number of descriptor chain heads to process should always // be smaller or equal to the queue size. @@ -595,14 +635,14 @@ impl Queue { } // Set the next expected avail_idx as avail_event. - self.set_used_ring_avail_event(self.next_avail.0, mem); + self.used_ring_avail_event_set(self.next_avail.0); - // Make sure all subsequent reads are performed after `set_used_ring_avail_event`. + // Make sure all subsequent reads are performed after we set avail_event. fence(Ordering::SeqCst); // If the actual avail_idx is different than next_avail one or more descriptors can still // be consumed from the available ring. - self.next_avail.0 == self.avail_idx(mem).0 + self.next_avail.0 == self.avail_ring_idx_get() } /// Enable notification suppression. @@ -617,9 +657,7 @@ impl Queue { /// updates `used_event` and/or the notification conditions hold once more. /// /// This is similar to the `vring_need_event()` method implemented by the Linux kernel. - pub fn prepare_kick(&mut self, mem: &M) -> bool { - debug_assert!(self.is_layout_valid(mem)); - + pub fn prepare_kick(&mut self) -> bool { // If the device doesn't use notification suppression, always return true if !self.uses_notif_suppression { return true; @@ -630,7 +668,7 @@ impl Queue { let new = self.next_used; let old = self.next_used - self.num_added; - let used_event = self.used_event(mem); + let used_event = Wrapping(self.avail_ring_used_event_get()); self.num_added = Wrapping(0); @@ -647,9 +685,7 @@ mod verification { use vm_memory::guest_memory::GuestMemoryIterator; use vm_memory::{GuestMemoryRegion, MemoryRegionAddress}; - use crate::devices::virtio::queue::{ - Descriptor, DescriptorChain, Queue, FIRECRACKER_MAX_QUEUE_SIZE, VIRTQ_DESC_F_NEXT, - }; + use super::*; use crate::vstate::memory::{Bytes, FileOffset, GuestAddress, GuestMemory, MmapRegion}; /// A made-for-kani version of `vm_memory::GuestMemoryMmap`. Unlike the real @@ -790,15 +826,6 @@ mod verification { ) } - fn setup_zeroed_guest_memory() -> ProofGuestMemory { - guest_memory(unsafe { - std::alloc::alloc_zeroed(std::alloc::Layout::from_size_align_unchecked( - GUEST_MEMORY_SIZE, - 16, - )) - }) - } - // Constants describing the in-memory layout of a queue of size FIRECRACKER_MAX_SIZE starting // at the beginning of guest memory. These are based on Section 2.6 of the VirtIO 1.1 // specification. @@ -823,9 +850,9 @@ mod verification { queue.size = FIRECRACKER_MAX_QUEUE_SIZE; queue.ready = true; - queue.desc_table = GuestAddress(QUEUE_BASE_ADDRESS); - queue.avail_ring = GuestAddress(AVAIL_RING_BASE_ADDRESS); - queue.used_ring = GuestAddress(USED_RING_BASE_ADDRESS); + queue.desc_table_address = GuestAddress(QUEUE_BASE_ADDRESS); + queue.avail_ring_address = GuestAddress(AVAIL_RING_BASE_ADDRESS); + queue.used_ring_address = GuestAddress(USED_RING_BASE_ADDRESS); queue.next_avail = Wrapping(kani::any()); queue.next_used = Wrapping(kani::any()); queue.uses_notif_suppression = kani::any(); @@ -839,20 +866,10 @@ mod verification { /// fixed to a known valid one pub fn bounded_queue() -> Self { let mem = setup_kani_guest_memory(); - let queue = less_arbitrary_queue(); - - assert!(queue.is_layout_valid(&mem)); - - ProofContext(queue, mem) - } + let mut queue = less_arbitrary_queue(); + queue.initialize(&mem).unwrap(); - /// Creates a [`ProofContext`] where the queue layout is fixed to a valid one and where - /// guest memory is initialized to all zeros. - pub fn bounded() -> Self { - let mem = setup_zeroed_guest_memory(); - let queue = less_arbitrary_queue(); - - assert!(queue.is_layout_valid(&mem)); + assert!(queue.is_valid(&mem)); ProofContext(queue, mem) } @@ -861,9 +878,10 @@ mod verification { impl kani::Arbitrary for ProofContext { fn any() -> Self { let mem = setup_kani_guest_memory(); - let queue: Queue = kani::any(); + let mut queue: Queue = kani::any(); - kani::assume(queue.is_layout_valid(&mem)); + kani::assume(queue.is_valid(&mem)); + queue.initialize(&mem).unwrap(); ProofContext(queue, mem) } @@ -876,9 +894,9 @@ mod verification { queue.size = kani::any(); queue.ready = true; - queue.desc_table = GuestAddress(kani::any()); - queue.avail_ring = GuestAddress(kani::any()); - queue.used_ring = GuestAddress(kani::any()); + queue.desc_table_address = GuestAddress(kani::any()); + queue.avail_ring_address = GuestAddress(kani::any()); + queue.used_ring_address = GuestAddress(kani::any()); queue.next_avail = Wrapping(kani::any()); queue.next_used = Wrapping(kani::any()); queue.uses_notif_suppression = kani::any(); @@ -899,21 +917,6 @@ mod verification { } } - mod stubs { - use super::*; - - // Calls to set_used_ring_avail_event tend to cause memory to grow unboundedly during - // verification. The function writes to the `avail_event` of the virtio queue, which - // is not read from by the device. It is only intended to be used by guest. - // Therefore, it does not affect any device functionality (e.g. its only call site, - // try_enable_notification, will behave independently of what value was written - // here). Thus we can stub it out with a no-op. Note that we have a separate harness - // for set_used_ring_avail_event, to ensure the function itself is sound. - fn set_used_ring_avail_event(_self: &mut Queue, _val: u16, _mem: &M) { - // do nothing - } - } - #[kani::proof] #[kani::unwind(0)] // There are no loops anywhere, but kani really enjoys getting stuck in std::ptr::drop_in_place. // This is a compiler intrinsic that has a "dummy" implementation in stdlib that just @@ -926,10 +929,10 @@ mod verification { // has been processed. This is done by the driver // defining a "used_event" index, which tells the device "please do not notify me until // used.ring[used_event] has been written to by you". - let ProofContext(mut queue, mem) = ProofContext::bounded_queue(); + let ProofContext(mut queue, _) = kani::any(); let num_added_old = queue.num_added.0; - let needs_notification = queue.prepare_kick(&mem); + let needs_notification = queue.prepare_kick(); // uses_notif_suppression equivalent to VIRTIO_F_EVENT_IDX negotiated if !queue.uses_notif_suppression { @@ -937,12 +940,14 @@ mod verification { // After the device writes a descriptor index into the used ring: // – If flags is 1, the device SHOULD NOT send a notification. // – If flags is 0, the device MUST send a notification - // flags is the first field in the avail_ring, which we completely ignore. We + // flags is the first field in the avail_ring_address, which we completely ignore. We // always send a notification, and as there only is a SHOULD NOT, that is okay assert!(needs_notification); } else { // next_used - 1 is where the previous descriptor was placed - if queue.used_event(&mem) == queue.next_used - Wrapping(1) && num_added_old > 0 { + if Wrapping(queue.avail_ring_used_event_get()) == queue.next_used - Wrapping(1) + && num_added_old > 0 + { // If the idx field in the used ring (which determined where that descriptor index // was placed) was equal to used_event, the device MUST send a // notification. @@ -965,7 +970,7 @@ mod verification { // number of added descriptors being counted in Queue.num_added), and then use // "prepare_kick" to check if any of those descriptors should have triggered a // notification. - let ProofContext(mut queue, mem) = ProofContext::bounded_queue(); + let ProofContext(mut queue, _) = kani::any(); queue.enable_notif_suppression(); assert!(queue.uses_notif_suppression); @@ -982,7 +987,7 @@ mod verification { // [next_used - num_added - 1, u16::MAX] ∪ [0, next_used - 1]. Since queue size is at most // 2^15, intervals can only wrap at most once. This gives us the following logic: - let used_event = queue.used_event(&mem); + let used_event = Wrapping(queue.avail_ring_used_event_get()); let interval_start = queue.next_used - queue.num_added; let interval_end = queue.next_used - Wrapping(1); let needs_notification = if queue.num_added.0 == 0 { @@ -993,15 +998,51 @@ mod verification { used_event >= interval_start && used_event <= interval_end }; - assert_eq!(queue.prepare_kick(&mem), needs_notification); + assert_eq!(queue.prepare_kick(), needs_notification); + } + + #[kani::proof] + #[kani::unwind(0)] + fn verify_add_used() { + let ProofContext(mut queue, _) = kani::any(); + + // The spec here says (2.6.8.2): + // + // The device MUST set len prior to updating the used idx. + // The device MUST write at least len bytes to descriptor, beginning at the first + // device-writable buffer, prior to updating the used idx. + // The device MAY write more than len bytes to descriptor. + // + // We can't really verify any of these. We can verify that guest memory is updated correctly + // though + + // index into used ring at which the index of the descriptor to which + // the device wrote. + let used_idx = queue.next_used; + + let used_desc_table_index = kani::any(); + if queue.add_used(used_desc_table_index, kani::any()).is_ok() { + assert_eq!(queue.next_used, used_idx + Wrapping(1)); + } else { + assert_eq!(queue.next_used, used_idx); + + // Ideally, here we would want to actually read the relevant values from memory and + // assert they are unchanged. However, kani will run out of memory if we try to do so, + // so we instead verify the following "proxy property": If an error happened, then + // it happened at the very beginning of add_used, meaning no memory accesses were + // done. This is relying on implementation details of add_used, namely that + // the check for out-of-bounds descriptor index happens at the very beginning of the + // function. + assert!(used_desc_table_index >= queue.actual_size()); + } } #[kani::proof] #[kani::unwind(0)] fn verify_is_empty() { - let ProofContext(queue, mem) = ProofContext::bounded_queue(); + let ProofContext(queue, _) = kani::any(); - assert_eq!(queue.len(&mem) == 0, queue.is_empty(&mem)); + assert_eq!(queue.len() == 0, queue.is_empty()); } #[kani::proof] @@ -1021,9 +1062,9 @@ mod verification { } } - assert!(alignment_of(queue.desc_table.0) >= 16); - assert!(alignment_of(queue.avail_ring.0) >= 2); - assert!(alignment_of(queue.used_ring.0) >= 4); + assert!(alignment_of(queue.desc_table_address.0) >= 16); + assert!(alignment_of(queue.avail_ring_address.0) >= 2); + assert!(alignment_of(queue.used_ring_address.0) >= 4); // length of queue must be power-of-two, and at most 2^15 assert_eq!(queue.size.count_ones(), 1); @@ -1042,29 +1083,71 @@ mod verification { #[kani::proof] #[kani::unwind(0)] - fn verify_set_used_ring_avail_event() { - let ProofContext(mut queue, mem) = ProofContext::bounded_queue(); + fn verify_avail_ring_idx_get() { + let ProofContext(queue, _) = kani::any(); + _ = queue.avail_ring_idx_get(); + } + + #[kani::proof] + #[kani::unwind(0)] + fn verify_avail_ring_ring_get() { + let ProofContext(queue, _) = kani::any(); + let x: usize = kani::any_where(|x| *x < usize::from(queue.size)); + unsafe { _ = queue.avail_ring_ring_get(x) }; + } + + #[kani::proof] + #[kani::unwind(0)] + fn verify_avail_ring_used_event_get() { + let ProofContext(queue, _) = kani::any(); + _ = queue.avail_ring_used_event_get(); + } + + #[kani::proof] + #[kani::unwind(0)] + fn verify_used_ring_idx_set() { + let ProofContext(mut queue, _) = kani::any(); + queue.used_ring_idx_set(kani::any()); + } + + #[kani::proof] + #[kani::unwind(0)] + fn verify_used_ring_ring_set() { + let ProofContext(mut queue, _) = kani::any(); + let x: usize = kani::any_where(|x| *x < usize::from(queue.size)); + let used_element = UsedElement { + id: kani::any(), + len: kani::any(), + }; + unsafe { queue.used_ring_ring_set(x, used_element) }; + } - queue.set_used_ring_avail_event(kani::any(), &mem); + #[kani::proof] + #[kani::unwind(0)] + fn verify_used_ring_avail_event() { + let ProofContext(mut queue, _) = kani::any(); + let x = kani::any(); + queue.used_ring_avail_event_set(x); + assert_eq!(x, queue.used_ring_avail_event_get()); } #[kani::proof] #[kani::unwind(0)] #[kani::solver(cadical)] fn verify_pop() { - let ProofContext(mut queue, mem) = ProofContext::bounded_queue(); + let ProofContext(mut queue, _) = kani::any(); // This is an assertion in pop which we use to abort firecracker in a ddos scenario // This condition being false means that the guest is asking us to process every element // in the queue multiple times. It cannot be checked by is_valid, as that function // is called when the queue is being initialized, e.g. empty. We compute it using // local variables here to make things easier on kani: One less roundtrip through vm-memory. - let queue_len = queue.len(&mem); + let queue_len = queue.len(); kani::assume(queue_len <= queue.actual_size()); let next_avail = queue.next_avail; - if let Some(_) = queue.pop(&mem) { + if let Some(_) = queue.pop() { // Can't get anything out of an empty queue, assert queue_len != 0 assert_ne!(queue_len, 0); assert_eq!(queue.next_avail, next_avail + Wrapping(1)); @@ -1075,13 +1158,13 @@ mod verification { #[kani::unwind(0)] #[kani::solver(cadical)] fn verify_undo_pop() { - let ProofContext(mut queue, mem) = ProofContext::bounded_queue(); + let ProofContext(mut queue, _) = kani::any(); // See verify_pop for explanation - kani::assume(queue.len(&mem) <= queue.actual_size()); + kani::assume(queue.len() <= queue.actual_size()); let queue_clone = queue.clone(); - if let Some(_) = queue.pop(&mem) { + if let Some(_) = queue.pop() { queue.undo_pop(); assert_eq!(queue, queue_clone); @@ -1091,18 +1174,17 @@ mod verification { #[kani::proof] #[kani::unwind(0)] - #[kani::stub(Queue::set_used_ring_avail_event, stubs::set_used_ring_avail_event)] fn verify_try_enable_notification() { - let ProofContext(mut queue, mem) = ProofContext::bounded_queue(); + let ProofContext(mut queue, _) = ProofContext::bounded_queue(); - kani::assume(queue.len(&mem) <= queue.actual_size()); + kani::assume(queue.len() <= queue.actual_size()); - if queue.try_enable_notification(&mem) && queue.uses_notif_suppression { + if queue.try_enable_notification() && queue.uses_notif_suppression { // We only require new notifications if the queue is empty (e.g. we've processed // everything we've been notified about), or if suppression is disabled. - assert!(queue.is_empty(&mem)); + assert!(queue.is_empty()); - assert_eq!(queue.avail_idx(&mem), queue.next_avail) + assert_eq!(Wrapping(queue.avail_ring_idx_get()), queue.next_avail) } } @@ -1110,11 +1192,11 @@ mod verification { #[kani::unwind(0)] #[kani::solver(cadical)] fn verify_checked_new() { - let ProofContext(queue, mem) = ProofContext::bounded_queue(); + let ProofContext(queue, mem) = kani::any(); let index = kani::any(); let maybe_chain = - DescriptorChain::checked_new(&mem, queue.desc_table, queue.actual_size(), index); + DescriptorChain::checked_new(queue.desc_table_ptr, queue.actual_size(), index); if index >= queue.actual_size() { assert!(maybe_chain.is_none()) @@ -1123,7 +1205,7 @@ mod verification { // able to compute the address of the descriptor table entry without going out // of bounds anywhere, and also read from that address. let desc_head = mem - .checked_offset(queue.desc_table, (index as usize) * 16) + .checked_offset(queue.desc_table_address, (index as usize) * 16) .unwrap(); mem.checked_offset(desc_head, 16).unwrap(); let desc = mem.read_obj::(desc_head).unwrap(); @@ -1144,34 +1226,25 @@ mod verification { #[cfg(test)] mod tests { + use vm_memory::Bytes; + pub use super::*; - use crate::devices::virtio::queue::QueueError::{DescIndexOutOfBounds, UsedRing}; + use crate::devices::virtio::queue::QueueError::DescIndexOutOfBounds; use crate::devices::virtio::test_utils::{default_mem, VirtQueue}; use crate::utilities::test_utils::{multi_region_mem, single_region_mem}; - use crate::vstate::memory::{GuestAddress, GuestMemoryMmap}; - - impl Queue { - fn avail_event(&self, mem: &GuestMemoryMmap) -> u16 { - let avail_event_addr = self - .used_ring - .unchecked_add(u64::from(4 + 8 * self.actual_size())); - - mem.read_obj::(avail_event_addr).unwrap() - } - } + use crate::vstate::memory::GuestAddress; #[test] fn test_checked_new_descriptor_chain() { let m = &multi_region_mem(&[(GuestAddress(0), 0x10000), (GuestAddress(0x20000), 0x2000)]); let vq = VirtQueue::new(GuestAddress(0), m, 16); + let mut q = vq.create_queue(); + q.initialize(m).unwrap(); assert!(vq.end().0 < 0x1000); // index >= queue_size - assert!(DescriptorChain::checked_new(m, vq.dtable_start(), 16, 16).is_none()); - - // desc_table address is way off - assert!(DescriptorChain::checked_new(m, GuestAddress(0x00ff_ffff_ffff), 16, 0).is_none()); + assert!(DescriptorChain::checked_new(q.desc_table_ptr, 16, 16).is_none()); // Let's create an invalid chain. { @@ -1182,7 +1255,7 @@ mod tests { // .. but the index of the next descriptor is too large vq.dtable[0].next.set(16); - assert!(DescriptorChain::checked_new(m, vq.dtable_start(), 16, 0).is_none()); + assert!(DescriptorChain::checked_new(q.desc_table_ptr, 16, 0).is_none()); } // Finally, let's test an ok chain. @@ -1190,10 +1263,9 @@ mod tests { vq.dtable[0].next.set(1); vq.dtable[1].set(0x2000, 0x1000, 0, 0); - let c = DescriptorChain::checked_new(m, vq.dtable_start(), 16, 0).unwrap(); + let c = DescriptorChain::checked_new(q.desc_table_ptr, 16, 0).unwrap(); - assert_eq!(c.mem as *const GuestMemoryMmap, m as *const GuestMemoryMmap); - assert_eq!(c.desc_table, vq.dtable_start()); + assert_eq!(c.desc_table_ptr, q.desc_table_ptr); assert_eq!(c.queue_size, 16); assert_eq!(c.ttl, c.queue_size); assert_eq!(c.index, 0); @@ -1236,43 +1308,31 @@ mod tests { assert!(!q.is_valid(m)); q.size = q.max_size; - // or when avail_idx - next_avail > max_size - q.next_avail = Wrapping(5); - assert!(!q.is_valid(m)); - // avail_ring + 2 is the address of avail_idx in guest mem - m.write_obj::(64_u16, q.avail_ring.unchecked_add(2)) - .unwrap(); - assert!(!q.is_valid(m)); - m.write_obj::(5_u16, q.avail_ring.unchecked_add(2)) - .unwrap(); - q.max_size = 2; - assert!(!q.is_valid(m)); - // reset dirtied values q.max_size = 16; q.next_avail = Wrapping(0); - m.write_obj::(0, q.avail_ring.unchecked_add(2)) + m.write_obj::(0, q.avail_ring_address.unchecked_add(2)) .unwrap(); // or if the various addresses are off - q.desc_table = GuestAddress(0xffff_ffff); + q.desc_table_address = GuestAddress(0xffff_ffff); assert!(!q.is_valid(m)); - q.desc_table = GuestAddress(0x1001); + q.desc_table_address = GuestAddress(0x1001); assert!(!q.is_valid(m)); - q.desc_table = vq.dtable_start(); + q.desc_table_address = vq.dtable_start(); - q.avail_ring = GuestAddress(0xffff_ffff); + q.avail_ring_address = GuestAddress(0xffff_ffff); assert!(!q.is_valid(m)); - q.avail_ring = GuestAddress(0x1001); + q.avail_ring_address = GuestAddress(0x1001); assert!(!q.is_valid(m)); - q.avail_ring = vq.avail_start(); + q.avail_ring_address = vq.avail_start(); - q.used_ring = GuestAddress(0xffff_ffff); + q.used_ring_address = GuestAddress(0xffff_ffff); assert!(!q.is_valid(m)); - q.used_ring = GuestAddress(0x1001); + q.used_ring_address = GuestAddress(0x1001); assert!(!q.is_valid(m)); - q.used_ring = vq.used_start(); + q.used_ring_address = vq.used_start(); } #[test] @@ -1297,19 +1357,19 @@ mod tests { vq.avail.idx.set(2); // We've just set up two chains. - assert_eq!(q.len(m), 2); + assert_eq!(q.len(), 2); // The first chain should hold exactly two descriptors. - let d = q.pop(m).unwrap().next_descriptor().unwrap(); + let d = q.pop().unwrap().next_descriptor().unwrap(); assert!(!d.has_next()); assert!(d.next_descriptor().is_none()); // We popped one chain, so there should be only one left. - assert_eq!(q.len(m), 1); + assert_eq!(q.len(), 1); // The next chain holds three descriptors. let d = q - .pop(m) + .pop() .unwrap() .next_descriptor() .unwrap() @@ -1319,16 +1379,16 @@ mod tests { assert!(d.next_descriptor().is_none()); // We've popped both chains, so the queue should be empty. - assert!(q.is_empty(m)); - assert!(q.pop(m).is_none()); + assert!(q.is_empty()); + assert!(q.pop().is_none()); // Undoing the last pop should let us walk the last chain again. q.undo_pop(); - assert_eq!(q.len(m), 1); + assert_eq!(q.len(), 1); // Walk the last chain again (three descriptors). let d = q - .pop(m) + .pop() .unwrap() .next_descriptor() .unwrap() @@ -1339,11 +1399,11 @@ mod tests { // Undoing the last pop should let us walk the last chain again. q.undo_pop(); - assert_eq!(q.len(m), 1); + assert_eq!(q.len(), 1); // Walk the last chain again (three descriptors) using pop_or_enable_notification(). let d = q - .pop_or_enable_notification(m) + .pop_or_enable_notification() .unwrap() .next_descriptor() .unwrap() @@ -1354,15 +1414,15 @@ mod tests { // There are no more descriptors, but notification suppression is disabled. // Calling pop_or_enable_notification() should simply return None. - assert_eq!(q.avail_event(m), 0); - assert!(q.pop_or_enable_notification(m).is_none()); - assert_eq!(q.avail_event(m), 0); + assert_eq!(q.used_ring_avail_event_get(), 0); + assert!(q.pop_or_enable_notification().is_none()); + assert_eq!(q.used_ring_avail_event_get(), 0); // There are no more descriptors and notification suppression is enabled. Calling // pop_or_enable_notification() should enable notifications. q.enable_notif_suppression(); - assert!(q.pop_or_enable_notification(m).is_none()); - assert_eq!(q.avail_event(m), 2); + assert!(q.pop_or_enable_notification().is_none()); + assert_eq!(q.used_ring_avail_event_get(), 2); } #[test] @@ -1393,21 +1453,21 @@ mod tests { vq.avail.idx.set(2); // We've just set up two chains. - assert_eq!(q.len(m), 2); + assert_eq!(q.len(), 2); // We process the first descriptor. - let d = q.pop(m).unwrap().next_descriptor(); + let d = q.pop().unwrap().next_descriptor(); assert!(matches!(d, Some(x) if !x.has_next())); // We confuse the device and set the available index as being 6. vq.avail.idx.set(6); // We've actually just popped a descriptor so 6 - 1 = 5. - assert_eq!(q.len(m), 5); + assert_eq!(q.len(), 5); // However, since the apparent length set by the driver is more than the queue size, // we would be running the risk of going through some descriptors more than once. // As such, we expect to panic. - q.pop(m); + q.pop(); } #[test] @@ -1438,7 +1498,7 @@ mod tests { // driver sets available index to suspicious value. vq.avail.idx.set(6); - q.pop_or_enable_notification(m); + q.pop_or_enable_notification(); } #[test] @@ -1452,13 +1512,13 @@ mod tests { // Valid queue addresses configuration { // index too large - match q.add_used(m, 16, 0x1000) { + match q.add_used(16, 0x1000) { Err(DescIndexOutOfBounds(16)) => (), _ => unreachable!(), } // should be ok - q.add_used(m, 1, 0x1000).unwrap(); + q.add_used(1, 0x1000).unwrap(); assert_eq!(vq.used.idx.get(), 1); let x = vq.used.ring[0].get(); assert_eq!(x.id, 1); @@ -1472,13 +1532,13 @@ mod tests { let vq = VirtQueue::new(GuestAddress(0), m, 16); let q = vq.create_queue(); - assert_eq!(q.used_event(m), Wrapping(0)); + assert_eq!(q.avail_ring_used_event_get(), 0); vq.avail.event.set(10); - assert_eq!(q.used_event(m), Wrapping(10)); + assert_eq!(q.avail_ring_used_event_get(), 10); vq.avail.event.set(u16::MAX); - assert_eq!(q.used_event(m), Wrapping(u16::MAX)); + assert_eq!(q.avail_ring_used_event_get(), u16::MAX); } #[test] @@ -1489,10 +1549,10 @@ mod tests { let mut q = vq.create_queue(); assert_eq!(vq.used.event.get(), 0); - q.set_used_ring_avail_event(10, m); + q.used_ring_avail_event_set(10); assert_eq!(vq.used.event.get(), 10); - q.set_used_ring_avail_event(u16::MAX, m); + q.used_ring_avail_event_set(u16::MAX); assert_eq!(vq.used.event.get(), u16::MAX); } @@ -1512,7 +1572,7 @@ mod tests { q.next_used = Wrapping(used_idx); vq.avail.event.set(used_event); q.num_added = Wrapping(num_added); - assert!(q.prepare_kick(m)); + assert!(q.prepare_kick()); } } } @@ -1524,7 +1584,7 @@ mod tests { q.next_used = Wrapping(10); vq.avail.event.set(6); q.num_added = Wrapping(5); - assert!(q.prepare_kick(m)); + assert!(q.prepare_kick()); } { @@ -1532,7 +1592,7 @@ mod tests { q.next_used = Wrapping(10); vq.avail.event.set(6); q.num_added = Wrapping(4); - assert!(q.prepare_kick(m)); + assert!(q.prepare_kick()); } { @@ -1540,7 +1600,7 @@ mod tests { q.next_used = Wrapping(10); vq.avail.event.set(6); q.num_added = Wrapping(3); - assert!(!q.prepare_kick(m)); + assert!(!q.prepare_kick()); } } @@ -1557,27 +1617,27 @@ mod tests { vq.avail.ring[0].set(0); vq.avail.idx.set(1); - assert_eq!(q.len(m), 1); + assert_eq!(q.len(), 1); // Notification suppression is disabled. try_enable_notification shouldn't do anything. - assert!(q.try_enable_notification(m)); - assert_eq!(q.avail_event(m), 0); + assert!(q.try_enable_notification()); + assert_eq!(q.used_ring_avail_event_get(), 0); // Enable notification suppression and check again. There is 1 available descriptor chain. // Again nothing should happen. q.enable_notif_suppression(); - assert!(!q.try_enable_notification(m)); - assert_eq!(q.avail_event(m), 0); + assert!(!q.try_enable_notification()); + assert_eq!(q.used_ring_avail_event_get(), 0); // Consume the descriptor. avail_event should be modified - assert!(q.pop(m).is_some()); - assert!(q.try_enable_notification(m)); - assert_eq!(q.avail_event(m), 1); + assert!(q.pop().is_some()); + assert!(q.try_enable_notification()); + assert_eq!(q.used_ring_avail_event_get(), 1); } #[test] fn test_queue_error_display() { - let err = UsedRing(vm_memory::GuestMemoryError::InvalidGuestAddress( + let err = QueueError::MemoryError(vm_memory::GuestMemoryError::InvalidGuestAddress( GuestAddress(0), )); let _ = format!("{}{:?}", err, err); diff --git a/src/vmm/src/devices/virtio/rng/device.rs b/src/vmm/src/devices/virtio/rng/device.rs index f671f00e554..43742ab0327 100644 --- a/src/vmm/src/devices/virtio/rng/device.rs +++ b/src/vmm/src/devices/virtio/rng/device.rs @@ -42,7 +42,7 @@ pub struct Entropy { // Transport fields device_state: DeviceState, - queues: Vec, + pub(crate) queues: Vec, queue_events: Vec, irq_trigger: IrqTrigger, @@ -128,14 +128,14 @@ impl Entropy { let mem = self.device_state.mem().unwrap(); let mut used_any = false; - while let Some(desc) = self.queues[RNG_QUEUE].pop(mem) { + while let Some(desc) = self.queues[RNG_QUEUE].pop() { let index = desc.index; METRICS.entropy_event_count.inc(); // SAFETY: This descriptor chain is only loaded once // virtio requests are handled sequentially so no two IoVecBuffers // are live at the same time, meaning this has exclusive ownership over the memory - let bytes = match unsafe { IoVecBufferMut::from_descriptor_chain(desc) } { + let bytes = match unsafe { IoVecBufferMut::from_descriptor_chain(mem, desc) } { Ok(mut iovec) => { debug!( "entropy: guest request for {} bytes of entropy", @@ -165,7 +165,7 @@ impl Entropy { } }; - match self.queues[RNG_QUEUE].add_used(mem, index, bytes) { + match self.queues[RNG_QUEUE].add_used(index, bytes) { Ok(_) => { used_any = true; METRICS.entropy_bytes.add(bytes.into()); @@ -287,6 +287,11 @@ impl VirtioDevice for Entropy { } fn activate(&mut self, mem: GuestMemoryMmap) -> Result<(), ActivateError> { + for q in self.queues.iter_mut() { + q.initialize(&mem) + .map_err(ActivateError::QueueMemoryError)?; + } + self.activate_event.write(1).map_err(|_| { METRICS.activate_fails.inc(); ActivateError::EventFd @@ -429,17 +434,17 @@ mod tests { let mut entropy_dev = th.device(); // This should succeed, we just added two descriptors - let desc = entropy_dev.queues_mut()[RNG_QUEUE].pop(&mem).unwrap(); + let desc = entropy_dev.queues_mut()[RNG_QUEUE].pop().unwrap(); assert!(matches!( // SAFETY: This descriptor chain is only loaded into one buffer - unsafe { IoVecBufferMut::from_descriptor_chain(desc) }, + unsafe { IoVecBufferMut::from_descriptor_chain(&mem, desc) }, Err(crate::devices::virtio::iovec::IoVecError::ReadOnlyDescriptor) )); // This should succeed, we should have one more descriptor - let desc = entropy_dev.queues_mut()[RNG_QUEUE].pop(&mem).unwrap(); + let desc = entropy_dev.queues_mut()[RNG_QUEUE].pop().unwrap(); // SAFETY: This descriptor chain is only loaded into one buffer - let mut iovec = unsafe { IoVecBufferMut::from_descriptor_chain(desc).unwrap() }; + let mut iovec = unsafe { IoVecBufferMut::from_descriptor_chain(&mem, desc).unwrap() }; entropy_dev.handle_one(&mut iovec).unwrap(); } diff --git a/src/vmm/src/devices/virtio/test_utils.rs b/src/vmm/src/devices/virtio/test_utils.rs index daa0acd4f84..32807d7381e 100644 --- a/src/vmm/src/devices/virtio/test_utils.rs +++ b/src/vmm/src/devices/virtio/test_utils.rs @@ -289,9 +289,11 @@ impl<'a> VirtQueue<'a> { q.size = self.size(); q.ready = true; - q.desc_table = self.dtable_start(); - q.avail_ring = self.avail_start(); - q.used_ring = self.used_start(); + q.desc_table_address = self.dtable_start(); + q.avail_ring_address = self.avail_start(); + q.used_ring_address = self.used_start(); + + q.initialize(self.memory()).unwrap(); q } diff --git a/src/vmm/src/devices/virtio/vhost_user.rs b/src/vmm/src/devices/virtio/vhost_user.rs index 82d87a6bb6b..3e3ce0a2dca 100644 --- a/src/vmm/src/devices/virtio/vhost_user.rs +++ b/src/vmm/src/devices/virtio/vhost_user.rs @@ -420,14 +420,14 @@ impl VhostUserHandleImpl { queue_size: queue.actual_size(), flags: 0u32, desc_table_addr: mem - .get_host_address(queue.desc_table) + .get_host_address(queue.desc_table_address) .map_err(VhostUserError::DescriptorTableAddress)? as u64, used_ring_addr: mem - .get_host_address(queue.used_ring) + .get_host_address(queue.used_ring_address) .map_err(VhostUserError::UsedAddress)? as u64, avail_ring_addr: mem - .get_host_address(queue.avail_ring) + .get_host_address(queue.avail_ring_address) .map_err(VhostUserError::AvailAddress)? as u64, log_addr: None, }; @@ -436,7 +436,7 @@ impl VhostUserHandleImpl { .set_vring_addr(*queue_index, &config_data) .map_err(VhostUserError::VhostUserSetVringAddr)?; self.vu - .set_vring_base(*queue_index, queue.avail_idx(mem).0) + .set_vring_base(*queue_index, queue.avail_ring_idx_get()) .map_err(VhostUserError::VhostUserSetVringBase)?; // No matter the queue, we set irq_evt for signaling the guest that buffers were @@ -891,7 +891,9 @@ mod tests { let guest_memory = GuestMemoryMmap::from_raw_regions_file(regions, false, false).unwrap(); - let queue = Queue::new(69); + let mut queue = Queue::new(69); + queue.initialize(&guest_memory).unwrap(); + let event_fd = EventFd::new(0).unwrap(); let irq_trigger = IrqTrigger::new().unwrap(); @@ -909,12 +911,18 @@ mod tests { queue_max_size: 69, queue_size: 0, flags: 0, - desc_table_addr: guest_memory.get_host_address(queue.desc_table).unwrap() as u64, - used_ring_addr: guest_memory.get_host_address(queue.used_ring).unwrap() as u64, - avail_ring_addr: guest_memory.get_host_address(queue.avail_ring).unwrap() as u64, + desc_table_addr: guest_memory + .get_host_address(queue.desc_table_address) + .unwrap() as u64, + used_ring_addr: guest_memory + .get_host_address(queue.used_ring_address) + .unwrap() as u64, + avail_ring_addr: guest_memory + .get_host_address(queue.avail_ring_address) + .unwrap() as u64, log_addr: None, }, - base: queue.avail_idx(&guest_memory).0, + base: queue.avail_ring_idx_get(), call: irq_trigger.irq_evt.as_raw_fd(), kick: event_fd.as_raw_fd(), enable: true, diff --git a/src/vmm/src/devices/virtio/vsock/csm/connection.rs b/src/vmm/src/devices/virtio/vsock/csm/connection.rs index e38d9bab974..9fe744058c0 100644 --- a/src/vmm/src/devices/virtio/vsock/csm/connection.rs +++ b/src/vmm/src/devices/virtio/vsock/csm/connection.rs @@ -862,15 +862,13 @@ mod tests { let mut handler_ctx = vsock_test_ctx.create_event_handler_context(); let stream = TestStream::new(); let mut rx_pkt = VsockPacket::from_rx_virtq_head( - handler_ctx.device.queues[RXQ_INDEX] - .pop(&vsock_test_ctx.mem) - .unwrap(), + &vsock_test_ctx.mem, + handler_ctx.device.queues[RXQ_INDEX].pop().unwrap(), ) .unwrap(); let tx_pkt = VsockPacket::from_tx_virtq_head( - handler_ctx.device.queues[TXQ_INDEX] - .pop(&vsock_test_ctx.mem) - .unwrap(), + &vsock_test_ctx.mem, + handler_ctx.device.queues[TXQ_INDEX].pop().unwrap(), ) .unwrap(); let conn = match conn_state { diff --git a/src/vmm/src/devices/virtio/vsock/device.rs b/src/vmm/src/devices/virtio/vsock/device.rs index d48782901ec..17b1789191f 100644 --- a/src/vmm/src/devices/virtio/vsock/device.rs +++ b/src/vmm/src/devices/virtio/vsock/device.rs @@ -145,9 +145,9 @@ where let mut have_used = false; - while let Some(head) = self.queues[RXQ_INDEX].pop(mem) { + while let Some(head) = self.queues[RXQ_INDEX].pop() { let index = head.index; - let used_len = match VsockPacket::from_rx_virtq_head(head) { + let used_len = match VsockPacket::from_rx_virtq_head(mem, head) { Ok(mut pkt) => { if self.backend.recv_pkt(&mut pkt).is_ok() { match pkt.commit_hdr() { @@ -180,7 +180,7 @@ where have_used = true; self.queues[RXQ_INDEX] - .add_used(mem, index, used_len) + .add_used(index, used_len) .unwrap_or_else(|err| { error!("Failed to add available descriptor {}: {}", index, err) }); @@ -198,15 +198,15 @@ where let mut have_used = false; - while let Some(head) = self.queues[TXQ_INDEX].pop(mem) { + while let Some(head) = self.queues[TXQ_INDEX].pop() { let index = head.index; - let pkt = match VsockPacket::from_tx_virtq_head(head) { + let pkt = match VsockPacket::from_tx_virtq_head(mem, head) { Ok(pkt) => pkt, Err(err) => { error!("vsock: error reading TX packet: {:?}", err); have_used = true; self.queues[TXQ_INDEX] - .add_used(mem, index, 0) + .add_used(index, 0) .unwrap_or_else(|err| { error!("Failed to add available descriptor {}: {}", index, err); }); @@ -221,7 +221,7 @@ where have_used = true; self.queues[TXQ_INDEX] - .add_used(mem, index, 0) + .add_used(index, 0) .unwrap_or_else(|err| { error!("Failed to add available descriptor {}: {}", index, err); }); @@ -237,7 +237,7 @@ where // This is safe since we checked in the caller function that the device is activated. let mem = self.device_state.mem().unwrap(); - let head = self.queues[EVQ_INDEX].pop(mem).ok_or_else(|| { + let head = self.queues[EVQ_INDEX].pop().ok_or_else(|| { METRICS.ev_queue_event_fails.inc(); DeviceError::VsockError(VsockError::EmptyQueue) })?; @@ -246,7 +246,7 @@ where .unwrap_or_else(|err| error!("Failed to write virtio vsock reset event: {:?}", err)); self.queues[EVQ_INDEX] - .add_used(mem, head.index, head.len) + .add_used(head.index, head.len) .unwrap_or_else(|err| { error!("Failed to add used descriptor {}: {}", head.index, err); }); @@ -323,6 +323,11 @@ where } fn activate(&mut self, mem: GuestMemoryMmap) -> Result<(), ActivateError> { + for q in self.queues.iter_mut() { + q.initialize(&mem) + .map_err(ActivateError::QueueMemoryError)?; + } + if self.queues.len() != defs::VSOCK_NUM_QUEUES { METRICS.activate_fails.inc(); return Err(ActivateError::QueueMismatch { diff --git a/src/vmm/src/devices/virtio/vsock/event_handler.rs b/src/vmm/src/devices/virtio/vsock/event_handler.rs index 80225eaa543..99640f1b5c2 100755 --- a/src/vmm/src/devices/virtio/vsock/event_handler.rs +++ b/src/vmm/src/devices/virtio/vsock/event_handler.rs @@ -438,8 +438,8 @@ mod tests { ctx.guest_rxvq.dtable[desc_idx].len.set(len); // If the descriptor chain is already declared invalid, there's no reason to assemble // a packet. - if let Some(rx_desc) = ctx.device.queues[RXQ_INDEX].pop(&test_ctx.mem) { - VsockPacket::from_rx_virtq_head(rx_desc).unwrap_err(); + if let Some(rx_desc) = ctx.device.queues[RXQ_INDEX].pop() { + VsockPacket::from_rx_virtq_head(&test_ctx.mem, rx_desc).unwrap_err(); } } @@ -460,8 +460,8 @@ mod tests { ctx.guest_txvq.dtable[desc_idx].addr.set(addr); ctx.guest_txvq.dtable[desc_idx].len.set(len); - if let Some(tx_desc) = ctx.device.queues[TXQ_INDEX].pop(&test_ctx.mem) { - VsockPacket::from_tx_virtq_head(tx_desc).unwrap_err(); + if let Some(tx_desc) = ctx.device.queues[TXQ_INDEX].pop() { + VsockPacket::from_tx_virtq_head(&test_ctx.mem, tx_desc).unwrap_err(); } } } @@ -485,14 +485,14 @@ mod tests { // The default configured descriptor chains are valid. { let mut ctx = test_ctx.create_event_handler_context(); - let rx_desc = ctx.device.queues[RXQ_INDEX].pop(&test_ctx.mem).unwrap(); - VsockPacket::from_rx_virtq_head(rx_desc).unwrap(); + let rx_desc = ctx.device.queues[RXQ_INDEX].pop().unwrap(); + VsockPacket::from_rx_virtq_head(&test_ctx.mem, rx_desc).unwrap(); } { let mut ctx = test_ctx.create_event_handler_context(); - let tx_desc = ctx.device.queues[TXQ_INDEX].pop(&test_ctx.mem).unwrap(); - VsockPacket::from_tx_virtq_head(tx_desc).unwrap(); + let tx_desc = ctx.device.queues[TXQ_INDEX].pop().unwrap(); + VsockPacket::from_tx_virtq_head(&test_ctx.mem, tx_desc).unwrap(); } // Let's check what happens when the header descriptor is right before the gap. diff --git a/src/vmm/src/devices/virtio/vsock/packet.rs b/src/vmm/src/devices/virtio/vsock/packet.rs index c18b45b9a94..4c7e68ccb54 100644 --- a/src/vmm/src/devices/virtio/vsock/packet.rs +++ b/src/vmm/src/devices/virtio/vsock/packet.rs @@ -23,7 +23,7 @@ use vm_memory::{GuestMemoryError, ReadVolatile, WriteVolatile}; use super::{defs, VsockError}; use crate::devices::virtio::iovec::{IoVecBuffer, IoVecBufferMut}; use crate::devices::virtio::queue::DescriptorChain; -use crate::vstate::memory::ByteValued; +use crate::vstate::memory::{ByteValued, GuestMemoryMmap}; // The vsock packet header is defined by the C struct: // @@ -123,11 +123,14 @@ impl VsockPacket { /// length would exceed [`defs::MAX_PKT_BUR_SIZE`]. /// - [`VsockError::DescChainTooShortForPacket`] if the contained vsock header describes a vsock /// packet whose length exceeds the descriptor chain's actual total buffer length. - pub fn from_tx_virtq_head(chain: DescriptorChain) -> Result { + pub fn from_tx_virtq_head( + mem: &GuestMemoryMmap, + chain: DescriptorChain, + ) -> Result { // SAFETY: This descriptor chain is only loaded once // virtio requests are handled sequentially so no two IoVecBuffers // are live at the same time, meaning this has exclusive ownership over the memory - let buffer = unsafe { IoVecBuffer::from_descriptor_chain(chain)? }; + let buffer = unsafe { IoVecBuffer::from_descriptor_chain(mem, chain)? }; let mut hdr = VsockPacketHeader::default(); match buffer.read_exact_volatile_at(hdr.as_mut_slice(), 0) { @@ -160,11 +163,14 @@ impl VsockPacket { /// ## Errors /// Returns [`VsockError::DescChainTooShortForHeader`] if the descriptor chain's total buffer /// length is insufficient to hold the 44 byte vsock header - pub fn from_rx_virtq_head(chain: DescriptorChain) -> Result { + pub fn from_rx_virtq_head( + mem: &GuestMemoryMmap, + chain: DescriptorChain, + ) -> Result { // SAFETY: This descriptor chain is only loaded once // virtio requests are handled sequentially so no two IoVecBuffers // are live at the same time, meaning this has exclusive ownership over the memory - let buffer = unsafe { IoVecBufferMut::from_descriptor_chain(chain)? }; + let buffer = unsafe { IoVecBufferMut::from_descriptor_chain(mem, chain)? }; if buffer.len() < VSOCK_PKT_HDR_SIZE { return Err(VsockError::DescChainTooShortForHeader(buffer.len() as usize)); @@ -395,9 +401,8 @@ mod tests { }; ($test_ctx:expr, $handler_ctx:expr, $err:pat, $ctor:ident, $vq_index:ident) => { let result = VsockPacket::$ctor( - $handler_ctx.device.queues[$vq_index] - .pop(&$test_ctx.mem) - .unwrap(), + &$test_ctx.mem, + $handler_ctx.device.queues[$vq_index].pop().unwrap(), ); assert!(matches!(result, Err($err)), "{:?}", result) }; @@ -426,9 +431,8 @@ mod tests { create_context!(test_ctx, handler_ctx); let pkt = VsockPacket::from_tx_virtq_head( - handler_ctx.device.queues[TXQ_INDEX] - .pop(&test_ctx.mem) - .unwrap(), + &test_ctx.mem, + handler_ctx.device.queues[TXQ_INDEX].pop().unwrap(), ) .unwrap(); @@ -467,9 +471,8 @@ mod tests { create_context!(test_ctx, handler_ctx); set_pkt_len(0, &handler_ctx.guest_txvq.dtable[0], &test_ctx.mem); VsockPacket::from_tx_virtq_head( - handler_ctx.device.queues[TXQ_INDEX] - .pop(&test_ctx.mem) - .unwrap(), + &test_ctx.mem, + handler_ctx.device.queues[TXQ_INDEX].pop().unwrap(), ) .unwrap(); } @@ -530,9 +533,8 @@ mod tests { { create_context!(test_ctx, handler_ctx); let pkt = VsockPacket::from_rx_virtq_head( - handler_ctx.device.queues[RXQ_INDEX] - .pop(&test_ctx.mem) - .unwrap(), + &test_ctx.mem, + handler_ctx.device.queues[RXQ_INDEX].pop().unwrap(), ) .unwrap(); assert_eq!( @@ -580,9 +582,8 @@ mod tests { create_context!(test_ctx, handler_ctx); let mut pkt = VsockPacket::from_rx_virtq_head( - handler_ctx.device.queues[RXQ_INDEX] - .pop(&test_ctx.mem) - .unwrap(), + &test_ctx.mem, + handler_ctx.device.queues[RXQ_INDEX].pop().unwrap(), ) .unwrap(); @@ -634,15 +635,13 @@ mod tests { // same area of memory. We need both a rx-view and a tx-view into the packet, as tx-queue // buffers are read only, while rx queue buffers are write-only let mut pkt = VsockPacket::from_rx_virtq_head( - handler_ctx.device.queues[RXQ_INDEX] - .pop(&test_ctx.mem) - .unwrap(), + &test_ctx.mem, + handler_ctx.device.queues[RXQ_INDEX].pop().unwrap(), ) .unwrap(); let pkt2 = VsockPacket::from_tx_virtq_head( - handler_ctx.device.queues[TXQ_INDEX] - .pop(&test_ctx.mem) - .unwrap(), + &test_ctx.mem, + handler_ctx.device.queues[TXQ_INDEX].pop().unwrap(), ) .unwrap(); diff --git a/src/vmm/src/devices/virtio/vsock/unix/muxer.rs b/src/vmm/src/devices/virtio/vsock/unix/muxer.rs index 00bf511a209..d543a2799dd 100644 --- a/src/vmm/src/devices/virtio/vsock/unix/muxer.rs +++ b/src/vmm/src/devices/virtio/vsock/unix/muxer.rs @@ -833,15 +833,13 @@ mod tests { let vsock_test_ctx = VsockTestContext::new(); let mut handler_ctx = vsock_test_ctx.create_event_handler_context(); let rx_pkt = VsockPacket::from_rx_virtq_head( - handler_ctx.device.queues[RXQ_INDEX] - .pop(&vsock_test_ctx.mem) - .unwrap(), + &vsock_test_ctx.mem, + handler_ctx.device.queues[RXQ_INDEX].pop().unwrap(), ) .unwrap(); let tx_pkt = VsockPacket::from_tx_virtq_head( - handler_ctx.device.queues[TXQ_INDEX] - .pop(&vsock_test_ctx.mem) - .unwrap(), + &vsock_test_ctx.mem, + handler_ctx.device.queues[TXQ_INDEX].pop().unwrap(), ) .unwrap(); diff --git a/src/vmm/src/persist.rs b/src/vmm/src/persist.rs index c7ad27a5777..d9183178797 100644 --- a/src/vmm/src/persist.rs +++ b/src/vmm/src/persist.rs @@ -267,6 +267,23 @@ fn snapshot_memory_to_file( dump_res } }?; + // We need to mark queues as dirty again for all activated devices. The reason we + // do it here is because we don't mark pages as dirty during runtime + // for queue objects. + // SAFETY: + // This should never fail as we only mark pages only if device has already been activated, + // and the address validation was already performed on device activation. + vmm.mmio_device_manager + .for_each_virtio_device(|_, _, _, dev| { + let d = dev.lock().unwrap(); + if d.is_activated() { + d.mark_queue_memory_dirty(vmm.guest_memory()) + } else { + Ok(()) + } + }) + .unwrap(); + file.flush() .map_err(|err| MemoryBackingFile("flush", err))?; file.sync_all() diff --git a/tests/integration_tests/functional/test_dirty_pages_in_full_snapshot.py b/tests/integration_tests/functional/test_dirty_pages_in_full_snapshot.py index 75f0cdda2d6..af5c03d1ea1 100644 --- a/tests/integration_tests/functional/test_dirty_pages_in_full_snapshot.py +++ b/tests/integration_tests/functional/test_dirty_pages_in_full_snapshot.py @@ -23,8 +23,11 @@ def test_dirty_pages_after_full_snapshot(uvm_plain): # file size is the same, but the `diff` snapshot is actually a sparse file assert snap_full.mem.stat().st_size == snap_diff.mem.stat().st_size - # diff -> diff there should be no differences - assert snap_diff2.mem.stat().st_blocks == 0 + # full -> diff: full should have more things in it + # Diff snapshots will contain some pages, because we always mark + # pages used for virt queues as dirty. + assert snap_diff.mem.stat().st_blocks < snap_full.mem.stat().st_blocks + assert snap_diff2.mem.stat().st_blocks < snap_full.mem.stat().st_blocks - # full -> diff there should be no differences - assert snap_diff.mem.stat().st_blocks == 0 + # diff -> diff: there should be no differences + assert snap_diff.mem.stat().st_blocks == snap_diff2.mem.stat().st_blocks