Skip to content

Commit 45a9e8a

Browse files
authored
feat: skip signature length encoding on final sig and add type safety (#192)
1 parent 14f1f89 commit 45a9e8a

File tree

9 files changed

+161
-197
lines changed

9 files changed

+161
-197
lines changed

src/libraries/SparseCalldataSegmentLib.sol

Lines changed: 16 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,26 @@ library SparseCalldataSegmentLib {
1515

1616
/// @notice Splits out a segment of calldata, sparsely-packed.
1717
/// The expected format is:
18-
/// [uint32(len(segment0)), segment0, uint32(len(segment1)), segment1, ... uint32(len(segmentN)), segmentN]
18+
/// [uint8(index0), uint32(len(segment0)), segment0, uint8(index1), uint32(len(segment1)), segment1,
19+
/// ... uint8(indexN), uint32(len(segmentN)), segmentN]
1920
/// @param source The calldata to extract the segment from.
2021
/// @return segment The extracted segment. Using the above example, this would be segment0.
2122
/// @return remainder The remaining calldata. Using the above example,
22-
/// this would start at uint32(len(segment1)) and continue to the end at segmentN.
23+
/// this would start at uint8(index1) and continue to the end at segmentN.
2324
function getNextSegment(bytes calldata source)
2425
internal
2526
pure
2627
returns (bytes calldata segment, bytes calldata remainder)
2728
{
28-
// The first 4 bytes hold the length of the segment, excluding the index.
29-
uint32 length = uint32(bytes4(source[:4]));
29+
// The first byte of the segment is the index.
30+
// The next 4 bytes hold the length of the segment, excluding the index.
31+
uint32 length = uint32(bytes4(source[1:5]));
3032

3133
// The offset of the remainder of the calldata.
32-
uint256 remainderOffset = 4 + length;
34+
uint256 remainderOffset = 5 + length;
3335

34-
// The segment is the next `length` + 1 bytes, to account for the index.
35-
// By convention, the first byte of each segment is the index of the segment.
36-
segment = source[4:remainderOffset];
36+
// The segment is the next `length` bytes after the first 5 bytes.
37+
segment = source[5:remainderOffset];
3738

3839
// The remainder is the rest of the calldata.
3940
remainder = source[remainderOffset:];
@@ -52,7 +53,7 @@ library SparseCalldataSegmentLib {
5253
pure
5354
returns (bytes memory, bytes calldata)
5455
{
55-
uint8 nextIndex = peekIndex(source);
56+
uint8 nextIndex = getIndex(source);
5657

5758
if (nextIndex < index) {
5859
revert SegmentOutOfOrder();
@@ -61,8 +62,6 @@ library SparseCalldataSegmentLib {
6162
if (nextIndex == index) {
6263
(bytes calldata segment, bytes calldata remainder) = getNextSegment(source);
6364

64-
segment = getBody(segment);
65-
6665
if (segment.length == 0) {
6766
revert NonCanonicalEncoding();
6867
}
@@ -73,25 +72,16 @@ library SparseCalldataSegmentLib {
7372
return ("", source);
7473
}
7574

75+
/// @notice Extracts the final segment from the source.
76+
/// @dev Reverts if the index of the segment is not RESERVED_VALIDATION_DATA_INDEX.
77+
/// @param source The calldata to extract the segment from.
78+
/// @return The final segment.
7679
function getFinalSegment(bytes calldata source) internal pure returns (bytes calldata) {
77-
(bytes calldata segment, bytes calldata remainder) = getNextSegment(source);
78-
79-
if (getIndex(segment) != RESERVED_VALIDATION_DATA_INDEX) {
80+
if (getIndex(source) != RESERVED_VALIDATION_DATA_INDEX) {
8081
revert ValidationSignatureSegmentMissing();
8182
}
8283

83-
if (remainder.length != 0) {
84-
revert NonCanonicalEncoding();
85-
}
86-
87-
return getBody(segment);
88-
}
89-
90-
/// @notice Returns the index of the next segment in the source.
91-
/// @param source The calldata to extract the index from.
92-
/// @return The index of the next segment.
93-
function peekIndex(bytes calldata source) internal pure returns (uint8) {
94-
return uint8(source[4]);
84+
return source[1:];
9585
}
9686

9787
/// @notice Extracts the index from a segment.
@@ -101,12 +91,4 @@ library SparseCalldataSegmentLib {
10191
function getIndex(bytes calldata segment) internal pure returns (uint8) {
10292
return uint8(segment[0]);
10393
}
104-
105-
/// @notice Extracts the body from a segment.
106-
/// @dev The body is the segment without the index.
107-
/// @param segment The segment to extract the body from
108-
/// @return The body of the segment.
109-
function getBody(bytes calldata segment) internal pure returns (bytes calldata) {
110-
return segment[1:];
111-
}
11294
}

src/libraries/ValidationConfigLib.sol

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,13 @@ pragma solidity ^0.8.20;
33

44
import {ModuleEntity, ValidationConfig} from "../interfaces/IModularAccount.sol";
55

6+
// Validation flags layout:
7+
// 0b00000___ // unused
8+
// 0b_____A__ // isGlobal
9+
// 0b______B_ // isSignatureValidation
10+
// 0b_______C // isUserOpValidation
11+
type ValidationFlags is uint8;
12+
613
// Validation config is a packed representation of a validation function and flags for its configuration.
714
// Layout:
815
// 0xAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA________________________ // Address
@@ -63,22 +70,22 @@ library ValidationConfigLib {
6370
function unpackUnderlying(ValidationConfig config)
6471
internal
6572
pure
66-
returns (address _module, uint32 _entityId, uint8 flags)
73+
returns (address _module, uint32 _entityId, ValidationFlags flags)
6774
{
6875
bytes25 configBytes = ValidationConfig.unwrap(config);
6976
_module = address(bytes20(configBytes));
7077
_entityId = uint32(bytes4(configBytes << 160));
71-
flags = uint8(configBytes[24]);
78+
flags = ValidationFlags.wrap(uint8(configBytes[24]));
7279
}
7380

7481
function unpack(ValidationConfig config)
7582
internal
7683
pure
77-
returns (ModuleEntity _validationFunction, uint8 flags)
84+
returns (ModuleEntity _validationFunction, ValidationFlags flags)
7885
{
7986
bytes25 configBytes = ValidationConfig.unwrap(config);
8087
_validationFunction = ModuleEntity.wrap(bytes24(configBytes));
81-
flags = uint8(configBytes[24]);
88+
flags = ValidationFlags.wrap(uint8(configBytes[24]));
8289
}
8390

8491
function module(ValidationConfig config) internal pure returns (address) {
@@ -97,23 +104,23 @@ library ValidationConfigLib {
97104
return ValidationConfig.unwrap(config) & _VALIDATION_FLAG_IS_GLOBAL != 0;
98105
}
99106

100-
function isGlobal(uint8 flags) internal pure returns (bool) {
101-
return flags & 0x04 != 0;
107+
function isGlobal(ValidationFlags flags) internal pure returns (bool) {
108+
return ValidationFlags.unwrap(flags) & 0x04 != 0;
102109
}
103110

104111
function isSignatureValidation(ValidationConfig config) internal pure returns (bool) {
105112
return ValidationConfig.unwrap(config) & _VALIDATION_FLAG_IS_SIGNATURE != 0;
106113
}
107114

108-
function isSignatureValidation(uint8 flags) internal pure returns (bool) {
109-
return flags & 0x02 != 0;
115+
function isSignatureValidation(ValidationFlags flags) internal pure returns (bool) {
116+
return ValidationFlags.unwrap(flags) & 0x02 != 0;
110117
}
111118

112119
function isUserOpValidation(ValidationConfig config) internal pure returns (bool) {
113120
return ValidationConfig.unwrap(config) & _VALIDATION_FLAG_IS_USER_OP != 0;
114121
}
115122

116-
function isUserOpValidation(uint8 flags) internal pure returns (bool) {
117-
return flags & 0x01 != 0;
123+
function isUserOpValidation(ValidationFlags flags) internal pure returns (bool) {
124+
return ValidationFlags.unwrap(flags) & 0x01 != 0;
118125
}
119126
}

test/account/AccountReturnData.t.sol

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ contract AccountReturnDataTest is AccountTestBase {
112112

113113
// Tests the ability to read data via executeWithRuntimeValidation
114114
function test_returnData_authorized_exec() public {
115-
bool result = ResultConsumerModule(address(account1)).checkResultExecuteWithAuthorization(
115+
bool result = ResultConsumerModule(address(account1)).checkResultExecuteWithRuntimeValidation(
116116
address(regularResultContract), keccak256("bar")
117117
);
118118

test/account/PerHookData.t.sol

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -246,34 +246,6 @@ contract PerHookDataTest is CustomValidationTestBase {
246246
entryPoint.handleOps(userOps, beneficiary);
247247
}
248248

249-
function test_failPerHookData_excessData_userOp() public {
250-
(PackedUserOperation memory userOp, bytes32 userOpHash) = _getCounterUserOP();
251-
(uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, userOpHash.toEthSignedMessageHash());
252-
253-
PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1);
254-
preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(_counter)});
255-
256-
userOp.signature = abi.encodePacked(
257-
_encodeSignature(
258-
_signerValidation, GLOBAL_VALIDATION, preValidationHookData, abi.encodePacked(r, s, v)
259-
),
260-
"extra data"
261-
);
262-
263-
PackedUserOperation[] memory userOps = new PackedUserOperation[](1);
264-
userOps[0] = userOp;
265-
266-
vm.expectRevert(
267-
abi.encodeWithSelector(
268-
IEntryPoint.FailedOpWithRevert.selector,
269-
0,
270-
"AA23 reverted",
271-
abi.encodeWithSelector(SparseCalldataSegmentLib.NonCanonicalEncoding.selector)
272-
)
273-
);
274-
entryPoint.handleOps(userOps, beneficiary);
275-
}
276-
277249
function test_passAccessControl_runtime() public {
278250
assertEq(_counter.number(), 0);
279251

@@ -420,22 +392,6 @@ contract PerHookDataTest is CustomValidationTestBase {
420392
);
421393
}
422394

423-
function test_failPerHookData_excessData_runtime() public {
424-
PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1);
425-
preValidationHookData[0] = PreValidationHookData({index: 0, validationData: abi.encodePacked(_counter)});
426-
427-
vm.prank(owner1);
428-
vm.expectRevert(abi.encodeWithSelector(SparseCalldataSegmentLib.NonCanonicalEncoding.selector));
429-
account1.executeWithRuntimeValidation(
430-
abi.encodeCall(
431-
ReferenceModularAccount.execute, (address(_counter), 0 wei, abi.encodeCall(Counter.increment, ()))
432-
),
433-
abi.encodePacked(
434-
_encodeSignature(_signerValidation, GLOBAL_VALIDATION, preValidationHookData, ""), "extra data"
435-
)
436-
);
437-
}
438-
439395
function test_pass1271AccessControl() public {
440396
string memory message = "Hello, world!";
441397

test/libraries/SparseCalldataSegmentLib.t.sol

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ contract SparseCalldataSegmentLibTest is Test {
4949
bytes memory result = "";
5050

5151
for (uint256 i = 0; i < segments.length; i++) {
52-
result = abi.encodePacked(result, uint32(segments[i].length), segments[i]);
52+
result = abi.encodePacked(result, uint8(0), uint32(segments[i].length), segments[i]);
5353
}
5454

5555
return result;
@@ -65,7 +65,7 @@ contract SparseCalldataSegmentLibTest is Test {
6565
bytes memory result = "";
6666

6767
for (uint256 i = 0; i < segments.length; i++) {
68-
result = abi.encodePacked(result, uint32(segments[i].length + 1), indices[i], segments[i]);
68+
result = abi.encodePacked(result, indices[i], uint32(segments[i].length), segments[i]);
6969
}
7070

7171
return result;
@@ -99,10 +99,10 @@ contract SparseCalldataSegmentLibTest is Test {
9999

100100
uint256 index = 0;
101101
while (remainder.length > 0) {
102+
indices[index] = remainder.getIndex();
102103
bytes calldata segment;
103104
(segment, remainder) = remainder.getNextSegment();
104-
bodies[index] = segment.getBody();
105-
indices[index] = segment.getIndex();
105+
bodies[index] = segment;
106106
index++;
107107
}
108108

test/libraries/ValidationConfigLib.t.sol

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ pragma solidity ^0.8.20;
44
import {Test} from "forge-std/Test.sol";
55

66
import {ModuleEntity, ModuleEntityLib} from "../../src/libraries/ModuleEntityLib.sol";
7-
import {ValidationConfig, ValidationConfigLib} from "../../src/libraries/ValidationConfigLib.sol";
7+
import {
8+
ValidationConfig, ValidationConfigLib, ValidationFlags
9+
} from "../../src/libraries/ValidationConfigLib.sol";
810

911
contract ValidationConfigLibTest is Test {
1012
using ModuleEntityLib for ModuleEntity;
@@ -23,7 +25,7 @@ contract ValidationConfigLibTest is Test {
2325
ValidationConfigLib.pack(module, entityId, isGlobal, isSignatureValidation, isUserOpValidation);
2426

2527
// Test unpacking underlying
26-
(address module2, uint32 entityId2, uint8 flags2) = validationConfig.unpackUnderlying();
28+
(address module2, uint32 entityId2, ValidationFlags flags2) = validationConfig.unpackUnderlying();
2729

2830
assertEq(module, module2, "module mismatch");
2931
assertEq(entityId, entityId2, "entityId mismatch");
@@ -35,7 +37,7 @@ contract ValidationConfigLibTest is Test {
3537

3638
ModuleEntity expectedModuleEntity = ModuleEntityLib.pack(module, entityId);
3739

38-
(ModuleEntity validationFunction, uint8 flags3) = validationConfig.unpack();
40+
(ModuleEntity validationFunction, ValidationFlags flags3) = validationConfig.unpack();
3941

4042
assertEq(
4143
ModuleEntity.unwrap(validationFunction),
@@ -73,7 +75,7 @@ contract ValidationConfigLibTest is Test {
7375

7476
(address expectedModule, uint32 expectedEntityId) = validationFunction.unpack();
7577

76-
(address module, uint32 entityId, uint8 flags2) = validationConfig.unpackUnderlying();
78+
(address module, uint32 entityId, ValidationFlags flags2) = validationConfig.unpackUnderlying();
7779

7880
assertEq(expectedModule, module, "module mismatch");
7981
assertEq(expectedEntityId, entityId, "entityId mismatch");
@@ -83,7 +85,7 @@ contract ValidationConfigLibTest is Test {
8385

8486
// Test unpacking to ModuleEntity
8587

86-
(ModuleEntity validationFunction2, uint8 flags3) = validationConfig.unpack();
88+
(ModuleEntity validationFunction2, ValidationFlags flags3) = validationConfig.unpack();
8789

8890
assertEq(
8991
ModuleEntity.unwrap(validationFunction),

test/mocks/modules/ReturnDataModuleMocks.sol

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,19 @@ pragma solidity ^0.8.20;
33

44
import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol";
55

6+
import {DIRECT_CALL_VALIDATION_ENTITYID} from "../../../src/helpers/Constants.sol";
67
import {
78
ExecutionManifest,
89
IExecutionModule,
910
ManifestExecutionFunction
1011
} from "../../../src/interfaces/IExecutionModule.sol";
11-
12-
import {DIRECT_CALL_VALIDATION_ENTITYID} from "../../../src/helpers/Constants.sol";
13-
1412
import {IModularAccount} from "../../../src/interfaces/IModularAccount.sol";
1513
import {IValidationModule} from "../../../src/interfaces/IValidationModule.sol";
16-
14+
import {ModuleEntityLib} from "../../../src/libraries/ModuleEntityLib.sol";
1715
import {BaseModule} from "../../../src/modules/BaseModule.sol";
1816

17+
import {ModuleSignatureUtils} from "../../utils/ModuleSignatureUtils.sol";
18+
1919
contract RegularResultContract {
2020
function foo() external pure returns (bytes32) {
2121
return keccak256("bar");
@@ -62,7 +62,7 @@ contract ResultCreatorModule is IExecutionModule, BaseModule {
6262
}
6363
}
6464

65-
contract ResultConsumerModule is IExecutionModule, BaseModule, IValidationModule {
65+
contract ResultConsumerModule is IExecutionModule, BaseModule, IValidationModule, ModuleSignatureUtils {
6666
ResultCreatorModule public immutable RESULT_CREATOR;
6767
RegularResultContract public immutable REGULAR_RESULT_CONTRACT;
6868

@@ -102,13 +102,11 @@ contract ResultConsumerModule is IExecutionModule, BaseModule, IValidationModule
102102
}
103103

104104
// Check the return data through the execute with authorization case
105-
function checkResultExecuteWithAuthorization(address target, bytes32 expected) external returns (bool) {
105+
function checkResultExecuteWithRuntimeValidation(address target, bytes32 expected) external returns (bool) {
106106
// This result should be allowed based on the manifest permission request
107107
bytes memory returnData = IModularAccount(msg.sender).executeWithRuntimeValidation(
108108
abi.encodeCall(IModularAccount.execute, (target, 0, abi.encodeCall(RegularResultContract.foo, ()))),
109-
abi.encodePacked(this, DIRECT_CALL_VALIDATION_ENTITYID, uint8(0), uint32(1), uint8(255)) // Validation
110-
// function of self,
111-
// selector-associated, with no auth data
109+
_encodeSignature(ModuleEntityLib.pack(address(this), DIRECT_CALL_VALIDATION_ENTITYID), uint8(0), "")
112110
);
113111

114112
bytes32 actual = abi.decode(abi.decode(returnData, (bytes)), (bytes32));
@@ -130,7 +128,7 @@ contract ResultConsumerModule is IExecutionModule, BaseModule, IValidationModule
130128
allowGlobalValidation: false
131129
});
132130
manifest.executionFunctions[1] = ManifestExecutionFunction({
133-
executionSelector: this.checkResultExecuteWithAuthorization.selector,
131+
executionSelector: this.checkResultExecuteWithRuntimeValidation.selector,
134132
skipRuntimeValidation: true,
135133
allowGlobalValidation: false
136134
});

0 commit comments

Comments
 (0)