Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 108 additions & 40 deletions src/basefold_verifier/query_phase.rs
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,8 @@ pub(crate) fn batch_verifier_query_phase<C: Config + Debug>(
let log2_max_codeword_size: Var<C::N> =
builder.eval(input.max_num_var.clone() + get_rate_log::<C>());

let zero: Ext<C::F, C::EF> = builder.constant(C::EF::ZERO);

iter_zip!(builder, input.indices, input.proof.query_opening_proof).for_each(
|ptr_vec, builder| {
// TODO: change type of input.indices to be `Array<C, Array<C, Var<C::N>>>`
Expand All @@ -363,7 +365,14 @@ pub(crate) fn batch_verifier_query_phase<C: Config + Debug>(

let reduced_codeword_by_height: Array<C, PackedCodeword<C>> =
builder.dyn_array(log2_max_codeword_size.clone());

// initialize reduced_codeword_by_height with zeroes
iter_zip!(builder, reduced_codeword_by_height).for_each(|ptr_vec, builder| {
let zero_codeword = PackedCodeword {
low: zero.clone(),
high: zero.clone(),
};
builder.set_value(&reduced_codeword_by_height, ptr_vec[0], zero_codeword);
});
let query = builder.iter_ptr_get(&input.proof.query_opening_proof, ptr_vec[1]);
let batch_coeffs_offset: Var<C::N> = builder.constant(C::N::ZERO);

Expand Down Expand Up @@ -440,8 +449,14 @@ pub(crate) fn batch_verifier_query_phase<C: Config + Debug>(
builder.assign(&high, high + coeff * high_value);
},
);
let codeword = PackedCodeword { low, high };
builder.set_value(&reduced_codeword_by_height, log2_height, codeword);
let codeword: PackedCodeword<C> = PackedCodeword { low, high };
let codeword_acc = builder.get(&reduced_codeword_by_height, log2_height);

// reduced_openings[log2_height] += codeword
builder.assign(&codeword_acc.low, codeword_acc.low + codeword.low);
builder.assign(&codeword_acc.high, codeword_acc.high + codeword.high);

builder.set_value(&reduced_codeword_by_height, log2_height, codeword_acc);
builder.assign(&batch_coeffs_offset, batch_coeffs_next_offset);
});
});
Expand Down Expand Up @@ -491,13 +506,21 @@ pub(crate) fn batch_verifier_query_phase<C: Config + Debug>(
builder.assign(&log2_height, log2_height - Usize::from(1));

let folded_idx = builder.get(&idx_bits, i);
// TODO: absorb smaller codeword
let new_involved_codeword: Ext<C::F, C::EF> = builder.constant(C::EF::ZERO);
let new_involved_packed_codeword =
builder.get(&reduced_codeword_by_height, log2_height.clone());

builder.if_eq(folded_idx, Usize::from(0)).then_or_else(
|builder| {
builder.assign(&folded, folded + new_involved_packed_codeword.low);
},
|builder| {
builder.assign(&folded, folded + new_involved_packed_codeword.high);
},
);

// leafs
let leafs = builder.dyn_array(2);
let sibling_idx = builder.eval_expr(RVar::from(1) - folded_idx);
builder.assign(&folded, folded + new_involved_codeword);
builder.set_value(&leafs, folded_idx, folded);
builder.set_value(&leafs, sibling_idx, sibling_value);

Expand Down Expand Up @@ -652,8 +675,8 @@ pub mod tests {
use ff_ext::{BabyBearExt4, FromUniformBytes};
use itertools::Itertools;
use mpcs::{
pcs_batch_commit, pcs_setup, pcs_trim, util::hash::write_digest_to_transcript,
BasefoldDefault, PolynomialCommitmentScheme,
pcs_batch_commit, pcs_trim, util::hash::write_digest_to_transcript, BasefoldDefault,
PolynomialCommitmentScheme,
};
use mpcs::{BasefoldRSParams, BasefoldSpec, PCSFriParam};
use openvm_circuit::arch::{instructions::program::Program, SystemConfig, VmExecutor};
Expand All @@ -662,8 +685,6 @@ pub mod tests {
use openvm_native_recursion::hints::Hintable;
use openvm_stark_backend::p3_challenger::GrindingChallenger;
use openvm_stark_sdk::p3_baby_bear::BabyBear;
use p3_field::Field;
use p3_field::FieldAlgebra;
use rand::thread_rng;

type F = BabyBear;
Expand All @@ -672,11 +693,10 @@ pub mod tests {

use crate::basefold_verifier::basefold::{Round, RoundOpening};
use crate::basefold_verifier::query_phase::PointAndEvals;
use crate::tower_verifier::binding::{Point, PointAndEval};
use crate::tower_verifier::binding::Point;

use super::{batch_verifier_query_phase, QueryPhaseVerifierInput};

#[allow(dead_code)]
pub fn build_batch_verifier_query_phase(
input: QueryPhaseVerifierInput,
) -> (Program<BabyBear>, Vec<Vec<BabyBear>>) {
Expand All @@ -694,42 +714,69 @@ pub mod tests {
(program, witness_stream)
}

#[test]
fn test_verify_query_phase_batch() {
fn construct_test(dimensions: Vec<(usize, usize)>) {
let mut rng = thread_rng();
let m1 = ceno_witness::RowMajorMatrix::<F>::rand(&mut rng, 1 << 10, 10);
let mles_1 = m1.to_mles();
let matrices = vec![m1];

// setup PCS
let pp = PCS::setup(1 << 20, mpcs::SecurityLevel::Conjecture100bits).unwrap();
let (pp, vp) = pcs_trim::<E, PCS>(pp, 1 << 20).unwrap();

let mut num_total_polys = 0;
let (matrices, mles): (Vec<_>, Vec<_>) = dimensions
.into_iter()
.map(|(num_vars, width)| {
let m = ceno_witness::RowMajorMatrix::<F>::rand(&mut rng, 1 << num_vars, width);
let mles = m.to_mles();
num_total_polys += width;

(m, mles)
})
.unzip();

// commit to matrices
let pcs_data = pcs_batch_commit::<E, PCS>(&pp, matrices).unwrap();
let comm = PCS::get_pure_commitment(&pcs_data);

let point = E::random_vec(10, &mut rng);
let evals = mles_1.iter().map(|mle| mle.evaluate(&point)).collect_vec();
let point_and_evals = mles
.iter()
.map(|mles| {
let point = E::random_vec(mles[0].num_vars(), &mut rng);
let evals = mles.iter().map(|mle| mle.evaluate(&point)).collect_vec();

// let evals = mles_1
// .iter()
// .map(|mle| points.iter().map(|p| mle.evaluate(&p)).collect_vec())
// .collect::<Vec<_>>();
(point, evals)
})
.collect_vec();

// batch open
let mut transcript = BasicTranscript::<E>::new(&[]);
let rounds = vec![(&pcs_data, vec![(point.clone(), evals.clone())])];
let rounds = vec![(&pcs_data, point_and_evals.clone())];
let opening_proof = PCS::batch_open(&pp, rounds, &mut transcript).unwrap();

// batch verify
let mut transcript = BasicTranscript::<E>::new(&[]);
let rounds = vec![(comm, vec![(point.len(), (point, evals.clone()))])];
let rounds = vec![(
comm,
point_and_evals
.iter()
.map(|(point, evals)| (point.len(), (point.clone(), evals.clone())))
.collect_vec(),
)];
PCS::batch_verify(&vp, rounds.clone(), &opening_proof, &mut transcript)
.expect("Native verification failed");

let mut transcript = BasicTranscript::<E>::new(&[]);
let batch_coeffs = transcript.sample_and_append_challenge_pows(10, b"batch coeffs");

let max_num_var = 10;
let batch_coeffs =
transcript.sample_and_append_challenge_pows(num_total_polys, b"batch coeffs");

let max_num_var = point_and_evals
.iter()
.map(|(point, _)| point.len())
.max()
.unwrap();
let num_rounds = max_num_var; // The final message is of length 1

// prepare folding challenges via sumcheck round msg + FRI commitment
let mut fold_challenges: Vec<E> = Vec::with_capacity(10);
let mut fold_challenges: Vec<E> = Vec::with_capacity(num_rounds);
let commits = &opening_proof.commits;

let sumcheck_messages = opening_proof.sumcheck_proof.as_ref().unwrap();
Expand Down Expand Up @@ -759,26 +806,23 @@ pub mod tests {
);

let query_input = QueryPhaseVerifierInput {
// t_inv_halves: vp.encoding_params.t_inv_halves,
max_num_var: 10,
max_num_var,
fold_challenges,
batch_coeffs,
indices: queries,
proof: opening_proof.into(),
rounds: rounds
.iter()
.into_iter()
.map(|round| Round {
commit: round.0.clone().into(),
commit: round.0.into(),
openings: round
.1
.iter()
.map(|opening| RoundOpening {
num_var: opening.0,
.into_iter()
.map(|(num_var, (point, evals))| RoundOpening {
num_var,
point_and_evals: PointAndEvals {
point: Point {
fs: opening.1.clone().0,
},
evals: opening.1.clone().1,
point: Point { fs: point },
evals,
},
})
.collect(),
Expand All @@ -801,4 +845,28 @@ pub mod tests {
println!("=> cycle count: {:?}", seg.metrics.cycle_count);
}
}

#[test]
fn test_simple_batch() {
for num_var in 5..20 {
construct_test(vec![(num_var, 20)]);
}
}

#[test]
fn test_decreasing_batch() {
construct_test(vec![
(14, 20),
(14, 40),
(13, 30),
(12, 30),
(11, 10),
(10, 15),
]);
}

#[test]
fn test_random_batch() {
construct_test(vec![(10, 20), (12, 30), (11, 10), (12, 15)]);
}
}