Skip to content

Commit c322ce5

Browse files
committed
Refactor + Add basic test utilities
1 parent e1d2e9a commit c322ce5

File tree

5 files changed

+349
-56
lines changed

5 files changed

+349
-56
lines changed
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
// SPDX-License-Identifier: Apache 2
2+
3+
pragma solidity ^0.8.0;
4+
5+
import "./external/UnsafeBytesLib.sol";
6+
7+
/**
8+
* @dev This library provides methods to construct and verify Merkle Tree proofs efficiently.
9+
*
10+
*/
11+
12+
library MerkleTree {
13+
uint8 constant MERKLE_LEAF_PREFIX = 0;
14+
uint8 constant MERKLE_NODE_PREFIX = 1;
15+
16+
function hash(bytes memory input) internal pure returns (bytes20) {
17+
return bytes20(keccak256(input));
18+
}
19+
20+
function leafHash(bytes memory data) internal pure returns (bytes20) {
21+
return hash(abi.encodePacked(MERKLE_LEAF_PREFIX, data));
22+
}
23+
24+
function nodeHash(
25+
bytes20 childA,
26+
bytes20 childB
27+
) internal pure returns (bytes20) {
28+
if (childA > childB) {
29+
(childA, childB) = (childB, childA);
30+
}
31+
return hash(abi.encodePacked(MERKLE_NODE_PREFIX, childA, childB));
32+
}
33+
34+
/// @notice Verify Merkle Tree proof for given leaf data.
35+
/// @dev This method does not perform any check on the boundry of the
36+
/// `encodedProof` and `proofOffset` parameters. It is the caller's
37+
/// responsibility to ensure that the `encodedProof` is long enough to
38+
/// contain the proof and the `proofOffset` is not out of bound.
39+
function verify(
40+
bytes memory encodedProof,
41+
uint proofOffset,
42+
bytes20 root,
43+
bytes memory leafData
44+
) internal pure returns (bool valid, uint endOffset) {
45+
unchecked {
46+
bytes20 currentDigest = MerkleTree.leafHash(leafData);
47+
48+
uint8 proofSize = UnsafeBytesLib.toUint8(encodedProof, proofOffset);
49+
proofOffset += 1;
50+
51+
for (uint i = 0; i < proofSize; i++) {
52+
bytes20 siblingDigest = bytes20(
53+
UnsafeBytesLib.toAddress(encodedProof, proofOffset)
54+
);
55+
proofOffset += 20;
56+
57+
currentDigest = MerkleTree.nodeHash(
58+
currentDigest,
59+
siblingDigest
60+
);
61+
}
62+
63+
valid = currentDigest == root;
64+
endOffset = proofOffset;
65+
}
66+
}
67+
68+
/// @notice Construct Merkle Tree proofs for given list of messages.
69+
/// @dev This function is only used for testing purposes and is not efficient
70+
/// for production use-cases.
71+
///
72+
/// This method creates a merkle tree with leaf size of (2^depth) with the
73+
/// messages as leafs (in the same given order) and returns the root digest
74+
/// and the proofs for each message. If the number of messages is not a power
75+
/// of 2, the tree is padded with empty messages.
76+
function constructProofs(
77+
bytes[] memory messages,
78+
uint8 depth
79+
) internal pure returns (bytes20 root, bytes[] memory proofs) {
80+
require((1 << depth) >= messages.length, "depth too small");
81+
82+
bytes20[] memory tree = new bytes20[]((1 << (depth + 1)));
83+
84+
// The tree is structured as follows:
85+
// 1
86+
// 2 3
87+
// 4 5 6 7
88+
// ...
89+
// In this structure the parent of node x is x//2 and the children
90+
// of node x are x*2 and x*2 + 1. Also, the sibling of the node x
91+
// is x^1. The root is at index 1 and index 0 is not used.
92+
93+
// Filling the leaf hashes
94+
for (uint i = 0; i < (1 << depth); i++) {
95+
if (i < messages.length) {
96+
tree[(1 << depth) + i] = leafHash(messages[i]);
97+
} else {
98+
tree[(1 << depth) + i] = leafHash("");
99+
}
100+
}
101+
102+
// Filling the node hashes from bottom to top
103+
for (uint k = depth; k > 0; k--) {
104+
uint level = k - 1;
105+
uint levelNumNodes = (1 << level);
106+
for (uint i = 0; i < levelNumNodes; i++) {
107+
uint id = (1 << level) + i;
108+
tree[id] = nodeHash(tree[id * 2], tree[id * 2 + 1]);
109+
}
110+
}
111+
112+
root = tree[1];
113+
114+
proofs = new bytes[](messages.length);
115+
116+
for (uint i = 0; i < messages.length; i++) {
117+
proofs[i] = abi.encodePacked(depth);
118+
119+
// This loop iterates through the leaf and its parents
120+
// and keeps adding the sibling of the current node to the proof.
121+
122+
uint idx = (1 << depth) + i;
123+
for (uint k = depth; k > 0; k--) {
124+
proofs[i] = abi.encodePacked(
125+
proofs[i],
126+
tree[idx ^ 1] // Sibling of this node
127+
);
128+
129+
// Jump to parent
130+
idx /= 2;
131+
}
132+
}
133+
}
134+
}

