Skip to content

Commit 4ea3c43

Browse files
authored
feat: support global direct call validation (#164)
1 parent a748955 commit 4ea3c43

File tree

2 files changed

+93
-28
lines changed

2 files changed

+93
-28
lines changed

src/account/ReferenceModularAccount.sol

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ contract ReferenceModularAccount is
5252
ModuleEntity postExecHook;
5353
}
5454

55+
enum ValidationCheckingType {
56+
GLOBAL,
57+
SELECTOR,
58+
EITHER
59+
}
60+
5561
IEntryPoint private immutable _ENTRY_POINT;
5662

5763
// As per the EIP-165 spec, no interface should ever match 0xffffffff
@@ -187,7 +193,11 @@ contract ReferenceModularAccount is
187193

188194
// Check if the runtime validation function is allowed to be called
189195
bool isGlobalValidation = uint8(authorization[24]) == 1;
190-
_checkIfValidationAppliesCallData(data, runtimeValidationFunction, isGlobalValidation);
196+
_checkIfValidationAppliesCallData(
197+
data,
198+
runtimeValidationFunction,
199+
isGlobalValidation ? ValidationCheckingType.GLOBAL : ValidationCheckingType.SELECTOR
200+
);
191201

192202
_doRuntimeValidation(runtimeValidationFunction, data, authorization[25:]);
193203

@@ -343,7 +353,11 @@ contract ReferenceModularAccount is
343353
ModuleEntity userOpValidationFunction = ModuleEntity.wrap(bytes24(userOp.signature[:24]));
344354
bool isGlobalValidation = uint8(userOp.signature[24]) == 1;
345355

346-
_checkIfValidationAppliesCallData(userOp.callData, userOpValidationFunction, isGlobalValidation);
356+
_checkIfValidationAppliesCallData(
357+
userOp.callData,
358+
userOpValidationFunction,
359+
isGlobalValidation ? ValidationCheckingType.GLOBAL : ValidationCheckingType.SELECTOR
360+
);
347361

348362
// Check if there are permission hooks associated with the validator, and revert if the call isn't to
349363
// `executeUserOp`
@@ -549,7 +563,7 @@ contract ReferenceModularAccount is
549563
ModuleEntity directCallValidationKey =
550564
ModuleEntityLib.pack(msg.sender, DIRECT_CALL_VALIDATION_ENTITYID);
551565

552-
_checkIfValidationAppliesCallData(msg.data, directCallValidationKey, false);
566+
_checkIfValidationAppliesCallData(msg.data, directCallValidationKey, ValidationCheckingType.EITHER);
553567

554568
// Direct call is allowed, run associated permission & validation hooks
555569

@@ -645,7 +659,7 @@ contract ReferenceModularAccount is
645659
function _checkIfValidationAppliesCallData(
646660
bytes calldata callData,
647661
ModuleEntity validationFunction,
648-
bool isGlobal
662+
ValidationCheckingType checkingType
649663
) internal view {
650664
bytes4 outerSelector = bytes4(callData[:4]);
651665
if (outerSelector == this.executeUserOp.selector) {
@@ -655,7 +669,7 @@ contract ReferenceModularAccount is
655669
outerSelector = bytes4(callData[:4]);
656670
}
657671

658-
_checkIfValidationAppliesSelector(outerSelector, validationFunction, isGlobal);
672+
_checkIfValidationAppliesSelector(outerSelector, validationFunction, checkingType);
659673

660674
if (outerSelector == IModularAccount.execute.selector) {
661675
(address target,,) = abi.decode(callData[4:], (address, uint256, bytes));
@@ -689,26 +703,50 @@ contract ReferenceModularAccount is
689703
revert SelfCallRecursionDepthExceeded();
690704
}
691705

692-
_checkIfValidationAppliesSelector(nestedSelector, validationFunction, isGlobal);
706+
_checkIfValidationAppliesSelector(nestedSelector, validationFunction, checkingType);
693707
}
694708
}
695709
}
696710
}
697711

