From a49422f36ffb243e8a40a5751c9ccdeb5cbf4034 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Fri, 15 Sep 2023 13:25:28 +0200 Subject: [PATCH 1/8] accounts: disallow requests after critical errors --- accounts/checkers_test.go | 4 + accounts/interceptor.go | 11 +++ accounts/interface.go | 9 +++ accounts/rpcserver.go | 20 +++++ accounts/service.go | 163 +++++++++++++++++++++++++++++++++----- accounts/service_test.go | 14 +++- 6 files changed, 201 insertions(+), 20 deletions(-) diff --git a/accounts/checkers_test.go b/accounts/checkers_test.go index 481609c91..64ce53a0b 100644 --- a/accounts/checkers_test.go +++ b/accounts/checkers_test.go @@ -85,6 +85,10 @@ func (m *mockService) RemovePayment(hash lntypes.Hash) error { return nil } +func (*mockService) IsRunning() bool { + return true +} + var _ Service = (*mockService)(nil) // TestAccountChecker makes sure all round trip checkers can be instantiated diff --git a/accounts/interceptor.go b/accounts/interceptor.go index d6dff8dfe..836c4a2fa 100644 --- a/accounts/interceptor.go +++ b/accounts/interceptor.go @@ -52,6 +52,17 @@ func (s *InterceptorService) Intercept(ctx context.Context, s.requestMtx.Lock() defer s.requestMtx.Unlock() + // If the account service is not running, we reject all requests. + // Note that this is by no means a guarantee that the account service + // will be running throughout processing the request, but at least we + // can stop requests early if the service was already disabled when the + // request came in. + if !s.IsRunning() { + return mid.RPCErrString( + req, "the account service has been stopped", + ) + } + mac := &macaroon.Macaroon{} err := mac.UnmarshalBinary(req.RawMacaroon) if err != nil { diff --git a/accounts/interface.go b/accounts/interface.go index 879d9a51a..efbefc442 100644 --- a/accounts/interface.go +++ b/accounts/interface.go @@ -160,6 +160,12 @@ var ( ErrNotSupportedWithAccounts = errors.New("this RPC call is not " + "supported with restricted account macaroons") + // ErrAccountServiceDisabled is the error that is returned when the + // account service has been disabled due to an error being thrown + // in the service that cannot be recovered from. + ErrAccountServiceDisabled = errors.New("the account service has been " + + "stopped") + // MacaroonPermissions are the permissions required for an account // macaroon. MacaroonPermissions = []bakery.Op{{ @@ -240,4 +246,7 @@ type Service interface { // longer needs to be tracked. The payment is certain to never succeed, // so we never need to debit the amount from the account. RemovePayment(hash lntypes.Hash) error + + // IsRunning returns true if the service can be used. + IsRunning() bool } diff --git a/accounts/rpcserver.go b/accounts/rpcserver.go index 22135f634..7556a95d8 100644 --- a/accounts/rpcserver.go +++ b/accounts/rpcserver.go @@ -53,6 +53,10 @@ func (s *RPCServer) CreateAccount(ctx context.Context, log.Infof("[createaccount] label=%v, balance=%d, expiration=%d", req.Label, req.AccountBalance, req.ExpirationDate) + if !s.service.IsRunning() { + return nil, ErrAccountServiceDisabled + } + var ( balanceMsat lnwire.MilliSatoshi expirationDate time.Time @@ -115,6 +119,10 @@ func (s *RPCServer) UpdateAccount(_ context.Context, log.Infof("[updateaccount] id=%s, label=%v, balance=%d, expiration=%d", req.Id, req.Label, req.AccountBalance, req.ExpirationDate) + if !s.service.IsRunning() { + return nil, ErrAccountServiceDisabled + } + accountID, err := s.findAccount(req.Id, req.Label) if err != nil { return nil, err @@ -138,6 +146,10 @@ func (s *RPCServer) ListAccounts(context.Context, log.Info("[listaccounts]") + if !s.service.IsRunning() { + return nil, ErrAccountServiceDisabled + } + // Retrieve all accounts from the macaroon account store. accts, err := s.service.Accounts() if err != nil { @@ -163,6 +175,10 @@ func (s *RPCServer) AccountInfo(_ context.Context, log.Infof("[accountinfo] id=%v, label=%v", req.Id, req.Label) + if !s.service.IsRunning() { + return nil, ErrAccountServiceDisabled + } + accountID, err := s.findAccount(req.Id, req.Label) if err != nil { return nil, err @@ -183,6 +199,10 @@ func (s *RPCServer) RemoveAccount(_ context.Context, log.Infof("[removeaccount] id=%v, label=%v", req.Id, req.Label) + if !s.service.IsRunning() { + return nil, ErrAccountServiceDisabled + } + accountID, err := s.findAccount(req.Id, req.Label) if err != nil { return nil, err diff --git a/accounts/service.go b/accounts/service.go index a229d831c..c4f0ab5e9 100644 --- a/accounts/service.go +++ b/accounts/service.go @@ -63,6 +63,8 @@ type InterceptorService struct { mainErrChan chan<- error wg sync.WaitGroup quit chan struct{} + + isEnabled bool } // NewService returns a service backed by the macaroon Bolt DB stored in the @@ -83,6 +85,7 @@ func NewService(dir string, errChan chan<- error) (*InterceptorService, error) { pendingPayments: make(map[lntypes.Hash]*trackedPayment), mainErrChan: errChan, quit: make(chan struct{}), + isEnabled: false, }, nil } @@ -93,12 +96,15 @@ func (s *InterceptorService) Start(lightningClient lndclient.LightningClient, s.routerClient = routerClient s.checkers = NewAccountChecker(s, params) + s.isEnabled = true + // Let's first fill our cache that maps invoices to accounts, which // allows us to credit an account easily once an invoice is settled. We // also track payments that aren't in a final state yet. existingAccounts, err := s.store.Accounts() if err != nil { - return fmt.Errorf("error querying existing accounts: %w", err) + return s.disableAndErrorf("error querying existing "+ + "accounts: %w", err) } for _, acct := range existingAccounts { acct := acct @@ -116,8 +122,8 @@ func (s *InterceptorService) Start(lightningClient lndclient.LightningClient, acct.ID, hash, entry.FullAmount, ) if err != nil { - return fmt.Errorf("error tracking "+ - "payment: %w", err) + return s.disableAndErrorf("error "+ + "tracking payment: %w", err) } } } @@ -146,8 +152,8 @@ func (s *InterceptorService) Start(lightningClient lndclient.LightningClient, s.currentSettleIndex = 0 default: - return fmt.Errorf("error determining last invoice indexes: %w", - err) + return s.disableAndErrorf("error determining last invoice "+ + "indexes: %w", err) } invoiceChan, invoiceErrChan, err := lightningClient.SubscribeInvoices( @@ -157,7 +163,7 @@ func (s *InterceptorService) Start(lightningClient lndclient.LightningClient, }, ) if err != nil { - return fmt.Errorf("error subscribing invoices: %w", err) + return s.disableAndErrorf("error subscribing invoices: %w", err) } s.wg.Add(1) @@ -187,8 +193,11 @@ func (s *InterceptorService) Start(lightningClient lndclient.LightningClient, } case err := <-invoiceErrChan: - log.Errorf("Error in invoice subscription: %v", - err) + // If the invoice subscription errors out, we + // stop the service as we won't be able to + // process invoices. + err = s.disableAndErrorf("Error in invoice "+ + "subscription: %w", err) select { case s.mainErrChan <- err: @@ -219,6 +228,49 @@ func (s *InterceptorService) Stop() error { return s.store.Close() } +// IsRunning checks if the account service is running, and returns a boolean +// indicating whether it is running or not. +func (s *InterceptorService) IsRunning() bool { + s.RLock() + defer s.RUnlock() + + return s.isEnabled +} + +// isRunningUnsafe checks if the account service is running, and returns a +// boolean indicating whether it is running or not +// +// NOTE: The store lock MUST be held as either a read or write lock when calling +// this method. +func (s *InterceptorService) isRunningUnsafe() bool { + return s.isEnabled +} + +// disable disables the account service, and marks the service as not running. +// The function acquires the store write lock before disabling the service. +// The function returns an error with the given format and arguments. +func (s *InterceptorService) disableAndErrorf(format string, a ...any) error { + s.Lock() + defer s.Unlock() + + s.isEnabled = false + + return fmt.Errorf(format, a...) +} + +// disableAndErrorfUnsafe disables the account service, and marks the service as +// not running. The function returns an error with the given format and +// arguments. +// +// NOTE: The store lock MUST be held when calling this method. +func (s *InterceptorService) disableAndErrorfUnsafe(format string, + a ...any) error { + + s.isEnabled = false + + return fmt.Errorf(format, a...) +} + // NewAccount creates a new OffChainBalanceAccount with the given balance and a // randomly chosen ID. func (s *InterceptorService) NewAccount(balance lnwire.MilliSatoshi, @@ -239,6 +291,14 @@ func (s *InterceptorService) UpdateAccount(accountID AccountID, accountBalance, s.Lock() defer s.Unlock() + // As this function updates account balances, we require that the + // service is running before we execute it. + if s.isRunningUnsafe() { + // This case can only happen if the service is disabled while + // we we're processing a request. + return nil, ErrAccountServiceDisabled + } + account, err := s.store.Account(accountID) if err != nil { return nil, fmt.Errorf("error fetching account: %w", err) @@ -364,10 +424,26 @@ func (s *InterceptorService) AssociateInvoice(id AccountID, // invoiceUpdate credits the account an invoice was registered with, in case the // invoice was settled. +// +// NOTE: Any code that errors in this function MUST call disableAndErrorfUnsafe +// while the store lock is held to ensure that the service is disabled under +// the same lock. Else we risk that other threads will try to update invoices +// while the service should be disabled, which could lead to us missing invoice +// updates on next startup. func (s *InterceptorService) invoiceUpdate(invoice *lndclient.Invoice) error { s.Lock() defer s.Unlock() + // As this function updates account balances, and is called from the + // invoice subscription, we ensure that the service is running before we + // execute it. + if !s.isRunningUnsafe() { + // We will process the invoice update on next startup instead, + // once the error that caused the service to stop has been + // resolved. + return ErrAccountServiceDisabled + } + // We update our indexes each time we get a new invoice from our // subscription. This might be a bit inefficient but makes sure we don't // miss an update. @@ -386,7 +462,9 @@ func (s *InterceptorService) invoiceUpdate(invoice *lndclient.Invoice) error { s.currentAddIndex, s.currentSettleIndex, ) if err != nil { - return err + return s.disableAndErrorfUnsafe( + "error storing last indexes: %w", err, + ) } } @@ -405,7 +483,9 @@ func (s *InterceptorService) invoiceUpdate(invoice *lndclient.Invoice) error { account, err := s.store.Account(acctID) if err != nil { - return fmt.Errorf("error fetching account: %w", err) + return s.disableAndErrorfUnsafe( + "error fetching account: %w", err, + ) } // If we get here, the current account has the invoice associated with @@ -413,7 +493,9 @@ func (s *InterceptorService) invoiceUpdate(invoice *lndclient.Invoice) error { // in the DB. account.CurrentBalance += int64(invoice.AmountPaid) if err := s.store.UpdateAccount(account); err != nil { - return fmt.Errorf("error updating account: %w", err) + return s.disableAndErrorfUnsafe( + "error updating account: %w", err, + ) } // We've now fully processed the invoice and don't need to keep it @@ -461,6 +543,15 @@ func (s *InterceptorService) TrackPayment(id AccountID, hash lntypes.Hash, return fmt.Errorf("error updating account: %w", err) } + // As this function updates account balances, we ensure that the service + // is running before we execute it. + if !s.isRunningUnsafe() { + // We will track the payment on next on next startup instead, + // once the error that caused the service to stop has been + // resolved. + return ErrAccountServiceDisabled + } + // And start the long-running TrackPayment RPC. ctxc, cancel := context.WithCancel(s.mainCtx) statusChan, errChan, err := s.routerClient.TrackPayment(ctxc, hash) @@ -516,10 +607,13 @@ func (s *InterceptorService) TrackPayment(id AccountID, hash lntypes.Hash, return } - log.Errorf("Received error from TrackPayment "+ - "RPC for payment %v: %v", hash, err) - if err != nil { + // If we error when tracking the + // payment, we stop the service. + err = s.disableAndErrorf("received "+ + "error from TrackPayment RPC "+ + "for payment %v: %w", hash, err) + select { case s.mainErrChan <- err: case <-s.mainCtx.Done(): @@ -544,6 +638,10 @@ func (s *InterceptorService) TrackPayment(id AccountID, hash lntypes.Hash, // associated with, in case it is settled. The boolean value returned indicates // whether the status was terminal or not. If it's not terminal then further // updates are expected. +// +// NOTE: Any code that errors in this function MUST call disableAndErrorfUnsafe +// while the store lock is held to ensure that the service is disabled under +// the same lock. func (s *InterceptorService) paymentUpdate(hash lntypes.Hash, status lndclient.PaymentStatus) (bool, error) { @@ -563,21 +661,40 @@ func (s *InterceptorService) paymentUpdate(hash lntypes.Hash, s.Lock() defer s.Unlock() + // As this function updates account balances, we ensure that the service + // is running before we execute it. + if !s.isRunningUnsafe() { + // We will update the payment on next startup instead, once the + // error that caused the service to stop has been resolved. + return false, ErrAccountServiceDisabled + } + pendingPayment, ok := s.pendingPayments[hash] if !ok { - return terminalState, fmt.Errorf("payment %x not mapped to "+ - "any account", hash[:]) + err := s.disableAndErrorfUnsafe("payment %x not mapped to any "+ + "account", hash[:]) + + return terminalState, err } // A failed payment can just be removed, no further action needed. if status.State == lnrpc.Payment_FAILED { - return terminalState, s.removePayment(hash, status.State) + err := s.removePayment(hash, status.State) + if err != nil { + err = s.disableAndErrorfUnsafe("error removing "+ + "payment: %w", err) + } + + return terminalState, err } // The payment went through! We now need to debit the full amount from // the account. account, err := s.store.Account(pendingPayment.accountID) if err != nil { + err = s.disableAndErrorfUnsafe("error fetching account: %w", + err) + return terminalState, err } @@ -590,13 +707,21 @@ func (s *InterceptorService) paymentUpdate(hash lntypes.Hash, FullAmount: fullAmount, } if err := s.store.UpdateAccount(account); err != nil { - return terminalState, fmt.Errorf("error updating account: %w", + err = s.disableAndErrorfUnsafe("error updating account: %w", err) + + return terminalState, err } // We've now fully processed the payment and don't need to keep it // mapped or tracked anymore. - return terminalState, s.removePayment(hash, lnrpc.Payment_SUCCEEDED) + err = s.removePayment(hash, lnrpc.Payment_SUCCEEDED) + if err != nil { + err = s.disableAndErrorfUnsafe("error removing payment: %w", + err) + } + + return terminalState, err } // RemovePayment removes a failed payment from the service because it no longer diff --git a/accounts/service_test.go b/accounts/service_test.go index 3b0b604ae..8357ef2e6 100644 --- a/accounts/service_test.go +++ b/accounts/service_test.go @@ -72,6 +72,18 @@ func (m *mockLnd) assertMainErr(t *testing.T, expectedErr error) { } } +// assertMainErrContains asserts that the main error contains the expected error +// string. +func (m *mockLnd) assertMainErrContains(t *testing.T, expectedStr string) { + select { + case err := <-m.mainErrChan: + require.ErrorContains(t, err, expectedStr) + + case <-time.After(testTimeout): + t.Fatalf("Did not get expected main err before timeout") + } +} + func (m *mockLnd) assertNoInvoiceRequest(t *testing.T) { select { case req := <-m.invoiceReq: @@ -201,7 +213,7 @@ func TestAccountService(t *testing.T) { s *InterceptorService) { lnd.assertInvoiceRequest(t, 0, 0) - lnd.assertMainErr(t, testErr) + lnd.assertMainErrContains(t, testErr.Error()) }, }, { name: "startup do not track completed payments", From 4696c47f5904749f10cd3ee1046ed2d3807926e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Sun, 24 Sep 2023 17:07:32 +0200 Subject: [PATCH 2/8] accounts: add service disabled unit tests --- accounts/service_test.go | 241 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 225 insertions(+), 16 deletions(-) diff --git a/accounts/service_test.go b/accounts/service_test.go index 8357ef2e6..7098e72b3 100644 --- a/accounts/service_test.go +++ b/accounts/service_test.go @@ -32,10 +32,12 @@ type mockLnd struct { invoiceReq chan lndclient.InvoiceSubscriptionRequest paymentReq chan lntypes.Hash - callErr error - errChan chan error - invoiceChan chan *lndclient.Invoice - paymentChans map[lntypes.Hash]chan lndclient.PaymentStatus + invoiceSubscriptionErr error + trackPaymentErr error + invoiceErrChan chan error + paymentErrChan chan error + invoiceChan chan *lndclient.Invoice + paymentChans map[lntypes.Hash]chan lndclient.PaymentStatus } func newMockLnd() *mockLnd { @@ -44,9 +46,10 @@ func newMockLnd() *mockLnd { invoiceReq: make( chan lndclient.InvoiceSubscriptionRequest, 10, ), - paymentReq: make(chan lntypes.Hash, 10), - errChan: make(chan error, 10), - invoiceChan: make(chan *lndclient.Invoice), + paymentReq: make(chan lntypes.Hash, 10), + invoiceErrChan: make(chan error, 10), + paymentErrChan: make(chan error, 10), + invoiceChan: make(chan *lndclient.Invoice), paymentChans: make( map[lntypes.Hash]chan lndclient.PaymentStatus, ), @@ -144,13 +147,13 @@ func (m *mockLnd) SubscribeInvoices(_ context.Context, req lndclient.InvoiceSubscriptionRequest) (<-chan *lndclient.Invoice, <-chan error, error) { - if m.callErr != nil { - return nil, nil, m.callErr + if m.invoiceSubscriptionErr != nil { + return nil, nil, m.invoiceSubscriptionErr } m.invoiceReq <- req - return m.invoiceChan, m.errChan, nil + return m.invoiceChan, m.invoiceErrChan, nil } // TrackPayment picks up a previously started payment and returns a payment @@ -158,14 +161,14 @@ func (m *mockLnd) SubscribeInvoices(_ context.Context, func (m *mockLnd) TrackPayment(_ context.Context, hash lntypes.Hash) (chan lndclient.PaymentStatus, chan error, error) { - if m.callErr != nil { - return nil, nil, m.callErr + if m.trackPaymentErr != nil { + return nil, nil, m.trackPaymentErr } m.paymentReq <- hash m.paymentChans[hash] = make(chan lndclient.PaymentStatus, 1) - return m.paymentChans[hash], m.errChan, nil + return m.paymentChans[hash], m.paymentErrChan, nil } // TestAccountService tests that the account service can track payments and @@ -181,15 +184,92 @@ func TestAccountService(t *testing.T) { validate func(t *testing.T, lnd *mockLnd, s *InterceptorService) }{{ - name: "startup err on tracking payment", + name: "startup err on invoice subscription", setup: func(t *testing.T, lnd *mockLnd, s *InterceptorService) { - lnd.callErr = testErr + lnd.invoiceSubscriptionErr = testErr }, startupErr: testErr.Error(), validate: func(t *testing.T, lnd *mockLnd, s *InterceptorService) { lnd.assertNoInvoiceRequest(t) + require.False(t, s.IsRunning()) + }, + }, { + name: "err on invoice update", + setup: func(t *testing.T, lnd *mockLnd, s *InterceptorService) { + acct := &OffChainBalanceAccount{ + ID: testID, + Type: TypeInitialBalance, + CurrentBalance: 1234, + Invoices: AccountInvoices{ + testHash: {}, + }, + } + + err := s.store.UpdateAccount(acct) + require.NoError(t, err) + }, + validate: func(t *testing.T, lnd *mockLnd, + s *InterceptorService) { + + // Start by closing the store. This should cause an + // error once we make an invoice update, as the service + // will fail when persisting the invoice update. + s.store.Close() + + // Ensure that the service was started successfully and + // still running though, despite the closing of the + // db store. + require.True(t, s.IsRunning()) + + // Now let's send the invoice update, which should fail. + lnd.invoiceChan <- &lndclient.Invoice{ + AddIndex: 12, + SettleIndex: 12, + Hash: testHash, + AmountPaid: 777, + State: invpkg.ContractSettled, + } + + // Ensure that the service was eventually disabled. + assertEventually(t, func() bool { + isRunning := s.IsRunning() + return isRunning == false + }) + lnd.assertMainErrContains(t, "database not open") + }, + }, { + name: "err in invoice err channel", + setup: func(t *testing.T, lnd *mockLnd, s *InterceptorService) { + acct := &OffChainBalanceAccount{ + ID: testID, + Type: TypeInitialBalance, + CurrentBalance: 1234, + Invoices: AccountInvoices{ + testHash: {}, + }, + } + + err := s.store.UpdateAccount(acct) + require.NoError(t, err) + }, + validate: func(t *testing.T, lnd *mockLnd, + s *InterceptorService) { + // Ensure that the service was started successfully. + require.True(t, s.IsRunning()) + + // Now let's send an error over the invoice error + // channel. This should disable the service. + lnd.invoiceErrChan <- testErr + + // Ensure that the service was eventually disabled. + assertEventually(t, func() bool { + isRunning := s.IsRunning() + return isRunning == false + }) + + lnd.assertMainErrContains(t, testErr.Error()) }, }, { name: "goroutine err sent on main err chan", @@ -207,7 +287,7 @@ func TestAccountService(t *testing.T) { err := s.store.UpdateAccount(acct) require.NoError(t, err) - lnd.errChan <- testErr + lnd.mainErrChan <- testErr }, validate: func(t *testing.T, lnd *mockLnd, s *InterceptorService) { @@ -239,6 +319,135 @@ func TestAccountService(t *testing.T) { lnd.assertNoPaymentRequest(t) lnd.assertInvoiceRequest(t, 0, 0) lnd.assertNoMainErr(t) + require.True(t, s.IsRunning()) + }, + }, { + name: "startup err on payment tracking", + setup: func(t *testing.T, lnd *mockLnd, s *InterceptorService) { + acct := &OffChainBalanceAccount{ + ID: testID, + Type: TypeInitialBalance, + CurrentBalance: 1234, + Invoices: AccountInvoices{ + testHash: {}, + }, + Payments: AccountPayments{ + testHash: { + Status: lnrpc.Payment_IN_FLIGHT, + FullAmount: 1234, + }, + }, + } + + err := s.store.UpdateAccount(acct) + require.NoError(t, err) + + lnd.trackPaymentErr = testErr + }, + validate: func(t *testing.T, lnd *mockLnd, + s *InterceptorService) { + + // Assert that the invoice subscription succeeded. + require.Contains(t, s.invoiceToAccount, testHash) + + // But setting up the payment tracking should have failed. + require.False(t, s.IsRunning()) + + // Finally let's assert that we didn't successfully add the + // payment to pending payment, and that lnd isn't awaiting + // the payment request. + require.NotContains(t, s.pendingPayments, testHash) + lnd.assertNoPaymentRequest(t) + }, + }, { + name: "err on payment update", + setup: func(t *testing.T, lnd *mockLnd, s *InterceptorService) { + acct := &OffChainBalanceAccount{ + ID: testID, + Type: TypeInitialBalance, + CurrentBalance: 1234, + Payments: AccountPayments{ + testHash: { + Status: lnrpc.Payment_IN_FLIGHT, + FullAmount: 1234, + }, + }, + } + + err := s.store.UpdateAccount(acct) + require.NoError(t, err) + }, + validate: func(t *testing.T, lnd *mockLnd, + s *InterceptorService) { + // Ensure that the service was started successfully, + // and lnd contains the payment request. + require.True(t, s.IsRunning()) + lnd.assertPaymentRequests(t, map[lntypes.Hash]struct{}{ + testHash: {}, + }) + + // Now let's wipe the service's pending payments. + // This will cause an error send an update over + // the payment channel, which should disable the + // service. + s.pendingPayments = make(map[lntypes.Hash]*trackedPayment) + + // Send an invalid payment over the payment chan + // which should error and disable the service + lnd.paymentChans[testHash] <- lndclient.PaymentStatus{ + State: lnrpc.Payment_SUCCEEDED, + Fee: 234, + Value: 1000, + } + + // Ensure that the service was eventually disabled. + assertEventually(t, func() bool { + isRunning := s.IsRunning() + return isRunning == false + }) + lnd.assertMainErrContains( + t, "not mapped to any account", + ) + + }, + }, { + name: "err in payment update chan", + setup: func(t *testing.T, lnd *mockLnd, s *InterceptorService) { + acct := &OffChainBalanceAccount{ + ID: testID, + Type: TypeInitialBalance, + CurrentBalance: 1234, + Payments: AccountPayments{ + testHash: { + Status: lnrpc.Payment_IN_FLIGHT, + FullAmount: 1234, + }, + }, + } + + err := s.store.UpdateAccount(acct) + require.NoError(t, err) + }, + validate: func(t *testing.T, lnd *mockLnd, + s *InterceptorService) { + // Ensure that the service was started successfully, + // and lnd contains the payment request. + require.True(t, s.IsRunning()) + lnd.assertPaymentRequests(t, map[lntypes.Hash]struct{}{ + testHash: {}, + }) + + // Now let's send an error over the payment error + // channel. This should disable the service. + lnd.paymentErrChan <- testErr + + // Ensure that the service was eventually disabled. + assertEventually(t, func() bool { + isRunning := s.IsRunning() + return isRunning == false + }) + + lnd.assertMainErrContains(t, testErr.Error()) }, }, { name: "startup track in-flight payments", From 32b2a5885f062845e32761b9542e748d1d86e149 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Fri, 15 Sep 2023 14:14:19 +0200 Subject: [PATCH 3/8] terminal: don't stop litd on account system error --- accounts/service.go | 36 +++++++++++------------------------- accounts/service_test.go | 9 +++++---- terminal.go | 15 ++++++++++++--- 3 files changed, 28 insertions(+), 32 deletions(-) diff --git a/accounts/service.go b/accounts/service.go index c4f0ab5e9..bdf671730 100644 --- a/accounts/service.go +++ b/accounts/service.go @@ -60,16 +60,18 @@ type InterceptorService struct { invoiceToAccount map[lntypes.Hash]AccountID pendingPayments map[lntypes.Hash]*trackedPayment - mainErrChan chan<- error - wg sync.WaitGroup - quit chan struct{} + mainErrCallback func(error) + wg sync.WaitGroup + quit chan struct{} isEnabled bool } // NewService returns a service backed by the macaroon Bolt DB stored in the // passed-in directory. -func NewService(dir string, errChan chan<- error) (*InterceptorService, error) { +func NewService(dir string, + errCallback func(error)) (*InterceptorService, error) { + accountStore, err := NewBoltStore(dir, DBFilename) if err != nil { return nil, err @@ -83,7 +85,7 @@ func NewService(dir string, errChan chan<- error) (*InterceptorService, error) { contextCancel: contextCancel, invoiceToAccount: make(map[lntypes.Hash]AccountID), pendingPayments: make(map[lntypes.Hash]*trackedPayment), - mainErrChan: errChan, + mainErrCallback: errCallback, quit: make(chan struct{}), isEnabled: false, }, nil @@ -184,11 +186,7 @@ func (s *InterceptorService) Start(lightningClient lndclient.LightningClient, log.Errorf("Error processing invoice "+ "update: %v", err) - select { - case s.mainErrChan <- err: - case <-s.mainCtx.Done(): - case <-s.quit: - } + s.mainErrCallback(err) return } @@ -199,11 +197,7 @@ func (s *InterceptorService) Start(lightningClient lndclient.LightningClient, err = s.disableAndErrorf("Error in invoice "+ "subscription: %w", err) - select { - case s.mainErrChan <- err: - case <-s.mainCtx.Done(): - case <-s.quit: - } + s.mainErrCallback(err) return case <-s.mainCtx.Done(): @@ -581,11 +575,7 @@ func (s *InterceptorService) TrackPayment(id AccountID, hash lntypes.Hash, hash, paymentUpdate, ) if err != nil { - select { - case s.mainErrChan <- err: - case <-s.mainCtx.Done(): - case <-s.quit: - } + s.mainErrCallback(err) return } @@ -614,11 +604,7 @@ func (s *InterceptorService) TrackPayment(id AccountID, hash lntypes.Hash, "error from TrackPayment RPC "+ "for payment %v: %w", hash, err) - select { - case s.mainErrChan <- err: - case <-s.mainCtx.Done(): - case <-s.quit: - } + s.mainErrCallback(err) } return diff --git a/accounts/service_test.go b/accounts/service_test.go index 7098e72b3..f1cf120d8 100644 --- a/accounts/service_test.go +++ b/accounts/service_test.go @@ -287,7 +287,7 @@ func TestAccountService(t *testing.T) { err := s.store.UpdateAccount(acct) require.NoError(t, err) - lnd.mainErrChan <- testErr + s.mainErrCallback(testErr) }, validate: func(t *testing.T, lnd *mockLnd, s *InterceptorService) { @@ -672,9 +672,10 @@ func TestAccountService(t *testing.T) { tt.Parallel() lndMock := newMockLnd() - service, err := NewService( - t.TempDir(), lndMock.mainErrChan, - ) + errFunc := func(err error) { + lndMock.mainErrChan <- err + } + service, err := NewService(t.TempDir(), errFunc) require.NoError(t, err) // Is a setup call required to initialize initial diff --git a/terminal.go b/terminal.go index a4ac4afdf..c141e1d6e 100644 --- a/terminal.go +++ b/terminal.go @@ -305,8 +305,14 @@ func (g *LightningTerminal) Run() error { func (g *LightningTerminal) start() error { var err error + accountServiceErrCallback := func(err error) { + log.Errorf("Error thrown in the accounts service, keeping "+ + "litd running: %v", err, + ) + } + g.accountService, err = accounts.NewService( - filepath.Dir(g.cfg.MacaroonPath), g.errQueue.ChanIn(), + filepath.Dir(g.cfg.MacaroonPath), accountServiceErrCallback, ) if err != nil { return fmt.Errorf("error creating account service: %v", err) @@ -843,9 +849,12 @@ func (g *LightningTerminal) startInternalSubServers( g.lndClient.ChainParams, ) if err != nil { - return fmt.Errorf("error starting account service: %v", - err) + log.Errorf("error starting account service: %v, disabling "+ + "account service", err) } + // Even if we error on accountService.Start, we still want to mark the + // service as started so that we can properly shut it down in the + // shutdownSubServers call. g.accountServiceStarted = true requestLogger, err := firewall.NewRequestLogger( From 666f19cee855bb58b9a29529fdd67f5c6d34846e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Wed, 20 Sep 2023 01:07:20 +0200 Subject: [PATCH 4/8] accounts: associate payments before sending them --- accounts/checkers.go | 13 +++++++++ accounts/checkers_test.go | 6 ++++ accounts/interface.go | 7 +++-- accounts/service.go | 58 +++++++++++++++++++++++++++++++++++++-- 4 files changed, 80 insertions(+), 4 deletions(-) diff --git a/accounts/checkers.go b/accounts/checkers.go index 8fdb8c9b3..e4eff8935 100644 --- a/accounts/checkers.go +++ b/accounts/checkers.go @@ -522,6 +522,7 @@ func checkSend(ctx context.Context, chainParams *chaincfg.Params, } // The invoice is optional. + var paymentHash lntypes.Hash if len(invoice) > 0 { payReq, err := zpay32.Decode(invoice, chainParams) if err != nil { @@ -531,6 +532,10 @@ func checkSend(ctx context.Context, chainParams *chaincfg.Params, if payReq.MilliSat != nil && *payReq.MilliSat > sendAmt { sendAmt = *payReq.MilliSat } + + if payReq.PaymentHash != nil { + paymentHash = *payReq.PaymentHash + } } // We also add the max fee to the amount to check. This might mean that @@ -549,6 +554,14 @@ func checkSend(ctx context.Context, chainParams *chaincfg.Params, return fmt.Errorf("error validating account balance: %w", err) } + emptyHash := lntypes.Hash{} + if paymentHash != emptyHash { + err = service.AssociatePayment(acct.ID, paymentHash, sendAmt) + if err != nil { + return fmt.Errorf("error associating payment: %w", err) + } + } + return nil } diff --git a/accounts/checkers_test.go b/accounts/checkers_test.go index 64ce53a0b..2b37b9493 100644 --- a/accounts/checkers_test.go +++ b/accounts/checkers_test.go @@ -68,6 +68,12 @@ func (m *mockService) AssociateInvoice(id AccountID, hash lntypes.Hash) error { return nil } +func (m *mockService) AssociatePayment(id AccountID, paymentHash lntypes.Hash, + amt lnwire.MilliSatoshi) error { + + return nil +} + func (m *mockService) TrackPayment(_ AccountID, hash lntypes.Hash, amt lnwire.MilliSatoshi) error { diff --git a/accounts/interface.go b/accounts/interface.go index efbefc442..8fec489e2 100644 --- a/accounts/interface.go +++ b/accounts/interface.go @@ -247,6 +247,9 @@ type Service interface { // so we never need to debit the amount from the account. RemovePayment(hash lntypes.Hash) error - // IsRunning returns true if the service can be used. - IsRunning() bool + // AssociatePayment associates a payment (hash) with the given account, + // ensuring that the payment will be tracked for a user when LiT is + // restarted. + AssociatePayment(id AccountID, paymentHash lntypes.Hash, + fullAmt lnwire.MilliSatoshi) error } diff --git a/accounts/service.go b/accounts/service.go index bdf671730..661faac8f 100644 --- a/accounts/service.go +++ b/accounts/service.go @@ -416,6 +416,40 @@ func (s *InterceptorService) AssociateInvoice(id AccountID, return s.store.UpdateAccount(account) } +// AssociatePayment associates a payment (hash) with the given account, +// ensuring that the payment will be tracked for a user when LiT is +// restarted. +func (s *InterceptorService) AssociatePayment(id AccountID, + paymentHash lntypes.Hash, fullAmt lnwire.MilliSatoshi) error { + + s.Lock() + defer s.Unlock() + + account, err := s.store.Account(id) + if err != nil { + return err + } + + // If the payment is already associated with the account, we don't need + // to associate it again. + _, ok := account.Payments[paymentHash] + if ok { + return nil + } + + // Associate the payment with the account and store it. + account.Payments[paymentHash] = &PaymentEntry{ + Status: lnrpc.Payment_UNKNOWN, + FullAmount: fullAmt, + } + + if err := s.store.UpdateAccount(account); err != nil { + return fmt.Errorf("error updating account: %w", err) + } + + return nil +} + // invoiceUpdate credits the account an invoice was registered with, in case the // invoice was settled. // @@ -527,13 +561,33 @@ func (s *InterceptorService) TrackPayment(id AccountID, hash lntypes.Hash, return nil } - // Okay, we haven't tracked this payment before. So let's now associate - // the account with it. account.Payments[hash] = &PaymentEntry{ Status: lnrpc.Payment_UNKNOWN, FullAmount: fullAmt, } + if err := s.store.UpdateAccount(account); err != nil { + if !ok { + // In the rare case that the payment isn't associated + // with an account yet, and we fail to update the + // account we will not be tracking the payment, even if + // track the service is restarted. Therefore the node + // runner needs to manually check if the payment was + // made and debit the account if that's the case. + errStr := "critical error: failed to store the " + + "payment with hash %v for user with account " + + "id %v. Manual intervention required! " + + "Verify if the payment was executed, and " + + "manually update the user account balance by " + + "subtracting the payment amount if it was" + + mainChanErr := s.disableAndErrorfUnsafe( + errStr, hash, id, + ) + + s.mainErrCallback(mainChanErr) + } + return fmt.Errorf("error updating account: %w", err) } From 6f9f3244fb4fdfcac832bed95cb73e378517b6e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Thu, 28 Sep 2023 00:30:50 +0200 Subject: [PATCH 5/8] accounts: process requests before stopping service Ensure that we don't stop the service while we're processing a request. This is especially important to ensure that we don't stop the service exactly after a user has made an rpc call to send a payment we can't know the payment hash for prior to the actual payment being sent (i.e. Keysend or SendToRoute). This is because if we stop the service after the send request has been sent to lnd, but before TrackPayment has been called, we won't be able to track the payment and debit the account. --- accounts/service.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/accounts/service.go b/accounts/service.go index 661faac8f..fb77c061b 100644 --- a/accounts/service.go +++ b/accounts/service.go @@ -214,6 +214,18 @@ func (s *InterceptorService) Start(lightningClient lndclient.LightningClient, // Stop shuts down the account service. func (s *InterceptorService) Stop() error { + // We need to lock the request mutex to ensure that we don't stop the + // service while we're processing a request. + // This is especially important to ensure that we don't stop the service + // exactly after a user has made an rpc call to send a payment we can't + // know the payment hash for prior to the actual payment being sent + // (i.e. Keysend or SendToRoute). This is because if we stop the service + // after the send request has been sent to lnd, but before TrackPayment + // has been called, we won't be able to track the payment and debit the + // account. + s.requestMtx.Lock() + defer s.requestMtx.Unlock() + s.contextCancel() close(s.quit) From ec2a7a39b49dec586f0a73eb9ffab2b3f8a7400d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Thu, 21 Sep 2023 02:01:38 +0200 Subject: [PATCH 6/8] multi: add accounts service to status manager Add the accounts service to status manager. This will allow us to query the status of the accounts service and see if it is running or not. For incoming gRPC requests to the accounts service, we also use the status manager to check if the accounts service is running or not to determine if we should let the request through or not. --- accounts/rpcserver.go | 20 -------------------- rpc_proxy.go | 13 +++++++++++++ subservers/subserver.go | 13 +++++++------ terminal.go | 12 +++++++++++- 4 files changed, 31 insertions(+), 27 deletions(-) diff --git a/accounts/rpcserver.go b/accounts/rpcserver.go index 7556a95d8..22135f634 100644 --- a/accounts/rpcserver.go +++ b/accounts/rpcserver.go @@ -53,10 +53,6 @@ func (s *RPCServer) CreateAccount(ctx context.Context, log.Infof("[createaccount] label=%v, balance=%d, expiration=%d", req.Label, req.AccountBalance, req.ExpirationDate) - if !s.service.IsRunning() { - return nil, ErrAccountServiceDisabled - } - var ( balanceMsat lnwire.MilliSatoshi expirationDate time.Time @@ -119,10 +115,6 @@ func (s *RPCServer) UpdateAccount(_ context.Context, log.Infof("[updateaccount] id=%s, label=%v, balance=%d, expiration=%d", req.Id, req.Label, req.AccountBalance, req.ExpirationDate) - if !s.service.IsRunning() { - return nil, ErrAccountServiceDisabled - } - accountID, err := s.findAccount(req.Id, req.Label) if err != nil { return nil, err @@ -146,10 +138,6 @@ func (s *RPCServer) ListAccounts(context.Context, log.Info("[listaccounts]") - if !s.service.IsRunning() { - return nil, ErrAccountServiceDisabled - } - // Retrieve all accounts from the macaroon account store. accts, err := s.service.Accounts() if err != nil { @@ -175,10 +163,6 @@ func (s *RPCServer) AccountInfo(_ context.Context, log.Infof("[accountinfo] id=%v, label=%v", req.Id, req.Label) - if !s.service.IsRunning() { - return nil, ErrAccountServiceDisabled - } - accountID, err := s.findAccount(req.Id, req.Label) if err != nil { return nil, err @@ -199,10 +183,6 @@ func (s *RPCServer) RemoveAccount(_ context.Context, log.Infof("[removeaccount] id=%v, label=%v", req.Id, req.Label) - if !s.service.IsRunning() { - return nil, ErrAccountServiceDisabled - } - accountID, err := s.findAccount(req.Id, req.Label) if err != nil { return nil, err diff --git a/rpc_proxy.go b/rpc_proxy.go index 798f8592e..f1be75934 100644 --- a/rpc_proxy.go +++ b/rpc_proxy.go @@ -623,6 +623,9 @@ func (p *rpcProxy) checkSubSystemStarted(requestURI string) error { switch { case handled: + case isAccountsReq(requestURI): + system = subservers.ACCOUNTS + case p.permsMgr.IsSubServerURI(subservers.LIT, requestURI): system = subservers.LIT @@ -694,3 +697,13 @@ func isProxyReq(uri string) bool { uri, fmt.Sprintf("/%s", litrpc.Proxy_ServiceDesc.ServiceName), ) } + +// isAccountsReq returns true if the given request is intended for the +// litrpc.Accounts service. +func isAccountsReq(uri string) bool { + return strings.HasPrefix( + uri, fmt.Sprintf( + "/%s", litrpc.Accounts_ServiceDesc.ServiceName, + ), + ) +} diff --git a/subservers/subserver.go b/subservers/subserver.go index 8605eb422..82ba22292 100644 --- a/subservers/subserver.go +++ b/subservers/subserver.go @@ -11,12 +11,13 @@ import ( ) const ( - LND string = "lnd" - LIT string = "lit" - LOOP string = "loop" - POOL string = "pool" - TAP string = "taproot-assets" - FARADAY string = "faraday" + LND string = "lnd" + LIT string = "lit" + LOOP string = "loop" + POOL string = "pool" + TAP string = "taproot-assets" + FARADAY string = "faraday" + ACCOUNTS string = "accounts" ) // subServerWrapper is a wrapper around the SubServer interface and is used by diff --git a/terminal.go b/terminal.go index c141e1d6e..571c9dc57 100644 --- a/terminal.go +++ b/terminal.go @@ -232,9 +232,10 @@ func (g *LightningTerminal) Run() error { return fmt.Errorf("could not create permissions manager") } - // Register LND and LiT with the status manager. + // Register LND, LiT and Accounts with the status manager. g.statusMgr.RegisterAndEnableSubServer(subservers.LND) g.statusMgr.RegisterAndEnableSubServer(subservers.LIT) + g.statusMgr.RegisterAndEnableSubServer(subservers.ACCOUNTS) // Create the instances of our subservers now so we can hook them up to // lnd once it's fully started. @@ -306,6 +307,11 @@ func (g *LightningTerminal) start() error { var err error accountServiceErrCallback := func(err error) { + g.statusMgr.SetErrored( + subservers.ACCOUNTS, + err.Error(), + ) + log.Errorf("Error thrown in the accounts service, keeping "+ "litd running: %v", err, ) @@ -851,6 +857,10 @@ func (g *LightningTerminal) startInternalSubServers( if err != nil { log.Errorf("error starting account service: %v, disabling "+ "account service", err) + + g.statusMgr.SetErrored(subservers.ACCOUNTS, err.Error()) + } else { + g.statusMgr.SetRunning(subservers.ACCOUNTS) } // Even if we error on accountService.Start, we still want to mark the // service as started so that we can properly shut it down in the From 20fa2c0915de0d43f793f4c22a140e514dcbdbad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Thu, 21 Sep 2023 02:45:28 +0200 Subject: [PATCH 7/8] terminal: add disable accounts service cfg option --- accounts/service.go | 6 ++++ config.go | 4 +++ terminal.go | 70 ++++++++++++++++++++++++++++++++------------- 3 files changed, 60 insertions(+), 20 deletions(-) diff --git a/accounts/service.go b/accounts/service.go index fb77c061b..7712258bc 100644 --- a/accounts/service.go +++ b/accounts/service.go @@ -16,6 +16,12 @@ import ( "github.com/lightningnetwork/lnd/lnwire" ) +// Config holds the configuration options for the accounts service. +type Config struct { + // Disable will disable the accounts service if set. + Disable bool `long:"disable" description:"disable the accounts service"` +} + // trackedPayment is a struct that holds all information that identifies a // payment that we are tracking in the service. type trackedPayment struct { diff --git a/config.go b/config.go index 107e2a4f8..65b9f54c5 100644 --- a/config.go +++ b/config.go @@ -17,6 +17,7 @@ import ( "github.com/lightninglabs/faraday" "github.com/lightninglabs/faraday/chain" "github.com/lightninglabs/faraday/frdrpcserver" + "github.com/lightninglabs/lightning-terminal/accounts" "github.com/lightninglabs/lightning-terminal/autopilotserver" "github.com/lightninglabs/lightning-terminal/firewall" mid "github.com/lightninglabs/lightning-terminal/rpcmiddleware" @@ -205,6 +206,8 @@ type Config struct { Firewall *firewall.Config `group:"Firewall options" namespace:"firewall"` + Accounts *accounts.Config `group:"Accounts options" namespace:"accounts"` + // faradayRpcConfig is a subset of faraday's full configuration that is // passed into faraday's RPC server. faradayRpcConfig *frdrpcserver.Config @@ -320,6 +323,7 @@ func defaultConfig() *Config { PingCadence: time.Hour, }, Firewall: firewall.DefaultConfig(), + Accounts: &accounts.Config{}, } } diff --git a/terminal.go b/terminal.go index 571c9dc57..f0c1e9350 100644 --- a/terminal.go +++ b/terminal.go @@ -235,7 +235,12 @@ func (g *LightningTerminal) Run() error { // Register LND, LiT and Accounts with the status manager. g.statusMgr.RegisterAndEnableSubServer(subservers.LND) g.statusMgr.RegisterAndEnableSubServer(subservers.LIT) - g.statusMgr.RegisterAndEnableSubServer(subservers.ACCOUNTS) + g.statusMgr.RegisterSubServer(subservers.ACCOUNTS) + + // Also enable the accounts subserver if it's not disabled. + if !g.cfg.Accounts.Disable { + g.statusMgr.SetEnabled(subservers.ACCOUNTS) + } // Create the instances of our subservers now so we can hook them up to // lnd once it's fully started. @@ -849,23 +854,41 @@ func (g *LightningTerminal) startInternalSubServers( return nil } + // Even if the accounts service fails on the Start function, or the + // accounts service is disabled, we still want to call Stop function as + // this closes the contexts and the db store which were opened with the + // accounts.NewService function call in the LightningTerminal start + // function above. + closeAccountService := func() { + if err := g.accountService.Stop(); err != nil { + // We only log the error if we fail to stop the service, + // as it's not critical that this succeeds in order to + // keep litd running + log.Errorf("Error stopping account service: %v", err) + } + } + log.Infof("Starting LiT account service") - err = g.accountService.Start( - g.lndClient.Client, g.lndClient.Router, - g.lndClient.ChainParams, - ) - if err != nil { - log.Errorf("error starting account service: %v, disabling "+ - "account service", err) + if !g.cfg.Accounts.Disable { + err = g.accountService.Start( + g.lndClient.Client, g.lndClient.Router, + g.lndClient.ChainParams, + ) + if err != nil { + log.Errorf("error starting account service: %v, "+ + "disabling account service", err) + + g.statusMgr.SetErrored(subservers.ACCOUNTS, err.Error()) + + closeAccountService() + } else { + g.statusMgr.SetRunning(subservers.ACCOUNTS) - g.statusMgr.SetErrored(subservers.ACCOUNTS, err.Error()) + g.accountServiceStarted = true + } } else { - g.statusMgr.SetRunning(subservers.ACCOUNTS) + closeAccountService() } - // Even if we error on accountService.Start, we still want to mark the - // service as started so that we can properly shut it down in the - // shutdownSubServers call. - g.accountServiceStarted = true requestLogger, err := firewall.NewRequestLogger( g.cfg.Firewall.RequestLogger, g.firewallDB, @@ -952,7 +975,12 @@ func (g *LightningTerminal) registerSubDaemonGrpcServers(server *grpc.Server, litrpc.RegisterStatusServer(server, g.statusMgr) } else { litrpc.RegisterSessionsServer(server, g.sessionRpcServer) - litrpc.RegisterAccountsServer(server, g.accountRpcServer) + + if !g.cfg.Accounts.Disable { + litrpc.RegisterAccountsServer( + server, g.accountRpcServer, + ) + } } litrpc.RegisterFirewallServer(server, g.sessionRpcServer) @@ -979,11 +1007,13 @@ func (g *LightningTerminal) RegisterRestSubserver(ctx context.Context, return err } - err = litrpc.RegisterAccountsHandlerFromEndpoint( - ctx, mux, endpoint, dialOpts, - ) - if err != nil { - return err + if !g.cfg.Accounts.Disable { + err = litrpc.RegisterAccountsHandlerFromEndpoint( + ctx, mux, endpoint, dialOpts, + ) + if err != nil { + return err + } } err = litrpc.RegisterFirewallHandlerFromEndpoint( From 294e9d00b5d63f7d6be38a862470e7a99bd49864 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Tigerstr=C3=B6m?= Date: Mon, 25 Sep 2023 19:57:52 +0200 Subject: [PATCH 8/8] itest: add disable test for accounts endpoint --- itest/litd_mode_integrated_test.go | 52 +++++++++++++++++++++++------- itest/litd_mode_remote_test.go | 20 +++++++++--- 2 files changed, 56 insertions(+), 16 deletions(-) diff --git a/itest/litd_mode_integrated_test.go b/itest/litd_mode_integrated_test.go index 2d5aeecdf..d75433619 100644 --- a/itest/litd_mode_integrated_test.go +++ b/itest/litd_mode_integrated_test.go @@ -213,10 +213,21 @@ var ( } endpoints = []struct { - name string - macaroonFn macaroonFn - requestFn requestFn - successPattern string + name string + macaroonFn macaroonFn + requestFn requestFn + successPattern string + + // disabledPattern represents a substring that is expected to be + // part of the error returned when a gRPC request is made to the + // disabled endpoint. + // TODO: once we have a subsystem manager, we can unify the + // returned for disabled endpoints for both subsystems and + // subservers by not registering the subsystem URIs to the + // permsMgr if it has been disabled. This field will then be + // unnecessary and can be removed. + disabledPattern string + allowedThroughLNC bool grpcWebURI string restWebURI string @@ -269,6 +280,7 @@ var ( macaroonFn: faradayMacaroonFn, requestFn: faradayRequestFn, successPattern: "\"reports\":[]", + disabledPattern: "unknown request", allowedThroughLNC: true, grpcWebURI: "/frdrpc.FaradayServer/RevenueReport", restWebURI: "/v1/faraday/revenue", @@ -278,6 +290,7 @@ var ( macaroonFn: loopMacaroonFn, requestFn: loopRequestFn, successPattern: "\"swaps\":[]", + disabledPattern: "unknown request", allowedThroughLNC: true, grpcWebURI: "/looprpc.SwapClient/ListSwaps", restWebURI: "/v1/loop/swaps", @@ -287,6 +300,7 @@ var ( macaroonFn: poolMacaroonFn, requestFn: poolRequestFn, successPattern: "\"accounts_active\":0", + disabledPattern: "unknown request", allowedThroughLNC: true, grpcWebURI: "/poolrpc.Trader/GetInfo", restWebURI: "/v1/pool/info", @@ -296,6 +310,7 @@ var ( macaroonFn: tapMacaroonFn, requestFn: tapRequestFn, successPattern: "\"assets\":[]", + disabledPattern: "unknown request", allowedThroughLNC: true, grpcWebURI: "/taprpc.TaprootAssets/ListAssets", restWebURI: "/v1/taproot-assets/assets", @@ -305,6 +320,7 @@ var ( macaroonFn: emptyMacaroonFn, requestFn: tapUniverseRequestFn, successPattern: "\"num_assets\":", + disabledPattern: "unknown request", allowedThroughLNC: true, grpcWebURI: "/universerpc.Universe/Info", restWebURI: "/v1/taproot-assets/universe/info", @@ -326,9 +342,11 @@ var ( macaroonFn: litMacaroonFn, requestFn: litAccountRequestFn, successPattern: "\"accounts\":[", + disabledPattern: "accounts has been disabled", allowedThroughLNC: false, grpcWebURI: "/litrpc.Accounts/ListAccounts", restWebURI: "/v1/accounts", + canDisable: true, }, { name: "litrpc-autopilot", macaroonFn: litMacaroonFn, @@ -384,6 +402,7 @@ func testDisablingSubServers(ctx context.Context, net *NetworkHarness, WithLitArg("loop-mode", "disable"), WithLitArg("pool-mode", "disable"), WithLitArg("faraday-mode", "disable"), + WithLitArg("accounts.disable", ""), }, ) require.NoError(t, err) @@ -494,7 +513,7 @@ func integratedTestSuite(ctx context.Context, net *NetworkHarness, t *testing.T, endpoint.requestFn, endpoint.successPattern, endpointDisabled, - "unknown request", + endpoint.disabledPattern, ) }) } @@ -532,7 +551,7 @@ func integratedTestSuite(ctx context.Context, net *NetworkHarness, t *testing.T, shouldFailWithoutMacaroon, endpoint.successPattern, endpointDisabled, - "unknown request", + endpoint.disabledPattern, ) }) } @@ -557,7 +576,8 @@ func integratedTestSuite(ctx context.Context, net *NetworkHarness, t *testing.T, ttt, cfg.LitAddr(), cfg.UIPassword, endpoint.grpcWebURI, withoutUIPassword, endpointDisabled, - "unknown request", endpoint.noAuth, + endpoint.disabledPattern, + endpoint.noAuth, ) }) } @@ -596,7 +616,7 @@ func integratedTestSuite(ctx context.Context, net *NetworkHarness, t *testing.T, endpoint.requestFn, endpoint.successPattern, endpointDisabled, - "unknown request", + endpoint.disabledPattern, ) }) } @@ -649,7 +669,9 @@ func integratedTestSuite(ctx context.Context, net *NetworkHarness, t *testing.T, endpoint.successPattern, endpoint.allowedThroughLNC, "unknown service", - endpointDisabled, endpoint.noAuth, + endpointDisabled, + endpoint.disabledPattern, + endpoint.noAuth, ) }) } @@ -658,6 +680,12 @@ func integratedTestSuite(ctx context.Context, net *NetworkHarness, t *testing.T, t.Run("gRPC super macaroon account system test", func(tt *testing.T) { cfg := net.Alice.Cfg + // If the accounts service is disabled, we skip this test as it + // will fail due to the accounts service being disabled. + if subServersDisabled { + return + } + superMacFile, err := bakeSuperMacaroon(cfg, false) require.NoError(tt, err) @@ -722,6 +750,7 @@ func integratedTestSuite(ctx context.Context, net *NetworkHarness, t *testing.T, endpoint.successPattern, allowed, expectedErr, endpointDisabled, + endpoint.disabledPattern, endpoint.noAuth, ) }) @@ -1169,7 +1198,8 @@ func runRESTAuthTest(t *testing.T, hostPort, uiPassword, macaroonPath, restURI, // through Lightning Node Connect. func runLNCAuthTest(t *testing.T, rawLNCConn grpc.ClientConnInterface, makeRequest requestFn, successContent string, callAllowed bool, - expectErrContains string, disabled, noMac bool) { + expectErrContains string, disabled bool, disabledPattern string, + noMac bool) { ctxt, cancel := context.WithTimeout( context.Background(), defaultTimeout, @@ -1186,7 +1216,7 @@ func runLNCAuthTest(t *testing.T, rawLNCConn grpc.ClientConnInterface, // The call should be allowed, so we expect no error unless this is // for a disabled sub-server. case disabled: - require.ErrorContains(t, err, "unknown request") + require.ErrorContains(t, err, disabledPattern) return case noMac: diff --git a/itest/litd_mode_remote_test.go b/itest/litd_mode_remote_test.go index 41f34c13f..67e0aa14f 100644 --- a/itest/litd_mode_remote_test.go +++ b/itest/litd_mode_remote_test.go @@ -67,7 +67,7 @@ func remoteTestSuite(ctx context.Context, net *NetworkHarness, t *testing.T, endpoint.requestFn, endpoint.successPattern, endpointEnabled, - "unknown request", + endpoint.disabledPattern, ) }) } @@ -94,7 +94,7 @@ func remoteTestSuite(ctx context.Context, net *NetworkHarness, t *testing.T, shouldFailWithoutMacaroon, endpoint.successPattern, endpointEnabled, - "unknown request", + endpoint.disabledPattern, ) }) } @@ -117,7 +117,8 @@ func remoteTestSuite(ctx context.Context, net *NetworkHarness, t *testing.T, ttt, cfg.LitAddr(), cfg.UIPassword, endpoint.grpcWebURI, withoutUIPassword, endpointEnabled, - "unknown request", endpoint.noAuth, + endpoint.disabledPattern, + endpoint.noAuth, ) }) } @@ -145,7 +146,7 @@ func remoteTestSuite(ctx context.Context, net *NetworkHarness, t *testing.T, endpoint.requestFn, endpoint.successPattern, endpointEnabled, - "unknown request", + endpoint.disabledPattern, ) }) } @@ -197,7 +198,9 @@ func remoteTestSuite(ctx context.Context, net *NetworkHarness, t *testing.T, endpoint.successPattern, endpoint.allowedThroughLNC, "unknown service", - endpointDisabled, endpoint.noAuth, + endpointDisabled, + endpoint.disabledPattern, + endpoint.noAuth, ) }) } @@ -248,6 +251,7 @@ func remoteTestSuite(ctx context.Context, net *NetworkHarness, t *testing.T, endpoint.successPattern, allowed, expectedErr, endpointDisabled, + endpoint.disabledPattern, endpoint.noAuth, ) }) @@ -257,6 +261,12 @@ func remoteTestSuite(ctx context.Context, net *NetworkHarness, t *testing.T, t.Run("gRPC super macaroon account system test", func(tt *testing.T) { cfg := net.Bob.Cfg + // If the accounts service is disabled, we skip this test as it + // will fail due to the accounts service being disabled. + if subServersDisabled { + return + } + superMacFile, err := bakeSuperMacaroon(cfg, false) require.NoError(tt, err)