|  | 
| 1 |  | -use super::ExpandMsg; | 
| 2 |  | -use core::marker::PhantomData; | 
| 3 |  | -use digest::{ | 
|  | 1 | +use super::{Domain, ExpandMsg}; | 
|  | 2 | +use digest_traits::{ | 
| 4 | 3 |     generic_array::{typenum::Unsigned, GenericArray}, | 
| 5 | 4 |     BlockInput, Digest, | 
| 6 | 5 | }; | 
| 7 |  | -use subtle::{Choice, ConditionallySelectable}; | 
| 8 | 6 | 
 | 
| 9 | 7 | /// Placeholder type for implementing expand_message_xmd based on a hash function | 
| 10 |  | -#[derive(Debug)] | 
| 11 |  | -pub struct ExpandMsgXmd<HashT> { | 
| 12 |  | -    phantom: PhantomData<HashT>, | 
|  | 8 | +pub struct ExpandMsgXmd<HashT> | 
|  | 9 | +where | 
|  | 10 | +    HashT: Digest + BlockInput, | 
|  | 11 | +{ | 
|  | 12 | +    b_0: GenericArray<u8, HashT::OutputSize>, | 
|  | 13 | +    b_vals: GenericArray<u8, HashT::OutputSize>, | 
|  | 14 | +    domain: Domain<HashT::OutputSize>, | 
|  | 15 | +    index: usize, | 
|  | 16 | +    offset: usize, | 
|  | 17 | +    ell: usize, | 
|  | 18 | +} | 
|  | 19 | + | 
|  | 20 | +impl<HashT> ExpandMsgXmd<HashT> | 
|  | 21 | +where | 
|  | 22 | +    HashT: Digest + BlockInput, | 
|  | 23 | +{ | 
|  | 24 | +    fn next(&mut self) -> bool { | 
|  | 25 | +        if self.index < self.ell { | 
|  | 26 | +            self.index += 1; | 
|  | 27 | +            self.offset = 0; | 
|  | 28 | +            // b_0 XOR b_(idx - 1) | 
|  | 29 | +            let mut tmp = GenericArray::<u8, HashT::OutputSize>::default(); | 
|  | 30 | +            self.b_0 | 
|  | 31 | +                .iter() | 
|  | 32 | +                .zip(&self.b_vals[..]) | 
|  | 33 | +                .enumerate() | 
|  | 34 | +                .for_each(|(j, (b0val, bi1val))| tmp[j] = b0val ^ bi1val); | 
|  | 35 | +            self.b_vals = HashT::new() | 
|  | 36 | +                .chain(tmp) | 
|  | 37 | +                .chain([self.index as u8]) | 
|  | 38 | +                .chain(self.domain.data()) | 
|  | 39 | +                .chain([self.domain.len() as u8]) | 
|  | 40 | +                .finalize(); | 
|  | 41 | +            true | 
|  | 42 | +        } else { | 
|  | 43 | +            false | 
|  | 44 | +        } | 
|  | 45 | +    } | 
| 13 | 46 | } | 
| 14 | 47 | 
 | 
| 15 | 48 | /// ExpandMsgXmd implements expand_message_xmd for the ExpandMsg trait | 
| 16 |  | -impl<HashT, const LEN_IN_BYTES: usize> ExpandMsg<LEN_IN_BYTES> for ExpandMsgXmd<HashT> | 
|  | 49 | +impl<HashT> ExpandMsg for ExpandMsgXmd<HashT> | 
| 17 | 50 | where | 
| 18 | 51 |     HashT: Digest + BlockInput, | 
| 19 | 52 | { | 
| 20 |  | -    fn expand_message(msg: &[u8], dst: &[u8]) -> [u8; LEN_IN_BYTES] { | 
|  | 53 | +    fn expand_message(msg: &[u8], dst: &'static [u8], len_in_bytes: usize) -> Self { | 
| 21 | 54 |         let b_in_bytes = HashT::OutputSize::to_usize(); | 
| 22 |  | -        let ell = (LEN_IN_BYTES + b_in_bytes - 1) / b_in_bytes; | 
|  | 55 | +        let ell = (len_in_bytes + b_in_bytes - 1) / b_in_bytes; | 
| 23 | 56 |         if ell > 255 { | 
| 24 | 57 |             panic!("ell was too big in expand_message_xmd"); | 
| 25 | 58 |         } | 
|  | 59 | +        let domain = Domain::xmd::<HashT>(dst); | 
| 26 | 60 |         let b_0 = HashT::new() | 
| 27 | 61 |             .chain(GenericArray::<u8, HashT::BlockSize>::default()) | 
| 28 | 62 |             .chain(msg) | 
| 29 |  | -            .chain([(LEN_IN_BYTES >> 8) as u8, LEN_IN_BYTES as u8, 0u8]) | 
| 30 |  | -            .chain(dst) | 
| 31 |  | -            .chain([dst.len() as u8]) | 
|  | 63 | +            .chain([(len_in_bytes >> 8) as u8, len_in_bytes as u8, 0u8]) | 
|  | 64 | +            .chain(domain.data()) | 
|  | 65 | +            .chain([domain.len() as u8]) | 
| 32 | 66 |             .finalize(); | 
| 33 | 67 | 
 | 
| 34 |  | -        let mut b_vals = HashT::new() | 
|  | 68 | +        let b_vals = HashT::new() | 
| 35 | 69 |             .chain(&b_0[..]) | 
| 36 | 70 |             .chain([1u8]) | 
| 37 |  | -            .chain(dst) | 
| 38 |  | -            .chain([dst.len() as u8]) | 
|  | 71 | +            .chain(domain.data()) | 
|  | 72 | +            .chain([domain.len() as u8]) | 
| 39 | 73 |             .finalize(); | 
| 40 | 74 | 
 | 
| 41 |  | -        let mut buf = [0u8; LEN_IN_BYTES]; | 
| 42 |  | -        let mut offset = 0; | 
|  | 75 | +        Self { | 
|  | 76 | +            b_0, | 
|  | 77 | +            b_vals, | 
|  | 78 | +            domain, | 
|  | 79 | +            index: 1, | 
|  | 80 | +            offset: 0, | 
|  | 81 | +            ell, | 
|  | 82 | +        } | 
|  | 83 | +    } | 
| 43 | 84 | 
 | 
| 44 |  | -        for i in 1..ell { | 
| 45 |  | -            // b_0 XOR b_(idx - 1) | 
| 46 |  | -            let mut tmp = GenericArray::<u8, HashT::OutputSize>::default(); | 
| 47 |  | -            b_0.iter() | 
| 48 |  | -                .zip(&b_vals[..]) | 
| 49 |  | -                .enumerate() | 
| 50 |  | -                .for_each(|(j, (b0val, bi1val))| tmp[j] = b0val ^ bi1val); | 
| 51 |  | -            for b in b_vals { | 
| 52 |  | -                buf[offset % LEN_IN_BYTES].conditional_assign( | 
| 53 |  | -                    &b, | 
| 54 |  | -                    Choice::from(if offset < LEN_IN_BYTES { 1 } else { 0 }), | 
| 55 |  | -                ); | 
| 56 |  | -                offset += 1; | 
|  | 85 | +    fn fill_bytes(&mut self, okm: &mut [u8]) { | 
|  | 86 | +        for b in okm { | 
|  | 87 | +            if self.offset == self.b_vals.len() && !self.next() { | 
|  | 88 | +                return; | 
| 57 | 89 |             } | 
| 58 |  | -            b_vals = HashT::new() | 
| 59 |  | -                .chain(tmp) | 
| 60 |  | -                .chain([(i + 1) as u8]) | 
| 61 |  | -                .chain(dst) | 
| 62 |  | -                .chain([dst.len() as u8]) | 
| 63 |  | -                .finalize(); | 
| 64 |  | -        } | 
| 65 |  | -        for b in b_vals { | 
| 66 |  | -            buf[offset % LEN_IN_BYTES] | 
| 67 |  | -                .conditional_assign(&b, Choice::from(if offset < LEN_IN_BYTES { 1 } else { 0 })); | 
| 68 |  | -            offset += 1; | 
|  | 90 | +            *b = self.b_vals[self.offset]; | 
|  | 91 | +            self.offset += 1; | 
| 69 | 92 |         } | 
| 70 |  | -        buf | 
| 71 | 93 |     } | 
| 72 | 94 | } | 
0 commit comments