diff --git a/.gitmodules b/.gitmodules index 229aff1b..962d846f 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,9 +7,6 @@ [submodule "lib/forge-std"] path = lib/forge-std url = https://github.com/foundry-rs/forge-std -[submodule "lib/modular-account-libs"] - path = lib/modular-account-libs - url = https://github.com/erc6900/modular-account-libs [submodule "lib/solady"] path = lib/solady url = https://github.com/vectorized/solady diff --git a/lib/modular-account-libs b/lib/modular-account-libs deleted file mode 160000 index 5d9d0e40..00000000 --- a/lib/modular-account-libs +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 5d9d0e403332251045eee2954c2a8b7ea0bae953 diff --git a/remappings.txt b/remappings.txt index 8d9639cf..7f94ff52 100644 --- a/remappings.txt +++ b/remappings.txt @@ -2,5 +2,4 @@ ds-test/=lib/forge-std/lib/ds-test/src/ forge-std/=lib/forge-std/src/ @eth-infinitism/account-abstraction/=lib/account-abstraction/contracts/ @openzeppelin/=lib/openzeppelin-contracts/ -@modular-account-libs/=lib/modular-account-libs/src/ solady=lib/solady/src/ diff --git a/src/modules/ERC20TokenLimitModule.sol b/src/modules/ERC20TokenLimitModule.sol index 8d5aad46..9966ffa9 100644 --- a/src/modules/ERC20TokenLimitModule.sol +++ b/src/modules/ERC20TokenLimitModule.sol @@ -3,14 +3,7 @@ pragma solidity ^0.8.20; import {UserOperationLib} from "@eth-infinitism/account-abstraction/core/UserOperationLib.sol"; import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; - -import { - AssociatedLinkedListSet, - AssociatedLinkedListSetLib, - SetValue -} from "@modular-account-libs/libraries/AssociatedLinkedListSetLib.sol"; import {IERC20} from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; -import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; import {IExecutionHookModule} from "../interfaces/IExecutionHookModule.sol"; import {Call, IModularAccount} from "../interfaces/IModularAccount.sol"; @@ -27,29 +20,24 @@ import {BaseModule, IERC165} from "./BaseModule.sol"; /// token contract contract ERC20TokenLimitModule is BaseModule, IExecutionHookModule { using UserOperationLib for PackedUserOperation; - using EnumerableSet for EnumerableSet.AddressSet; - using AssociatedLinkedListSetLib for AssociatedLinkedListSet; struct ERC20SpendLimit { address token; - uint256[] limits; + uint256 limit; } - string internal constant _NAME = "ERC20 Token Limit Module"; - string internal constant _VERSION = "1.0.0"; - string internal constant _AUTHOR = "ERC-6900 Authors"; + struct SpendLimit { + bool hasLimit; + uint256 limit; + } - mapping(uint32 entityId => mapping(address token => mapping(address account => uint256 limit))) public limits; - AssociatedLinkedListSet internal _tokenList; + mapping(uint32 entityId => mapping(address token => mapping(address account => SpendLimit))) public limits; + error ERC20NotAllowed(address); error ExceededTokenLimit(); - error ExceededNumberOfEntities(); + error InvalidCalldataLength(); error SelectorNotAllowed(); - - function updateLimits(uint32 entityId, address token, uint256 newLimit) external { - _tokenList.tryAdd(msg.sender, SetValue.wrap(bytes30(bytes20(token)))); - limits[entityId][token][msg.sender] = newLimit; - } + error SpendingRequestNotAllowed(bytes4); /// @inheritdoc IExecutionHookModule function preExecutionHook(uint32 entityId, address, uint256, bytes calldata data) @@ -60,54 +48,40 @@ contract ERC20TokenLimitModule is BaseModule, IExecutionHookModule { (bytes4 selector, bytes memory callData) = _getSelectorAndCalldata(data); if (selector == IModularAccount.execute.selector) { + // when calling execute or ERC20 functions directly (address token,, bytes memory innerCalldata) = abi.decode(callData, (address, uint256, bytes)); - if (_tokenList.contains(msg.sender, SetValue.wrap(bytes30(bytes20(token))))) { - _decrementLimit(entityId, token, innerCalldata); - } + _decrementLimitIfApplies(entityId, token, innerCalldata); } else if (selector == IModularAccount.executeBatch.selector) { Call[] memory calls = abi.decode(callData, (Call[])); for (uint256 i = 0; i < calls.length; i++) { - if (_tokenList.contains(msg.sender, SetValue.wrap(bytes30(bytes20(calls[i].target))))) { - _decrementLimit(entityId, calls[i].target, calls[i].data); - } + _decrementLimitIfApplies(entityId, calls[i].target, calls[i].data); } + } else { + revert SpendingRequestNotAllowed(selector); } - return ""; } /// @inheritdoc IModule + /// @param data should be encoded with the entityId of the validation and a list of ERC20 spend limits function onInstall(bytes calldata data) external override { - (uint32 startEntityId, ERC20SpendLimit[] memory spendLimits) = - abi.decode(data, (uint32, ERC20SpendLimit[])); - - if (startEntityId + spendLimits.length > type(uint32).max) { - revert ExceededNumberOfEntities(); - } + (uint32 entityId, ERC20SpendLimit[] memory spendLimits) = abi.decode(data, (uint32, ERC20SpendLimit[])); for (uint8 i = 0; i < spendLimits.length; i++) { - _tokenList.tryAdd(msg.sender, SetValue.wrap(bytes30(bytes20(spendLimits[i].token)))); - for (uint256 j = 0; j < spendLimits[i].limits.length; j++) { - limits[i + startEntityId][spendLimits[i].token][msg.sender] = spendLimits[i].limits[j]; - } + address token = spendLimits[i].token; + updateLimits(entityId, token, true, spendLimits[i].limit); } } /// @inheritdoc IModule + /// @notice uninstall this module can only clear limit for one token of one entity. To clear all limits, users + /// are recommended to use updateLimit for each token and entityId. + /// @param data should be encoded with the entityId of the validation and the token address to be uninstalled function onUninstall(bytes calldata data) external override { (address token, uint32 entityId) = abi.decode(data, (address, uint32)); delete limits[entityId][token][msg.sender]; } - function getTokensForAccount(address account) external view returns (address[] memory tokens) { - SetValue[] memory set = _tokenList.getAll(account); - tokens = new address[](set.length); - for (uint256 i = 0; i < tokens.length; i++) { - tokens[i] = address(bytes20(bytes32(SetValue.unwrap(set[i])))); - } - return tokens; - } - /// @inheritdoc IExecutionHookModule function postExecutionHook(uint32, bytes calldata) external pure override { revert NotImplemented(); @@ -118,27 +92,53 @@ contract ERC20TokenLimitModule is BaseModule, IExecutionHookModule { return "erc6900.erc20-token-limit-module.1.0.0"; } + /// @notice Update the token limit of a validation + /// @param entityId The validation entityId to update + /// @param token The token address whose limit will be updated + /// @param newLimit The new limit of the token for the validation + function updateLimits(uint32 entityId, address token, bool hasLimit, uint256 newLimit) public { + if (token == address(0)) { + revert ERC20NotAllowed(address(0)); + } + limits[entityId][token][msg.sender] = SpendLimit({hasLimit: hasLimit, limit: newLimit}); + } + /// @inheritdoc BaseModule function supportsInterface(bytes4 interfaceId) public view override(BaseModule, IERC165) returns (bool) { return interfaceId == type(IExecutionHookModule).interfaceId || super.supportsInterface(interfaceId); } - function _decrementLimit(uint32 entityId, address token, bytes memory innerCalldata) internal { + function _decrementLimitIfApplies(uint32 entityId, address token, bytes memory innerCalldata) internal { + SpendLimit storage spendLimit = limits[entityId][token][msg.sender]; + + if (!spendLimit.hasLimit) { + return; + } + + if (innerCalldata.length < 68) { + revert InvalidCalldataLength(); + } + bytes4 selector; uint256 spend; - assembly { + assembly ("memory-safe") { selector := mload(add(innerCalldata, 32)) // 0:32 is arr len, 32:36 is selector spend := mload(add(innerCalldata, 68)) // 36:68 is recipient, 68:100 is spend } - if (selector == IERC20.transfer.selector || selector == IERC20.approve.selector) { - uint256 limit = limits[entityId][token][msg.sender]; + if (_isAllowedERC20Function(selector)) { + uint256 limit = spendLimit.limit; if (spend > limit) { revert ExceededTokenLimit(); } - // solhint-disable-next-line reentrancy - limits[entityId][token][msg.sender] = limit - spend; + unchecked { + spendLimit.limit = limit - spend; + } } else { revert SelectorNotAllowed(); } } + + function _isAllowedERC20Function(bytes4 selector) internal pure returns (bool) { + return selector == IERC20.transfer.selector || selector == IERC20.approve.selector; + } } diff --git a/test/module/ERC20TokenLimitModule.t.sol b/test/module/ERC20TokenLimitModule.t.sol index 390a9829..d903cd84 100644 --- a/test/module/ERC20TokenLimitModule.t.sol +++ b/test/module/ERC20TokenLimitModule.t.sol @@ -43,7 +43,7 @@ contract ERC20TokenLimitModuleTest is AccountTestBase { limits[0] = spendLimit; ERC20TokenLimitModule.ERC20SpendLimit[] memory limit = new ERC20TokenLimitModule.ERC20SpendLimit[](1); - limit[0] = ERC20TokenLimitModule.ERC20SpendLimit({token: address(erc20), limits: limits}); + limit[0] = ERC20TokenLimitModule.ERC20SpendLimit({token: address(erc20), limit: spendLimit}); bytes[] memory hooks = new bytes[](1); hooks[0] = abi.encodePacked( @@ -82,9 +82,14 @@ contract ERC20TokenLimitModuleTest is AccountTestBase { function test_userOp_executeLimit() public { vm.startPrank(address(entryPoint)); - assertEq(module.limits(0, address(erc20), address(acct)), 10 ether); + + (, uint256 limit) = module.limits(0, address(erc20), address(acct)); + + assertEq(limit, 10 ether); acct.executeUserOp(_getPackedUO(_getExecuteWithSpend(5 ether)), bytes32(0)); - assertEq(module.limits(0, address(erc20), address(acct)), 5 ether); + + (, limit) = module.limits(0, address(erc20), address(acct)); + assertEq(limit, 5 ether); } function test_userOp_executeBatchLimit() public { @@ -100,9 +105,12 @@ contract ERC20TokenLimitModuleTest is AccountTestBase { }); vm.startPrank(address(entryPoint)); - assertEq(module.limits(0, address(erc20), address(acct)), 10 ether); + (, uint256 limit) = module.limits(0, address(erc20), address(acct)); + assertEq(limit, 10 ether); acct.executeUserOp(_getPackedUO(abi.encodeCall(IModularAccount.executeBatch, (calls))), bytes32(0)); - assertEq(module.limits(0, address(erc20), address(acct)), 10 ether - 6 ether - 100_001); + + (, limit) = module.limits(0, address(erc20), address(acct)); + assertEq(limit, 10 ether - 6 ether - 100_001); } function test_userOp_executeBatch_approveAndTransferLimit() public { @@ -118,9 +126,12 @@ contract ERC20TokenLimitModuleTest is AccountTestBase { }); vm.startPrank(address(entryPoint)); - assertEq(module.limits(0, address(erc20), address(acct)), 10 ether); + (, uint256 limit) = module.limits(0, address(erc20), address(acct)); + assertEq(limit, 10 ether); acct.executeUserOp(_getPackedUO(abi.encodeCall(IModularAccount.executeBatch, (calls))), bytes32(0)); - assertEq(module.limits(0, address(erc20), address(acct)), 10 ether - 6 ether - 100_001); + + (, limit) = module.limits(0, address(erc20), address(acct)); + assertEq(limit, 10 ether - 6 ether - 100_001); } function test_userOp_executeBatch_approveAndTransferLimit_fail() public { @@ -136,21 +147,27 @@ contract ERC20TokenLimitModuleTest is AccountTestBase { }); vm.startPrank(address(entryPoint)); - assertEq(module.limits(0, address(erc20), address(acct)), 10 ether); + (, uint256 limit) = module.limits(0, address(erc20), address(acct)); + assertEq(limit, 10 ether); PackedUserOperation[] memory uos = new PackedUserOperation[](1); uos[0] = _getPackedUO(abi.encodeCall(IModularAccount.executeBatch, (calls))); entryPoint.handleOps(uos, bundler); // no spend consumed - assertEq(module.limits(0, address(erc20), address(acct)), 10 ether); + + (, limit) = module.limits(0, address(erc20), address(acct)); + assertEq(limit, 10 ether); } function test_runtime_executeLimit() public { - assertEq(module.limits(0, address(erc20), address(acct)), 10 ether); + (, uint256 limit) = module.limits(0, address(erc20), address(acct)); + assertEq(limit, 10 ether); acct.executeWithRuntimeValidation( _getExecuteWithSpend(5 ether), _encodeSignature(ModuleEntityLib.pack(address(validationModule), 0), 1, "") ); - assertEq(module.limits(0, address(erc20), address(acct)), 5 ether); + + (, limit) = module.limits(0, address(erc20), address(acct)); + assertEq(limit, 5 ether); } function test_runtime_executeBatchLimit() public { @@ -165,11 +182,14 @@ contract ERC20TokenLimitModuleTest is AccountTestBase { data: abi.encodeCall(IERC20.approve, (recipient, 5 ether + 100_000)) }); - assertEq(module.limits(0, address(erc20), address(acct)), 10 ether); + (, uint256 limit) = module.limits(0, address(erc20), address(acct)); + assertEq(limit, 10 ether); acct.executeWithRuntimeValidation( abi.encodeCall(IModularAccount.executeBatch, (calls)), _encodeSignature(ModuleEntityLib.pack(address(validationModule), 0), 1, "") ); - assertEq(module.limits(0, address(erc20), address(acct)), 10 ether - 6 ether - 100_001); + + (, limit) = module.limits(0, address(erc20), address(acct)); + assertEq(limit, 10 ether - 6 ether - 100_001); } }