698-
function _checkIfValidationAppliesSelector(bytes4 selector, ModuleEntity validationFunction, bool isGlobal)
699-
internal
700-
view
701-
{
712+
function _checkIfValidationAppliesSelector(
713+
bytes4 selector,
714+
ModuleEntity validationFunction,
715+
ValidationCheckingType checkingType
716+
) internal view {
702717
// Check that the provided validation function is applicable to the selector
703-
if (isGlobal) {
704-
if (!_globalValidationAllowed(selector) || !_isValidationGlobal(validationFunction)) {
718+
719+
if (checkingType == ValidationCheckingType.GLOBAL) {
720+
if (!_globalValidationApplies(selector, validationFunction)) {
721+
revert ValidationFunctionMissing(selector);
722+
}
723+
} else if (checkingType == ValidationCheckingType.SELECTOR) {
724+
if (!_selectorValidationApplies(selector, validationFunction)) {
705725
revert ValidationFunctionMissing(selector);
706726
}
707727
} else {
708-
// Not global validation, but per-selector
709-
if (!getAccountStorage().validationData[validationFunction].selectors.contains(toSetValue(selector))) {
728+
if (
729+
!_globalValidationApplies(selector, validationFunction)
730+
&& !_selectorValidationApplies(selector, validationFunction)
731+
) {
710732
revert ValidationFunctionMissing(selector);
711733
}
712734
}
713735
}
736+
737+
function _globalValidationApplies(bytes4 selector, ModuleEntity validationFunction)
738+
internal
739+
view
740+
returns (bool)
741+
{
742+
return _globalValidationAllowed(selector) && _isValidationGlobal(validationFunction);
743+
}
744+
745+
function _selectorValidationApplies(bytes4 selector, ModuleEntity validationFunction)
746+
internal
747+
view
748+
returns (bool)
749+
{
750+
return getAccountStorage().validationData[validationFunction].selectors.contains(toSetValue(selector));
751+
}
714752
}

test/account/DirectCallsFromModule.t.sol

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,15 @@ contract DirectCallsFromModuleTest is AccountTestBase {
2121

2222
event ValidationUninstalled(address indexed module, uint32 indexed entityId, bool onUninstallSucceeded);
2323

24+
modifier randomizedValidationType(bool selectorValidation) {
25+
if (selectorValidation) {
26+
_installValidationSelector();
27+
} else {
28+
_installValidationGlobal();
29+
}
30+
_;
31+
}
32+
2433
function setUp() public {
2534
_module = new DirectCallModule();
2635
assertFalse(_module.preHookRan());
@@ -38,9 +47,10 @@ contract DirectCallsFromModuleTest is AccountTestBase {
3847
account1.execute(address(0), 0, "");
3948
}
4049

41-
function test_Fail_DirectCallModuleUninstalled() external {
42-
_installValidation();
43-
50+
function testFuzz_Fail_DirectCallModuleUninstalled(bool validationType)
51+
external
52+
randomizedValidationType(validationType)
53+
{
4454
_uninstallValidation();
4555

4656
vm.prank(address(_module));
@@ -49,7 +59,7 @@ contract DirectCallsFromModuleTest is AccountTestBase {
4959
}
5060

5161
function test_Fail_DirectCallModuleCallOtherSelector() external {
52-
_installValidation();
62+
_installValidationSelector();
5363

5464
Call[] memory calls = new Call[](0);
5565

@@ -62,19 +72,21 @@ contract DirectCallsFromModuleTest is AccountTestBase {
6272
/* Positives */
6373
/* -------------------------------------------------------------------------- */
6474

65-
function test_Pass_DirectCallFromModulePrank() external {
66-
_installValidation();
67-
75+
function testFuzz_Pass_DirectCallFromModulePrank(bool validationType)
76+
external
77+
randomizedValidationType(validationType)
78+
{
6879
vm.prank(address(_module));
6980
account1.execute(address(0), 0, "");
7081

7182
assertTrue(_module.preHookRan());
7283
assertTrue(_module.postHookRan());
7384
}
7485

75-
function test_Pass_DirectCallFromModuleCallback() external {
76-
_installValidation();
77-
86+
function testFuzz_Pass_DirectCallFromModuleCallback(bool validationType)
87+
external
88+
randomizedValidationType(validationType)
89+
{
7890
bytes memory encodedCall = abi.encodeCall(DirectCallModule.directCall, ());
7991

8092
vm.prank(address(entryPoint));
@@ -88,11 +100,12 @@ contract DirectCallsFromModuleTest is AccountTestBase {
88100
assertEq(abi.decode(result, (bytes)), abi.encode(_module.getData()));
89101
}
90102

91-
function test_Flow_DirectCallFromModuleSequence() external {
103+
function testFuzz_Flow_DirectCallFromModuleSequence(bool validationType)
104+
external
105+
randomizedValidationType(validationType)
106+
{
92107
// Install => Succeesfully call => uninstall => fail to call
93108

94-
_installValidation();
95-
96109
vm.prank(address(_module));
97110
account1.execute(address(0), 0, "");
98111

@@ -129,7 +142,7 @@ contract DirectCallsFromModuleTest is AccountTestBase {
129142
/* Internals */
130143
/* -------------------------------------------------------------------------- */
131144

132-
function _installValidation() internal {
145+
function _installValidationSelector() internal {
133146
bytes4[] memory selectors = new bytes4[](1);
134147
selectors[0] = IModularAccount.execute.selector;
135148

@@ -146,6 +159,20 @@ contract DirectCallsFromModuleTest is AccountTestBase {
146159
account1.installValidation(validationConfig, selectors, "", hooks);
147160
}
148161

162+
function _installValidationGlobal() internal {
163+
bytes[] memory hooks = new bytes[](1);
164+
hooks[0] = abi.encodePacked(
165+
HookConfigLib.packExecHook({_hookFunction: _moduleEntity, _hasPre: true, _hasPost: true}),
166+
hex"00" // onInstall data
167+
);
168+
169+
vm.prank(address(entryPoint));
170+
171+
ValidationConfig validationConfig = ValidationConfigLib.pack(_moduleEntity, true, false);
172+
173+
account1.installValidation(validationConfig, new bytes4[](0), "", hooks);
174+
}
175+
149176
function _uninstallValidation() internal {
150177
(address module, uint32 entityId) = ModuleEntityLib.unpack(_moduleEntity);
151178
vm.prank(address(entryPoint));

0 commit comments

Comments
 (0)