diff --git a/risc0/zkvm/platform/io.h b/risc0/zkvm/platform/io.h index 83e20751c2..55d3436b2a 100644 --- a/risc0/zkvm/platform/io.h +++ b/risc0/zkvm/platform/io.h @@ -26,6 +26,7 @@ constexpr size_t kGPIO_GetKey = 0x01F0010; constexpr size_t kGPIO_SendRecvChannel = 0x01F00014; constexpr size_t kGPIO_SendRecvSize = 0x01F00018; constexpr size_t kGPIO_SendRecvAddr = 0x01F0001C; +constexpr size_t kGPIO_Mul = 0x01F00020; // Standard ZKVM channels; must match zkvm/sdk/rust/platform/src/io.rs. @@ -67,6 +68,15 @@ struct ShaDescriptor { uint32_t digest; }; +struct MulDescriptor { + // Address of first byte of MUL data to process + // 64 bits for first operand and 64 bits for second + uint32_t source; + + // 64 bit result + uint32_t result; +}; + inline volatile ShaDescriptor* volatile* GPIO_SHA() { return reinterpret_cast(kGPIO_SHA); } diff --git a/risc0/zkvm/platform/memory.h b/risc0/zkvm/platform/memory.h index 7b427a43f9..66723d37f3 100644 --- a/risc0/zkvm/platform/memory.h +++ b/risc0/zkvm/platform/memory.h @@ -43,8 +43,9 @@ MEM_REGION(Input, 0x01E00000, k1MB) MEM_REGION(GPIO, 0x01F00000, k1MB) MEM_REGION(Prog, 0x02000000, 10 * k1MB) MEM_REGION(SHA, 0x02A00000, k1MB) -MEM_REGION(WOM, 0x02B00000, 21 * k1MB) -MEM_REGION(Output, 0x02B00000, 20 * k1MB) +MEM_REGION(MUL, 0x02B00000, k1MB) +MEM_REGION(WOM, 0x02C00000, 20 * k1MB) +MEM_REGION(Output, 0x02C00000, 19 * k1MB) MEM_REGION(Commit, 0x03F00000, k1MB) // clang-format on diff --git a/risc0/zkvm/platform/risc0.ld b/risc0/zkvm/platform/risc0.ld index ed4e89e238..7b51400427 100644 --- a/risc0/zkvm/platform/risc0.ld +++ b/risc0/zkvm/platform/risc0.ld @@ -29,7 +29,8 @@ MEMORY { gpio : ORIGIN = 0x01F00000, LENGTH = 1M prog (X) : ORIGIN = 0x02000000, LENGTH = 10M sha : ORIGIN = 0x02A00000, LENGTH = 1M - wom : ORIGIN = 0x02B00000, LENGTH = 21M + mul : ORIGIN = 0x02B00000, LENGTH = 1M + wom : ORIGIN = 0x02C00000, LENGTH = 20M } SECTIONS { diff --git a/risc0/zkvm/prove/io_handler.cpp b/risc0/zkvm/prove/io_handler.cpp index 68784408ef..d34070f20d 100644 --- a/risc0/zkvm/prove/io_handler.cpp +++ b/risc0/zkvm/prove/io_handler.cpp @@ -46,6 +46,33 @@ static void processSHA(MemoryState& mem, const ShaDescriptor& desc) { } } +static void processMul(MemoryState& mem, const MulDescriptor& desc) { + uint32_t a_hi = mem.load(desc.source); + LOG(1, "Input[" << hex(0, 2) << "]: " << hex(desc.source) << " -> " << hex(a_hi)); + uint32_t a_lo = mem.load(desc.source + 4); + LOG(1, "Input[" << hex(1, 2) << "]: " << hex(desc.source + 4) << " -> " << hex(a_lo)); + uint32_t b_hi = mem.load(desc.source + 8); + LOG(1, "Input[" << hex(2, 2) << "]: " << hex(desc.source + 8) << " -> " << hex(b_hi)); + uint32_t b_lo = mem.load(desc.source + 12); + LOG(1, "Input[" << hex(3, 2) << "]: " << hex(desc.source + 12) << " -> " << hex(b_lo)); + + uint64_t first = a_lo | (uint64_t(a_hi) << 32); + uint64_t second = b_lo | (uint64_t(b_hi) << 32); + + __uint128_t result = __uint128_t(first) * __uint128_t(second); + + // goldilocks + uint64_t moded_result = result % 0xFFFFFFFF00000001; + + uint32_t high = (uint32_t)((moded_result & 0xFFFFFFFF00000000LL) >> 32); + uint32_t low = (uint32_t)(moded_result & 0xFFFFFFFFLL); + + LOG(1, "Output[" << hex(0, 2) << "]: " << hex(desc.result) << " <- " << hex(high)); + mem.store(desc.result, high); + LOG(1, "Output[" << hex(1, 2) << "]: " << hex(desc.result + 4) << " <- " << hex(low)); + mem.store(desc.result + 4, low); +} + void IoHandler::onFault(const std::string& msg) { throw std::runtime_error(msg); } @@ -63,6 +90,13 @@ void MemoryHandler::onInit(MemoryState& mem) { void MemoryHandler::onWrite(MemoryState& mem, uint32_t cycle, uint32_t addr, uint32_t value) { LOG(2, "MemoryHandler::onWrite> " << hex(addr) << ": " << hex(value)); switch (addr) { + case kGPIO_Mul: { + LOG(1, "MemoryHandler::onWrite> GPIO_MUL"); + MulDescriptor desc; + mem.loadRegion(value, &desc, sizeof(desc)); + processMul(mem, desc); + break; + } case kGPIO_SHA: { LOG(1, "MemoryHandler::onWrite> GPIO_SHA"); ShaDescriptor desc; diff --git a/risc0/zkvm/sdk/rust/guest/src/lib.rs b/risc0/zkvm/sdk/rust/guest/src/lib.rs index a7fd6745da..094fc776d4 100644 --- a/risc0/zkvm/sdk/rust/guest/src/lib.rs +++ b/risc0/zkvm/sdk/rust/guest/src/lib.rs @@ -32,6 +32,9 @@ pub mod env; /// Functions for computing SHA-256 hashes. pub mod sha; +/// mul +pub mod mul; + /// Functions for handling input and output pub mod io; diff --git a/risc0/zkvm/sdk/rust/guest/src/mul.rs b/risc0/zkvm/sdk/rust/guest/src/mul.rs new file mode 100644 index 0000000000..22593744f7 --- /dev/null +++ b/risc0/zkvm/sdk/rust/guest/src/mul.rs @@ -0,0 +1,70 @@ +use core::{cell::UnsafeCell, mem}; + +use crate::env::log; +use _alloc::format; +use _alloc::{boxed::Box, vec::Vec}; +use risc0_zkvm::platform::{ + io::{MulDescriptor, GPIO_MUL}, + memory, +}; + +// Current sha descriptor index. +struct CurOutput(UnsafeCell); + +// SAFETY: single threaded environment +unsafe impl Sync for CurOutput {} + +static CUR_OUTPUT: CurOutput = CurOutput(UnsafeCell::new(0)); + +/// Result of multiply goldilocks +pub struct MulGoldilocks([u32; 2]); + +impl MulGoldilocks { + /// Get the result as u64 + pub fn get_u64(&self) -> u64 { + (self.0[1] as u64) | ((self.0[0] as u64) << 32) + } +} + +fn alloc_output() -> *mut MulDescriptor { + // SAFETY: Single threaded and this is the only place we use CUR_DESC. + unsafe { + let cur_desc = CUR_OUTPUT.0.get(); + let ptr = (memory::MUL.start() as *mut MulDescriptor).add(*cur_desc); + *cur_desc += 1; + ptr + } +} + +/// Multiply goldilocks oracle, verification is done separately +pub fn mul_goldilocks(a: &u64, b: &u64) -> &'static MulGoldilocks { + let a_hi = ((a & 0xFFFFFFFF00000000) >> 32) as u32; + let a_lo = (a & 0xFFFFFFFF) as u32; + + let b_hi = ((b & 0xFFFFFFFF00000000) >> 32) as u32; + let b_lo = (b & 0xFFFFFFFF) as u32; + + let buf = [a_hi, a_lo, b_hi, b_lo]; + + unsafe { + let alloced = Box::>::new( + mem::MaybeUninit::::uninit(), + ); + let output = (*Box::into_raw(alloced)).as_mut_ptr(); + mul_raw(&buf[..], output); + &*output + } +} + +pub(crate) unsafe fn mul_raw(data: &[u32], result: *mut MulGoldilocks) { + let output_ptr = alloc_output(); + + let ptr = data.as_ptr(); + super::memory_barrier(ptr); + output_ptr.write_volatile(MulDescriptor { + source: ptr as usize, + result: result as usize, + }); + + GPIO_MUL.as_ptr().write_volatile(output_ptr); +} diff --git a/risc0/zkvm/sdk/rust/platform/src/io.rs b/risc0/zkvm/sdk/rust/platform/src/io.rs index b75b641db3..34a950351f 100644 --- a/risc0/zkvm/sdk/rust/platform/src/io.rs +++ b/risc0/zkvm/sdk/rust/platform/src/io.rs @@ -49,6 +49,8 @@ pub const GPIO_SENDRECV_CHANNEL: Gpio = Gpio::new(0x01F0_0014); pub const GPIO_SENDRECV_SIZE: Gpio = Gpio::new(0x01F0_0018); pub const GPIO_SENDRECV_ADDR: Gpio<*const u8> = Gpio::new(0x01F0_001C); +pub const GPIO_MUL: Gpio<*const MulDescriptor> = Gpio::new(0x01F0_0020); + pub mod addr { pub const GPIO_SHA: u32 = super::GPIO_SHA.addr(); pub const GPIO_COMMIT: u32 = super::GPIO_COMMIT.addr(); @@ -59,6 +61,8 @@ pub mod addr { pub const GPIO_SENDRECV_CHANNEL: u32 = super::GPIO_SENDRECV_CHANNEL.addr(); pub const GPIO_SENDRECV_SIZE: u32 = super::GPIO_SENDRECV_SIZE.addr(); pub const GPIO_SENDRECV_ADDR: u32 = super::GPIO_SENDRECV_ADDR.addr(); + + pub const GPIO_MUL: u32 = super::GPIO_MUL.addr(); } #[repr(C)] @@ -75,6 +79,12 @@ pub struct SHADescriptor { pub digest: usize, } +#[repr(C)] +pub struct MulDescriptor { + pub source: usize, + pub result: usize, +} + #[repr(C)] pub struct GetKeyDescriptor { pub name: u32, diff --git a/risc0/zkvm/sdk/rust/platform/src/memory.rs b/risc0/zkvm/sdk/rust/platform/src/memory.rs index e9034d9d4a..5b2fea9209 100644 --- a/risc0/zkvm/sdk/rust/platform/src/memory.rs +++ b/risc0/zkvm/sdk/rust/platform/src/memory.rs @@ -62,6 +62,7 @@ pub const INPUT: Region = Region::new(0x01E0_0000, mb(1)); pub const GPIO: Region = Region::new(0x01F0_0000, mb(1)); pub const PROG: Region = Region::new(0x0200_0000, mb(10)); pub const SHA: Region = Region::new(0x02A0_0000, mb(1)); -pub const WOM: Region = Region::new(0x02B0_0000, mb(21)); -pub const OUTPUT: Region = Region::new(0x02B0_0000, mb(20)); +pub const MUL: Region = Region::new(0x02B0_0000, mb(1)); +pub const WOM: Region = Region::new(0x02C0_0000, mb(20)); +pub const OUTPUT: Region = Region::new(0x02C0_0000, mb(19)); pub const COMMIT: Region = Region::new(0x03F0_0000, mb(1));