11import boto3
2- from asn1crypto import core
2+ from cryptography .hazmat .primitives import serialization
3+ from cryptography .hazmat .primitives .asymmetric .utils import decode_dss_signature
34from eth_account .messages import encode_typed_data , _hash_eip191_message
5+ from eth_keys .backends .native .ecdsa import N as SECP256K1_N
46from eth_keys .datatypes import Signature
57from eth_utils import keccak , to_hex
68from hyperliquid .exchange import Exchange
79from hyperliquid .utils .constants import TESTNET_API_URL , MAINNET_API_URL
810from hyperliquid .utils .signing import get_timestamp_ms , action_hash , construct_phantom_agent , l1_payload
911from loguru import logger
1012
13+ from config import Config
14+
15+ SECP256K1_N_HALF = SECP256K1_N // 2
16+
1117
1218class KMSSigner :
13- def __init__ (self , key_id , aws_region_name , use_testnet ):
19+ def __init__ (self , config : Config ):
20+ use_testnet = config .hyperliquid .use_testnet
1421 url = TESTNET_API_URL if use_testnet else MAINNET_API_URL
1522 self .oracle_publisher_exchange : Exchange = Exchange (wallet = None , base_url = url )
23+ self .client = self ._init_client (config )
1624
17- self .key_id = key_id
18- self .client = boto3 .client ("kms" , region_name = aws_region_name )
1925 # Fetch public key once so we can derive address and check recovery id
20- pub_der = self .client .get_public_key (KeyId = key_id )["PublicKey" ]
21-
22- from cryptography .hazmat .primitives import serialization
23- pub = serialization .load_der_public_key (pub_der )
26+ key_path = config .kms .key_path
27+ self .key_id = open (key_path , "r" ).read ().strip ()
28+ self .pubkey_der = self .client .get_public_key (KeyId = self .key_id )["PublicKey" ]
29+ # Construct eth address to log
30+ pub = serialization .load_der_public_key (self .pubkey_der )
2431 numbers = pub .public_numbers ()
2532 x = numbers .x .to_bytes (32 , "big" )
2633 y = numbers .y .to_bytes (32 , "big" )
@@ -29,6 +36,22 @@ def __init__(self, key_id, aws_region_name, use_testnet):
2936 self .address = "0x" + keccak (uncompressed [1 :])[- 20 :].hex ()
3037 logger .info ("KMSSigner address: {}" , self .address )
3138
39+ def _init_client (self , config ):
40+ aws_region_name = config .kms .aws_region_name
41+ access_key_id_path = config .kms .access_key_id_path
42+ access_key_id = open (access_key_id_path , "r" ).read ().strip ()
43+ secret_access_key_path = config .kms .secret_access_key_path
44+ secret_access_key = open (secret_access_key_path , "r" ).read ().strip ()
45+
46+ return boto3 .client (
47+ "kms" ,
48+ region_name = aws_region_name ,
49+ aws_access_key_id = access_key_id ,
50+ aws_secret_access_key = secret_access_key ,
51+ # can specify an endpoint for e.g. LocalStack
52+ # endpoint_url="http://localhost:4566"
53+ )
54+
3255 def set_oracle (self , dex , oracle_pxs , all_mark_pxs , external_perp_pxs ):
3356 timestamp = get_timestamp_ms ()
3457 oracle_pxs_wire = sorted (list (oracle_pxs .items ()))
@@ -60,34 +83,39 @@ def sign_l1_action(self, action, nonce, is_mainnet):
6083 data = l1_payload (phantom_agent )
6184 structured_data = encode_typed_data (full_message = data )
6285 message_hash = _hash_eip191_message (structured_data )
63- signed = self .sign_message (message_hash )
64- return {"r" : to_hex (signed ["r" ]), "s" : to_hex (signed ["s" ]), "v" : signed ["v" ]}
86+ return self .sign_message (message_hash )
6587
66- def sign_message (self , message_hash : bytes ):
88+ def sign_message (self , message_hash : bytes ) -> dict :
89+ # Send message hash to KMS for signing
6790 resp = self .client .sign (
6891 KeyId = self .key_id ,
6992 Message = message_hash ,
7093 MessageType = "DIGEST" ,
7194 SigningAlgorithm = "ECDSA_SHA_256" , # required for secp256k1
7295 )
73- der_sig = resp ["Signature" ]
74-
75- seq = core .Sequence .load (der_sig )
76- r = int (seq [0 ].native )
77- s = int (seq [1 ].native )
78-
79- for recovery_id in (0 , 1 ):
80- candidate = Signature (vrs = (recovery_id , r , s ))
81- pubkey = candidate .recover_public_key_from_msg_hash (message_hash )
82- if pubkey .to_bytes () == self .public_key_bytes :
83- v = recovery_id + 27
84- break
85- else :
86- raise ValueError ("Failed to determine recovery id" )
87-
88- return {
89- "r" : r ,
90- "s" : s ,
91- "v" : v ,
92- "signature" : Signature (vrs = (v , r , s )).to_bytes ().hex (),
93- }
96+ kms_signature = resp ["Signature" ]
97+ # Decode the KMS DER signature -> (r, s)
98+ r , s = decode_dss_signature (kms_signature )
99+ # Ethereum requires low-s form
100+ if s > SECP256K1_N_HALF :
101+ s = SECP256K1_N - s
102+ # Parse KMS public key into uncompressed secp256k1 bytes
103+ # TODO: Pull this into init
104+ pubkey = serialization .load_der_public_key (self .pubkey_der )
105+ pubkey_bytes = pubkey .public_bytes (
106+ serialization .Encoding .X962 ,
107+ serialization .PublicFormat .UncompressedPoint ,
108+ )
109+ # Strip leading 0x04 (uncompressed point indicator)
110+ raw_pubkey_bytes = pubkey_bytes [1 :]
111+ # Try both recovery ids
112+ for v in (0 , 1 ):
113+ sig_obj = Signature (vrs = (v , r , s ))
114+ recovered_pub = sig_obj .recover_public_key_from_msg_hash (message_hash )
115+ if recovered_pub .to_bytes () == raw_pubkey_bytes :
116+ return {
117+ "r" : to_hex (r ),
118+ "s" : to_hex (s ),
119+ "v" : v + 27 ,
120+ }
121+ raise ValueError ("Could not recover public key; signature mismatch" )
0 commit comments