Skip to content

Commit 0f0a123

Browse files
committed
Simplistic trait
1 parent 29f9976 commit 0f0a123

File tree

6 files changed

+239
-89
lines changed

6 files changed

+239
-89
lines changed

program/rust/src/bindings.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,8 @@ typedef unsigned int uint32_t;
1111
typedef signed long int int64_t;
1212
typedef unsigned long int uint64_t;
1313

14+
#include <stddef.h>
1415
#include "../../c/src/oracle/oracle.h"
16+
17+
const size_t PC_PRICE_T_COMP_OFFSET = offsetof(struct pc_price, comp_);
18+
const size_t PC_MAP_TABLE_T_PROD_OFFSET = offsetof(struct pc_map_table, prod_);

program/rust/src/c_oracle_header.rs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,32 @@ use bytemuck::{
77
Pod,
88
Zeroable,
99
};
10-
10+
use std::mem::size_of;
1111
//bindings.rs is generated by build.rs to include
1212
//things defined in bindings.h
1313
include!("../bindings.rs");
1414

15+
pub trait PythStruct: Pod {
16+
const ACCOUNT_TYPE: u32;
17+
const INITIAL_SIZE: u32;
18+
}
19+
20+
impl PythStruct for pc_map_table_t {
21+
const ACCOUNT_TYPE: u32 = PC_ACCTYPE_MAPPING;
22+
const INITIAL_SIZE: u32 = PC_MAP_TABLE_T_PROD_OFFSET as u32;
23+
}
24+
25+
impl PythStruct for pc_prod_t {
26+
const ACCOUNT_TYPE: u32 = PC_ACCTYPE_PRODUCT;
27+
const INITIAL_SIZE: u32 = size_of::<pc_prod_t>() as u32;
28+
}
29+
30+
impl PythStruct for pc_price_t {
31+
const ACCOUNT_TYPE: u32 = PC_ACCTYPE_PRICE;
32+
const INITIAL_SIZE: u32 = PC_PRICE_T_COMP_OFFSET as u32;
33+
}
34+
35+
1536
#[cfg(target_endian = "little")]
1637
unsafe impl Zeroable for pc_acc {
1738
}

program/rust/src/rust_oracle.rs

