diff --git a/lightning/src/offers/merkle.rs b/lightning/src/offers/merkle.rs index 2afd001017c..f39545d7a9e 100644 --- a/lightning/src/offers/merkle.rs +++ b/lightning/src/offers/merkle.rs @@ -46,6 +46,59 @@ impl TaggedHash { Self::from_tlv_stream(tag, tlv_stream) } + /// Creates a tagged hash with the given parameters, validating the TLV stream. + /// + /// This is a low-level function exposed for specific use cases like command-line tools + /// and testing. For production use, prefer higher-level methods like + /// [`Bolt12Invoice::try_from`] which handle validation automatically. + /// + /// Returns an error if `bytes` is not a well-formed TLV stream containing at least one TLV record. + /// + /// [`Bolt12Invoice::try_from`]: crate::offers::invoice::Bolt12Invoice::try_from + pub fn from_tlv_stream_bytes(tag: &'static str, bytes: &[u8]) -> Result { + // Validate the TLV stream first + if bytes.is_empty() { + return Err(TlvStreamError::EmptyStream); + } + + // Try to parse the TLV stream to check validity + let mut cursor = io::Cursor::new(bytes); + let mut has_records = false; + + while cursor.position() < bytes.len() as u64 { + // Try to read type + let type_result = ::read(&mut cursor); + if type_result.is_err() { + return Err(TlvStreamError::InvalidRecord); + } + + // Try to read length + let length_result = ::read(&mut cursor); + if length_result.is_err() { + return Err(TlvStreamError::InvalidRecord); + } + + let length = length_result.unwrap().0; + let end_position = cursor.position() + length; + + // Check if the record extends beyond the buffer + if end_position > bytes.len() as u64 { + return Err(TlvStreamError::InvalidRecord); + } + + // Skip the value + cursor.set_position(end_position); + has_records = true; + } + + if !has_records { + return Err(TlvStreamError::EmptyStream); + } + + // If validation passes, create the tagged hash + Ok(Self::from_valid_tlv_stream_bytes(tag, bytes)) + } + /// Creates a tagged hash with the given parameters. /// /// Panics if `tlv_stream` is not a well-formed TLV stream containing at least one TLV record. @@ -93,8 +146,17 @@ pub enum SignError { Verification(secp256k1::Error), } +/// Error when parsing TLV streams. +#[derive(Debug, PartialEq)] +pub enum TlvStreamError { + /// The TLV stream is empty (contains no records). + EmptyStream, + /// The TLV stream contains an invalid record. + InvalidRecord, +} + /// A function for signing a [`TaggedHash`]. -pub(super) trait SignFn> { +pub trait SignFn> { /// Signs a [`TaggedHash`] computed over the merkle root of `message`'s TLV stream. fn sign(&self, message: &T) -> Result; } @@ -111,15 +173,17 @@ where /// Signs a [`TaggedHash`] computed over the merkle root of `message`'s TLV stream, checking if it /// can be verified with the supplied `pubkey`. /// +/// This is a low-level function exposed for specific use cases like command-line tools +/// and testing. For production use, prefer higher-level methods on invoice types that handle +/// signing automatically. +/// /// Since `message` is any type that implements [`AsRef`], `sign` may be a closure that /// takes a message such as [`Bolt12Invoice`] or [`InvoiceRequest`]. This allows further message /// verification before signing its [`TaggedHash`]. /// /// [`Bolt12Invoice`]: crate::offers::invoice::Bolt12Invoice /// [`InvoiceRequest`]: crate::offers::invoice_request::InvoiceRequest -pub(super) fn sign_message( - f: F, message: &T, pubkey: PublicKey, -) -> Result +pub fn sign_message(f: F, message: &T, pubkey: PublicKey) -> Result where F: SignFn, T: AsRef, @@ -136,7 +200,13 @@ where /// Verifies the signature with a pubkey over the given message using a tagged hash as the message /// digest. -pub(super) fn verify_signature( +/// +/// This is a low-level function exposed for specific use cases like command-line tools +/// and testing. For production use, prefer higher-level methods like +/// [`Bolt12Invoice::try_from`] which handle signature verification automatically. +/// +/// [`Bolt12Invoice::try_from`]: crate::offers::invoice::Bolt12Invoice::try_from +pub fn verify_signature( signature: &Signature, message: &TaggedHash, pubkey: PublicKey, ) -> Result<(), secp256k1::Error> { let digest = message.as_digest(); @@ -481,6 +551,119 @@ mod tests { assert_eq!(tlv_stream, invoice_request.bytes); } + #[test] + fn validates_tlv_stream_bytes() { + // Test with valid TLV stream + const VALID_HEX: &'static str = "010203e8"; + let valid_bytes = >::from_hex(VALID_HEX).unwrap(); + let result = super::TaggedHash::from_tlv_stream_bytes("test", &valid_bytes); + assert!(result.is_ok()); + + // Test with empty stream + let empty_bytes = Vec::new(); + let result = super::TaggedHash::from_tlv_stream_bytes("test", &empty_bytes); + assert_eq!(result, Err(super::TlvStreamError::EmptyStream)); + + // Test with invalid TLV stream (truncated) + let invalid_bytes = vec![0x01, 0x02]; // Type and length but no value + let result = super::TaggedHash::from_tlv_stream_bytes("test", &invalid_bytes); + assert_eq!(result, Err(super::TlvStreamError::InvalidRecord)); + } + + #[test] + fn consistent_results_between_validating_and_non_validating_functions() { + // Test vectors from BOLT 12 + let test_vectors = vec![ + "010203e8", + "010203e802080000010000020003", + "010203e802080000010000020003", // Using same as above for simplicity + ]; + + for hex_data in test_vectors { + let bytes = >::from_hex(hex_data).unwrap(); + let tag = "test_tag"; + + // Create tagged hash using the validating function + let validating_result = super::TaggedHash::from_tlv_stream_bytes(tag, &bytes); + assert!( + validating_result.is_ok(), + "Validating function should succeed for valid TLV stream" + ); + let validating_hash = validating_result.unwrap(); + + // Create tagged hash using the non-validating function + let non_validating_hash = super::TaggedHash::from_valid_tlv_stream_bytes(tag, &bytes); + + // Both should produce identical results + assert_eq!( + validating_hash.tag(), + non_validating_hash.tag(), + "Tags should be identical" + ); + assert_eq!( + validating_hash.merkle_root(), + non_validating_hash.merkle_root(), + "Merkle roots should be identical" + ); + assert_eq!( + validating_hash.as_digest(), + non_validating_hash.as_digest(), + "Digests should be identical" + ); + assert_eq!(validating_hash, non_validating_hash, "Tagged hashes should be identical"); + } + } + + #[test] + fn regression_test_with_invoice_request_data() { + // Use real invoice request data to ensure no regression + let expanded_key = ExpandedKey::new([42; 32]); + let nonce = Nonce([0u8; 16]); + let secp_ctx = Secp256k1::new(); + let payment_id = PaymentId([1; 32]); + + let recipient_pubkey = { + let secret_key = SecretKey::from_slice(&[41; 32]).unwrap(); + Keypair::from_secret_key(&secp_ctx, &secret_key).public_key() + }; + + let invoice_request = OfferBuilder::new(recipient_pubkey) + .amount_msats(100) + .build_unchecked() + .request_invoice(&expanded_key, nonce, &secp_ctx, payment_id) + .unwrap() + .build_and_sign() + .unwrap(); + + // Extract bytes without signature for testing + let mut bytes_without_signature = Vec::new(); + let tlv_stream_without_signatures = TlvStream::new(&invoice_request.bytes) + .filter(|record| !SIGNATURE_TYPES.contains(&record.r#type)); + for record in tlv_stream_without_signatures { + record.write(&mut bytes_without_signature).unwrap(); + } + + let tag = "invoice_request"; + + // Test both functions produce the same result + let validating_result = + super::TaggedHash::from_tlv_stream_bytes(tag, &bytes_without_signature); + assert!( + validating_result.is_ok(), + "Should successfully validate real invoice request data" + ); + let validating_hash = validating_result.unwrap(); + + let non_validating_hash = + super::TaggedHash::from_valid_tlv_stream_bytes(tag, &bytes_without_signature); + + // Verify they produce identical results + assert_eq!( + validating_hash, non_validating_hash, + "Both functions should produce identical results for real data" + ); + } + impl AsRef<[u8]> for InvoiceRequest { fn as_ref(&self) -> &[u8] { &self.bytes