1+ import asyncio
12from decimal import Decimal
23from pathlib import Path
3- from typing import Awaitable , Dict , Optional , Set , Union
4+ from typing import Awaitable , List , Optional , Union
45
56from aleph_message .models import Chain
67from eth_account import Account
78from eth_account .messages import encode_defunct
89from eth_account .signers .local import LocalAccount
910from eth_keys .exceptions import BadSignature as EthBadSignatureError
11+ from eth_utils import to_wei
1012from superfluid import Web3FlowInfo
13+ from web3 import Web3
14+ from web3 .middleware import geth_poa_middleware
15+ from web3 .types import ChecksumAddress , TxParams , TxReceipt
16+
17+ from aleph .sdk .exceptions import InsufficientFundsError
1118
1219from ..conf import settings
1320from ..connectors .superfluid import Superfluid
1421from ..exceptions import BadSignatureError
1522from ..utils import bytes_from_hex
1623from .common import BaseAccount , get_fallback_private_key , get_public_key
1724
18- CHAINS_WITH_SUPERTOKEN : Set [Chain ] = {Chain .AVAX }
19- CHAIN_IDS : Dict [Chain , int ] = {
20- Chain .AVAX : settings .AVAX_CHAIN_ID ,
21- }
25+ MIN_ETH_BALANCE : float = 0.005
26+ MIN_ETH_BALANCE_WEI = Decimal (to_wei (MIN_ETH_BALANCE , "ether" ))
27+ BALANCEOF_ABI = """[{
28+ "name": "balanceOf",
29+ "inputs": [{"name": "account", "type": "address"}],
30+ "outputs": [{"name": "balance", "type": "uint256"}],
31+ "constant": true,
32+ "payable": false,
33+ "stateMutability": "view",
34+ "type": "function"
35+ }]"""
2236
2337
24- def get_rpc_for_chain (chain : Chain ):
25- """Returns the RPC to use for a given Ethereum based blockchain"""
26- if not chain :
27- return None
38+ def to_human_readable_token (amount : Decimal ) -> float :
39+ return float (amount / (Decimal (10 ) ** Decimal (settings .TOKEN_DECIMALS )))
2840
29- if chain == Chain .AVAX :
30- return settings .AVAX_RPC
31- else :
32- raise ValueError (f"Unknown RPC for chain { chain } " )
3341
42+ def to_wei_token (amount : Decimal ) -> Decimal :
43+ return amount * Decimal (10 ) ** Decimal (settings .TOKEN_DECIMALS )
3444
35- def get_chain_id_for_chain (chain : Chain ):
36- """Returns the chain ID of a given Ethereum based blockchain"""
37- if not chain :
38- return None
3945
40- if chain in CHAIN_IDS :
41- return CHAIN_IDS [chain ]
42- else :
43- raise ValueError (f"Unknown RPC for chain { chain } " )
46+ def get_chain_id (chain : Union [Chain , str , None ]) -> Optional [int ]:
47+ """Returns the CHAIN_ID of a given EVM blockchain"""
48+ if chain :
49+ if chain in settings .CHAINS and settings .CHAINS [chain ].chain_id :
50+ return settings .CHAINS [chain ].chain_id
51+ else :
52+ raise ValueError (f"Unknown RPC for chain { chain } " )
53+ return None
54+
55+
56+ def get_rpc (chain : Union [Chain , str , None ]) -> Optional [str ]:
57+ """Returns the RPC to use for a given EVM blockchain"""
58+ if chain :
59+ if chain in settings .CHAINS and settings .CHAINS [chain ].rpc :
60+ return settings .CHAINS [chain ].rpc
61+ else :
62+ raise ValueError (f"Unknown RPC for chain { chain } " )
63+ return None
64+
65+
66+ def get_token_address (chain : Union [Chain , str , None ]) -> Optional [ChecksumAddress ]:
67+ if chain :
68+ if chain in settings .CHAINS :
69+ address = settings .CHAINS [chain ].super_token
70+ if address :
71+ try :
72+ return Web3 .to_checksum_address (address )
73+ except ValueError :
74+ raise ValueError (f"Invalid token address { address } " )
75+ else :
76+ raise ValueError (f"Unknown token for chain { chain } " )
77+ return None
78+
79+
80+ def get_super_token_address (
81+ chain : Union [Chain , str , None ]
82+ ) -> Optional [ChecksumAddress ]:
83+ if chain :
84+ if chain in settings .CHAINS :
85+ address = settings .CHAINS [chain ].super_token
86+ if address :
87+ try :
88+ return Web3 .to_checksum_address (address )
89+ except ValueError :
90+ raise ValueError (f"Invalid token address { address } " )
91+ else :
92+ raise ValueError (f"Unknown super_token for chain { chain } " )
93+ return None
94+
95+
96+ def get_chains_with_super_token () -> List [Union [Chain , str ]]:
97+ return [chain for chain , info in settings .CHAINS .items () if info .super_token ]
4498
4599
46100class ETHAccount (BaseAccount ):
47- """Interact with an Ethereum address or key pair"""
101+ """Interact with an Ethereum address or key pair on EVM blockchains """
48102
49103 CHAIN = "ETH"
50104 CURVE = "secp256k1"
51105 _account : LocalAccount
106+ _provider : Optional [Web3 ]
52107 chain : Optional [Chain ]
108+ chain_id : Optional [int ]
109+ rpc : Optional [str ]
53110 superfluid_connector : Optional [Superfluid ]
54111
55112 def __init__ (
56113 self ,
57114 private_key : bytes ,
58115 chain : Optional [Chain ] = None ,
59- rpc : Optional [str ] = None ,
60- chain_id : Optional [int ] = None ,
61116 ):
62- self .private_key = private_key
63- self ._account = Account .from_key (self .private_key )
64- self .chain = chain
65- rpc = rpc or get_rpc_for_chain (chain )
66- chain_id = chain_id or get_chain_id_for_chain (chain )
67- self .superfluid_connector = (
68- Superfluid (
69- rpc = rpc ,
70- chain_id = chain_id ,
71- account = self ._account ,
72- )
73- if chain in CHAINS_WITH_SUPERTOKEN
74- else None
117+ self ._account : LocalAccount = Account .from_key (private_key )
118+ self .connect_chain (chain = chain )
119+
120+ @staticmethod
121+ def from_mnemonic (mnemonic : str , chain : Optional [Chain ] = None ) -> "ETHAccount" :
122+ Account .enable_unaudited_hdwallet_features ()
123+ return ETHAccount (
124+ private_key = Account .from_mnemonic (mnemonic = mnemonic ).key , chain = chain
75125 )
76126
127+ def get_address (self ) -> str :
128+ return self ._account .address
129+
130+ def get_public_key (self ) -> str :
131+ return "0x" + get_public_key (private_key = self ._account .key ).hex ()
132+
77133 async def sign_raw (self , buffer : bytes ) -> bytes :
78134 """Sign a raw buffer."""
79135 msghash = encode_defunct (text = buffer .decode ("utf-8" ))
80136 sig = self ._account .sign_message (msghash )
81137 return sig ["signature" ]
82138
83- def get_address (self ) -> str :
84- return self ._account .address
139+ def connect_chain (self , chain : Optional [Chain ] = None ):
140+ self .chain = chain
141+ if self .chain :
142+ self .chain_id = get_chain_id (self .chain )
143+ self .rpc = get_rpc (self .chain )
144+ self ._provider = Web3 (Web3 .HTTPProvider (self .rpc ))
145+ if chain == Chain .BSC :
146+ self ._provider .middleware_onion .inject (
147+ geth_poa_middleware , "geth_poa" , layer = 0
148+ )
149+ else :
150+ self .chain_id = None
151+ self .rpc = None
152+ self ._provider = None
85153
86- def get_public_key (self ) -> str :
87- return "0x" + get_public_key (private_key = self ._account .key ).hex ()
154+ if chain in get_chains_with_super_token () and self ._provider :
155+ self .superfluid_connector = Superfluid (self )
156+ else :
157+ self .superfluid_connector = None
88158
89- @staticmethod
90- def from_mnemonic (mnemonic : str ) -> "ETHAccount" :
91- Account .enable_unaudited_hdwallet_features ()
92- return ETHAccount (private_key = Account .from_mnemonic (mnemonic = mnemonic ).key )
159+ def switch_chain (self , chain : Optional [Chain ] = None ):
160+ self .connect_chain (chain = chain )
161+
162+ def can_transact (self , block = True ) -> bool :
163+ balance = self .get_eth_balance ()
164+ valid = balance > MIN_ETH_BALANCE_WEI if self .chain else False
165+ if not valid and block :
166+ raise InsufficientFundsError (
167+ required_funds = MIN_ETH_BALANCE ,
168+ available_funds = to_human_readable_token (balance ),
169+ )
170+ return valid
171+
172+ async def _sign_and_send_transaction (self , tx_params : TxParams ) -> str :
173+ """
174+ Sign and broadcast a transaction using the provided ETHAccount
175+ @param tx_params - Transaction parameters
176+ @returns - str - Transaction hash
177+ """
178+ self .can_transact ()
179+
180+ def sign_and_send () -> TxReceipt :
181+ if self ._provider is None :
182+ raise ValueError ("Provider not connected" )
183+ signed_tx = self ._provider .eth .account .sign_transaction (
184+ tx_params , self ._account .key
185+ )
186+ tx_hash = self ._provider .eth .send_raw_transaction (signed_tx .rawTransaction )
187+ tx_receipt = self ._provider .eth .wait_for_transaction_receipt (
188+ tx_hash , settings .TX_TIMEOUT
189+ )
190+ return tx_receipt
191+
192+ loop = asyncio .get_running_loop ()
193+ tx_receipt = await loop .run_in_executor (None , sign_and_send )
194+ return tx_receipt ["transactionHash" ].hex ()
195+
196+ def get_eth_balance (self ) -> Decimal :
197+ return Decimal (
198+ self ._provider .eth .get_balance (self ._account .address )
199+ if self ._provider
200+ else 0
201+ )
202+
203+ def get_token_balance (self ) -> Decimal :
204+ if self .chain and self ._provider :
205+ contact_address = get_token_address (self .chain )
206+ if contact_address :
207+ contract = self ._provider .eth .contract (
208+ address = contact_address , abi = BALANCEOF_ABI
209+ )
210+ return Decimal (contract .functions .balanceOf (self .get_address ()).call ())
211+ return Decimal (0 )
212+
213+ def get_super_token_balance (self ) -> Decimal :
214+ if self .chain and self ._provider :
215+ contact_address = get_super_token_address (self .chain )
216+ if contact_address :
217+ contract = self ._provider .eth .contract (
218+ address = contact_address , abi = BALANCEOF_ABI
219+ )
220+ return Decimal (contract .functions .balanceOf (self .get_address ()).call ())
221+ return Decimal (0 )
93222
94223 def create_flow (self , receiver : str , flow : Decimal ) -> Awaitable [str ]:
95224 """Creat a Superfluid flow between this account and the receiver address."""
96225 if not self .superfluid_connector :
97226 raise ValueError ("Superfluid connector is required to create a flow" )
98- return self .superfluid_connector .create_flow (
99- sender = self .get_address (), receiver = receiver , flow = flow
100- )
227+ return self .superfluid_connector .create_flow (receiver = receiver , flow = flow )
101228
102229 def get_flow (self , receiver : str ) -> Awaitable [Web3FlowInfo ]:
103230 """Get the Superfluid flow between this account and the receiver address."""
@@ -111,29 +238,19 @@ def update_flow(self, receiver: str, flow: Decimal) -> Awaitable[str]:
111238 """Update the Superfluid flow between this account and the receiver address."""
112239 if not self .superfluid_connector :
113240 raise ValueError ("Superfluid connector is required to update a flow" )
114- return self .superfluid_connector .update_flow (
115- sender = self .get_address (), receiver = receiver , flow = flow
116- )
241+ return self .superfluid_connector .update_flow (receiver = receiver , flow = flow )
117242
118243 def delete_flow (self , receiver : str ) -> Awaitable [str ]:
119244 """Delete the Superfluid flow between this account and the receiver address."""
120245 if not self .superfluid_connector :
121246 raise ValueError ("Superfluid connector is required to delete a flow" )
122- return self .superfluid_connector .delete_flow (
123- sender = self .get_address (), receiver = receiver
124- )
125-
126- def update_superfluid_connector (self , rpc : str , chain_id : int ):
127- """Update the Superfluid connector after initialisation."""
128- self .superfluid_connector = Superfluid (
129- rpc = rpc ,
130- chain_id = chain_id ,
131- account = self ._account ,
132- )
247+ return self .superfluid_connector .delete_flow (receiver = receiver )
133248
134249
135- def get_fallback_account (path : Optional [Path ] = None ) -> ETHAccount :
136- return ETHAccount (private_key = get_fallback_private_key (path = path ))
250+ def get_fallback_account (
251+ path : Optional [Path ] = None , chain : Optional [Chain ] = None
252+ ) -> ETHAccount :
253+ return ETHAccount (private_key = get_fallback_private_key (path = path ), chain = chain )
137254
138255
139256def verify_signature (
0 commit comments