Skip to content

Commit 801f1f7

Browse files
authored
refactor: merge segment collection logic [1/2] (#143)
1 parent 98507f8 commit 801f1f7

File tree

3 files changed

+120
-63
lines changed

3 files changed

+120
-63
lines changed

src/account/UpgradeableModularAccount.sol

Lines changed: 9 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import {SparseCalldataSegmentLib} from "../helpers/SparseCalldataSegmentLib.sol"
1919
import {ValidationConfigLib} from "../helpers/ValidationConfigLib.sol";
2020
import {_coalescePreValidation, _coalesceValidation} from "../helpers/ValidationResHelpers.sol";
2121

22-
import {DIRECT_CALL_VALIDATION_ENTITYID, RESERVED_VALIDATION_DATA_INDEX} from "../helpers/Constants.sol";
22+
import {DIRECT_CALL_VALIDATION_ENTITYID} from "../helpers/Constants.sol";
2323

2424
import {IExecutionHookModule} from "../interfaces/IExecutionHookModule.sol";
2525
import {ExecutionManifest} from "../interfaces/IExecutionModule.sol";
@@ -67,7 +67,6 @@ contract UpgradeableModularAccount is
6767
bytes4 internal constant _1271_MAGIC_VALUE = 0x1626ba7e;
6868
bytes4 internal constant _1271_INVALID = 0xffffffff;
6969

70-
error NonCanonicalEncoding();
7170
error NotEntryPoint();
7271
error PostExecHookReverted(address module, uint32 entityId, bytes revertReason);
7372
error PreExecHookReverted(address module, uint32 entityId, bytes revertReason);
@@ -79,8 +78,6 @@ contract UpgradeableModularAccount is
7978
error UnexpectedAggregator(address module, uint32 entityId, address aggregator);
8079
error UnrecognizedFunction(bytes4 selector);
8180
error ValidationFunctionMissing(bytes4 selector);
82-
error ValidationSignatureSegmentMissing();
83-
error SignatureSegmentOutOfOrder();
8481

8582
// Wraps execution of a native function with runtime validation and hooks
8683
// Used for upgradeTo, upgradeToAndCall, execute, executeBatch, installExecution, uninstallExecution
@@ -347,36 +344,14 @@ contract UpgradeableModularAccount is
347344
bytes calldata signature,
348345
bytes32 userOpHash
349346
) internal returns (uint256) {
350-
// Set up the per-hook data tracking fields
351-
bytes calldata signatureSegment;
352-
(signatureSegment, signature) = signature.getNextSegment();
353-
354347
uint256 validationRes;
355348

356349
// Do preUserOpValidation hooks
357350
ModuleEntity[] memory preUserOpValidationHooks =
358351
getAccountStorage().validationData[userOpValidationFunction].preValidationHooks;
359352

360353
for (uint256 i = 0; i < preUserOpValidationHooks.length; ++i) {
361-
// Load per-hook data, if any is present
362-
// The segment index is the first byte of the signature
363-
if (signatureSegment.getIndex() == i) {
364-
// Use the current segment
365-
userOp.signature = signatureSegment.getBody();
366-
367-
if (userOp.signature.length == 0) {
368-
revert NonCanonicalEncoding();
369-
}
370-
371-
// Load the next per-hook data segment
372-
(signatureSegment, signature) = signature.getNextSegment();
373-
374-
if (signatureSegment.getIndex() <= i) {
375-
revert SignatureSegmentOutOfOrder();
376-
}
377-
} else {
378-
userOp.signature = "";
379-
}
354+
(userOp.signature, signature) = signature.advanceSegmentIfAtIndex(uint8(i));
380355

381356
(address module, uint32 entityId) = preUserOpValidationHooks[i].unpack();
382357
uint256 currentValidationRes =
@@ -389,13 +364,9 @@ contract UpgradeableModularAccount is
389364
validationRes = _coalescePreValidation(validationRes, currentValidationRes);
390365
}
391366

392-
// Run the user op validationFunction
367+
// Run the user op validation function
393368
{
394-
if (signatureSegment.getIndex() != RESERVED_VALIDATION_DATA_INDEX) {
395-
revert ValidationSignatureSegmentMissing();
396-
}
397-
398-
userOp.signature = signatureSegment.getBody();
369+
userOp.signature = signature.getFinalSegment();
399370

400371
uint256 currentValidationRes = _execUserOpValidation(userOpValidationFunction, userOp, userOpHash);
401372

@@ -415,42 +386,21 @@ contract UpgradeableModularAccount is
415386
bytes calldata callData,
416387
bytes calldata authorizationData
417388
) internal {
418-
// Set up the per-hook data tracking fields
419-
bytes calldata authSegment;
420-
(authSegment, authorizationData) = authorizationData.getNextSegment();
421-
422389
// run all preRuntimeValidation hooks
423390
ModuleEntity[] memory preRuntimeValidationHooks =
424391
getAccountStorage().validationData[runtimeValidationFunction].preValidationHooks;
425392

426393
for (uint256 i = 0; i < preRuntimeValidationHooks.length; ++i) {
427-
bytes memory currentAuthData;
428-
429-
if (authSegment.getIndex() == i) {
430-
// Use the current segment
431-
currentAuthData = authSegment.getBody();
432-
433-
if (currentAuthData.length == 0) {
434-
revert NonCanonicalEncoding();
435-
}
394+
bytes memory currentAuthSegment;
436395

437-
// Load the next per-hook data segment
438-
(authSegment, authorizationData) = authorizationData.getNextSegment();
396+
(currentAuthSegment, authorizationData) = authorizationData.advanceSegmentIfAtIndex(uint8(i));
439397

440-
if (authSegment.getIndex() <= i) {
441-
revert SignatureSegmentOutOfOrder();
442-
}
443-
} else {
444-
currentAuthData = "";
445-
}
446-
_doPreRuntimeValidationHook(preRuntimeValidationHooks[i], callData, currentAuthData);
398+
_doPreRuntimeValidationHook(preRuntimeValidationHooks[i], callData, currentAuthSegment);
447399
}
448400

449-
if (authSegment.getIndex() != RESERVED_VALIDATION_DATA_INDEX) {
450-
revert ValidationSignatureSegmentMissing();
451-
}
401+
authorizationData = authorizationData.getFinalSegment();
452402

453-
_execRuntimeValidation(runtimeValidationFunction, callData, authSegment.getBody());
403+
_execRuntimeValidation(runtimeValidationFunction, callData, authorizationData);
454404
}
455405

456406
function _doPreHooks(EnumerableSet.Bytes32Set storage executionHooks, bytes memory data)

src/helpers/SparseCalldataSegmentLib.sol

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
// SPDX-License-Identifier: GPL-3.0
22
pragma solidity ^0.8.25;
33

4+
import {RESERVED_VALIDATION_DATA_INDEX} from "./Constants.sol";
5+
46
/// @title Sparse Calldata Segment Library
57
/// @notice Library for working with sparsely-packed calldata segments, identified with an index.
68
/// @dev The first byte of each segment is the index of the segment.
79
/// To prevent accidental stack-to-deep errors, the body and index of the segment are extracted separately, rather
810
/// than inline as part of the tuple returned by `getNextSegment`.
911
library SparseCalldataSegmentLib {
12+
error NonCanonicalEncoding();
13+
error SegmentOutOfOrder();
14+
error ValidationSignatureSegmentMissing();
15+
1016
/// @notice Splits out a segment of calldata, sparsely-packed.
1117
/// The expected format is:
1218
/// [uint32(len(segment0)), segment0, uint32(len(segment1)), segment1, ... uint32(len(segmentN)), segmentN]
@@ -33,6 +39,61 @@ library SparseCalldataSegmentLib {
3339
remainder = source[remainderOffset:];
3440
}
3541

42+
/// @notice If the index of the next segment in the source equals the provided index, return the next body and
43+
/// advance the source by one segment.
44+
/// @dev Reverts if the index of the next segment is less than the provided index, or if the extracted segment
45+
/// has length 0.
46+
/// @param source The calldata to extract the segment from.
47+
/// @param index The index of the segment to extract.
48+
/// @return A tuple containing the extracted segment's body, or an empty buffer if the index is not found, and
49+
/// the remaining calldata.
50+
function advanceSegmentIfAtIndex(bytes calldata source, uint8 index)
51+
internal
52+
pure
53+
returns (bytes memory, bytes calldata)
54+
{
55+
uint8 nextIndex = peekIndex(source);
56+
57+
if (nextIndex < index) {
58+
revert SegmentOutOfOrder();
59+
}
60+
61+
if (nextIndex == index) {
62+
(bytes calldata segment, bytes calldata remainder) = getNextSegment(source);
63+
64+
segment = getBody(segment);
65+
66+
if (segment.length == 0) {
67+
revert NonCanonicalEncoding();
68+
}
69+
70+
return (segment, remainder);
71+
}
72+
73+
return ("", source);
74+
}
75+
76+
function getFinalSegment(bytes calldata source) internal pure returns (bytes calldata) {
77+
(bytes calldata segment, bytes calldata remainder) = getNextSegment(source);
78+
79+
if (getIndex(segment) != RESERVED_VALIDATION_DATA_INDEX) {
80+
revert ValidationSignatureSegmentMissing();
81+
}
82+
83+
if (remainder.length != 0) {
84+
revert NonCanonicalEncoding();
85+
}
86+
87+
return getBody(segment);
88+
}
89+
90+
/// @notice Returns the index of the next segment in the source.
91+
/// @param source The calldata to extract the index from.
92+
/// @return The index of the next segment.
93+
function peekIndex(bytes calldata source) internal pure returns (uint8) {
94+
return uint8(source[4]);
95+
}
96+
3697
/// @notice Extracts the index from a segment.
3798
/// @dev The first byte of the segment is the index.
3899
/// @param segment The segment to extract the index from

test/account/PerHookData.t.sol

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import {UpgradeableModularAccount} from "../../src/account/UpgradeableModularAcc
99

1010
import {HookConfigLib} from "../../src/helpers/HookConfigLib.sol";
1111
import {ModuleEntity, ModuleEntityLib} from "../../src/helpers/ModuleEntityLib.sol";
12+
import {SparseCalldataSegmentLib} from "../../src/helpers/SparseCalldataSegmentLib.sol";
1213

1314
import {Counter} from "../mocks/Counter.sol";
1415
import {MockAccessControlHookModule} from "../mocks/modules/MockAccessControlHookModule.sol";
@@ -123,7 +124,7 @@ contract PerHookDataTest is CustomValidationTestBase {
123124
IEntryPoint.FailedOpWithRevert.selector,
124125
0,
125126
"AA23 reverted",
126-
abi.encodeWithSelector(UpgradeableModularAccount.ValidationSignatureSegmentMissing.selector)
127+
abi.encodeWithSelector(SparseCalldataSegmentLib.ValidationSignatureSegmentMissing.selector)
127128
)
128129
);
129130
entryPoint.handleOps(userOps, beneficiary);
@@ -187,7 +188,35 @@ contract PerHookDataTest is CustomValidationTestBase {
187188
IEntryPoint.FailedOpWithRevert.selector,
188189
0,
189190
"AA23 reverted",
190-
abi.encodeWithSelector(UpgradeableModularAccount.NonCanonicalEncoding.selector)
191+
abi.encodeWithSelector(SparseCalldataSegmentLib.NonCanonicalEncoding.selector)
192+
)
193+
);
194+
entryPoint.handleOps(userOps, beneficiary);
195+
}
196+
197+
function test_failPerHookData_excessData_userOp() public {
198+
(PackedUserOperation memory userOp, bytes32 userOpHash) = _getCounterUserOP();
199+
(uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash());
200+
201+
PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1);
202+
preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(_counter)});
203+
204+
userOp.signature = abi.encodePacked(
205+
_encodeSignature(
206+
_signerValidation, GLOBAL_VALIDATION, preValidationHookData, abi.encodePacked(r, s, v)
207+
),
208+
"extra data"
209+
);
210+
211+
PackedUserOperation[] memory userOps = new PackedUserOperation[](1);
212+
userOps[0] = userOp;
213+
214+
vm.expectRevert(
215+
abi.encodeWithSelector(
216+
IEntryPoint.FailedOpWithRevert.selector,
217+
0,
218+
"AA23 reverted",
219+
abi.encodeWithSelector(SparseCalldataSegmentLib.NonCanonicalEncoding.selector)
191220
)
192221
);
193222
entryPoint.handleOps(userOps, beneficiary);
@@ -262,7 +291,7 @@ contract PerHookDataTest is CustomValidationTestBase {
262291

263292
vm.prank(owner1);
264293
vm.expectRevert(
265-
abi.encodeWithSelector(UpgradeableModularAccount.ValidationSignatureSegmentMissing.selector)
294+
abi.encodeWithSelector(SparseCalldataSegmentLib.ValidationSignatureSegmentMissing.selector)
266295
);
267296
account1.executeWithAuthorization(
268297
abi.encodeCall(
@@ -299,7 +328,7 @@ contract PerHookDataTest is CustomValidationTestBase {
299328
preValidationHookData[0] = PreValidationHookData({index: 0, validationData: ""});
300329

301330
vm.prank(owner1);
302-
vm.expectRevert(abi.encodeWithSelector(UpgradeableModularAccount.NonCanonicalEncoding.selector));
331+
vm.expectRevert(abi.encodeWithSelector(SparseCalldataSegmentLib.NonCanonicalEncoding.selector));
303332
account1.executeWithAuthorization(
304333
abi.encodeCall(
305334
UpgradeableModularAccount.execute,
@@ -309,6 +338,23 @@ contract PerHookDataTest is CustomValidationTestBase {
309338
);
310339
}
311340

341+
function test_failPerHookData_excessData_runtime() public {
342+
PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1);
343+
preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(_counter)});
344+
345+
vm.prank(owner1);
346+
vm.expectRevert(abi.encodeWithSelector(SparseCalldataSegmentLib.NonCanonicalEncoding.selector));
347+
account1.executeWithAuthorization(
348+
abi.encodeCall(
349+
UpgradeableModularAccount.execute,
350+
(address(_counter), 0 wei, abi.encodeCall(Counter.increment, ()))
351+
),
352+
abi.encodePacked(
353+
_encodeSignature(_signerValidation, GLOBAL_VALIDATION, preValidationHookData, ""), "extra data"
354+
)
355+
);
356+
}
357+
312358
function _getCounterUserOP() internal view returns (PackedUserOperation memory, bytes32) {
313359
PackedUserOperation memory userOp = PackedUserOperation({
314360
sender: address(account1),

0 commit comments

Comments
 (0)