Skip to content

Commit dbab017

Browse files
committed
Refactor: Private keys could be deleted
Never delete private keys, as this could put the user at risk. Instead, use temporary files in tests.
1 parent 258cd58 commit dbab017

File tree

10 files changed

+163
-140
lines changed

10 files changed

+163
-140
lines changed

src/aleph_client/chains/common.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
from abc import abstractmethod, ABC
3-
from typing import Dict
3+
from pathlib import Path
4+
from typing import Dict, Optional
45

56
from coincurve.keys import PrivateKey
67
from ecies import decrypt, encrypt
@@ -70,24 +71,24 @@ def generate_key() -> bytes:
7071
return privkey.secret
7172

7273

73-
def get_fallback_private_key() -> bytes:
74+
def get_fallback_private_key(path: Optional[Path] = None) -> bytes:
75+
path = path or settings.PRIVATE_KEY_FILE
7476
private_key: bytes
75-
try:
76-
with open(settings.PRIVATE_KEY_FILE, "rb") as prvfile:
77+
if path.exists() and path.stat().st_size > 0:
78+
with open(path, "rb") as prvfile:
7779
private_key = prvfile.read()
78-
except OSError:
80+
else:
7981
private_key = generate_key()
80-
os.makedirs(os.path.dirname(settings.PRIVATE_KEY_FILE), exist_ok=True)
81-
with open(settings.PRIVATE_KEY_FILE, "wb") as prvfile:
82+
os.makedirs(path.parent, exist_ok=True)
83+
with open(path, "wb") as prvfile:
8284
prvfile.write(private_key)
83-
os.symlink(settings.PRIVATE_KEY_FILE, os.path.join(os.path.dirname(settings.PRIVATE_KEY_FILE), "default.key"))
8485

85-
return private_key
86+
with open(path, "rb") as prvfile:
87+
print(prvfile.read())
8688

8789

88-
def delete_private_key_file():
89-
try:
90-
os.remove(settings.PRIVATE_KEY_FILE)
91-
os.unlink(os.path.join(os.path.dirname(settings.PRIVATE_KEY_FILE), "default.key"))
92-
except FileNotFoundError:
93-
pass
90+
default_key_path = path.parent / "default.key"
91+
if not default_key_path.is_symlink():
92+
# Create a symlink to use this key by default
93+
os.symlink(path, default_key_path)
94+
return private_key

src/aleph_client/chains/ethereum.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Dict
1+
from pathlib import Path
2+
from typing import Dict, Optional
23

34
from eth_account import Account
45
from eth_account.signers.local import LocalAccount
@@ -38,5 +39,5 @@ def get_public_key(self) -> str:
3839
return "0x" + get_public_key(private_key=self._account.key).hex()
3940

4041

41-
def get_fallback_account() -> ETHAccount:
42-
return ETHAccount(private_key=get_fallback_private_key())
42+
def get_fallback_account(path: Optional[Path] = None) -> ETHAccount:
43+
return ETHAccount(private_key=get_fallback_private_key(path=path))

src/aleph_client/chains/sol.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import json
2-
from typing import Dict
2+
import os
3+
from pathlib import Path
4+
from typing import Dict, Optional
35

46
import base58
57
from nacl.public import PrivateKey, SealedBox
@@ -53,17 +55,34 @@ async def decrypt(self, content) -> bytes:
5355
return value
5456

5557

56-
def get_fallback_account() -> SOLAccount:
57-
return SOLAccount(private_key=get_fallback_private_key())
58+
def get_fallback_account(path: Optional[Path] = None) -> SOLAccount:
59+
return SOLAccount(private_key=get_fallback_private_key(path=path))
5860

5961

60-
def get_fallback_private_key():
61-
try:
62-
with open(settings.PRIVATE_KEY_FILE, "rb") as prvfile:
63-
pkey = prvfile.read()
64-
except OSError:
65-
pkey = bytes(SigningKey.generate())
66-
with open(settings.PRIVATE_KEY_FILE, "wb") as prvfile:
67-
prvfile.write(pkey)
62+
def generate_key() -> bytes:
63+
privkey = bytes(SigningKey.generate())
64+
return privkey
65+
66+
67+
def get_fallback_private_key(path: Optional[Path] = None) -> bytes:
68+
path = path or settings.PRIVATE_KEY_FILE
69+
private_key: bytes
70+
if path.exists() and path.stat().st_size > 0:
71+
with open(path, "rb") as prvfile:
72+
private_key = prvfile.read()
73+
else:
74+
private_key = generate_key()
75+
os.makedirs(path.parent, exist_ok=True)
76+
with open(path, "wb") as prvfile:
77+
prvfile.write(private_key)
78+
79+
with open(path, "rb") as prvfile:
80+
print(prvfile.read())
81+
82+
83+
default_key_path = path.parent / "default.key"
84+
if not default_key_path.is_symlink():
85+
# Create a symlink to use this key by default
86+
os.symlink(path, default_key_path)
87+
return private_key
6888