Lines changed: 33 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@ use crate::c_oracle_header::{
2828
pc_price_t,
2929
pc_prod_t,
3030
pc_pub_key_t,
31-
PC_ACCTYPE_MAPPING,
32-
PC_ACCTYPE_PRICE,
33-
PC_ACCTYPE_PRODUCT,
31+
PythStruct,
3432
PC_MAGIC,
3533
PC_MAP_TABLE_SIZE,
3634
PC_MAX_NUM_DECIMALS,
@@ -90,7 +88,7 @@ pub fn init_mapping(
9088

9189
// Initialize by setting to zero again (just in case) and populating the account header
9290
let hdr = load::<cmd_hdr_t>(instruction_data)?;
93-
initialize_mapping_account(fresh_mapping_account, hdr.ver_)?;
91+
initialize_checked::<pc_map_table_t>(fresh_mapping_account, hdr.ver_)?;
9492

9593
Ok(SUCCESS)
9694
}
@@ -111,14 +109,14 @@ pub fn add_mapping(
111109
check_valid_fresh_account(next_mapping)?;
112110

113111
let hdr = load::<cmd_hdr_t>(instruction_data)?;
114-
let mut cur_mapping = load_mapping_account_mut(cur_mapping, hdr.ver_)?;
112+
let mut cur_mapping = load_checked::<pc_map_table_t>(cur_mapping, hdr.ver_)?;
115113
pyth_assert(
116114
cur_mapping.num_ == PC_MAP_TABLE_SIZE
117115
&& unsafe { cur_mapping.next_.k8_.iter().all(|x| *x == 0) },
118116
ProgramError::InvalidArgument,
119117
)?;
120118

121-
initialize_mapping_account(next_mapping, hdr.ver_)?;
119+
initialize_checked::<pc_map_table_t>(next_mapping, hdr.ver_)?;
122120
pubkey_assign(&mut cur_mapping.next_, &next_mapping.key.to_bytes());
123121

124122
Ok(SUCCESS)
@@ -152,15 +150,9 @@ pub fn add_price(
152150
check_valid_signable_account(program_id, price_account, size_of::<pc_price_t>())?;
153151
check_valid_fresh_account(price_account)?;
154152

155-
let mut product_data = load_product_account_mut(product_account, cmd_args.ver_)?;
153+
let mut product_data = load_checked::<pc_prod_t>(product_account, cmd_args.ver_)?;
156154

157-
clear_account(price_account)?;
158-
159-
let mut price_data = load_account_as_mut::<pc_price_t>(price_account)?;
160-
price_data.magic_ = PC_MAGIC;
161-
price_data.ver_ = cmd_args.ver_;
162-
price_data.type_ = PC_ACCTYPE_PRICE;
163-
price_data.size_ = (size_of::<pc_price_t>() - size_of_val(&price_data.comp_)) as u32;
155+
let mut price_data = initialize_checked::<pc_price_t>(price_account, cmd_args.ver_)?;
164156
price_data.expo_ = cmd_args.expo_;
165157
price_data.ptype_ = cmd_args.ptype_;
166158
pubkey_assign(&mut price_data.prod_, &product_account.key.to_bytes());
@@ -190,14 +182,14 @@ pub fn add_product(
190182
check_valid_fresh_account(new_product_account)?;
191183

192184
let hdr = load::<cmd_hdr_t>(instruction_data)?;
193-
let mut mapping_data = load_mapping_account_mut(tail_mapping_account, hdr.ver_)?;
185+
let mut mapping_data = load_checked::<pc_map_table_t>(tail_mapping_account, hdr.ver_)?;
194186
// The mapping account must have free space to add the product account
195187
pyth_assert(
196188
mapping_data.num_ < PC_MAP_TABLE_SIZE,
197189
ProgramError::InvalidArgument,
198190
)?;
199191

200-
initialize_product_account(new_product_account, hdr.ver_)?;
192+
initialize_checked::<pc_prod_t>(new_product_account, hdr.ver_)?;
201193

202194
let current_index: usize = try_convert(mapping_data.num_)?;
203195
unsafe {
@@ -267,73 +259,40 @@ pub fn clear_account(account: &AccountInfo) -> Result<(), ProgramError> {
267259
Ok(())
268260
}
269261

270-
271-
/// Mutably borrow the data in `account` as a mapping account, validating that the account
272-
/// is properly formatted. Any mutations to the returned value will be reflected in the
273-
/// account data. Use this to read already-initialized accounts.
274-
pub fn load_mapping_account_mut<'a>(
262+
pub fn load_checked<'a, T: PythStruct>(
275263
account: &'a AccountInfo,
276-
expected_version: u32,
277-
) -> Result<RefMut<'a, pc_map_table_t>, ProgramError> {
278-
let mapping_data = load_account_as_mut::<pc_map_table_t>(account)?;
279-
280-
pyth_assert(
281-
mapping_data.magic_ == PC_MAGIC
282-
&& mapping_data.ver_ == expected_version
283-
&& mapping_data.type_ == PC_ACCTYPE_MAPPING,
284-
ProgramError::InvalidArgument,
285-
)?;
264+
version: u32,
265+
) -> Result<RefMut<'a, T>, ProgramError> {
266+
{
267+
let account_header = load_account_as::<pc_acc>(account)?;
268+
pyth_assert(
269+
account_header.magic_ == PC_MAGIC
270+
&& account_header.ver_ == version
271+
&& account_header.type_ == T::ACCOUNT_TYPE,
272+
ProgramError::InvalidArgument,
273+
)?;
274+
}
286275

287-
Ok(mapping_data)
276+
load_account_as_mut::<T>(account)
288277
}
289278

290-
/// Initialize account as a new mapping account. This function will zero out any existing data in
291-
/// the account.
292-
pub fn initialize_mapping_account(account: &AccountInfo, version: u32) -> Result<(), ProgramError> {
293-
clear_account(account)?;
294-
295-
let mut mapping_account = load_account_as_mut::<pc_map_table_t>(account)?;
296-
mapping_account.magic_ = PC_MAGIC;
297-
mapping_account.ver_ = version;
298-
mapping_account.type_ = PC_ACCTYPE_MAPPING;
299-
mapping_account.size_ =
300-
try_convert(size_of::<pc_map_table_t>() - size_of_val(&mapping_account.prod_))?;
301-
302-
Ok(())
303-
}
304-
305-
/// Initialize account as a new product account. This function will zero out any existing data in
306-
/// the account.
307-
pub fn initialize_product_account(account: &AccountInfo, version: u32) -> Result<(), ProgramError> {
279+
pub fn initialize_checked<'a, T: PythStruct>(
280+
account: &'a AccountInfo,
281+
version: u32,
282+
) -> Result<RefMut<'a, T>, ProgramError> {
308283
clear_account(account)?;
309284

310-
let mut prod_account = load_account_as_mut::<pc_prod_t>(account)?;
311-
prod_account.magic_ = PC_MAGIC;
312-
prod_account.ver_ = version;
313-
prod_account.type_ = PC_ACCTYPE_PRODUCT;
314-
prod_account.size_ = try_convert(size_of::<pc_prod_t>())?;
285+
{
286+
let mut account_header = load_account_as_mut::<pc_acc>(account)?;
287+
account_header.magic_ = PC_MAGIC;
288+
account_header.ver_ = version;
289+
account_header.type_ = T::ACCOUNT_TYPE;
290+
account_header.size_ = T::INITIAL_SIZE;
291+
}
315292

316-
Ok(())
293+
load_account_as_mut::<T>(account)
317294
}
318295

319-
/// Mutably borrow the data in `account` as a product account, validating that the account
320-
/// is properly formatted. Any mutations to the returned value will be reflected in the
321-
/// account data. Use this to read already-initialized accounts.
322-
pub fn load_product_account_mut<'a>(
323-
account: &'a AccountInfo,
324-
expected_version: u32,
325-
) -> Result<RefMut<'a, pc_prod_t>, ProgramError> {
326-
let product_data = load_account_as_mut::<pc_prod_t>(account)?;
327-
328-
pyth_assert(
329-
product_data.magic_ == PC_MAGIC
330-
&& product_data.ver_ == expected_version
331-
&& product_data.type_ == PC_ACCTYPE_PRODUCT,
332-
ProgramError::InvalidArgument,
333-
)?;
334-
335-
Ok(product_data)
336-
}
337296

