diff --git a/program/c/src/oracle/oracle.c b/program/c/src/oracle/oracle.c index b08d78d21..620c4b1f1 100644 --- a/program/c/src/oracle/oracle.c +++ b/program/c/src/oracle/oracle.c @@ -59,56 +59,6 @@ static bool valid_writable_account( SolParameters *prm, is_rent_exempt( *ka->lamports, ka->data_len ); } -#define PC_ADD_STR \ - tag = (pc_str_t*)src;\ - tag_len = 1 + tag->len_;\ - if ( &src[tag_len] > end ) return ERROR_INVALID_ARGUMENT;\ - sol_memcpy( tgt, tag, tag_len );\ - tgt += tag_len;\ - src += tag_len;\ - -static uint64_t upd_product( SolParameters *prm, SolAccountInfo *ka ) -{ - // Account (1) is the existing product account - // Verify that these are signed, writable accounts with correct ownership - // and size - if ( prm->ka_num != 2 || - !valid_funding_account( &ka[0] ) || - !valid_signable_account( prm, &ka[1], PC_PROD_ACC_SIZE ) ) { - return ERROR_INVALID_ARGUMENT; - } - - // verify that product account is valid - cmd_hdr_t *hdr = (cmd_hdr_t*)prm->data; - pc_prod_t *pptr = (pc_prod_t*)ka[1].data; - if ( pptr->magic_ != PC_MAGIC || - pptr->ver_ != hdr->ver_ || - pptr->type_ != PC_ACCTYPE_PRODUCT ) { - return ERROR_INVALID_ARGUMENT; - } - - // unpack and verify attribute set and ssign to product account - if ( prm->data_len < sizeof( cmd_upd_product_t ) || - prm->data_len > PC_PROD_ACC_SIZE + - sizeof( cmd_upd_product_t ) - sizeof( pc_prod_t ) ) { - return ERROR_INVALID_ARGUMENT; - } - pptr->size_ = ( uint32_t )( sizeof( pc_prod_t ) + prm->data_len - - sizeof( cmd_upd_product_t ) ); - uint8_t *tgt = (uint8_t*)pptr + sizeof( pc_prod_t ); - const uint8_t *src = prm->data + sizeof( cmd_upd_product_t ); - const uint8_t *end = prm->data + prm->data_len; - const pc_str_t *tag; - int tag_len; - while( src != end ) { - // check key string - PC_ADD_STR - // check value string - PC_ADD_STR - } - return SUCCESS; -} - static uint64_t add_price( SolParameters *prm, SolAccountInfo *ka ) { // Validate command parameters @@ -432,7 +382,7 @@ static uint64_t dispatch( SolParameters *prm, SolAccountInfo *ka ) case e_cmd_init_mapping: return ERROR_INVALID_ARGUMENT; case e_cmd_add_mapping: return ERROR_INVALID_ARGUMENT; case e_cmd_add_product: return ERROR_INVALID_ARGUMENT; - case e_cmd_upd_product: return upd_product( prm, ka ); + case e_cmd_upd_product: return ERROR_INVALID_ARGUMENT; case e_cmd_add_price: return add_price( prm, ka ); case e_cmd_add_publisher: return add_publisher( prm, ka ); case e_cmd_del_publisher: return del_publisher( prm, ka ); diff --git a/program/rust/src/processor.rs b/program/rust/src/processor.rs index cc5eb6a74..dbaa75ba8 100644 --- a/program/rust/src/processor.rs +++ b/program/rust/src/processor.rs @@ -16,8 +16,10 @@ use crate::c_oracle_header::{ command_t_e_cmd_upd_account_version, command_t_e_cmd_upd_price, command_t_e_cmd_upd_price_no_fail_on_error, + command_t_e_cmd_upd_product, PC_VERSION, }; +use crate::deserialize::load; use crate::error::{ OracleError, OracleResult, @@ -28,11 +30,11 @@ use crate::rust_oracle::{ add_product, add_publisher, init_mapping, + upd_product, update_price, update_version, }; -use crate::deserialize::load; ///dispatch to the right instruction in the oracle pub fn process_instruction( program_id: &Pubkey, @@ -69,6 +71,7 @@ pub fn process_instruction( command_t_e_cmd_add_mapping => add_mapping(program_id, accounts, instruction_data), command_t_e_cmd_add_publisher => add_publisher(program_id, accounts, instruction_data), command_t_e_cmd_add_product => add_product(program_id, accounts, instruction_data), + command_t_e_cmd_upd_product => upd_product(program_id, accounts, instruction_data), _ => c_entrypoint_wrapper(input), } } diff --git a/program/rust/src/rust_oracle.rs b/program/rust/src/rust_oracle.rs index ddee73360..41b18d43e 100644 --- a/program/rust/src/rust_oracle.rs +++ b/program/rust/src/rust_oracle.rs @@ -5,20 +5,17 @@ use std::mem::{ size_of_val, }; -use crate::deserialize::{ - load, - load_account_as, - load_account_as_mut, -}; use bytemuck::{ bytes_of, bytes_of_mut, }; - use solana_program::account_info::AccountInfo; use solana_program::entrypoint::SUCCESS; use solana_program::program_error::ProgramError; -use solana_program::program_memory::sol_memset; +use solana_program::program_memory::{ + sol_memcpy, + sol_memset, +}; use solana_program::pubkey::Pubkey; use solana_program::rent::Rent; @@ -26,6 +23,7 @@ use crate::c_oracle_header::{ cmd_add_price_t, cmd_add_publisher_t, cmd_hdr_t, + cmd_upd_product_t, pc_acc, pc_map_table_t, pc_price_comp, @@ -40,10 +38,14 @@ use crate::c_oracle_header::{ PC_PROD_ACC_SIZE, PC_PTYPE_UNKNOWN, }; +use crate::deserialize::{ + load, + load_account_as, + load_account_as_mut, +}; use crate::error::OracleResult; -use crate::OracleError; - use crate::utils::pyth_assert; +use crate::OracleError; use super::c_entrypoint_wrapper; @@ -260,6 +262,66 @@ pub fn add_product( Ok(SUCCESS) } +/// Update the metadata associated with a product, overwriting any existing metadata. +/// The metadata is provided as a list of key-value pairs at the end of the `instruction_data`. +pub fn upd_product( + program_id: &Pubkey, + accounts: &[AccountInfo], + instruction_data: &[u8], +) -> OracleResult { + let [funding_account, product_account] = match accounts { + [x, y] => Ok([x, y]), + _ => Err(ProgramError::InvalidArgument), + }?; + + check_valid_funding_account(funding_account)?; + check_valid_signable_account(program_id, product_account, try_convert(PC_PROD_ACC_SIZE)?)?; + + let hdr = load::(instruction_data)?; + { + // Validate that product_account contains the appropriate account header + let mut _product_data = load_checked::(product_account, hdr.ver_)?; + } + + pyth_assert( + instruction_data.len() >= size_of::(), + ProgramError::InvalidInstructionData, + )?; + let new_data_len = instruction_data.len() - size_of::(); + let max_data_len = try_convert::<_, usize>(PC_PROD_ACC_SIZE)? - size_of::(); + pyth_assert(new_data_len <= max_data_len, ProgramError::InvalidArgument)?; + + let new_data = &instruction_data[size_of::()..instruction_data.len()]; + let mut idx = 0; + // new_data must be a list of key-value pairs, both of which are instances of pc_str_t. + // Try reading the key-value pairs to validate that new_data is properly formatted. + while idx < new_data.len() { + let key = read_pc_str_t(&new_data[idx..])?; + idx += key.len(); + let value = read_pc_str_t(&new_data[idx..])?; + idx += value.len(); + } + + // This assertion shouldn't ever fail, but be defensive. + pyth_assert(idx == new_data.len(), ProgramError::InvalidArgument)?; + + { + let mut data = product_account.try_borrow_mut_data()?; + // Note that this memcpy doesn't necessarily overwrite all existing data in the account. + // This case is handled by updating the .size_ field below. + sol_memcpy( + &mut data[size_of::()..], + new_data, + new_data.len(), + ); + } + + let mut product_data = load_checked::(product_account, hdr.ver_)?; + product_data.size_ = try_convert(size_of::() + new_data.len())?; + + Ok(SUCCESS) +} + fn valid_funding_account(account: &AccountInfo) -> bool { account.is_signer && account.is_writable } @@ -362,7 +424,22 @@ pub fn pubkey_equal(target: &pc_pub_key_t, source: &[u8]) -> bool { } /// Convert `x: T` into a `U`, returning the appropriate `OracleError` if the conversion fails. -fn try_convert>(x: T) -> Result { +pub fn try_convert>(x: T) -> Result { // Note: the error here assumes we're only applying this function to integers right now. U::try_from(x).map_err(|_| OracleError::IntegerCastingError) } + +/// Read a `pc_str_t` from the beginning of `source`. Returns a slice of `source` containing +/// the bytes of the `pc_str_t`. +pub fn read_pc_str_t(source: &[u8]) -> Result<&[u8], ProgramError> { + if source.is_empty() { + Err(ProgramError::InvalidArgument) + } else { + let tag_len: usize = try_convert(source[0])?; + if tag_len + 1 > source.len() { + Err(ProgramError::InvalidArgument) + } else { + Ok(&source[..(1 + tag_len)]) + } + } +} diff --git a/program/rust/src/tests/mod.rs b/program/rust/src/tests/mod.rs index fba2a74a9..95aeb1ae8 100644 --- a/program/rust/src/tests/mod.rs +++ b/program/rust/src/tests/mod.rs @@ -1,3 +1,4 @@ mod test_add_mapping; mod test_add_product; mod test_init_mapping; +mod test_upd_product; diff --git a/program/rust/src/tests/test_add_product.rs b/program/rust/src/tests/test_add_product.rs index 15c27a6b9..5e31a475c 100644 --- a/program/rust/src/tests/test_add_product.rs +++ b/program/rust/src/tests/test_add_product.rs @@ -7,6 +7,7 @@ use bytemuck::{ use solana_program::account_info::AccountInfo; use solana_program::clock::Epoch; use solana_program::native_token::LAMPORTS_PER_SOL; +use solana_program::program_error::ProgramError; use solana_program::pubkey::Pubkey; use solana_program::rent::Rent; use solana_program::system_program; @@ -162,16 +163,18 @@ fn test_add_product() { false, Epoch::default(), ); - assert!(add_product( - &program_id, - &[ - funding_account.clone(), - mapping_account.clone(), - product_account_3.clone() - ], - instruction_data - ) - .is_err()); + assert_eq!( + add_product( + &program_id, + &[ + funding_account.clone(), + mapping_account.clone(), + product_account_3.clone() + ], + instruction_data + ), + Err(ProgramError::InvalidArgument) + ); // test fill up of mapping table clear_account(&mapping_account).unwrap(); @@ -196,16 +199,18 @@ fn test_add_product() { clear_account(&product_account).unwrap(); - assert!(add_product( - &program_id, - &[ - funding_account.clone(), - mapping_account.clone(), - product_account.clone() - ], - instruction_data - ) - .is_err()); + assert_eq!( + add_product( + &program_id, + &[ + funding_account.clone(), + mapping_account.clone(), + product_account.clone() + ], + instruction_data + ), + Err(ProgramError::InvalidArgument) + ); let mapping_data = load_checked::(&mapping_account, PC_VERSION).unwrap(); assert_eq!(mapping_data.num_, PC_MAP_TABLE_SIZE); diff --git a/program/rust/src/tests/test_upd_product.rs b/program/rust/src/tests/test_upd_product.rs new file mode 100644 index 000000000..a547d8c2f --- /dev/null +++ b/program/rust/src/tests/test_upd_product.rs @@ -0,0 +1,167 @@ +use std::mem::size_of; + +use solana_program::account_info::AccountInfo; +use solana_program::clock::Epoch; +use solana_program::native_token::LAMPORTS_PER_SOL; +use solana_program::program_error::ProgramError; +use solana_program::pubkey::Pubkey; +use solana_program::rent::Rent; +use solana_program::system_program; + +use crate::c_oracle_header::{ + cmd_hdr_t, + cmd_upd_product_t, + command_t_e_cmd_upd_product, + pc_prod_t, + PC_PROD_ACC_SIZE, + PC_VERSION, +}; +use crate::deserialize::load_mut; +use crate::rust_oracle::{ + initialize_checked, + load_checked, + read_pc_str_t, + try_convert, + upd_product, +}; + +#[test] +fn test_upd_product() { + let mut instruction_data = [0u8; PC_PROD_ACC_SIZE as usize]; + + let program_id = Pubkey::new_unique(); + let funding_key = Pubkey::new_unique(); + let product_key = Pubkey::new_unique(); + + let system_program = system_program::id(); + let mut funding_balance = LAMPORTS_PER_SOL.clone(); + let funding_account = AccountInfo::new( + &funding_key, + true, + true, + &mut funding_balance, + &mut [], + &system_program, + false, + Epoch::default(), + ); + + let mut product_balance = Rent::minimum_balance(&Rent::default(), PC_PROD_ACC_SIZE as usize); + let mut prod_raw_data = [0u8; PC_PROD_ACC_SIZE as usize]; + let product_account = AccountInfo::new( + &product_key, + true, + true, + &mut product_balance, + &mut prod_raw_data, + &program_id, + false, + Epoch::default(), + ); + + initialize_checked::(&product_account, PC_VERSION).unwrap(); + + let kvs = ["foo", "barz"]; + let size = populate_instruction(&mut instruction_data, &kvs); + assert!(upd_product( + &program_id, + &[funding_account.clone(), product_account.clone()], + &instruction_data[..size] + ) + .is_ok()); + assert!(account_has_key_values(&product_account, &kvs).unwrap_or(false)); + + // bad size on the 1st string in the key-value pair list + instruction_data[size_of::()] = 2; + assert_eq!( + upd_product( + &program_id, + &[funding_account.clone(), product_account.clone()], + &instruction_data[..size] + ), + Err(ProgramError::InvalidArgument) + ); + assert!(account_has_key_values(&product_account, &kvs).unwrap_or(false)); + + let kvs = []; + let size = populate_instruction(&mut instruction_data, &kvs); + assert!(upd_product( + &program_id, + &[funding_account.clone(), product_account.clone()], + &instruction_data[..size] + ) + .is_ok()); + assert!(account_has_key_values(&product_account, &kvs).unwrap_or(false)); + + // uneven number of keys and values + let bad_kvs = ["foo", "bar", "baz"]; + let size = populate_instruction(&mut instruction_data, &bad_kvs); + assert_eq!( + upd_product( + &program_id, + &[funding_account.clone(), product_account.clone()], + &instruction_data[..size] + ), + Err(ProgramError::InvalidArgument) + ); + assert!(account_has_key_values(&product_account, &kvs).unwrap_or(false)); +} + +// Create an upd_product instruction that sets the product metadata to strings +fn populate_instruction(instruction_data: &mut [u8], strings: &[&str]) -> usize { + { + let mut hdr = load_mut::(instruction_data).unwrap(); + hdr.ver_ = PC_VERSION; + hdr.cmd_ = command_t_e_cmd_upd_product as i32 + } + + let mut idx = size_of::(); + for s in strings.iter() { + let pc_str = create_pc_str_t(s); + instruction_data[idx..(idx + pc_str.len())].copy_from_slice(pc_str.as_slice()); + idx += pc_str.len() + } + + idx +} + +fn create_pc_str_t(s: &str) -> Vec { + let mut v = vec![s.len() as u8]; + v.extend_from_slice(s.as_bytes()); + v +} + +// Check that the key-value list in product_account equals the strings in expected +// Returns an Err if the account data is incorrectly formatted and the comparison cannot be +// performed. +fn account_has_key_values( + product_account: &AccountInfo, + expected: &[&str], +) -> Result { + let account_size: usize = + try_convert(load_checked::(product_account, PC_VERSION)?.size_)?; + let mut all_account_data = product_account.try_borrow_mut_data()?; + let kv_data = &mut all_account_data[size_of::()..account_size]; + let mut kv_idx = 0; + let mut expected_idx = 0; + + while kv_idx < kv_data.len() { + let key = read_pc_str_t(&kv_data[kv_idx..])?; + if key[0] != try_convert::<_, u8>(key.len())? - 1 { + return Ok(false); + } + + if &key[1..] != expected[expected_idx].as_bytes() { + return Ok(false); + } + + kv_idx += key.len(); + expected_idx += 1; + } + + if expected_idx != expected.len() { + return Ok(false); + } + + Ok(true) +}