22import base64
33from decimal import Decimal
44from pathlib import Path
5- from typing import Awaitable , Optional , Union
5+ from typing import Awaitable , Dict , Optional , Union
66
77from aleph_message .models import Chain
88from eth_account import Account # type: ignore
1111from eth_keys .exceptions import BadSignature as EthBadSignatureError
1212from superfluid import Web3FlowInfo
1313from web3 import Web3
14- from web3 .middleware import geth_poa_middleware
14+ from web3 .exceptions import ContractCustomError
15+ from web3 .middleware import ExtraDataToPOAMiddleware
1516from web3 .types import TxParams , TxReceipt
1617
1718from aleph .sdk .exceptions import InsufficientFundsError
2122from ..connectors .superfluid import Superfluid
2223from ..evm_utils import (
2324 BALANCEOF_ABI ,
24- MIN_ETH_BALANCE ,
2525 MIN_ETH_BALANCE_WEI ,
2626 FlowUpdate ,
2727 from_wei_token ,
@@ -80,6 +80,22 @@ async def sign_raw(self, buffer: bytes) -> bytes:
8080 sig = self ._account .sign_message (msghash )
8181 return sig ["signature" ]
8282
83+ async def sign_message (self , message : Dict ) -> Dict :
84+ """
85+ Returns a signed message from an aleph.im message.
86+ Args:
87+ message: Message to sign
88+ Returns:
89+ Dict: Signed message
90+ """
91+ signed_message = await super ().sign_message (message )
92+
93+ # Apply that fix as seems that sometimes the .hex() method doesn't add the 0x str at the beginning
94+ if not str (signed_message ["signature" ]).startswith ("0x" ):
95+ signed_message ["signature" ] = "0x" + signed_message ["signature" ]
96+
97+ return signed_message
98+
8399 def connect_chain (self , chain : Optional [Chain ] = None ):
84100 self .chain = chain
85101 if self .chain :
@@ -88,7 +104,7 @@ def connect_chain(self, chain: Optional[Chain] = None):
88104 self ._provider = Web3 (Web3 .HTTPProvider (self .rpc ))
89105 if chain == Chain .BSC :
90106 self ._provider .middleware_onion .inject (
91- geth_poa_middleware , "geth_poa" , layer = 0
107+ ExtraDataToPOAMiddleware , "geth_poa" , layer = 0
92108 )
93109 else :
94110 self .chain_id = None
@@ -103,14 +119,34 @@ def connect_chain(self, chain: Optional[Chain] = None):
103119 def switch_chain (self , chain : Optional [Chain ] = None ):
104120 self .connect_chain (chain = chain )
105121
106- def can_transact (self , block = True ) -> bool :
107- balance = self .get_eth_balance ()
108- valid = balance > MIN_ETH_BALANCE_WEI if self .chain else False
122+ def can_transact (self , tx : TxParams , block = True ) -> bool :
123+ balance_wei = self .get_eth_balance ()
124+ try :
125+ assert self ._provider is not None
126+
127+ estimated_gas = self ._provider .eth .estimate_gas (tx )
128+
129+ gas_price = tx .get ("gasPrice" , self ._provider .eth .gas_price )
130+
131+ if "maxFeePerGas" in tx :
132+ max_fee = tx ["maxFeePerGas" ]
133+ total_fee_wei = estimated_gas * max_fee
134+ else :
135+ total_fee_wei = estimated_gas * gas_price
136+
137+ total_fee_wei = int (total_fee_wei * 1.2 )
138+
139+ except ContractCustomError :
140+ total_fee_wei = MIN_ETH_BALANCE_WEI # Fallback if estimation fails
141+
142+ required_fee_wei = total_fee_wei + (tx .get ("value" , 0 ))
143+
144+ valid = balance_wei > required_fee_wei if self .chain else False
109145 if not valid and block :
110146 raise InsufficientFundsError (
111147 token_type = TokenType .GAS ,
112- required_funds = MIN_ETH_BALANCE ,
113- available_funds = float (from_wei_token (balance )),
148+ required_funds = float ( from_wei_token ( required_fee_wei )) ,
149+ available_funds = float (from_wei_token (balance_wei )),
114150 )
115151 return valid
116152
@@ -120,15 +156,15 @@ async def _sign_and_send_transaction(self, tx_params: TxParams) -> str:
120156 @param tx_params - Transaction parameters
121157 @returns - str - Transaction hash
122158 """
123- self .can_transact ()
124159
125160 def sign_and_send () -> TxReceipt :
126161 if self ._provider is None :
127162 raise ValueError ("Provider not connected" )
128163 signed_tx = self ._provider .eth .account .sign_transaction (
129164 tx_params , self ._account .key
130165 )
131- tx_hash = self ._provider .eth .send_raw_transaction (signed_tx .rawTransaction )
166+
167+ tx_hash = self ._provider .eth .send_raw_transaction (signed_tx .raw_transaction )
132168 tx_receipt = self ._provider .eth .wait_for_transaction_receipt (
133169 tx_hash , settings .TX_TIMEOUT
134170 )
0 commit comments