338297
// Assign pubkey bytes from source to target, fails if source is not 32 bytes
339298
pub fn pubkey_assign(target: &mut pc_pub_key_t, source: &[u8]) {

program/rust/src/tests/test_add_mapping.rs

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ use crate::deserialize::load_account_as_mut;
1010
use crate::rust_oracle::{
1111
add_mapping,
1212
clear_account,
13-
initialize_mapping_account,
14-
load_mapping_account_mut,
13+
initialize_checked,
14+
load_checked,
1515
pubkey_assign,
1616
};
1717
use bytemuck::bytes_of;
@@ -65,10 +65,11 @@ fn test_add_mapping() {
6565
Epoch::default(),
6666
);
6767

68-
initialize_mapping_account(&cur_mapping, PC_VERSION).unwrap();
68+
initialize_checked::<pc_map_table_t>(&cur_mapping, PC_VERSION).unwrap();
6969

7070
{
71-
let mut cur_mapping_data = load_mapping_account_mut(&cur_mapping, PC_VERSION).unwrap();
71+
let mut cur_mapping_data =
72+
load_checked::<pc_map_table_t>(&cur_mapping, PC_VERSION).unwrap();
7273
cur_mapping_data.num_ = PC_MAP_TABLE_SIZE;
7374
}
7475

@@ -99,8 +100,9 @@ fn test_add_mapping() {
99100
.is_ok());
100101

101102
{
102-
let next_mapping_data = load_mapping_account_mut(&next_mapping, PC_VERSION).unwrap();
103-
let mut cur_mapping_data = load_mapping_account_mut(&cur_mapping, PC_VERSION).unwrap();
103+
let next_mapping_data = load_checked::<pc_map_table_t>(&next_mapping, PC_VERSION).unwrap();
104+
let mut cur_mapping_data =
105+
load_checked::<pc_map_table_t>(&cur_mapping, PC_VERSION).unwrap();
104106

105107
assert!(unsafe {
106108
cur_mapping_data
@@ -131,7 +133,8 @@ fn test_add_mapping() {
131133
);
132134

133135
{
134-
let mut cur_mapping_data = load_mapping_account_mut(&cur_mapping, PC_VERSION).unwrap();
136+
let mut cur_mapping_data =
137+
load_checked::<pc_map_table_t>(&cur_mapping, PC_VERSION).unwrap();
135138
assert!(unsafe { cur_mapping_data.next_.k8_.iter().all(|x| *x == 0) });
136139
cur_mapping_data.num_ = PC_MAP_TABLE_SIZE;
137140
cur_mapping_data.magic_ = 0;

program/rust/src/tests/test_add_product.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ use crate::deserialize::load_account_as;
2828
use crate::rust_oracle::{
2929
add_product,
3030
clear_account,
31-
initialize_mapping_account,
32-
load_mapping_account_mut,
31+
initialize_checked,
32+
load_checked,
3333
};
3434

3535
#[test]
@@ -116,7 +116,7 @@ fn test_add_product() {
116116

117117
{
118118
let product_data = load_account_as::<pc_prod_t>(&product_account).unwrap();
119-
let mapping_data = load_mapping_account_mut(&mapping_account, PC_VERSION).unwrap();
119+
let mapping_data = load_checked::<pc_map_table_t>(&mapping_account, PC_VERSION).unwrap();
120120

121121
assert_eq!(product_data.magic_, PC_MAGIC);
122122
assert_eq!(product_data.ver_, PC_VERSION);
@@ -140,7 +140,7 @@ fn test_add_product() {
140140
)
141141
.is_ok());
142142
{
143-
let mapping_data = load_mapping_account_mut(&mapping_account, PC_VERSION).unwrap();
143+
let mapping_data = load_checked::<pc_map_table_t>(&mapping_account, PC_VERSION).unwrap();
144144
assert_eq!(mapping_data.num_, 2);
145145
assert!(pubkey_equal(
146146
&mapping_data.prod_[1],
@@ -175,7 +175,7 @@ fn test_add_product() {
175175

176176
// test fill up of mapping table
177177
clear_account(&mapping_account).unwrap();
178-
initialize_mapping_account(&mapping_account, PC_VERSION).unwrap();
178+
initialize_checked::<pc_map_table_t>(&mapping_account, PC_VERSION).unwrap();
179179

180180
for i in 0..PC_MAP_TABLE_SIZE {
181181
clear_account(&product_account).unwrap();
@@ -190,7 +190,7 @@ fn test_add_product() {
190190
instruction_data
191191
)
192192
.is_ok());
193-
let mapping_data = load_mapping_account_mut(&mapping_account, PC_VERSION).unwrap();
193+
let mapping_data = load_checked::<pc_map_table_t>(&mapping_account, PC_VERSION).unwrap();
194194
assert_eq!(mapping_data.num_, i + 1);
195195
}
196196

@@ -207,7 +207,7 @@ fn test_add_product() {
207207
)
208208
.is_err());
209209

210-
let mapping_data = load_mapping_account_mut(&mapping_account, PC_VERSION).unwrap();
210+
let mapping_data = load_checked::<pc_map_table_t>(&mapping_account, PC_VERSION).unwrap();
211211
assert_eq!(mapping_data.num_, PC_MAP_TABLE_SIZE);
212212
}
213213

0 commit comments

Comments
 (0)