diff --git a/src/samples/plugins/ModularSessionKeyPlugin.sol b/src/samples/plugins/ModularSessionKeyPlugin.sol new file mode 100644 index 00000000..e4a4d02b --- /dev/null +++ b/src/samples/plugins/ModularSessionKeyPlugin.sol @@ -0,0 +1,391 @@ +// SPDX-License-Identifier: GPL-3.0 +pragma solidity ^0.8.19; + +import {ECDSA} from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; +import {EnumerableSet} from "@openzeppelin/contracts/utils/structs/EnumerableSet.sol"; +import {UserOperation} from "@eth-infinitism/account-abstraction/interfaces/UserOperation.sol"; +import {UpgradeableModularAccount} from "../../account/UpgradeableModularAccount.sol"; +import { + ManifestFunction, + ManifestAssociatedFunctionType, + ManifestAssociatedFunction, + PluginManifest, + PluginMetadata, + SelectorPermission +} from "../../interfaces/IPlugin.sol"; +import {BasePlugin} from "../../plugins/BasePlugin.sol"; +import {IModularSessionKeyPlugin} from "./interfaces/ISessionKeyPlugin.sol"; +import {ISingleOwnerPlugin} from "../../plugins/owner/ISingleOwnerPlugin.sol"; +import {SingleOwnerPlugin} from "../../plugins/owner/SingleOwnerPlugin.sol"; +import {PluginStorageLib, StoragePointer} from "../../libraries/PluginStorageLib.sol"; + +/// @title Modular Session Key Plugin +/// @author Decipher ERC-6900 Team +/// @notice This plugin allows some designated EOA or smart contract to temporarily +/// own a modular account. Note that this plugin is ONLY for demonstrating the purpose +/// of the functionalities of ERC-6900, and MUST not be used at the production level. +/// This modular session key plugin acts as a 'parent plugin' for all specific session +/// keys. Using dependency, this plugin can be thought as a parent contract that stores +/// session key duration information, and validation functions for session keys. All +/// logics for session keys will be implemented in child plugins. +/// It allows for session key owners to access MSCA both through user operation and +/// runtime, with its own validation functions. +/// Also, it has a dependency on SingleOwnerPlugin, to make sure that only the owner of +/// the MSCA can add or remove session keys. +contract ModularSessionKeyPlugin is BasePlugin, IModularSessionKeyPlugin { + using ECDSA for bytes32; + using PluginStorageLib for address; + using PluginStorageLib for bytes; + using EnumerableSet for EnumerableSet.Bytes32Set; + + string public constant NAME = "Modular Session Key Plugin"; + string public constant VERSION = "1.0.0"; + string public constant AUTHOR = "Decipher ERC-6900 Team"; + + uint256 internal constant _SIG_VALIDATION_FAILED = 1; + + mapping(address account => EnumerableSet.Bytes32Set) private _sessionKeySet; + + struct SessionInfo { + uint48 validAfter; + uint48 validUntil; + } + + // ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + // ┃ Execution functions ┃ + // ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ + + /// @inheritdoc IModularSessionKeyPlugin + function addSessionKey(address sessionKey, bytes4 allowedSelector, uint48 validAfter, uint48 validUntil) + external + { + _addSessionKey(msg.sender, sessionKey, allowedSelector, validAfter, validUntil); + emit SessionKeyAdded(msg.sender, sessionKey, allowedSelector, validAfter, validUntil); + } + + /// @inheritdoc IModularSessionKeyPlugin + function removeSessionKey(address sessionKey, bytes4 allowedSelector) external { + _removeSessionKey(msg.sender, sessionKey, allowedSelector); + emit SessionKeyRemoved(msg.sender, sessionKey, allowedSelector); + } + + /// @inheritdoc IModularSessionKeyPlugin + function addSessionKeyBatch( + address[] calldata sessionKeys, + bytes4[] calldata allowedSelectors, + uint48[] calldata validAfters, + uint48[] calldata validUntils + ) external { + if ( + sessionKeys.length != allowedSelectors.length || sessionKeys.length != validAfters.length + || sessionKeys.length != validUntils.length + ) { + revert WrongDataLength(); + } + for (uint256 i = 0; i < sessionKeys.length;) { + _addSessionKey(msg.sender, sessionKeys[i], allowedSelectors[i], validAfters[i], validUntils[i]); + + unchecked { + ++i; + } + } + emit SessionKeysAdded(msg.sender, sessionKeys, allowedSelectors, validAfters, validUntils); + } + + function removeSessionKeyBatch(address[] calldata sessionKeys, bytes4[] calldata allowedSelectors) external { + if (sessionKeys.length != allowedSelectors.length) { + revert WrongDataLength(); + } + for (uint256 i = 0; i < sessionKeys.length;) { + _removeSessionKey(msg.sender, sessionKeys[i], allowedSelectors[i]); + + unchecked { + ++i; + } + } + emit SessionKeysRemoved(msg.sender, sessionKeys, allowedSelectors); + } + + // ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + // ┃ Plugin view functions ┃ + // ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ + + /// @inheritdoc IModularSessionKeyPlugin + function getSessionDuration(address account, address sessionKey, bytes4 allowedSelector) + external + view + returns (uint48 validAfter, uint48 validUntil) + { + bytes memory key = account.allocateAssociatedStorageKey(0, 1); + StoragePointer ptr = key.associatedStorageLookup(keccak256(abi.encodePacked(sessionKey, allowedSelector))); + SessionInfo storage sessionInfo = _castPtrToStruct(ptr); + validAfter = sessionInfo.validAfter; + validUntil = sessionInfo.validUntil; + } + + /// @inheritdoc IModularSessionKeyPlugin + function getSessionKeysAndSelectors(address account) + external + view + returns (address[] memory sessionKeys, bytes4[] memory selectors) + { + EnumerableSet.Bytes32Set storage sessionKeySet = _sessionKeySet[account]; + uint256 length = sessionKeySet.length(); + sessionKeys = new address[](length); + selectors = new bytes4[](length); + for (uint256 i = 0; i < length;) { + (sessionKeys[i], selectors[i]) = _castToAddressAndBytes4(sessionKeySet.at(i)); + + unchecked { + ++i; + } + } + } + + // ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + // ┃ Plugin interface functions ┃ + // ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ + + /// @inheritdoc BasePlugin + function onInstall(bytes calldata data) external override { + if (data.length != 0) { + ( + address[] memory sessionKeys, + bytes4[] memory allowedSelectors, + uint48[] memory validAfters, + uint48[] memory validUntils + ) = abi.decode(data, (address[], bytes4[], uint48[], uint48[])); + if ( + sessionKeys.length != allowedSelectors.length || sessionKeys.length != validAfters.length + || sessionKeys.length != validUntils.length + ) { + revert WrongDataLength(); + } + for (uint256 i = 0; i < sessionKeys.length;) { + _addSessionKey(msg.sender, sessionKeys[i], allowedSelectors[i], validAfters[i], validUntils[i]); + + unchecked { + ++i; + } + } + } + } + + /// @inheritdoc BasePlugin + function onUninstall(bytes calldata) external override { + EnumerableSet.Bytes32Set storage sessionKeySet = _sessionKeySet[msg.sender]; + uint256 length = sessionKeySet.length(); + for (uint256 i = 0; i < length;) { + (address sessionKey, bytes4 allowedSelecor) = _castToAddressAndBytes4(sessionKeySet.at(i)); + _removeSessionKey(msg.sender, sessionKey, allowedSelecor); + + unchecked { + ++i; + } + } + } + + /// @inheritdoc BasePlugin + function userOpValidationFunction(uint8 functionId, UserOperation calldata userOp, bytes32 userOpHash) + external + view + override + returns (uint256) + { + if (functionId == uint8(FunctionId.USER_OP_VALIDATION_TEMPORARY_OWNER)) { + (address signer, ECDSA.RecoverError err) = + userOpHash.toEthSignedMessageHash().tryRecover(userOp.signature); + if (err != ECDSA.RecoverError.NoError) { + revert InvalidSignature(); + } + bytes4 selector = bytes4(userOp.callData[0:4]); + bytes memory key = msg.sender.allocateAssociatedStorageKey(0, 1); + StoragePointer ptr = key.associatedStorageLookup(keccak256(abi.encodePacked(signer, selector))); + SessionInfo storage duration = _castPtrToStruct(ptr); + uint48 validAfter = duration.validAfter; + uint48 validUntil = duration.validUntil; + + return _packValidationData(validUntil == 0, validUntil, validAfter); + } + revert NotImplemented(); + } + + /// @inheritdoc BasePlugin + function runtimeValidationFunction(uint8 functionId, address sender, uint256, bytes calldata data) + external + view + override + { + if (functionId == uint8(FunctionId.RUNTIME_VALIDATION_TEMPORARY_OWNER)) { + bytes4 selector = bytes4(data[0:4]); + bytes memory key = msg.sender.allocateAssociatedStorageKey(0, 1); + StoragePointer ptr = key.associatedStorageLookup(keccak256(abi.encodePacked(sender, selector))); + SessionInfo storage duration = _castPtrToStruct(ptr); + uint48 validAfter = duration.validAfter; + uint48 validUntil = duration.validUntil; + + if (validUntil != 0) { + if (block.timestamp < validAfter || block.timestamp > validUntil) { + revert WrongTimeRangeForSession(); + } + return; + } + revert NotAuthorized(); + } + revert NotImplemented(); + } + + /// @inheritdoc BasePlugin + function pluginManifest() external pure override returns (PluginManifest memory) { + PluginManifest memory manifest; + + manifest.executionFunctions = new bytes4[](4); + manifest.executionFunctions[0] = this.addSessionKey.selector; + manifest.executionFunctions[1] = this.removeSessionKey.selector; + manifest.executionFunctions[2] = this.addSessionKeyBatch.selector; + manifest.executionFunctions[3] = this.removeSessionKeyBatch.selector; + + ManifestFunction memory ownerUserOpValidationFunction = ManifestFunction({ + functionType: ManifestAssociatedFunctionType.DEPENDENCY, + functionId: 0, // Unused. + dependencyIndex: 0 // Used as first index. + }); + manifest.userOpValidationFunctions = new ManifestAssociatedFunction[](4); + manifest.userOpValidationFunctions[0] = ManifestAssociatedFunction({ + executionSelector: this.addSessionKey.selector, + associatedFunction: ownerUserOpValidationFunction + }); + manifest.userOpValidationFunctions[1] = ManifestAssociatedFunction({ + executionSelector: this.removeSessionKey.selector, + associatedFunction: ownerUserOpValidationFunction + }); + manifest.userOpValidationFunctions[2] = ManifestAssociatedFunction({ + executionSelector: this.addSessionKeyBatch.selector, + associatedFunction: ownerUserOpValidationFunction + }); + manifest.userOpValidationFunctions[3] = ManifestAssociatedFunction({ + executionSelector: this.removeSessionKeyBatch.selector, + associatedFunction: ownerUserOpValidationFunction + }); + + ManifestFunction memory ownerOrSelfRuntimeValidationFunction = ManifestFunction({ + functionType: ManifestAssociatedFunctionType.DEPENDENCY, + functionId: 0, // Unused. + dependencyIndex: 1 + }); + ManifestFunction memory alwaysAllowFunction = ManifestFunction({ + functionType: ManifestAssociatedFunctionType.RUNTIME_VALIDATION_ALWAYS_ALLOW, + functionId: 0, // Unused. + dependencyIndex: 0 // Unused. + }); + + manifest.runtimeValidationFunctions = new ManifestAssociatedFunction[](5); + manifest.runtimeValidationFunctions[0] = ManifestAssociatedFunction({ + executionSelector: this.addSessionKey.selector, + associatedFunction: ownerOrSelfRuntimeValidationFunction + }); + manifest.runtimeValidationFunctions[1] = ManifestAssociatedFunction({ + executionSelector: this.removeSessionKey.selector, + associatedFunction: ownerOrSelfRuntimeValidationFunction + }); + manifest.runtimeValidationFunctions[2] = ManifestAssociatedFunction({ + executionSelector: this.addSessionKeyBatch.selector, + associatedFunction: ownerOrSelfRuntimeValidationFunction + }); + manifest.runtimeValidationFunctions[3] = ManifestAssociatedFunction({ + executionSelector: this.removeSessionKeyBatch.selector, + associatedFunction: ownerOrSelfRuntimeValidationFunction + }); + manifest.runtimeValidationFunctions[4] = ManifestAssociatedFunction({ + executionSelector: this.getSessionDuration.selector, + associatedFunction: alwaysAllowFunction + }); + + manifest.dependencyInterfaceIds = new bytes4[](2); + manifest.dependencyInterfaceIds[0] = type(ISingleOwnerPlugin).interfaceId; + manifest.dependencyInterfaceIds[1] = type(ISingleOwnerPlugin).interfaceId; + + return manifest; + } + + /// @inheritdoc BasePlugin + function pluginMetadata() external pure virtual override returns (PluginMetadata memory) { + PluginMetadata memory metadata; + metadata.name = NAME; + metadata.version = VERSION; + metadata.author = AUTHOR; + + return metadata; + } + + // ┏━━━━━━━━━━━━━━━┓ + // ┃ EIP-165 ┃ + // ┗━━━━━━━━━━━━━━━┛ + + /// @inheritdoc BasePlugin + function supportsInterface(bytes4 interfaceId) public view override returns (bool) { + return interfaceId == type(IModularSessionKeyPlugin).interfaceId || super.supportsInterface(interfaceId); + } + + // ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + // ┃ Internal / Private functions ┃ + // ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ + + function _addSessionKey( + address account, + address sessionKey, + bytes4 allowedSelector, + uint48 validAfter, + uint48 validUntil + ) internal { + if (validUntil <= validAfter) { + revert WrongTimeRangeForSession(); + } + bytes memory key = account.allocateAssociatedStorageKey(0, 1); + StoragePointer ptr = key.associatedStorageLookup(keccak256(abi.encodePacked(sessionKey, allowedSelector))); + SessionInfo storage sessionInfo = _castPtrToStruct(ptr); + sessionInfo.validAfter = validAfter; + sessionInfo.validUntil = validUntil; + + EnumerableSet.Bytes32Set storage sessionKeySet = _sessionKeySet[account]; + sessionKeySet.add(_castToBytes32(sessionKey, allowedSelector)); + } + + function _removeSessionKey(address account, address sessionKey, bytes4 allowedSelector) internal { + bytes memory key = account.allocateAssociatedStorageKey(0, 1); + StoragePointer ptr = key.associatedStorageLookup(keccak256(abi.encodePacked(sessionKey, allowedSelector))); + SessionInfo storage sessionInfo = _castPtrToStruct(ptr); + sessionInfo.validAfter = 0; + sessionInfo.validUntil = 0; + + EnumerableSet.Bytes32Set storage sessionKeySet = _sessionKeySet[account]; + sessionKeySet.remove(_castToBytes32(sessionKey, allowedSelector)); + } + + function _castPtrToStruct(StoragePointer ptr) internal pure returns (SessionInfo storage val) { + assembly ("memory-safe") { + val.slot := ptr + } + } + + function _castToBytes32(address addr, bytes4 b4) internal pure returns (bytes32 res) { + assembly { + res := or(shl(32, addr), b4) + } + } + + function _castToAddressAndBytes4(bytes32 b32) internal pure returns (address addr, bytes4 b4) { + assembly { + addr := shr(32, b32) + b4 := and(b32, 0xFFFFFFFF) + } + } + + function _packValidationData(bool sigFailed, uint48 validUntil, uint48 validAfter) + internal + pure + returns (uint256) + { + return (sigFailed ? 1 : 0) | (uint256(validUntil) << 160) | (uint256(validAfter) << (160 + 48)); + } +} diff --git a/src/samples/plugins/TokenSessionKeyPlugin.sol b/src/samples/plugins/TokenSessionKeyPlugin.sol new file mode 100644 index 00000000..9e2e8abd --- /dev/null +++ b/src/samples/plugins/TokenSessionKeyPlugin.sol @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: GPL-3.0 +pragma solidity ^0.8.19; + +import { + ManifestFunction, + ManifestAssociatedFunctionType, + ManifestAssociatedFunction, + PluginManifest, + PluginMetadata, + SelectorPermission, + ManifestExternalCallPermission +} from "../../interfaces/IPlugin.sol"; +import {BasePlugin} from "../../plugins/BasePlugin.sol"; +import {ModularSessionKeyPlugin} from "./ModularSessionKeyPlugin.sol"; +import {ITokenSessionKeyPlugin} from "./interfaces/ITokenSessionKeyPlugin.sol"; +import {IModularSessionKeyPlugin} from "./interfaces/ISessionKeyPlugin.sol"; +import {IPluginExecutor} from "../../interfaces/IPluginExecutor.sol"; + +/// @title Token Session Key Plugin +/// @author Decipher ERC-6900 Team +/// @notice This plugin acts as a 'child plugin' for ModularSessionKeyPlugin. +/// It implements the logic for session keys that are allowed to call ERC20 +/// transferFrom function. It allows for session key owners to access MSCA +/// with `transferFromSessionKey` function, which calls `executeFromPluginExternal` +/// function in PluginExecutor contract. +/// The target ERC20 contract and the selector for transferFrom function are hardcoded +/// in this plugin, since the pluginManifest function requires the information of +/// permitted external calls not to be changed in the future. For other child session +/// key plugins, there can be a set of permitted external calls according to the +/// specific needs. +contract TokenSessionKeyPlugin is BasePlugin, ITokenSessionKeyPlugin { + string public constant NAME = "Token Session Key Plugin"; + string public constant VERSION = "1.0.0"; + string public constant AUTHOR = "Decipher ERC-6900 Team"; + + // Mock address of target ERC20 contract + address public constant TARGET_ERC20_CONTRACT = 0xdeaDDeADDEaDdeaDdEAddEADDEAdDeadDEADDEaD; + bytes4 public constant TRANSFERFROM_SELECTOR = + bytes4(keccak256(bytes("transferFrom(address,address,uint256)"))); + + // ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + // ┃ Execution functions ┃ + // ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ + + /// @inheritdoc ITokenSessionKeyPlugin + function transferFromSessionKey(address target, address from, address to, uint256 amount) + external + returns (bytes memory returnData) + { + bytes memory data = abi.encodeWithSelector(TRANSFERFROM_SELECTOR, from, to, amount); + returnData = IPluginExecutor(msg.sender).executeFromPluginExternal(target, 0, data); + } + + // ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + // ┃ Plugin interface functions ┃ + // ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ + + /// @inheritdoc BasePlugin + function onInstall(bytes calldata data) external override {} + + /// @inheritdoc BasePlugin + function onUninstall(bytes calldata data) external override {} + + /// @inheritdoc BasePlugin + function pluginManifest() external pure override returns (PluginManifest memory) { + PluginManifest memory manifest; + + manifest.executionFunctions = new bytes4[](1); + manifest.executionFunctions[0] = this.transferFromSessionKey.selector; + + ManifestFunction memory tempOwnerUserOpValidationFunction = ManifestFunction({ + functionType: ManifestAssociatedFunctionType.DEPENDENCY, + functionId: 0, // Unused + dependencyIndex: 0 // Used as first index + }); + ManifestFunction memory tempOwnerRuntimeValidationFunction = ManifestFunction({ + functionType: ManifestAssociatedFunctionType.DEPENDENCY, + functionId: 0, // Unused + dependencyIndex: 1 // Used as second index + }); + + manifest.userOpValidationFunctions = new ManifestAssociatedFunction[](1); + manifest.userOpValidationFunctions[0] = ManifestAssociatedFunction({ + executionSelector: this.transferFromSessionKey.selector, + associatedFunction: tempOwnerUserOpValidationFunction + }); + + manifest.runtimeValidationFunctions = new ManifestAssociatedFunction[](1); + manifest.runtimeValidationFunctions[0] = ManifestAssociatedFunction({ + executionSelector: this.transferFromSessionKey.selector, + associatedFunction: tempOwnerRuntimeValidationFunction + }); + + manifest.dependencyInterfaceIds = new bytes4[](2); + manifest.dependencyInterfaceIds[0] = type(IModularSessionKeyPlugin).interfaceId; + manifest.dependencyInterfaceIds[1] = type(IModularSessionKeyPlugin).interfaceId; + + bytes4[] memory permittedExternalSelectors = new bytes4[](1); + permittedExternalSelectors[0] = TRANSFERFROM_SELECTOR; + + manifest.permittedExternalCalls = new ManifestExternalCallPermission[](1); + manifest.permittedExternalCalls[0] = ManifestExternalCallPermission({ + externalAddress: TARGET_ERC20_CONTRACT, + permitAnySelector: false, + selectors: permittedExternalSelectors + }); + + return manifest; + } + + /// @inheritdoc BasePlugin + function pluginMetadata() external pure virtual override returns (PluginMetadata memory) { + PluginMetadata memory metadata; + metadata.name = NAME; + metadata.version = VERSION; + metadata.author = AUTHOR; + + return metadata; + } + + // ┏━━━━━━━━━━━━━━━┓ + // ┃ EIP-165 ┃ + // ┗━━━━━━━━━━━━━━━┛ + + /// @inheritdoc BasePlugin + function supportsInterface(bytes4 interfaceId) public view override returns (bool) { + return interfaceId == type(ITokenSessionKeyPlugin).interfaceId || super.supportsInterface(interfaceId); + } +} diff --git a/src/samples/plugins/interfaces/ISessionKeyPlugin.sol b/src/samples/plugins/interfaces/ISessionKeyPlugin.sol new file mode 100644 index 00000000..88d31f1a --- /dev/null +++ b/src/samples/plugins/interfaces/ISessionKeyPlugin.sol @@ -0,0 +1,109 @@ +// SPDX-License-Identifier: GPL-3.0 +pragma solidity ^0.8.19; + +import {UserOperation} from "@eth-infinitism/account-abstraction/interfaces/UserOperation.sol"; + +interface IModularSessionKeyPlugin { + enum FunctionId { + RUNTIME_VALIDATION_TEMPORARY_OWNER, + USER_OP_VALIDATION_TEMPORARY_OWNER + } + + /// @notice This event is emitted when a session key is added to the account. + /// @param account The account whose session key is updated. + /// @param sessionKey The address of the session key. + /// @param selector The selector of the function that the session key is allowed to call. + /// @param validAfter The time after which the owner is valid. + /// @param validUntil The time until which the owner is valid. + event SessionKeyAdded( + address indexed account, address indexed sessionKey, bytes4 selector, uint48 validAfter, uint48 validUntil + ); + + /// @notice This event is emitted when a session key is removed from the account. + /// @param account The account whose session key is updated. + /// @param sessionKey The address of the session key. + /// @param selector The selector of the function that the session key is allowed to call. + event SessionKeyRemoved(address indexed account, address indexed sessionKey, bytes4 selector); + + /// @notice This event is emitted when session keys are added to the account. + /// @param account The account whose session keys are updated. + /// @param sessionKeys The addresses of the session keys. + /// @param selectors The selectors of the functions that the session keys are allowed to call. + /// @param validAfters The times after which the owners are valid. + /// @param validUntils The times until which the owners are valid. + event SessionKeysAdded( + address indexed account, + address[] sessionKeys, + bytes4[] selectors, + uint48[] validAfters, + uint48[] validUntils + ); + + /// @notice This event is emitted when session keys are removed from the account. + /// @param account The account whose session keys are updated. + /// @param sessionKeys The addresses of the session keys. + /// @param selectors The selectors of the functions that the session keys are allowed to call. + event SessionKeysRemoved(address indexed account, address[] sessionKeys, bytes4[] selectors); + + error InvalidSignature(); + error NotAuthorized(); + error WrongTimeRangeForSession(); + error WrongDataLength(); + + /// @notice Add a session key to the account. + /// @dev This function is installed on the account as part of plugin installation, and should + /// only be called from an account. The function selector installed by a child session key plugin + /// is passed as a parameter, which enforces its own permissions on the calls it can make. + /// @param sessionKey The address of the session key. + /// @param allowedSelector The selector of the function that the session key is allowed to call. + /// @param validAfter The time after which the owner is valid. + /// @param validUntil The time until which the owner is valid. + function addSessionKey(address sessionKey, bytes4 allowedSelector, uint48 validAfter, uint48 validUntil) + external; + + /// @notice Remove a session key from the account. + /// @dev This function is installed on the account as part of plugin installation, and should + /// only be called from an account. + /// @param sessionKey The address of the session key. + /// @param allowedSelector The selector of the function that the session key is allowed to call. + function removeSessionKey(address sessionKey, bytes4 allowedSelector) external; + + /// @notice Add session keys to the account. + /// @dev This function is installed on the account as part of plugin installation, and should + /// only be called from an account. + /// @param sessionKeys The addresses of the session keys. + /// @param allowedSelectors The selectors of the functions that the session keys are allowed to call. + /// @param validAfters The times after which the owners are valid. + /// @param validUntils The times until which the owners are valid. + function addSessionKeyBatch( + address[] calldata sessionKeys, + bytes4[] calldata allowedSelectors, + uint48[] calldata validAfters, + uint48[] calldata validUntils + ) external; + + /// @notice Remove session keys from the account. + /// @dev This function is installed on the account as part of plugin installation, and should + /// only be called from an account. + /// @param sessionKeys The addresses of the session keys. + /// @param allowedSelectors The selectors of the functions that the session keys are allowed to call. + function removeSessionKeyBatch(address[] calldata sessionKeys, bytes4[] calldata allowedSelectors) external; + + /// @notice Get Session data for a given account and session key. + /// @param account The account to get session data for. + /// @param sessionKey The address of the session key. + /// @param allowedSelector The selector of the function that the session key is allowed to call. + function getSessionDuration(address account, address sessionKey, bytes4 allowedSelector) + external + view + returns (uint48 validAfter, uint48 validUntil); + + /// @notice Get all session keys and selectors for a given account. + /// @param account The account to get session keys and selectors for. + /// @return sessionKeys The addresses of the session keys. + /// @return selectors The selectors of the functions that the session keys are allowed to call. + function getSessionKeysAndSelectors(address account) + external + view + returns (address[] memory sessionKeys, bytes4[] memory selectors); +} diff --git a/src/samples/plugins/interfaces/ITokenSessionKeyPlugin.sol b/src/samples/plugins/interfaces/ITokenSessionKeyPlugin.sol new file mode 100644 index 00000000..65e64113 --- /dev/null +++ b/src/samples/plugins/interfaces/ITokenSessionKeyPlugin.sol @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: GPL-3.0 +pragma solidity ^0.8.19; + +import {UserOperation} from "@eth-infinitism/account-abstraction/interfaces/UserOperation.sol"; + +interface ITokenSessionKeyPlugin { + error NotAuthorized(); + + /// @notice Route call to executeFromPluginExternal at the MSCA. + /// @dev This function will call with value = 0, since sending ether + /// to ERC20 contract is not a normal case. + /// @param target The target address to execute the call on. + /// @param from The address to transfer tokens from. + /// @param to The address to transfer tokens to. + /// @param amount The amount of tokens to transfer. + function transferFromSessionKey(address target, address from, address to, uint256 amount) + external + returns (bytes memory returnData); +} diff --git a/test/mocks/MockERC20.sol b/test/mocks/MockERC20.sol new file mode 100644 index 00000000..d87a9234 --- /dev/null +++ b/test/mocks/MockERC20.sol @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import {ERC20} from "@openzeppelin/contracts/token/ERC20/ERC20.sol"; + +contract MockERC20 is ERC20 { + constructor(string memory name, string memory symbol) ERC20(name, symbol) {} + + function mint(address account, uint256 amount) external { + _mint(account, amount); + } +} diff --git a/test/samples/plugins/ModularSessionKeyPlugin.t.sol b/test/samples/plugins/ModularSessionKeyPlugin.t.sol new file mode 100644 index 00000000..bdbcbfff --- /dev/null +++ b/test/samples/plugins/ModularSessionKeyPlugin.t.sol @@ -0,0 +1,379 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.19; + +import {Test} from "forge-std/Test.sol"; + +import {EntryPoint} from "@eth-infinitism/account-abstraction/core/EntryPoint.sol"; +import {UserOperation} from "@eth-infinitism/account-abstraction/interfaces/UserOperation.sol"; +import {ECDSA} from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; + +import {SingleOwnerPlugin} from "../../../src/plugins/owner/SingleOwnerPlugin.sol"; +import {ISingleOwnerPlugin} from "../../../src/plugins/owner/ISingleOwnerPlugin.sol"; +import {ModularSessionKeyPlugin} from "../../../src/samples/plugins/ModularSessionKeyPlugin.sol"; +import {IModularSessionKeyPlugin} from "../../../src/samples/plugins/interfaces/ISessionKeyPlugin.sol"; +import {TokenSessionKeyPlugin} from "../../../src/samples/plugins/TokenSessionKeyPlugin.sol"; +import {ITokenSessionKeyPlugin} from "../../../src/samples/plugins/interfaces/ITokenSessionKeyPlugin.sol"; + +import {UpgradeableModularAccount} from "../../../src/account/UpgradeableModularAccount.sol"; +import {MSCAFactoryFixture} from "../../mocks/MSCAFactoryFixture.sol"; +import {FunctionReference, FunctionReferenceLib} from "../../../src/helpers/FunctionReferenceLib.sol"; +import {IPluginManager} from "../../../src/interfaces/IPluginManager.sol"; +import {MockERC20} from "../../mocks/MockERC20.sol"; + +contract ModularSessionKeyPluginTest is Test { + using ECDSA for bytes32; + + SingleOwnerPlugin public ownerPlugin; + ModularSessionKeyPlugin public modularSessionKeyPlugin; + TokenSessionKeyPlugin public tokenSessionKeyPlugin; + EntryPoint public entryPoint; + MSCAFactoryFixture public factory; + UpgradeableModularAccount public account; + + MockERC20 public mockERC20impl; + MockERC20 public mockERC20; + address public mockEmptyERC20Addr; + + address public owner; + uint256 public ownerKey; + + address public maliciousOwner; + + address public tempOwner; + uint256 public tempOwnerKey; + + address public target; + + address payable public beneficiary; + + uint256 public constant CALL_GAS_LIMIT = 150000; + uint256 public constant VERIFICATION_GAS_LIMIT = 3600000; + + bytes4 public constant TRANSFERFROM_SESSIONKEY_SELECTOR = + ITokenSessionKeyPlugin.transferFromSessionKey.selector; + + // Event declarations (needed for vm.expectEmit) + event UserOperationRevertReason( + bytes32 indexed userOpHash, address indexed sender, uint256 nonce, bytes revertReason + ); + event SessionKeyAdded( + address indexed account, address indexed sessionKey, bytes4 allowedSelector, uint48 _after, uint48 _until + ); + event SessionKeyRemoved(address indexed account, address indexed sessionKey, bytes4 allowedSelector); + event SessionKeysAdded( + address indexed account, address[] sessionKeys, bytes4[] allowedSelectors, uint48[] afters, uint48[] untils + ); + event SessionKeysRemoved(address indexed account, address[] sessionKeys, bytes4[] allowedSelectors); + event PluginUninstalled(address indexed plugin, bool indexed onUninstallSuccess); + + function setUp() public { + ownerPlugin = new SingleOwnerPlugin(); + modularSessionKeyPlugin = new ModularSessionKeyPlugin(); + tokenSessionKeyPlugin = new TokenSessionKeyPlugin(); + + entryPoint = new EntryPoint(); + factory = new MSCAFactoryFixture(entryPoint, ownerPlugin); + mockERC20impl = new MockERC20("Mock", "MCK"); + + // Etching MockERC20 code into hardcoded address at TokenSessionKeyPlugin + mockEmptyERC20Addr = tokenSessionKeyPlugin.TARGET_ERC20_CONTRACT(); + bytes memory code = address(mockERC20impl).code; + vm.etch(mockEmptyERC20Addr, code); + mockERC20 = MockERC20(mockEmptyERC20Addr); + + (owner, ownerKey) = makeAddrAndKey("owner"); + (maliciousOwner,) = makeAddrAndKey("maliciousOwner"); + (tempOwner, tempOwnerKey) = makeAddrAndKey("tempOwner"); + target = makeAddr("target"); + + beneficiary = payable(makeAddr("beneficiary")); + vm.deal(beneficiary, 1 wei); + vm.deal(owner, 10 ether); + + // Here, SingleOwnerPlugin already installed in factory + account = factory.createAccount(owner, 0); + + // Mint Mock ERC20 Tokens to account + mockERC20.mint(address(account), 1 ether); + // Fund the account with some ether + vm.deal(address(account), 1 ether); + + vm.startPrank(owner); + FunctionReference[] memory modularSessionDependency = new FunctionReference[](2); + modularSessionDependency[0] = FunctionReferenceLib.pack( + address(ownerPlugin), uint8(ISingleOwnerPlugin.FunctionId.USER_OP_VALIDATION_OWNER) + ); + modularSessionDependency[1] = FunctionReferenceLib.pack( + address(ownerPlugin), uint8(ISingleOwnerPlugin.FunctionId.RUNTIME_VALIDATION_OWNER_OR_SELF) + ); + + bytes32 modularSessionKeyManifestHash = keccak256(abi.encode(modularSessionKeyPlugin.pluginManifest())); + + address[] memory tempOwners = new address[](1); + tempOwners[0] = address(tempOwner); + + bytes4[] memory allowedSelectors = new bytes4[](1); + allowedSelectors[0] = TRANSFERFROM_SESSIONKEY_SELECTOR; + + uint48[] memory afters = new uint48[](1); + afters[0] = 0; + + uint48[] memory untils = new uint48[](1); + untils[0] = 2; + + bytes memory data = abi.encode(tempOwners, allowedSelectors, afters, untils); + + account.installPlugin({ + plugin: address(modularSessionKeyPlugin), + manifestHash: modularSessionKeyManifestHash, + pluginInstallData: data, + dependencies: modularSessionDependency + }); + + FunctionReference[] memory tokenSessionDependency = new FunctionReference[](2); + tokenSessionDependency[0] = FunctionReferenceLib.pack( + address(modularSessionKeyPlugin), + uint8(IModularSessionKeyPlugin.FunctionId.USER_OP_VALIDATION_TEMPORARY_OWNER) + ); + tokenSessionDependency[1] = FunctionReferenceLib.pack( + address(modularSessionKeyPlugin), + uint8(IModularSessionKeyPlugin.FunctionId.RUNTIME_VALIDATION_TEMPORARY_OWNER) + ); + bytes32 tokenSessionKeyManifestHash = keccak256(abi.encode(tokenSessionKeyPlugin.pluginManifest())); + + account.installPlugin({ + plugin: address(tokenSessionKeyPlugin), + manifestHash: tokenSessionKeyManifestHash, + pluginInstallData: "", + dependencies: tokenSessionDependency + }); + vm.stopPrank(); + + vm.startPrank(address(account)); + mockERC20.approve(address(account), 1 ether); + + (uint48 _after, uint48 _until) = modularSessionKeyPlugin.getSessionDuration( + address(account), tempOwner, TRANSFERFROM_SESSIONKEY_SELECTOR + ); + + assertEq(_after, 0); + assertEq(_until, 2); + vm.stopPrank(); + } + + function test_sessionKey_batch() public { + address tempOwner2 = makeAddr("tempOwner2"); + address tempOwner3 = makeAddr("tempOwner3"); + + address[] memory tempOwners = new address[](2); + tempOwners[0] = tempOwner2; + tempOwners[1] = tempOwner3; + + bytes4[] memory allowedSelectors = new bytes4[](2); + allowedSelectors[0] = TRANSFERFROM_SESSIONKEY_SELECTOR; + allowedSelectors[1] = TRANSFERFROM_SESSIONKEY_SELECTOR; + + uint48[] memory afters = new uint48[](2); + afters[0] = 0; + afters[1] = 0; + + uint48[] memory untils = new uint48[](2); + untils[0] = 2; + untils[1] = 2; + + vm.expectEmit(true, true, true, true); + emit SessionKeysAdded(address(account), tempOwners, allowedSelectors, afters, untils); + vm.prank(address(account)); + modularSessionKeyPlugin.addSessionKeyBatch(tempOwners, allowedSelectors, afters, untils); + + vm.prank(tempOwner3); + TokenSessionKeyPlugin(address(account)).transferFromSessionKey( + address(mockERC20), address(account), target, 1 ether + ); + + assertEq(mockERC20.balanceOf(address(account)), 0); + assertEq(mockERC20.balanceOf(target), 1 ether); + + vm.expectEmit(true, true, true, true); + emit SessionKeysRemoved(address(account), tempOwners, allowedSelectors); + vm.prank(address(account)); + modularSessionKeyPlugin.removeSessionKeyBatch(tempOwners, allowedSelectors); + } + + function test_sessionKey_userOp() public { + UserOperation[] memory userOps = new UserOperation[](1); + + (, UserOperation memory userOp) = _constructUserOp(address(mockERC20), address(account), target, 1 ether); + userOps[0] = userOp; + + entryPoint.handleOps(userOps, beneficiary); + + assertEq(mockERC20.balanceOf(address(account)), 0); + assertEq(mockERC20.balanceOf(target), 1 ether); + } + + function test_sessionKey_runtime() public { + vm.prank(address(tempOwner)); + TokenSessionKeyPlugin(address(account)).transferFromSessionKey( + address(mockERC20), address(account), target, 1 ether + ); + + assertEq(mockERC20.balanceOf(address(account)), 0); + assertEq(mockERC20.balanceOf(target), 1 ether); + } + + function test_sessionKey_removeTempOwner() public { + vm.startPrank(address(account)); + + vm.expectEmit(true, true, true, true); + emit SessionKeyRemoved(address(account), tempOwner, TRANSFERFROM_SESSIONKEY_SELECTOR); + modularSessionKeyPlugin.removeSessionKey(tempOwner, TRANSFERFROM_SESSIONKEY_SELECTOR); + + vm.stopPrank(); + + (uint48 _after, uint48 _until) = modularSessionKeyPlugin.getSessionDuration( + address(account), tempOwner, TRANSFERFROM_SESSIONKEY_SELECTOR + ); + assertEq(_after, 0); + assertEq(_until, 0); + + // Check if tempOwner can still send user operations + vm.startPrank(address(tempOwner)); + + bytes memory revertReason = abi.encodeWithSelector(IModularSessionKeyPlugin.NotAuthorized.selector); + vm.expectRevert( + abi.encodeWithSelector( + UpgradeableModularAccount.RuntimeValidationFunctionReverted.selector, + address(modularSessionKeyPlugin), + IModularSessionKeyPlugin.FunctionId.RUNTIME_VALIDATION_TEMPORARY_OWNER, + revertReason + ) + ); + TokenSessionKeyPlugin(address(account)).transferFromSessionKey( + address(mockERC20), address(account), target, 1 ether + ); + } + + function test_sessionKey_invalidContractFails() public { + address wrongERC20Contract = makeAddr("wrongERC20Contract"); + (bytes32 userOpHash, UserOperation memory userOp) = + _constructUserOp(address(wrongERC20Contract), address(account), target, 1 ether); + + UserOperation[] memory userOps = new UserOperation[](1); + userOps[0] = userOp; + + bytes memory revertCallData = abi.encodeWithSelector( + tokenSessionKeyPlugin.TRANSFERFROM_SELECTOR(), address(account), target, 1 ether + ); + bytes memory revertReason = abi.encodeWithSelector( + UpgradeableModularAccount.ExecFromPluginExternalNotPermitted.selector, + address(tokenSessionKeyPlugin), + address(wrongERC20Contract), + 0, + revertCallData + ); + vm.expectEmit(true, true, true, true); + emit UserOperationRevertReason(userOpHash, address(account), 0, revertReason); + + entryPoint.handleOps(userOps, beneficiary); + } + + function test_sessionKey_unregisteredTempOwnerFails() public { + vm.prank(address(maliciousOwner)); + bytes memory revertReason = abi.encodeWithSelector(IModularSessionKeyPlugin.NotAuthorized.selector); + + vm.expectRevert( + abi.encodeWithSelector( + UpgradeableModularAccount.RuntimeValidationFunctionReverted.selector, + address(modularSessionKeyPlugin), + IModularSessionKeyPlugin.FunctionId.RUNTIME_VALIDATION_TEMPORARY_OWNER, + revertReason + ) + ); + TokenSessionKeyPlugin(address(account)).transferFromSessionKey( + address(mockERC20), address(account), target, 1 ether + ); + } + + function test_sessionKey_invalidSessionDurationFails() public { + // Move block.timestamp to 12345 + vm.warp(12345); + + vm.startPrank(address(tempOwner)); + + bytes memory revertReason = + abi.encodeWithSelector(IModularSessionKeyPlugin.WrongTimeRangeForSession.selector); + + vm.expectRevert( + abi.encodeWithSelector( + UpgradeableModularAccount.RuntimeValidationFunctionReverted.selector, + address(modularSessionKeyPlugin), + IModularSessionKeyPlugin.FunctionId.RUNTIME_VALIDATION_TEMPORARY_OWNER, + revertReason + ) + ); + TokenSessionKeyPlugin(address(account)).transferFromSessionKey( + address(mockERC20), address(account), target, 1 ether + ); + } + + function test_sessionKey_uninstallModularSessionKeyPlugin() public { + address[] memory tempOwners = new address[](1); + tempOwners[0] = address(tempOwner); + + bytes4[] memory allowedSelectors = new bytes4[](1); + allowedSelectors[0] = TRANSFERFROM_SESSIONKEY_SELECTOR; + + vm.startPrank(owner); + + vm.expectEmit(true, true, true, true); + + emit PluginUninstalled(address(tokenSessionKeyPlugin), true); + account.uninstallPlugin({ + plugin: address(tokenSessionKeyPlugin), + config: bytes(""), + pluginUninstallData: "" + }); + + vm.expectEmit(true, true, true, true); + emit PluginUninstalled(address(modularSessionKeyPlugin), true); + account.uninstallPlugin({ + plugin: address(modularSessionKeyPlugin), + config: bytes(""), + pluginUninstallData: "" + }); + + vm.stopPrank(); + } + + // Internal Function + function _constructUserOp(address targetContract, address from, address to, uint256 amount) + internal + view + returns (bytes32, UserOperation memory) + { + bytes memory userOpCallData = + abi.encodeCall(TokenSessionKeyPlugin.transferFromSessionKey, (targetContract, from, to, amount)); + + UserOperation memory userOp = UserOperation({ + sender: address(account), + nonce: 0, + initCode: "", + callData: userOpCallData, + callGasLimit: CALL_GAS_LIMIT, + verificationGasLimit: VERIFICATION_GAS_LIMIT, + preVerificationGas: 0, + maxFeePerGas: 2, + maxPriorityFeePerGas: 1, + paymasterAndData: "", + signature: "" + }); + + // Generate signature + bytes32 userOpHash = entryPoint.getUserOpHash(userOp); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(tempOwnerKey, userOpHash.toEthSignedMessageHash()); + userOp.signature = abi.encodePacked(r, s, v); + + return (userOpHash, userOp); + } +}