69-
return pkey

src/aleph_client/chains/tezos.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
2-
from typing import Dict
2+
from pathlib import Path
3+
from typing import Dict, Optional
34

45
from aleph_pytezos.crypto.key import Key
56
from nacl.public import SealedBox
@@ -49,5 +50,5 @@ async def decrypt(self, content) -> bytes:
4950
return SealedBox(self._private_key).decrypt(content)
5051

5152

52-
def get_fallback_account() -> TezosAccount:
53-
return TezosAccount(private_key=get_fallback_private_key())
53+
def get_fallback_account(path: Optional[Path] = None) -> TezosAccount:
54+
return TezosAccount(private_key=get_fallback_private_key(path=path))

tests/unit/conftest.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,38 @@
66
Read more about conftest.py under:
77
https://pytest.org/latest/plugins.html
88
"""
9+
from pathlib import Path
10+
from tempfile import NamedTemporaryFile
911

10-
# import pytest
12+
import pytest
13+
14+
from aleph_client.chains.common import get_fallback_private_key
15+
import aleph_client.chains.ethereum as ethereum
16+
import aleph_client.chains.sol as solana
17+
import aleph_client.chains.tezos as tezos
18+
19+
@pytest.fixture
20+
def fallback_private_key() -> bytes:
21+
with NamedTemporaryFile() as private_key_file:
22+
yield get_fallback_private_key(path=Path(private_key_file.name))
23+
24+
25+
@pytest.fixture
26+
def ethereum_account() -> ethereum.ETHAccount:
27+
with NamedTemporaryFile(delete=False) as private_key_file:
28+
private_key_file.close()
29+
yield ethereum.get_fallback_account(path=Path(private_key_file.name))
30+
31+
32+
@pytest.fixture
33+
def solana_account() -> solana.SOLAccount:
34+
with NamedTemporaryFile(delete=False) as private_key_file:
35+
private_key_file.close()
36+
yield solana.get_fallback_account(path=Path(private_key_file.name))
37+
38+
39+
@pytest.fixture
40+
def tezos_account() -> tezos.TezosAccount:
41+
with NamedTemporaryFile(delete=False) as private_key_file:
42+
private_key_file.close()
43+
yield tezos.get_fallback_account(path=Path(private_key_file.name))

tests/unit/test_asynchronous.py

Lines changed: 16 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import os
21
from unittest.mock import MagicMock, patch, AsyncMock
32

43
import pytest as pytest
@@ -10,6 +9,8 @@
109
ForgetMessage,
1110
)
1211

12+
from aleph_client.types import StorageEnum, MessageStatus
13+
1314
from aleph_client.asynchronous import (
1415
create_post,
1516
_get_fallback_session,
@@ -18,10 +19,6 @@
1819
create_program,
1920
forget,
2021
)
21-
from aleph_client.chains.common import get_fallback_private_key, delete_private_key_file
22-
from aleph_client.chains.ethereum import ETHAccount
23-
from aleph_client.conf import settings
24-
from aleph_client.types import StorageEnum, MessageStatus
2522

2623

2724
def new_mock_session_with_post_success():
@@ -41,21 +38,15 @@ def new_mock_session_with_post_success():
4138

4239

4340
@pytest.mark.asyncio
44-
async def test_create_post():
41+
async def test_create_post(ethereum_account):
4542
_get_fallback_session.cache_clear()
4643

47-
if os.path.exists(settings.PRIVATE_KEY_FILE):
48-
delete_private_key_file()
49-
50-
private_key = get_fallback_private_key()
51-
account: ETHAccount = ETHAccount(private_key=private_key)
52-
5344
content = {"Hello": "World"}
5445

5546
mock_session = new_mock_session_with_post_success()
5647

5748
post_message, message_status = await create_post(
58-
account=account,
49+
account=ethereum_account,
5950
post_content=content,
6051
post_type="TEST",
6152
channel="TEST",
@@ -70,29 +61,23 @@ async def test_create_post():
7061

7162

7263
@pytest.mark.asyncio
73-
async def test_create_aggregate():
64+
async def test_create_aggregate(ethereum_account):
7465
_get_fallback_session.cache_clear()
7566

76-
if os.path.exists(settings.PRIVATE_KEY_FILE):
77-
delete_private_key_file()
78-
79-
private_key = get_fallback_private_key()
80-
account: ETHAccount = ETHAccount(private_key=private_key)
81-
8267
content = {"Hello": "World"}
8368

8469
mock_session = new_mock_session_with_post_success()
8570

8671
_ = await create_aggregate(
87-
account=account,
72+
account=ethereum_account,
8873
key="hello",
8974
content=content,
9075
channel="TEST",
9176
session=mock_session,
9277
)
9378

9479
aggregate_message, message_status = await create_aggregate(
95-
account=account,
80+
account=ethereum_account,
9681
key="hello",
9782
content="world",
9883
channel="TEST",
@@ -105,23 +90,17 @@ async def test_create_aggregate():
10590

10691

10792
@pytest.mark.asyncio
108-
async def test_create_store():
93+
async def test_create_store(ethereum_account):
10994
_get_fallback_session.cache_clear()
11095

111-
if os.path.exists(settings.PRIVATE_KEY_FILE):
112-
delete_private_key_file()
113-
114-
private_key = get_fallback_private_key()
115-
account: ETHAccount = ETHAccount(private_key=private_key)
116-
11796
mock_session = new_mock_session_with_post_success()
11897

11998
mock_ipfs_push_file = AsyncMock()
12099
mock_ipfs_push_file.return_value = "QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy"
121100

122101
with patch("aleph_client.asynchronous.ipfs_push_file", mock_ipfs_push_file):
123102
_ = await create_store(
124-
account=account,
103+
account=ethereum_account,
125104
file_content=b"HELLO",
126105
channel="TEST",
127106
storage_engine=StorageEnum.ipfs,
@@ -130,7 +109,7 @@ async def test_create_store():
130109
)
131110

132111
_ = await create_store(
133-
account=account,
112+
account=ethereum_account,
134113
file_hash="QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy",
135114
channel="TEST",
136115
storage_engine=StorageEnum.ipfs,
@@ -144,8 +123,9 @@ async def test_create_store():
144123
)
145124

146125
with patch("aleph_client.asynchronous.storage_push_file", mock_storage_push_file):
126+
147127
store_message, message_status = await create_store(
148-
account=account,
128+
account=ethereum_account,
149129
file_content=b"HELLO",
150130
channel="TEST",
151131
storage_engine=StorageEnum.storage,
@@ -158,19 +138,13 @@ async def test_create_store():
158138

159139

160140
@pytest.mark.asyncio
161-
async def test_create_program():
141+
async def test_create_program(ethereum_account):
162142
_get_fallback_session.cache_clear()
163143

164-
if os.path.exists(settings.PRIVATE_KEY_FILE):
165-
delete_private_key_file()
166-
167-
private_key = get_fallback_private_key()
168-
account: ETHAccount = ETHAccount(private_key=private_key)
169-
170144
mock_session = new_mock_session_with_post_success()
171145

172146
program_message, message_status = await create_program(
173-
account=account,
147+
account=ethereum_account,
174148
program_ref="FAKE-HASH",
175149
entrypoint="main:app",
176150
runtime="FAKE-HASH",
@@ -184,19 +158,13 @@ async def test_create_program():
184158

185159

186160
@pytest.mark.asyncio
187-
async def test_forget():
161+
async def test_forget(ethereum_account):
188162
_get_fallback_session.cache_clear()
189163

190-
if os.path.exists(settings.PRIVATE_KEY_FILE):
191-
delete_private_key_file()
192-
193-
private_key = get_fallback_private_key()
194-
account: ETHAccount = ETHAccount(private_key=private_key)
195-
196164
mock_session = new_mock_session_with_post_success()
197165

198166
forget_message, message_status = await forget(
199-
account=account,
167+
account=ethereum_account,
200168
hashes=["QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy"],
201169
reason="GDPR",
202170
channel="TEST",

tests/unit/test_chain_ethereum.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
from pathlib import Path
2+
from tempfile import NamedTemporaryFile
3+
14
import pytest
25
from dataclasses import dataclass, asdict
36

4-
from aleph_client.chains.common import delete_private_key_file
57
from aleph_client.chains.ethereum import ETHAccount, get_fallback_account
68

79

@@ -14,17 +16,16 @@ class Message:
1416

1517

1618
def test_get_fallback_account():
17-
delete_private_key_file()
18-
account: ETHAccount = get_fallback_account()
19-
20-
assert account.CHAIN == "ETH"
21-
assert account.CURVE == "secp256k1"
22-
assert account._account.address
19+
with NamedTemporaryFile() as private_key_file:
20+
account = get_fallback_account(path=Path(private_key_file.name))
21+
assert account.CHAIN == "ETH"
22+
assert account.CURVE == "secp256k1"
23+
assert account._account.address
2324

2425

2526
@pytest.mark.asyncio
26-
async def test_ETHAccount():
27-
account: ETHAccount = get_fallback_account()
27+
async def test_ETHAccount(ethereum_account):
28+
account = ethereum_account
2829

2930
message = Message("ETH", account.get_address(), "SomeType", "ItemHash")
3031
signed = await account.sign_message(asdict(message))
@@ -42,8 +43,8 @@ async def test_ETHAccount():
4243

4344

4445
@pytest.mark.asyncio
45-
async def test_decrypt_secp256k1():
46-
account: ETHAccount = get_fallback_account()
46+
async def test_decrypt_secp256k1(ethereum_account):
47+
account = ethereum_account
4748

4849
assert account.CURVE == "secp256k1"
4950
content = b"SomeContent"

0 commit comments

Comments
 (0)