Skip to content

Commit e1be4fe

Browse files
committed
Add chain arg to _load_account
1 parent 22eda92 commit e1be4fe

File tree

2 files changed

+39
-7
lines changed

2 files changed

+39
-7
lines changed

src/aleph/sdk/account.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from aleph.sdk.chains.remote import RemoteAccount
1111
from aleph.sdk.chains.solana import SOLAccount
1212
from aleph.sdk.conf import load_main_configuration, settings
13+
from aleph.sdk.evm_utils import get_chains_with_super_token
1314
from aleph.sdk.types import AccountFromPrivateKey
1415

1516
logger = logging.getLogger(__name__)
@@ -27,21 +28,48 @@ def load_chain_account_type(chain: Chain) -> Type[AccountFromPrivateKey]:
2728
return chain_account_map.get(chain) or ETHAccount
2829

2930

30-
def account_from_hex_string(private_key_str: str, account_type: Type[T]) -> T:
31+
def account_from_hex_string(
32+
private_key_str: str, account_type: Optional[Type[T]], chain: Optional[Chain] = None
33+
) -> T:
3134
if private_key_str.startswith("0x"):
3235
private_key_str = private_key_str[2:]
33-
return account_type(bytes.fromhex(private_key_str))
3436

37+
if not chain:
38+
if not account_type:
39+
account_type = ETHAccount
40+
return account_type(bytes.fromhex(private_key_str))
3541

36-
def account_from_file(private_key_path: Path, account_type: Type[T]) -> T:
42+
account_type = load_chain_account_type(chain)
43+
account = account_type(bytes.fromhex(private_key_str))
44+
if chain in get_chains_with_super_token():
45+
account.switch_chain(chain)
46+
return account
47+
48+
49+
def account_from_file(
50+
private_key_path: Path,
51+
account_type: Optional[Type[T]],
52+
chain: Optional[Chain] = None,
53+
) -> T:
3754
private_key = private_key_path.read_bytes()
38-
return account_type(private_key)
55+
56+
if not chain:
57+
if not account_type:
58+
account_type = ETHAccount
59+
return account_type(private_key)
60+
61+
account_type = load_chain_account_type(chain)
62+
account = account_type(private_key)
63+
if chain in get_chains_with_super_token():
64+
account.switch_chain(chain)
65+
return account
3966

4067

4168
def _load_account(
4269
private_key_str: Optional[str] = None,
4370
private_key_path: Optional[Path] = None,
4471
account_type: Optional[Type[AccountFromPrivateKey]] = None,
72+
chain: Optional[Chain] = None,
4573
) -> AccountFromPrivateKey:
4674
"""Load an account from a private key string or file, or from the configuration file."""
4775

@@ -61,10 +89,10 @@ def _load_account(
6189

6290
# Loads private key from a string
6391
if private_key_str:
64-
return account_from_hex_string(private_key_str, account_type)
92+
return account_from_hex_string(private_key_str, account_type, chain)
6593
# Loads private key from a file
6694
elif private_key_path and private_key_path.is_file():
67-
return account_from_file(private_key_path, account_type)
95+
return account_from_file(private_key_path, account_type, chain)
6896
# For ledger keys
6997
elif settings.REMOTE_CRYPTO_HOST:
7098
logger.debug("Using remote account")
@@ -78,7 +106,9 @@ def _load_account(
78106
# Fallback: config.path if set, else generate a new private key
79107
else:
80108
new_private_key = get_fallback_private_key()
81-
account = account_type(private_key=new_private_key)
109+
account = account_from_hex_string(
110+
bytes.hex(new_private_key), account_type, chain
111+
)
82112
logger.info(
83113
f"Generated fallback private key with address {account.get_address()}"
84114
)

src/aleph/sdk/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ async def sign_raw(self, buffer: bytes) -> bytes: ...
4141

4242
def export_private_key(self) -> str: ...
4343

44+
def switch_chain(self, chain: Optional[str] = None) -> None: ...
45+
4446

4547
GenericMessage = TypeVar("GenericMessage", bound=AlephMessage)
4648

0 commit comments

Comments
 (0)