Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
[submodule "lib/forge-std"]
path = lib/forge-std
url = https://github.com/foundry-rs/forge-std
[submodule "lib/modular-account-libs"]
path = lib/modular-account-libs
url = https://github.com/erc6900/modular-account-libs
[submodule "lib/solady"]
path = lib/solady
url = https://github.com/vectorized/solady
1 change: 0 additions & 1 deletion lib/modular-account-libs
Submodule modular-account-libs deleted from 5d9d0e
1 change: 0 additions & 1 deletion remappings.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,4 @@ ds-test/=lib/forge-std/lib/ds-test/src/
forge-std/=lib/forge-std/src/
@eth-infinitism/account-abstraction/=lib/account-abstraction/contracts/
@openzeppelin/=lib/openzeppelin-contracts/
@modular-account-libs/=lib/modular-account-libs/src/
solady=lib/solady/src/
106 changes: 53 additions & 53 deletions src/modules/ERC20TokenLimitModule.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,7 @@ pragma solidity ^0.8.20;

import {UserOperationLib} from "@eth-infinitism/account-abstraction/core/UserOperationLib.sol";
import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol";

import {
AssociatedLinkedListSet,
AssociatedLinkedListSetLib,
SetValue
} from "@modular-account-libs/libraries/AssociatedLinkedListSetLib.sol";
import {IERC20} from "@openzeppelin/contracts/token/ERC20/IERC20.sol";
import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol";