target_chains/ethereum/contracts/contracts/pyth/PythAccumulator.sol

Lines changed: 13 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ import "./PythGetters.sol";
1111
import "./PythSetters.sol";
1212
import "./PythInternalStructs.sol";
1313

14+
import "../libraries/MerkleTree.sol";
15+
1416
abstract contract PythAccumulator is PythGetters, PythSetters, AbstractPyth {
1517
uint32 constant ACCUMULATOR_MAGIC = 0x504e4155; // Stands for PNAU (Pyth Network Accumulator Update)
1618
uint32 constant ACCUMULATOR_WORMHOLE_MAGIC = 0x41555756; // Stands for AUWV (Accumulator Update Wormhole Verficiation)
@@ -83,11 +85,6 @@ abstract contract PythAccumulator is PythGetters, PythSetters, AbstractPyth {
8385
offset += trailingHeaderSize;
8486
}
8587

86-
// This value is only used as the check below which currently
87-
// never reverts
88-
// uint8 minorVersion = UnsafeBytesLib.toUint18(accumulatorUpdate, offset);
89-
offset += 1;
90-
9188
UpdateType updateType = UpdateType(
9289
UnsafeBytesLib.toUint8(accumulatorUpdate, offset)
9390
);
@@ -156,7 +153,7 @@ abstract contract PythAccumulator is PythGetters, PythSetters, AbstractPyth {
156153
digest = bytes20(
157154
UnsafeBytesLib.toAddress(encodedPayload, payloadoffset)
158155
);
159-
payloadoffset += 32;
156+
payloadoffset += 20;
160157

161158
// TODO: Do we need to be strict about the size of the payload? How it can evolve?
162159
if (payloadoffset != encodedPayload.length)
@@ -179,15 +176,8 @@ abstract contract PythAccumulator is PythGetters, PythSetters, AbstractPyth {
179176
}
180177
}
181178

182-
uint8 constant MERKLE_LEAF_PREFIX = 0;
183-
uint8 constant MERKLE_NODE_PREFIX = 1;
184-
185-
function merkleHash(bytes memory input) private pure returns (bytes20) {
186-
return bytes20(keccak256(input));
187-
}
188-
189179
function verifyAndUpdatePriceFeedFromMerkleProof(
190-
bytes32 digest,
180+
bytes20 digest,
191181
bytes memory encoded,
192182
uint offset
193183
) private returns (uint endOffset) {
@@ -202,48 +192,16 @@ abstract contract PythAccumulator is PythGetters, PythSetters, AbstractPyth {
202192
);
203193
offset += messageSize;
204194

205-
{
206-
bytes20 currentDigest = merkleHash(
207-
abi.encodePacked(MERKLE_LEAF_PREFIX, encodedMessage)
208-
);
209-
210-
uint8 proofSize = UnsafeBytesLib.toUint8(encoded, offset);
211-
offset += 1;
212-
213-
for (uint i = 0; i < proofSize; i++) {
214-
uint8 isSiblingRight = UnsafeBytesLib.toUint8(
215-
encoded,
216-
offset
217-
);
218-
offset += 1;
219-
220-
bytes32 siblingDigest = UnsafeBytesLib.toBytes32(
221-
encoded,
222-
offset
223-
);
224-
offset += 32;
225-
226-
if (isSiblingRight == 0) {
227-
currentDigest = merkleHash(
228-
abi.encodePacked(
229-
MERKLE_NODE_PREFIX,
230-
siblingDigest,
231-
currentDigest
232-
)
233-
);
234-
} else {
235-
currentDigest = merkleHash(
236-
abi.encodePacked(
237-
MERKLE_NODE_PREFIX,
238-
currentDigest,
239-
siblingDigest
240-
)
241-
);
242-
}
243-
}
195+
bool valid;
196+
(valid, offset) = MerkleTree.verify(
197+
encoded,
198+
offset,
199+
digest,
200+
encodedMessage
201+
);
244202

245-
if (currentDigest != digest)
246-
revert PythErrors.InvalidUpdateData();
203+
if (!valid) {
204+
revert PythErrors.InvalidUpdateData();
247205
}
248206

249207
parseAndProcessMessage(encodedMessage);
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
// SPDX-License-Identifier: Apache 2
2+
3+
pragma solidity ^0.8.0;
4+
5+
import "@openzeppelin/contracts/proxy/ERC1967/ERC1967Proxy.sol";
6+
import "forge-std/Test.sol";
7+
8+
import "@pythnetwork/pyth-sdk-solidity/IPyth.sol";
9+
import "@pythnetwork/pyth-sdk-solidity/PythErrors.sol";
10+
import "@pythnetwork/pyth-sdk-solidity/PythStructs.sol";
11+
import "./utils/WormholeTestUtils.t.sol";
12+
import "./utils/PythTestUtils.t.sol";
13+
import "./utils/RandTestUtils.t.sol";
14+
15+
contract PythWormholeMerkleAccumulatorTest is
16+
Test,
17+
WormholeTestUtils,
18+
PythTestUtils,
19+
RandTestUtils
20+
{
21+
IPyth public pyth;
22+
23+
function setUp() public {
24+
pyth = IPyth(setUpPyth(setUpWormhole(1)));
25+
}
26+
27+
function generateRandomPriceFeedMessage(
28+
uint numPriceFeeds
29+
) internal returns (PriceFeedMessage[] memory priceFeedMessages) {
30+
priceFeedMessages = new PriceFeedMessage[](numPriceFeeds);
31+
for (uint i = 0; i < numPriceFeeds; i++) {
32+
priceFeedMessages[i] = PriceFeedMessage({
33+
priceId: getRandBytes32(),
34+
price: getRandInt64(),
35+
conf: getRandUint64(),
36+
expo: getRandInt32(),
37+
publishTime: getRandUint64(),
38+
emaPrice: getRandInt64(),
39+
emaConf: getRandUint64()
40+
});
41+
}
42+
}
43+
44+
function createWormholeMerkleUpdateData(
45+
PriceFeedMessage[] memory priceFeedMessages
46+
) internal returns (bytes[] memory updateData, uint updateFee) {
47+
updateData = new bytes[](1);
48+
49+
uint8 depth = 0;
50+
while ((1 << depth) < priceFeedMessages.length) {
51+
depth++;
52+
}
53+
54+
depth += getRandUint8() % 3;
55+
56+
updateData[0] = generateWhMerkleUpdate(priceFeedMessages, depth, 1);
57+
58+
updateFee = pyth.getUpdateFee(updateData);
59+
}
60+
61+
/// Testing update price feeds method using wormhole merkle update type.
62+
function testUpdatePriceFeedWithWormholeMerkleWorks(uint seed) public {
63+
setRandSeed(seed);
64+
65+
PriceFeedMessage[]
66+
memory priceFeedMessages = generateRandomPriceFeedMessage(
67+
(getRandUint() % 10) + 1
68+
);
69+
(
70+
bytes[] memory updateData,
71+
uint updateFee
72+
) = createWormholeMerkleUpdateData(priceFeedMessages);
73+
74+
pyth.updatePriceFeeds{value: updateFee}(updateData);
75+
76+
for (uint i = 0; i < priceFeedMessages.length; i++) {
77+
PythStructs.Price memory aggregatePrice = pyth.getPriceUnsafe(
78+
priceFeedMessages[i].priceId
79+
);
80+
assertEq(aggregatePrice.price, priceFeedMessages[i].price);
81+
assertEq(aggregatePrice.conf, priceFeedMessages[i].conf);
82+
assertEq(aggregatePrice.expo, priceFeedMessages[i].expo);
83+
assertEq(
84+
aggregatePrice.publishTime,
85+
priceFeedMessages[i].publishTime
86+
);
87+
88+
PythStructs.Price memory emaPrice = pyth.getEmaPriceUnsafe(
89+
priceFeedMessages[i].priceId
90+
);
91+
assertEq(emaPrice.price, priceFeedMessages[i].emaPrice);
92+
assertEq(emaPrice.conf, priceFeedMessages[i].emaConf);
93+
assertEq(emaPrice.expo, priceFeedMessages[i].expo);
94+
assertEq(emaPrice.publishTime, priceFeedMessages[i].publishTime);
95+
}
96+
}
97+
98+
function testUpdatePriceFeedWithWormholeMerkleWorksOnMultiUpdate() public {}
99+
100+
function testUpdatePriceFeedWithWormholeMerkleIgnoresOutOfOrderUpdate()
101+
public
102+
{}
103+
104+
function testUpdatePriceFeedWithWormholeMerkleRevertsOnWrongVAA() public {}
105+
106+
function testUpdatePriceFeedWithWormholeMerkleRevertsOnWrongVAASource()
107+
public
108+
{}
109+
110+
function testUpdatePriceFeedWithWormholeMerkleRevertsOnWrongProof()
111+
public
112+
{}
113+
114+
function testUpdatePriceFeedWithWormholeMerkleRevertsOnWrongHeader()
115+
public
116+
{}
117+
118+
function testUpdatePriceFeedWithWormholeMerkleRevertsIfUpdateFeeIsNotPaid()
119+
public
120+
{}
121+
}

0 commit comments

Comments
 (0)