Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions accounts/checkers.go
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ func checkSend(ctx context.Context, chainParams *chaincfg.Params,
if len(invoice) > 0 {
payReq, err := zpay32.Decode(invoice, chainParams)
if err != nil {
return fmt.Errorf("error decoding pay req: %v", err)
return fmt.Errorf("error decoding pay req: %w", err)
}

if payReq.MilliSat != nil && *payReq.MilliSat > sendAmt {
Expand All @@ -546,7 +546,7 @@ func checkSend(ctx context.Context, chainParams *chaincfg.Params,

err = service.CheckBalance(acct.ID, sendAmt)
if err != nil {
return fmt.Errorf("error validating account balance: %v", err)
return fmt.Errorf("error validating account balance: %w", err)
}

return nil
Expand Down Expand Up @@ -609,7 +609,7 @@ func checkSendToRoute(ctx context.Context, service Service,

err = service.CheckBalance(acct.ID, sendAmt)
if err != nil {
return fmt.Errorf("error validating account balance: %v", err)
return fmt.Errorf("error validating account balance: %w", err)
}

return nil
Expand Down
10 changes: 5 additions & 5 deletions accounts/checkers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ type mockService struct {
acctBalanceMsat lnwire.MilliSatoshi

trackedInvoices map[lntypes.Hash]AccountID
trackedPayments map[lntypes.Hash]*PaymentEntry
trackedPayments AccountPayments
}

func newMockService() *mockService {
return &mockService{
acctBalanceMsat: 0,
trackedInvoices: make(map[lntypes.Hash]AccountID),
trackedPayments: make(map[lntypes.Hash]*PaymentEntry),
trackedPayments: make(AccountPayments),
}
}

Expand All @@ -68,7 +68,7 @@ func (m *mockService) AssociateInvoice(id AccountID, hash lntypes.Hash) error {
return nil
}

func (m *mockService) TrackPayment(id AccountID, hash lntypes.Hash,
func (m *mockService) TrackPayment(_ AccountID, hash lntypes.Hash,
amt lnwire.MilliSatoshi) error {

m.trackedPayments[hash] = &PaymentEntry{
Expand Down Expand Up @@ -403,8 +403,8 @@ func TestAccountCheckers(t *testing.T) {
acct := &OffChainBalanceAccount{
ID: testID,
Type: TypeInitialBalance,
Invoices: make(map[lntypes.Hash]struct{}),
Payments: make(map[lntypes.Hash]*PaymentEntry),
Invoices: make(AccountInvoices),
Payments: make(AccountPayments),
}
ctx := AddToContext(
context.Background(), KeyAccount, acct,
Expand Down
2 changes: 1 addition & 1 deletion accounts/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ func parseRPCMessage(msg *lnrpc.RPCMessage) (proto.Message, error) {
// No, it's a normal message.
parsedMsg, err := mid.ParseProtobuf(msg.TypeName, msg.Serialized)
if err != nil {
return nil, fmt.Errorf("error parsing proto of type %v: %v",
return nil, fmt.Errorf("error parsing proto of type %v: %w",
msg.TypeName, err)
}

Expand Down
20 changes: 15 additions & 5 deletions accounts/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func ParseAccountID(idStr string) (*AccountID, error) {

idBytes, err := hex.DecodeString(idStr)
if err != nil {
return nil, fmt.Errorf("error decoding account ID: %v", err)
return nil, fmt.Errorf("error decoding account ID: %w", err)
}

var id AccountID
Expand All @@ -67,6 +67,12 @@ type PaymentEntry struct {
FullAmount lnwire.MilliSatoshi
}

// AccountInvoices is the set of invoices that are associated with an account.
type AccountInvoices map[lntypes.Hash]struct{}

// AccountPayments is the set of payments that are associated with an account.
type AccountPayments map[lntypes.Hash]*PaymentEntry

// OffChainBalanceAccount holds all information that is needed to keep track of
// a user's off-chain account balance. This balance can only be spent by paying
// invoices.
Expand Down Expand Up @@ -99,11 +105,15 @@ type OffChainBalanceAccount struct {

// Invoices is a list of all invoices that are associated with the
// account.
Invoices map[lntypes.Hash]struct{}
Invoices AccountInvoices

// Payments is a list of all payments that are associated with the
// account and the last status we were aware of.
Payments map[lntypes.Hash]*PaymentEntry
Payments AccountPayments

// Label is an optional label that can be set for the account. If it is
// not empty then it must be unique.
Label string
}

// HasExpired returns true if the account has an expiration date set and that
Expand Down Expand Up @@ -180,8 +190,8 @@ var (
type Store interface {
// NewAccount creates a new OffChainBalanceAccount with the given
// balance and a randomly chosen ID.
NewAccount(balance lnwire.MilliSatoshi,
expirationDate time.Time) (*OffChainBalanceAccount, error)
NewAccount(balance lnwire.MilliSatoshi, expirationDate time.Time,
label string) (*OffChainBalanceAccount, error)

// UpdateAccount writes an account to the database, overwriting the
// existing one if it exists.
Expand Down
101 changes: 80 additions & 21 deletions accounts/rpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ func (s *RPCServer) CreateAccount(ctx context.Context,
req *litrpc.CreateAccountRequest) (*litrpc.CreateAccountResponse,
error) {

log.Infof("[createaccount] balance=%d, expiration=%d",
req.AccountBalance, req.ExpirationDate)
log.Infof("[createaccount] label=%v, balance=%d, expiration=%d",
req.Label, req.AccountBalance, req.ExpirationDate)

var (
balanceMsat lnwire.MilliSatoshi
Expand All @@ -70,9 +70,11 @@ func (s *RPCServer) CreateAccount(ctx context.Context,
balanceMsat = lnwire.NewMSatFromSatoshis(balance)

// Create the actual account in the macaroon account store.
account, err := s.service.NewAccount(balanceMsat, expirationDate)
account, err := s.service.NewAccount(
balanceMsat, expirationDate, req.Label,
)
if err != nil {
return nil, fmt.Errorf("unable to create account: %v", err)
return nil, fmt.Errorf("unable to create account: %w", err)
}

var rootKeyIdSuffix [4]byte
Expand All @@ -91,12 +93,12 @@ func (s *RPCServer) CreateAccount(ctx context.Context,
}},
})
if err != nil {
return nil, fmt.Errorf("error baking account macaroon: %v", err)
return nil, fmt.Errorf("error baking account macaroon: %w", err)
}

macBytes, err := hex.DecodeString(macHex)
if err != nil {
return nil, fmt.Errorf("error decoding account macaroon: %v",
return nil, fmt.Errorf("error decoding account macaroon: %w",
err)
}

Expand All @@ -110,16 +112,13 @@ func (s *RPCServer) CreateAccount(ctx context.Context,
func (s *RPCServer) UpdateAccount(_ context.Context,
req *litrpc.UpdateAccountRequest) (*litrpc.Account, error) {

log.Infof("[updateaccount] id=%s, balance=%d, expiration=%d", req.Id,
req.AccountBalance, req.ExpirationDate)
log.Infof("[updateaccount] id=%s, label=%v, balance=%d, expiration=%d",
req.Id, req.Label, req.AccountBalance, req.ExpirationDate)

// Account ID is always a hex string, convert it to our account ID type.
var accountID AccountID
decoded, err := hex.DecodeString(req.Id)
accountID, err := s.findAccount(req.Id, req.Label)
if err != nil {
return nil, fmt.Errorf("error decoding account ID: %v", err)
return nil, err
}
copy(accountID[:], decoded)

// Ask the service to update the account.
account, err := s.service.UpdateAccount(
Expand All @@ -142,7 +141,7 @@ func (s *RPCServer) ListAccounts(context.Context,
// Retrieve all accounts from the macaroon account store.
accts, err := s.service.Accounts()
if err != nil {
return nil, fmt.Errorf("unable to list accounts: %v", err)
return nil, fmt.Errorf("unable to list accounts: %w", err)
}

// Map the response into the proper response type and return it.
Expand All @@ -158,30 +157,89 @@ func (s *RPCServer) ListAccounts(context.Context,
}, nil
}

// AccountInfo returns the account with the given ID or label.
func (s *RPCServer) AccountInfo(_ context.Context,
req *litrpc.AccountInfoRequest) (*litrpc.Account, error) {

log.Infof("[accountinfo] id=%v, label=%v", req.Id, req.Label)

accountID, err := s.findAccount(req.Id, req.Label)
if err != nil {
return nil, err
}

dbAccount, err := s.service.Account(accountID)
if err != nil {
return nil, fmt.Errorf("error retrieving account: %w", err)
}

return marshalAccount(dbAccount), nil
}

// RemoveAccount removes the given account from the account database.
func (s *RPCServer) RemoveAccount(_ context.Context,
req *litrpc.RemoveAccountRequest) (*litrpc.RemoveAccountResponse,
error) {

log.Infof("[removeaccount] id=%v", req.Id)
log.Infof("[removeaccount] id=%v, label=%v", req.Id, req.Label)

// Account ID is always a hex string, convert it to our account ID type.
var accountID AccountID
decoded, err := hex.DecodeString(req.Id)
accountID, err := s.findAccount(req.Id, req.Label)
if err != nil {
return nil, fmt.Errorf("error decoding account ID: %v", err)
return nil, err
}
copy(accountID[:], decoded)

// Now remove the account.
err = s.service.RemoveAccount(accountID)
if err != nil {
return nil, fmt.Errorf("error removing account: %v", err)
return nil, fmt.Errorf("error removing account: %w", err)
}

return &litrpc.RemoveAccountResponse{}, nil
}

// findAccount finds an account by its ID or label.
func (s *RPCServer) findAccount(id string, label string) (AccountID, error) {
switch {
case id != "" && label != "":
return AccountID{}, fmt.Errorf("either account ID or label " +
"must be specified, not both")

case id != "":
// Account ID is always a hex string, convert it to our account
// ID type.
var accountID AccountID
decoded, err := hex.DecodeString(id)
if err != nil {
return AccountID{}, fmt.Errorf("error decoding "+
"account ID: %w", err)
}
copy(accountID[:], decoded)

return accountID, nil

case label != "":
// We need to find the account by its label.
accounts, err := s.service.Accounts()
if err != nil {
return AccountID{}, fmt.Errorf("unable to list "+
"accounts: %w", err)
}

for _, acct := range accounts {
if acct.Label == label {
return acct.ID, nil
}
}

return AccountID{}, fmt.Errorf("unable to find account "+
"with label '%s'", label)

default:
return AccountID{}, fmt.Errorf("either account ID or label " +
"must be specified")
}
}

// marshalAccount converts an account into its RPC counterpart.
func marshalAccount(acct *OffChainBalanceAccount) *litrpc.Account {
rpcAccount := &litrpc.Account{
Expand All @@ -196,6 +254,7 @@ func marshalAccount(acct *OffChainBalanceAccount) *litrpc.Account {
Payments: make(
[]*litrpc.AccountPayment, 0, len(acct.Payments),
),
Label: acct.Label,
}

for hash := range acct.Invoices {
Expand Down
Loading