import {IExecutionHookModule} from "../interfaces/IExecutionHookModule.sol";
import {Call, IModularAccount} from "../interfaces/IModularAccount.sol";
Expand All @@ -27,29 +20,24 @@ import {BaseModule, IERC165} from "./BaseModule.sol";
/// token contract
contract ERC20TokenLimitModule is BaseModule, IExecutionHookModule {
using UserOperationLib for PackedUserOperation;
using EnumerableSet for EnumerableSet.AddressSet;
using AssociatedLinkedListSetLib for AssociatedLinkedListSet;

struct ERC20SpendLimit {
address token;
uint256[] limits;
uint256 limit;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this struct only used in installation?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's right.

}

string internal constant _NAME = "ERC20 Token Limit Module";
string internal constant _VERSION = "1.0.0";
string internal constant _AUTHOR = "ERC-6900 Authors";
struct SpendLimit {
bool hasLimit;
uint256 limit;
}

mapping(uint32 entityId => mapping(address token => mapping(address account => uint256 limit))) public limits;
AssociatedLinkedListSet internal _tokenList;
mapping(uint32 entityId => mapping(address token => mapping(address account => SpendLimit))) public limits;

error ERC20NotAllowed(address);
error ExceededTokenLimit();
error ExceededNumberOfEntities();
error InvalidCalldataLength();
error SelectorNotAllowed();

function updateLimits(uint32 entityId, address token, uint256 newLimit) external {
_tokenList.tryAdd(msg.sender, SetValue.wrap(bytes30(bytes20(token))));
limits[entityId][token][msg.sender] = newLimit;
}
error SpendingRequestNotAllowed(bytes4);

/// @inheritdoc IExecutionHookModule
function preExecutionHook(uint32 entityId, address, uint256, bytes calldata data)
Expand All @@ -60,54 +48,40 @@ contract ERC20TokenLimitModule is BaseModule, IExecutionHookModule {
(bytes4 selector, bytes memory callData) = _getSelectorAndCalldata(data);

if (selector == IModularAccount.execute.selector) {
// when calling execute or ERC20 functions directly
(address token,, bytes memory innerCalldata) = abi.decode(callData, (address, uint256, bytes));
if (_tokenList.contains(msg.sender, SetValue.wrap(bytes30(bytes20(token))))) {
_decrementLimit(entityId, token, innerCalldata);
}
_decrementLimitIfApplies(entityId, token, innerCalldata);
} else if (selector == IModularAccount.executeBatch.selector) {
Call[] memory calls = abi.decode(callData, (Call[]));
for (uint256 i = 0; i < calls.length; i++) {
if (_tokenList.contains(msg.sender, SetValue.wrap(bytes30(bytes20(calls[i].target))))) {
_decrementLimit(entityId, calls[i].target, calls[i].data);
}
_decrementLimitIfApplies(entityId, calls[i].target, calls[i].data);
}
} else {
revert SpendingRequestNotAllowed(selector);
}

return "";
}

/// @inheritdoc IModule
/// @param data should be encoded with the entityId of the validation and a list of ERC20 spend limits
function onInstall(bytes calldata data) external override {
(uint32 startEntityId, ERC20SpendLimit[] memory spendLimits) =
abi.decode(data, (uint32, ERC20SpendLimit[]));

if (startEntityId + spendLimits.length > type(uint32).max) {
revert ExceededNumberOfEntities();
}
(uint32 entityId, ERC20SpendLimit[] memory spendLimits) = abi.decode(data, (uint32, ERC20SpendLimit[]));

for (uint8 i = 0; i < spendLimits.length; i++) {
_tokenList.tryAdd(msg.sender, SetValue.wrap(bytes30(bytes20(spendLimits[i].token))));
for (uint256 j = 0; j < spendLimits[i].limits.length; j++) {
limits[i + startEntityId][spendLimits[i].token][msg.sender] = spendLimits[i].limits[j];
}
address token = spendLimits[i].token;
updateLimits(entityId, token, true, spendLimits[i].limit);
}
}

/// @inheritdoc IModule
/// @notice uninstall this module can only clear limit for one token of one entity. To clear all limits, users
/// are recommended to use updateLimit for each token and entityId.
/// @param data should be encoded with the entityId of the validation and the token address to be uninstalled
function onUninstall(bytes calldata data) external override {
(address token, uint32 entityId) = abi.decode(data, (address, uint32));
delete limits[entityId][token][msg.sender];
}

function getTokensForAccount(address account) external view returns (address[] memory tokens) {
SetValue[] memory set = _tokenList.getAll(account);
tokens = new address[](set.length);
for (uint256 i = 0; i < tokens.length; i++) {
tokens[i] = address(bytes20(bytes32(SetValue.unwrap(set[i]))));
}
return tokens;
}

/// @inheritdoc IExecutionHookModule
function postExecutionHook(uint32, bytes calldata) external pure override {
revert NotImplemented();
Expand All @@ -118,27 +92,53 @@ contract ERC20TokenLimitModule is BaseModule, IExecutionHookModule {
return "erc6900.erc20-token-limit-module.1.0.0";
}

/// @notice Update the token limit of a validation
/// @param entityId The validation entityId to update
/// @param token The token address whose limit will be updated
/// @param newLimit The new limit of the token for the validation
function updateLimits(uint32 entityId, address token, bool hasLimit, uint256 newLimit) public {
if (token == address(0)) {
revert ERC20NotAllowed(address(0));
}
limits[entityId][token][msg.sender] = SpendLimit({hasLimit: hasLimit, limit: newLimit});
}

/// @inheritdoc BaseModule
function supportsInterface(bytes4 interfaceId) public view override(BaseModule, IERC165) returns (bool) {
return interfaceId == type(IExecutionHookModule).interfaceId || super.supportsInterface(interfaceId);
}

function _decrementLimit(uint32 entityId, address token, bytes memory innerCalldata) internal {
function _decrementLimitIfApplies(uint32 entityId, address token, bytes memory innerCalldata) internal {
SpendLimit storage spendLimit = limits[entityId][token][msg.sender];

if (!spendLimit.hasLimit) {
return;
}

if (innerCalldata.length < 68) {
revert InvalidCalldataLength();
}

bytes4 selector;
uint256 spend;
assembly {
assembly ("memory-safe") {
selector := mload(add(innerCalldata, 32)) // 0:32 is arr len, 32:36 is selector
spend := mload(add(innerCalldata, 68)) // 36:68 is recipient, 68:100 is spend
}
if (selector == IERC20.transfer.selector || selector == IERC20.approve.selector) {
uint256 limit = limits[entityId][token][msg.sender];
if (_isAllowedERC20Function(selector)) {
uint256 limit = spendLimit.limit;
if (spend > limit) {
revert ExceededTokenLimit();
}
// solhint-disable-next-line reentrancy
limits[entityId][token][msg.sender] = limit - spend;
unchecked {
spendLimit.limit = limit - spend;
}
} else {
revert SelectorNotAllowed();
}
}

function _isAllowedERC20Function(bytes4 selector) internal pure returns (bool) {
return selector == IERC20.transfer.selector || selector == IERC20.approve.selector;
}
}
46 changes: 33 additions & 13 deletions test/module/ERC20TokenLimitModule.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ contract ERC20TokenLimitModuleTest is AccountTestBase {
limits[0] = spendLimit;

ERC20TokenLimitModule.ERC20SpendLimit[] memory limit = new ERC20TokenLimitModule.ERC20SpendLimit[](1);
limit[0] = ERC20TokenLimitModule.ERC20SpendLimit({token: address(erc20), limits: limits});
limit[0] = ERC20TokenLimitModule.ERC20SpendLimit({token: address(erc20), limit: spendLimit});

bytes[] memory hooks = new bytes[](1);
hooks[0] = abi.encodePacked(
Expand Down Expand Up @@ -82,9 +82,14 @@ contract ERC20TokenLimitModuleTest is AccountTestBase {

function test_userOp_executeLimit() public {
vm.startPrank(address(entryPoint));
assertEq(module.limits(0, address(erc20), address(acct)), 10 ether);

(, uint256 limit) = module.limits(0, address(erc20), address(acct));

assertEq(limit, 10 ether);
acct.executeUserOp(_getPackedUO(_getExecuteWithSpend(5 ether)), bytes32(0));
assertEq(module.limits(0, address(erc20), address(acct)), 5 ether);

(, limit) = module.limits(0, address(erc20), address(acct));
assertEq(limit, 5 ether);
}

function test_userOp_executeBatchLimit() public {
Expand All @@ -100,9 +105,12 @@ contract ERC20TokenLimitModuleTest is AccountTestBase {
});

vm.startPrank(address(entryPoint));
assertEq(module.limits(0, address(erc20), address(acct)), 10 ether);
(, uint256 limit) = module.limits(0, address(erc20), address(acct));
assertEq(limit, 10 ether);
acct.executeUserOp(_getPackedUO(abi.encodeCall(IModularAccount.executeBatch, (calls))), bytes32(0));
assertEq(module.limits(0, address(erc20), address(acct)), 10 ether - 6 ether - 100_001);

(, limit) = module.limits(0, address(erc20), address(acct));
assertEq(limit, 10 ether - 6 ether - 100_001);
}

function test_userOp_executeBatch_approveAndTransferLimit() public {
Expand All @@ -118,9 +126,12 @@ contract ERC20TokenLimitModuleTest is AccountTestBase {
});

vm.startPrank(address(entryPoint));
assertEq(module.limits(0, address(erc20), address(acct)), 10 ether);
(, uint256 limit) = module.limits(0, address(erc20), address(acct));
assertEq(limit, 10 ether);
acct.executeUserOp(_getPackedUO(abi.encodeCall(IModularAccount.executeBatch, (calls))), bytes32(0));
assertEq(module.limits(0, address(erc20), address(acct)), 10 ether - 6 ether - 100_001);

(, limit) = module.limits(0, address(erc20), address(acct));
assertEq(limit, 10 ether - 6 ether - 100_001);
}

function test_userOp_executeBatch_approveAndTransferLimit_fail() public {
Expand All @@ -136,21 +147,27 @@ contract ERC20TokenLimitModuleTest is AccountTestBase {
});

vm.startPrank(address(entryPoint));
assertEq(module.limits(0, address(erc20), address(acct)), 10 ether);
(, uint256 limit) = module.limits(0, address(erc20), address(acct));
assertEq(limit, 10 ether);
PackedUserOperation[] memory uos = new PackedUserOperation[](1);
uos[0] = _getPackedUO(abi.encodeCall(IModularAccount.executeBatch, (calls)));
entryPoint.handleOps(uos, bundler);
// no spend consumed
assertEq(module.limits(0, address(erc20), address(acct)), 10 ether);

(, limit) = module.limits(0, address(erc20), address(acct));
assertEq(limit, 10 ether);
}

function test_runtime_executeLimit() public {
assertEq(module.limits(0, address(erc20), address(acct)), 10 ether);
(, uint256 limit) = module.limits(0, address(erc20), address(acct));
assertEq(limit, 10 ether);
acct.executeWithRuntimeValidation(
_getExecuteWithSpend(5 ether),
_encodeSignature(ModuleEntityLib.pack(address(validationModule), 0), 1, "")
);
assertEq(module.limits(0, address(erc20), address(acct)), 5 ether);

(, limit) = module.limits(0, address(erc20), address(acct));
assertEq(limit, 5 ether);
}

function test_runtime_executeBatchLimit() public {
Expand All @@ -165,11 +182,14 @@ contract ERC20TokenLimitModuleTest is AccountTestBase {
data: abi.encodeCall(IERC20.approve, (recipient, 5 ether + 100_000))
});

assertEq(module.limits(0, address(erc20), address(acct)), 10 ether);
(, uint256 limit) = module.limits(0, address(erc20), address(acct));
assertEq(limit, 10 ether);
acct.executeWithRuntimeValidation(
abi.encodeCall(IModularAccount.executeBatch, (calls)),
_encodeSignature(ModuleEntityLib.pack(address(validationModule), 0), 1, "")
);
assertEq(module.limits(0, address(erc20), address(acct)), 10 ether - 6 ether - 100_001);

(, limit) = module.limits(0, address(erc20), address(acct));
assertEq(limit, 10 ether - 6 ether - 100_001);
}
}
Loading