diff --git a/client.go b/client.go index 3f4899a..34b55c9 100644 --- a/client.go +++ b/client.go @@ -192,13 +192,13 @@ func NewClient(dbDir string, cfg *ClientConfig) (*Client, func(), error) { } // FetchSwaps returns all loop in and out swaps currently in the database. -func (s *Client) FetchSwaps() ([]*SwapInfo, error) { - loopOutSwaps, err := s.Store.FetchLoopOutSwaps() +func (s *Client) FetchSwaps(ctx context.Context) ([]*SwapInfo, error) { + loopOutSwaps, err := s.Store.FetchLoopOutSwaps(ctx) if err != nil { return nil, err } - loopInSwaps, err := s.Store.FetchLoopInSwaps() + loopInSwaps, err := s.Store.FetchLoopInSwaps(ctx) if err != nil { return nil, err } @@ -292,12 +292,12 @@ func (s *Client) Run(ctx context.Context, // Query store before starting event loop to prevent new swaps from // being treated as swaps that need to be resumed. - pendingLoopOutSwaps, err := s.Store.FetchLoopOutSwaps() + pendingLoopOutSwaps, err := s.Store.FetchLoopOutSwaps(mainCtx) if err != nil { return err } - pendingLoopInSwaps, err := s.Store.FetchLoopInSwaps() + pendingLoopInSwaps, err := s.Store.FetchLoopInSwaps(mainCtx) if err != nil { return err } diff --git a/liquidity/autoloop_testcontext_test.go b/liquidity/autoloop_testcontext_test.go index bbaf2b5..d0653cb 100644 --- a/liquidity/autoloop_testcontext_test.go +++ b/liquidity/autoloop_testcontext_test.go @@ -144,13 +144,15 @@ func newAutoloopTestCtx(t *testing.T, parameters Parameters, return <-testCtx.loopInRestrictions, nil }, - ListLoopOut: func() ([]*loopdb.LoopOut, error) { + ListLoopOut: func(context.Context) ([]*loopdb.LoopOut, error) { return <-testCtx.loopOuts, nil }, - GetLoopOut: func(hash lntypes.Hash) (*loopdb.LoopOut, error) { + GetLoopOut: func(ctx context.Context, + hash lntypes.Hash) (*loopdb.LoopOut, error) { + return testCtx.loopOutSingle, nil }, - ListLoopIn: func() ([]*loopdb.LoopIn, error) { + ListLoopIn: func(context.Context) ([]*loopdb.LoopIn, error) { return <-testCtx.loopIns, nil }, LoopOutQuote: func(_ context.Context, @@ -186,10 +188,10 @@ func newAutoloopTestCtx(t *testing.T, parameters Parameters, MinimumConfirmations: loop.DefaultSweepConfTarget, Lnd: &testCtx.lnd.LndServices, Clock: testCtx.testClock, - PutLiquidityParams: func(_ []byte) error { + PutLiquidityParams: func(_ context.Context, _ []byte) error { return nil }, - FetchLiquidityParams: func() ([]byte, error) { + FetchLiquidityParams: func(context.Context) ([]byte, error) { return nil, nil }, } diff --git a/liquidity/liquidity.go b/liquidity/liquidity.go index 9efd526..b7181f6 100644 --- a/liquidity/liquidity.go +++ b/liquidity/liquidity.go @@ -10,22 +10,22 @@ // // Fee restrictions are placed on swap suggestions to ensure that we only // suggest swaps that fit the configured fee preferences. -// - Sweep Fee Rate Limit: the maximum sat/vByte fee estimate for our sweep -// transaction to confirm within our configured number of confirmations -// that we will suggest swaps for. -// - Maximum Swap Fee PPM: the maximum server fee, expressed as parts per -// million of the full swap amount -// - Maximum Routing Fee PPM: the maximum off-chain routing fees for the swap -// invoice, expressed as parts per million of the swap amount. -// - Maximum Prepay Routing Fee PPM: the maximum off-chain routing fees for the -// swap prepayment, expressed as parts per million of the prepay amount. -// - Maximum Prepay: the maximum now-show fee, expressed in satoshis. This -// amount is only payable in the case where the swap server broadcasts a htlc -// and the client fails to sweep the preimage. -// - Maximum miner fee: the maximum miner fee we are willing to pay to sweep the -// on chain htlc. Note that the client will use current fee estimates to -// sweep, so this value acts more as a sanity check in the case of a large fee -// spike. +// - Sweep Fee Rate Limit: the maximum sat/vByte fee estimate for our sweep +// transaction to confirm within our configured number of confirmations +// that we will suggest swaps for. +// - Maximum Swap Fee PPM: the maximum server fee, expressed as parts per +// million of the full swap amount +// - Maximum Routing Fee PPM: the maximum off-chain routing fees for the swap +// invoice, expressed as parts per million of the swap amount. +// - Maximum Prepay Routing Fee PPM: the maximum off-chain routing fees for the +// swap prepayment, expressed as parts per million of the prepay amount. +// - Maximum Prepay: the maximum now-show fee, expressed in satoshis. This +// amount is only payable in the case where the swap server broadcasts a htlc +// and the client fails to sweep the preimage. +// - Maximum miner fee: the maximum miner fee we are willing to pay to sweep the +// on chain htlc. Note that the client will use current fee estimates to +// sweep, so this value acts more as a sanity check in the case of a large fee +// spike. // // The maximum fee per-swap is calculated as follows: // (swap amount * serverPPM/1e6) + miner fee + (swap amount * routingPPM/1e6) @@ -176,14 +176,14 @@ type Config struct { Lnd *lndclient.LndServices // ListLoopOut returns all of the loop our swaps stored on disk. - ListLoopOut func() ([]*loopdb.LoopOut, error) + ListLoopOut func(context.Context) ([]*loopdb.LoopOut, error) // GetLoopOut returns a single loop out swap based on the provided swap // hash. - GetLoopOut func(hash lntypes.Hash) (*loopdb.LoopOut, error) + GetLoopOut func(ctx context.Context, hash lntypes.Hash) (*loopdb.LoopOut, error) // ListLoopIn returns all of the loop in swaps stored on disk. - ListLoopIn func() ([]*loopdb.LoopIn, error) + ListLoopIn func(ctx context.Context) ([]*loopdb.LoopIn, error) // LoopOutQuote gets swap fee, estimated miner fee and prepay amount for // a loop out swap. @@ -219,13 +219,13 @@ type Config struct { // // NOTE: the params are encoded using `proto.Marshal` over an RPC // request. - PutLiquidityParams func(params []byte) error + PutLiquidityParams func(ctx context.Context, params []byte) error // FetchLiquidityParams reads the serialized `Parameters` from db. // // NOTE: the params are decoded using `proto.Unmarshal` over a // serialized RPC request. - FetchLiquidityParams func() ([]byte, error) + FetchLiquidityParams func(ctx context.Context) ([]byte, error) } // Manager contains a set of desired liquidity rules for our channel @@ -260,7 +260,7 @@ func (m *Manager) Run(ctx context.Context) error { defer m.cfg.AutoloopTicker.Stop() // Before we start the main loop, load the params from db. - req, err := m.loadParams() + req, err := m.loadParams(ctx) if err != nil { return err } @@ -338,7 +338,7 @@ func (m *Manager) SetParameters(ctx context.Context, // Since setting params is NOT a frequent action, it's should put // little pressure on our db. Only when performance becomes an issue, // we can then apply the alternative. - return m.saveParams(req) + return m.saveParams(ctx, req) } // SetParameters updates our current set of parameters if the new parameters @@ -372,7 +372,7 @@ func (m *Manager) setParameters(ctx context.Context, } // saveParams marshals an RPC request and saves it to db. -func (m *Manager) saveParams(req proto.Message) error { +func (m *Manager) saveParams(ctx context.Context, req proto.Message) error { // Marshal the params. paramsBytes, err := proto.Marshal(req) if err != nil { @@ -380,7 +380,7 @@ func (m *Manager) saveParams(req proto.Message) error { } // Save the params on disk. - if err := m.cfg.PutLiquidityParams(paramsBytes); err != nil { + if err := m.cfg.PutLiquidityParams(ctx, paramsBytes); err != nil { return fmt.Errorf("failed to save params: %v", err) } @@ -389,8 +389,10 @@ func (m *Manager) saveParams(req proto.Message) error { // loadParams unmarshals a serialized RPC request from db and returns the RPC // request. -func (m *Manager) loadParams() (*clientrpc.LiquidityParameters, error) { - paramsBytes, err := m.cfg.FetchLiquidityParams() +func (m *Manager) loadParams(ctx context.Context) ( + *clientrpc.LiquidityParameters, error) { + + paramsBytes, err := m.cfg.FetchLiquidityParams(ctx) if err != nil { return nil, fmt.Errorf("failed to read params: %v", err) } @@ -509,12 +511,12 @@ func (m *Manager) ForceAutoLoop(ctx context.Context) error { // local balance back to the target. func (m *Manager) dispatchBestEasyAutoloopSwap(ctx context.Context) error { // Retrieve existing swaps. - loopOut, err := m.cfg.ListLoopOut() + loopOut, err := m.cfg.ListLoopOut(ctx) if err != nil { return err } - loopIn, err := m.cfg.ListLoopIn() + loopIn, err := m.cfg.ListLoopIn(ctx) if err != nil { return err } @@ -723,12 +725,12 @@ func (m *Manager) SuggestSwaps(ctx context.Context) ( // List our current set of swaps so that we can determine which channels // are already being utilized by swaps. Note that these calls may race // with manual initiation of swaps. - loopOut, err := m.cfg.ListLoopOut() + loopOut, err := m.cfg.ListLoopOut(ctx) if err != nil { return nil, err } - loopIn, err := m.cfg.ListLoopIn() + loopIn, err := m.cfg.ListLoopIn(ctx) if err != nil { return nil, err } @@ -1212,7 +1214,7 @@ func (m *Manager) refreshAutoloopBudget(ctx context.Context) { return } - err = m.saveParams(paramsRpc) + err = m.saveParams(ctx, paramsRpc) if err != nil { log.Errorf("Error saving parameters: %v", err) } @@ -1334,7 +1336,7 @@ func (m *Manager) waitForSwapPayment(ctx context.Context, swapHash lntypes.Hash, case <-time.After(interval): } - swap, err = m.cfg.GetLoopOut(swapHash) + swap, err = m.cfg.GetLoopOut(ctx, swapHash) if err != nil { log.Errorf( "Error getting swap with hash %x: %v", swapHash, diff --git a/liquidity/liquidity_test.go b/liquidity/liquidity_test.go index 7ca9786..4066d07 100644 --- a/liquidity/liquidity_test.go +++ b/liquidity/liquidity_test.go @@ -154,10 +154,10 @@ func newTestConfig() (*Config, *test.LndMockServices) { }, Lnd: &lnd.LndServices, Clock: clock.NewTestClock(testTime), - ListLoopOut: func() ([]*loopdb.LoopOut, error) { + ListLoopOut: func(context.Context) ([]*loopdb.LoopOut, error) { return nil, nil }, - ListLoopIn: func() ([]*loopdb.LoopIn, error) { + ListLoopIn: func(context.Context) ([]*loopdb.LoopIn, error) { return nil, nil }, LoopOutQuote: func(_ context.Context, @@ -266,30 +266,34 @@ func TestPersistParams(t *testing.T) { cfg, _ := newTestConfig() manager := NewManager(cfg) + ctxb := context.Background() + var paramsBytes []byte // Mock the read method to return empty data. - manager.cfg.FetchLiquidityParams = func() ([]byte, error) { + manager.cfg.FetchLiquidityParams = func(context.Context) ([]byte, error) { return paramsBytes, nil } // Test the nil params is returned. - req, err := manager.loadParams() + req, err := manager.loadParams(ctxb) require.Nil(t, req) require.NoError(t, err) // Mock the write method to return no error. - manager.cfg.PutLiquidityParams = func(data []byte) error { + manager.cfg.PutLiquidityParams = func(ctx context.Context, + data []byte) error { + paramsBytes = data return nil } // Test save the message. - err = manager.saveParams(rpcParams) + err = manager.saveParams(ctxb, rpcParams) require.NoError(t, err) // Test the nil params is returned. - req, err = manager.loadParams() + req, err = manager.loadParams(ctxb) require.NoError(t, err) // Check the specified fields are set as expected. @@ -565,10 +569,10 @@ func TestRestrictedSuggestions(t *testing.T) { // Create a manager config which will return the test // case's set of existing swaps. cfg, lnd := newTestConfig() - cfg.ListLoopOut = func() ([]*loopdb.LoopOut, error) { + cfg.ListLoopOut = func(context.Context) ([]*loopdb.LoopOut, error) { return testCase.loopOut, nil } - cfg.ListLoopIn = func() ([]*loopdb.LoopIn, error) { + cfg.ListLoopIn = func(context.Context) ([]*loopdb.LoopIn, error) { return testCase.loopIn, nil } @@ -1093,7 +1097,7 @@ func TestFeeBudget(t *testing.T) { }) } - cfg.ListLoopOut = func() ([]*loopdb.LoopOut, error) { + cfg.ListLoopOut = func(context.Context) ([]*loopdb.LoopOut, error) { return swaps, nil } @@ -1270,10 +1274,10 @@ func TestInFlightLimit(t *testing.T) { t.Run(testCase.name, func(t *testing.T) { cfg, lnd := newTestConfig() - cfg.ListLoopOut = func() ([]*loopdb.LoopOut, error) { + cfg.ListLoopOut = func(context.Context) ([]*loopdb.LoopOut, error) { return testCase.existingSwaps, nil } - cfg.ListLoopIn = func() ([]*loopdb.LoopIn, error) { + cfg.ListLoopIn = func(context.Context) ([]*loopdb.LoopIn, error) { return testCase.existingInSwaps, nil } @@ -1755,7 +1759,7 @@ func TestBudgetWithLoopin(t *testing.T) { channel1, } - cfg.ListLoopIn = func() ([]*loopdb.LoopIn, error) { + cfg.ListLoopIn = func(context.Context) ([]*loopdb.LoopIn, error) { return testCase.loopIns, nil } diff --git a/loopd/daemon.go b/loopd/daemon.go index adca452..787afd2 100644 --- a/loopd/daemon.go +++ b/loopd/daemon.go @@ -450,7 +450,7 @@ func (d *Daemon) initialize(withMacaroonService bool) error { } // Retrieve all currently existing swaps from the database. - swapsList, err := d.impl.FetchSwaps() + swapsList, err := d.impl.FetchSwaps(d.mainCtx) if err != nil { if d.macaroonService == nil { cleanupMacaroonStore() diff --git a/loopd/swapclient_server.go b/loopd/swapclient_server.go index 90a97c6..798b429 100644 --- a/loopd/swapclient_server.go +++ b/loopd/swapclient_server.go @@ -741,11 +741,11 @@ func (s *swapClientServer) GetLsatTokens(ctx context.Context, // GetInfo returns basic information about the loop daemon and details to swaps // from the swap store. -func (s *swapClientServer) GetInfo(_ context.Context, +func (s *swapClientServer) GetInfo(ctx context.Context, _ *clientrpc.GetInfoRequest) (*clientrpc.GetInfoResponse, error) { // Fetch loop-outs from the loop db. - outSwaps, err := s.impl.Store.FetchLoopOutSwaps() + outSwaps, err := s.impl.Store.FetchLoopOutSwaps(ctx) if err != nil { return nil, err } @@ -772,7 +772,7 @@ func (s *swapClientServer) GetInfo(_ context.Context, } // Fetch loop-ins from the loop db. - inSwaps, err := s.impl.Store.FetchLoopInSwaps() + inSwaps, err := s.impl.Store.FetchLoopInSwaps(ctx) if err != nil { return nil, err } diff --git a/loopd/view.go b/loopd/view.go index 3a7b600..a809cc5 100644 --- a/loopd/view.go +++ b/loopd/view.go @@ -1,6 +1,7 @@ package loopd import ( + "context" "fmt" "github.com/btcsuite/btcd/chaincfg" @@ -42,7 +43,7 @@ func view(config *Config, lisCfg *ListenerCfg) error { } func viewOut(swapClient *loop.Client, chainParams *chaincfg.Params) error { - swaps, err := swapClient.Store.FetchLoopOutSwaps() + swaps, err := swapClient.Store.FetchLoopOutSwaps(context.Background()) if err != nil { return err } @@ -91,7 +92,7 @@ func viewOut(swapClient *loop.Client, chainParams *chaincfg.Params) error { } func viewIn(swapClient *loop.Client, chainParams *chaincfg.Params) error { - swaps, err := swapClient.Store.FetchLoopInSwaps() + swaps, err := swapClient.Store.FetchLoopInSwaps(context.Background()) if err != nil { return err } diff --git a/loopdb/interface.go b/loopdb/interface.go index 6f85067..be7db05 100644 --- a/loopdb/interface.go +++ b/loopdb/interface.go @@ -1,6 +1,7 @@ package loopdb import ( + "context" "time" "github.com/lightningnetwork/lnd/lntypes" @@ -10,30 +11,32 @@ import ( // houses information for all pending completed/failed swaps. type SwapStore interface { // FetchLoopOutSwaps returns all swaps currently in the store. - FetchLoopOutSwaps() ([]*LoopOut, error) + FetchLoopOutSwaps(ctx context.Context) ([]*LoopOut, error) // FetchLoopOutSwap returns the loop out swap with the given hash. - FetchLoopOutSwap(hash lntypes.Hash) (*LoopOut, error) + FetchLoopOutSwap(ctx context.Context, hash lntypes.Hash) (*LoopOut, error) // CreateLoopOut adds an initiated swap to the store. - CreateLoopOut(hash lntypes.Hash, swap *LoopOutContract) error + CreateLoopOut(ctx context.Context, hash lntypes.Hash, + swap *LoopOutContract) error // UpdateLoopOut stores a new event for a target loop out swap. This // appends to the event log for a particular swap as it goes through // the various stages in its lifetime. - UpdateLoopOut(hash lntypes.Hash, time time.Time, + UpdateLoopOut(ctx context.Context, hash lntypes.Hash, time time.Time, state SwapStateData) error // FetchLoopInSwaps returns all swaps currently in the store. - FetchLoopInSwaps() ([]*LoopIn, error) + FetchLoopInSwaps(ctx context.Context) ([]*LoopIn, error) // CreateLoopIn adds an initiated swap to the store. - CreateLoopIn(hash lntypes.Hash, swap *LoopInContract) error + CreateLoopIn(ctx context.Context, hash lntypes.Hash, + swap *LoopInContract) error // UpdateLoopIn stores a new event for a target loop in swap. This // appends to the event log for a particular swap as it goes through // the various stages in its lifetime. - UpdateLoopIn(hash lntypes.Hash, time time.Time, + UpdateLoopIn(ctx context.Context, hash lntypes.Hash, time time.Time, state SwapStateData) error // PutLiquidityParams writes the serialized `manager.Parameters` bytes @@ -41,14 +44,14 @@ type SwapStore interface { // // NOTE: it's the caller's responsibility to encode the param. Atm, // it's encoding using the proto package's `Marshal` method. - PutLiquidityParams(params []byte) error + PutLiquidityParams(ctx context.Context, params []byte) error // FetchLiquidityParams reads the serialized `manager.Parameters` bytes // from the bucket. // // NOTE: it's the caller's responsibility to decode the param. Atm, // it's decoding using the proto package's `Unmarshal` method. - FetchLiquidityParams() ([]byte, error) + FetchLiquidityParams(ctx context.Context) ([]byte, error) // Close closes the underlying database. Close() error diff --git a/loopdb/migration_04_updates_test.go b/loopdb/migration_04_updates_test.go index 80c1156..0c206ef 100644 --- a/loopdb/migration_04_updates_test.go +++ b/loopdb/migration_04_updates_test.go @@ -1,6 +1,7 @@ package loopdb import ( + "context" "io/ioutil" "os" "path/filepath" @@ -44,6 +45,8 @@ func TestMigrationUpdates(t *testing.T) { }, } + ctxb := context.Background() + // Restore a legacy database. tempDirName, err := ioutil.TempDir("", "clientstore") require.NoError(t, err) @@ -69,7 +72,7 @@ func TestMigrationUpdates(t *testing.T) { // Fetch the legacy loop out swap and assert that the updates are still // there. - outSwaps, err := store.FetchLoopOutSwaps() + outSwaps, err := store.FetchLoopOutSwaps(ctxb) require.NoError(t, err) outSwap := outSwaps[0] @@ -78,7 +81,7 @@ func TestMigrationUpdates(t *testing.T) { // Fetch the legacy loop in swap and assert that the updates are still // there. - inSwaps, err := store.FetchLoopInSwaps() + inSwaps, err := store.FetchLoopInSwaps(ctxb) require.NoError(t, err) inSwap := inSwaps[0] diff --git a/loopdb/store.go b/loopdb/store.go index 63b4be4..d7025cc 100644 --- a/loopdb/store.go +++ b/loopdb/store.go @@ -2,6 +2,7 @@ package loopdb import ( "bytes" + "context" "encoding/binary" "errors" "fmt" @@ -346,7 +347,9 @@ func unmarshalHtlcKeys(swapBucket *bbolt.Bucket, contract *SwapContract) error { // FetchLoopOutSwaps returns all loop out swaps currently in the store. // // NOTE: Part of the loopdb.SwapStore interface. -func (s *boltSwapStore) FetchLoopOutSwaps() ([]*LoopOut, error) { +func (s *boltSwapStore) FetchLoopOutSwaps(ctx context.Context) ([]*LoopOut, + error) { + var swaps []*LoopOut err := s.db.View(func(tx *bbolt.Tx) error { @@ -385,7 +388,9 @@ func (s *boltSwapStore) FetchLoopOutSwaps() ([]*LoopOut, error) { // FetchLoopOutSwap returns the loop out swap with the given hash. // // NOTE: Part of the loopdb.SwapStore interface. -func (s *boltSwapStore) FetchLoopOutSwap(hash lntypes.Hash) (*LoopOut, error) { +func (s *boltSwapStore) FetchLoopOutSwap(ctx context.Context, + hash lntypes.Hash) (*LoopOut, error) { + var swap *LoopOut err := s.db.View(func(tx *bbolt.Tx) error { @@ -414,7 +419,9 @@ func (s *boltSwapStore) FetchLoopOutSwap(hash lntypes.Hash) (*LoopOut, error) { // FetchLoopInSwaps returns all loop in swaps currently in the store. // // NOTE: Part of the loopdb.SwapStore interface. -func (s *boltSwapStore) FetchLoopInSwaps() ([]*LoopIn, error) { +func (s *boltSwapStore) FetchLoopInSwaps(ctx context.Context) ([]*LoopIn, + error) { + var swaps []*LoopIn err := s.db.View(func(tx *bbolt.Tx) error { @@ -475,7 +482,7 @@ func createLoopBucket(tx *bbolt.Tx, swapTypeKey []byte, hash lntypes.Hash) ( // CreateLoopOut adds an initiated swap to the store. // // NOTE: Part of the loopdb.SwapStore interface. -func (s *boltSwapStore) CreateLoopOut(hash lntypes.Hash, +func (s *boltSwapStore) CreateLoopOut(ctx context.Context, hash lntypes.Hash, swap *LoopOutContract) error { // If the hash doesn't match the pre-image, then this is an invalid @@ -561,7 +568,7 @@ func (s *boltSwapStore) CreateLoopOut(hash lntypes.Hash, // CreateLoopIn adds an initiated swap to the store. // // NOTE: Part of the loopdb.SwapStore interface. -func (s *boltSwapStore) CreateLoopIn(hash lntypes.Hash, +func (s *boltSwapStore) CreateLoopIn(ctx context.Context, hash lntypes.Hash, swap *LoopInContract) error { // If the hash doesn't match the pre-image, then this is an invalid @@ -678,8 +685,8 @@ func (s *boltSwapStore) updateLoop(bucketKey []byte, hash lntypes.Hash, // a particular swap as it goes through the various stages in its lifetime. // // NOTE: Part of the loopdb.SwapStore interface. -func (s *boltSwapStore) UpdateLoopOut(hash lntypes.Hash, time time.Time, - state SwapStateData) error { +func (s *boltSwapStore) UpdateLoopOut(ctx context.Context, + hash lntypes.Hash, time time.Time, state SwapStateData) error { return s.updateLoop(loopOutBucketKey, hash, time, state) } @@ -688,8 +695,8 @@ func (s *boltSwapStore) UpdateLoopOut(hash lntypes.Hash, time time.Time, // a particular swap as it goes through the various stages in its lifetime. // // NOTE: Part of the loopdb.SwapStore interface. -func (s *boltSwapStore) UpdateLoopIn(hash lntypes.Hash, time time.Time, - state SwapStateData) error { +func (s *boltSwapStore) UpdateLoopIn(ctx context.Context, hash lntypes.Hash, + time time.Time, state SwapStateData) error { return s.updateLoop(loopInBucketKey, hash, time, state) } @@ -706,7 +713,9 @@ func (s *boltSwapStore) Close() error { // // NOTE: it's the caller's responsibility to encode the param. Atm, it's // encoding using the proto package's `Marshal` method. -func (s *boltSwapStore) PutLiquidityParams(params []byte) error { +func (s *boltSwapStore) PutLiquidityParams(ctx context.Context, + params []byte) error { + return s.db.Update(func(tx *bbolt.Tx) error { // Read the root bucket. rootBucket := tx.Bucket(liquidityBucket) @@ -722,7 +731,9 @@ func (s *boltSwapStore) PutLiquidityParams(params []byte) error { // // NOTE: it's the caller's responsibility to decode the param. Atm, it's // decoding using the proto package's `Unmarshal` method. -func (s *boltSwapStore) FetchLiquidityParams() ([]byte, error) { +func (s *boltSwapStore) FetchLiquidityParams(ctx context.Context) ([]byte, + error) { + var params []byte err := s.db.View(func(tx *bbolt.Tx) error { diff --git a/loopdb/store_test.go b/loopdb/store_test.go index 324e965..35ae5e3 100644 --- a/loopdb/store_test.go +++ b/loopdb/store_test.go @@ -1,6 +1,7 @@ package loopdb import ( + "context" "crypto/sha256" "io/ioutil" "os" @@ -121,8 +122,10 @@ func testLoopOutStore(t *testing.T, pendingSwap *LoopOutContract) { store, err := NewBoltSwapStore(tempDirName, &chaincfg.MainNetParams) require.NoError(t, err) + ctxb := context.Background() + // First, verify that an empty database has no active swaps. - swaps, err := store.FetchLoopOutSwaps() + swaps, err := store.FetchLoopOutSwaps(ctxb) require.NoError(t, err) require.Empty(t, swaps) @@ -134,12 +137,12 @@ func testLoopOutStore(t *testing.T, pendingSwap *LoopOutContract) { checkSwap := func(expectedState SwapState) { t.Helper() - swaps, err := store.FetchLoopOutSwaps() + swaps, err := store.FetchLoopOutSwaps(ctxb) require.NoError(t, err) require.Len(t, swaps, 1) - swap, err := store.FetchLoopOutSwap(hash) + swap, err := store.FetchLoopOutSwap(ctxb, hash) require.NoError(t, err) require.Equal(t, hash, swap.Hash) @@ -158,20 +161,20 @@ func testLoopOutStore(t *testing.T, pendingSwap *LoopOutContract) { // If we create a new swap, then it should show up as being initialized // right after. - err = store.CreateLoopOut(hash, pendingSwap) + err = store.CreateLoopOut(ctxb, hash, pendingSwap) require.NoError(t, err) checkSwap(StateInitiated) // Trying to make the same swap again should result in an error. - err = store.CreateLoopOut(hash, pendingSwap) + err = store.CreateLoopOut(ctxb, hash, pendingSwap) require.Error(t, err) checkSwap(StateInitiated) // Next, we'll update to the next state of the pre-image being // revealed. The state should be reflected here again. err = store.UpdateLoopOut( - hash, testTime, + ctxb, hash, testTime, SwapStateData{ State: StatePreimageRevealed, HtlcTxHash: &chainhash.Hash{1, 6, 2}, @@ -184,7 +187,7 @@ func testLoopOutStore(t *testing.T, pendingSwap *LoopOutContract) { // Next, we'll update to the final state to ensure that the state is // properly updated. err = store.UpdateLoopOut( - hash, testTime, + ctxb, hash, testTime, SwapStateData{ State: StateFailInsufficientValue, }, @@ -260,8 +263,10 @@ func testLoopInStore(t *testing.T, pendingSwap LoopInContract) { store, err := NewBoltSwapStore(tempDirName, &chaincfg.MainNetParams) require.NoError(t, err) + ctxb := context.Background() + // First, verify that an empty database has no active swaps. - swaps, err := store.FetchLoopInSwaps() + swaps, err := store.FetchLoopInSwaps(ctxb) require.NoError(t, err) require.Empty(t, swaps) @@ -272,7 +277,7 @@ func testLoopInStore(t *testing.T, pendingSwap LoopInContract) { checkSwap := func(expectedState SwapState) { t.Helper() - swaps, err := store.FetchLoopInSwaps() + swaps, err := store.FetchLoopInSwaps(ctxb) require.NoError(t, err) require.Len(t, swaps, 1) @@ -285,13 +290,13 @@ func testLoopInStore(t *testing.T, pendingSwap LoopInContract) { // If we create a new swap, then it should show up as being initialized // right after. - err = store.CreateLoopIn(hash, &pendingSwap) + err = store.CreateLoopIn(ctxb, hash, &pendingSwap) require.NoError(t, err) checkSwap(StateInitiated) // Trying to make the same swap again should result in an error. - err = store.CreateLoopIn(hash, &pendingSwap) + err = store.CreateLoopIn(ctxb, hash, &pendingSwap) require.Error(t, err) checkSwap(StateInitiated) @@ -299,7 +304,7 @@ func testLoopInStore(t *testing.T, pendingSwap LoopInContract) { // Next, we'll update to the next state of the pre-image being // revealed. The state should be reflected here again. err = store.UpdateLoopIn( - hash, testTime, + ctxb, hash, testTime, SwapStateData{ State: StatePreimageRevealed, }, @@ -311,7 +316,7 @@ func testLoopInStore(t *testing.T, pendingSwap LoopInContract) { // Next, we'll update to the final state to ensure that the state is // properly updated. err = store.UpdateLoopIn( - hash, testTime, + ctxb, hash, testTime, SwapStateData{ State: StateFailInsufficientValue, }, @@ -407,6 +412,8 @@ func TestLegacyOutgoingChannel(t *testing.T) { legacyOutgoingChannel = Hex("0000000000000005") ) + ctxb := context.Background() + legacyDb := map[string]interface{}{ "loop-in": map[string]interface{}{}, "metadata": map[string]interface{}{ @@ -449,7 +456,7 @@ func TestLegacyOutgoingChannel(t *testing.T) { t.Fatal(err) } - swaps, err := store.FetchLoopOutSwaps() + swaps, err := store.FetchLoopOutSwaps(ctxb) if err != nil { t.Fatal(err) } @@ -467,23 +474,25 @@ func TestLiquidityParams(t *testing.T) { require.NoError(t, err, "failed to db") defer os.RemoveAll(tempDirName) + ctxb := context.Background() + store, err := NewBoltSwapStore(tempDirName, &chaincfg.MainNetParams) require.NoError(t, err, "failed to create store") // Test when there's no params saved before, an empty bytes is // returned. - params, err := store.FetchLiquidityParams() + params, err := store.FetchLiquidityParams(ctxb) require.NoError(t, err, "failed to fetch params") require.Empty(t, params, "expect empty bytes") params = []byte("test") // Test we can save the params. - err = store.PutLiquidityParams(params) + err = store.PutLiquidityParams(ctxb, params) require.NoError(t, err, "failed to put params") // Now fetch the db again should return the above saved bytes. - paramsRead, err := store.FetchLiquidityParams() + paramsRead, err := store.FetchLiquidityParams(ctxb) require.NoError(t, err, "failed to fetch params") require.Equal(t, params, paramsRead, "unexpected return value") } diff --git a/loopin.go b/loopin.go index d755122..dd1ee8e 100644 --- a/loopin.go +++ b/loopin.go @@ -298,7 +298,7 @@ func newLoopInSwap(globalCtx context.Context, cfg *swapConfig, // Persist the data before exiting this function, so that the caller can // trust that this swap will be resumed on restart. - err = cfg.store.CreateLoopIn(swapHash, &swap.LoopInContract) + err = cfg.store.CreateLoopIn(globalCtx, swapHash, &swap.LoopInContract) if err != nil { return nil, fmt.Errorf("cannot store swap: %v", err) } @@ -776,7 +776,7 @@ func (s *loopInSwap) publishOnChainHtlc(ctx context.Context) (bool, error) { s.cost.Onchain = fee s.lastUpdateTime = time.Now() - if err := s.persistState(); err != nil { + if err := s.persistState(ctx); err != nil { return false, fmt.Errorf("persist htlc tx: %v", err) } @@ -1068,7 +1068,7 @@ func (s *loopInSwap) publishTimeoutTx(ctx context.Context, // update notification. func (s *loopInSwap) persistAndAnnounceState(ctx context.Context) error { // Update state in store. - if err := s.persistState(); err != nil { + if err := s.persistState(ctx); err != nil { return err } @@ -1077,9 +1077,9 @@ func (s *loopInSwap) persistAndAnnounceState(ctx context.Context) error { } // persistState updates the swap state on disk. -func (s *loopInSwap) persistState() error { +func (s *loopInSwap) persistState(ctx context.Context) error { return s.store.UpdateLoopIn( - s.hash, s.lastUpdateTime, + ctx, s.hash, s.lastUpdateTime, loopdb.SwapStateData{ State: s.state, Cost: s.cost, diff --git a/loopin_test.go b/loopin_test.go index de4452b..52dee18 100644 --- a/loopin_test.go +++ b/loopin_test.go @@ -395,6 +395,7 @@ func testLoopInResume(t *testing.T, state loopdb.SwapState, expired bool, storedVersion loopdb.ProtocolVersion) { defer test.Guard(t)() + ctxb := context.Background() ctx := newLoopInTestContext(t) cfg := newSwapConfig(&ctx.lnd.LndServices, ctx.store, ctx.server) @@ -454,7 +455,7 @@ func testLoopInResume(t *testing.T, state loopdb.SwapState, expired bool, ) require.NoError(t, err) - err = ctx.store.CreateLoopIn(testPreimage.Hash(), contract) + err = ctx.store.CreateLoopIn(ctxb, testPreimage.Hash(), contract) require.NoError(t, err) inSwap, err := resumeLoopInSwap(context.Background(), cfg, pendSwap) diff --git a/loopout.go b/loopout.go index eca55bc..780a63d 100644 --- a/loopout.go +++ b/loopout.go @@ -233,7 +233,7 @@ func newLoopOutSwap(globalCtx context.Context, cfg *swapConfig, // Persist the data before exiting this function, so that the caller // can trust that this swap will be resumed on restart. - err = cfg.store.CreateLoopOut(swapHash, &swap.LoopOutContract) + err = cfg.store.CreateLoopOut(globalCtx, swapHash, &swap.LoopOutContract) if err != nil { return nil, fmt.Errorf("cannot store swap: %v", err) } @@ -578,7 +578,7 @@ func (s *loopOutSwap) persistState(ctx context.Context) error { // Update state in store. err := s.store.UpdateLoopOut( - s.hash, updateTime, + ctx, s.hash, updateTime, loopdb.SwapStateData{ State: s.state, Cost: s.cost, diff --git a/store_mock_test.go b/store_mock_test.go index ccf6305..1b9b483 100644 --- a/store_mock_test.go +++ b/store_mock_test.go @@ -1,6 +1,7 @@ package loop import ( + "context" "errors" "testing" "time" @@ -45,7 +46,7 @@ func newStoreMock(t *testing.T) *storeMock { // FetchLoopOutSwaps returns all swaps currently in the store. // // NOTE: Part of the loopdb.SwapStore interface. -func (s *storeMock) FetchLoopOutSwaps() ([]*loopdb.LoopOut, error) { +func (s *storeMock) FetchLoopOutSwaps(ctx context.Context) ([]*loopdb.LoopOut, error) { result := []*loopdb.LoopOut{} for hash, contract := range s.loopOutSwaps { @@ -73,7 +74,7 @@ func (s *storeMock) FetchLoopOutSwaps() ([]*loopdb.LoopOut, error) { // FetchLoopOutSwaps returns all swaps currently in the store. // // NOTE: Part of the loopdb.SwapStore interface. -func (s *storeMock) FetchLoopOutSwap( +func (s *storeMock) FetchLoopOutSwap(ctx context.Context, hash lntypes.Hash) (*loopdb.LoopOut, error) { contract, ok := s.loopOutSwaps[hash] @@ -103,7 +104,7 @@ func (s *storeMock) FetchLoopOutSwap( // CreateLoopOut adds an initiated swap to the store. // // NOTE: Part of the loopdb.SwapStore interface. -func (s *storeMock) CreateLoopOut(hash lntypes.Hash, +func (s *storeMock) CreateLoopOut(ctx context.Context, hash lntypes.Hash, swap *loopdb.LoopOutContract) error { _, ok := s.loopOutSwaps[hash] @@ -119,7 +120,9 @@ func (s *storeMock) CreateLoopOut(hash lntypes.Hash, } // FetchLoopInSwaps returns all in swaps currently in the store. -func (s *storeMock) FetchLoopInSwaps() ([]*loopdb.LoopIn, error) { +func (s *storeMock) FetchLoopInSwaps(ctx context.Context) ([]*loopdb.LoopIn, + error) { + result := []*loopdb.LoopIn{} for hash, contract := range s.loopInSwaps { @@ -147,7 +150,7 @@ func (s *storeMock) FetchLoopInSwaps() ([]*loopdb.LoopIn, error) { // CreateLoopIn adds an initiated loop in swap to the store. // // NOTE: Part of the loopdb.SwapStore interface. -func (s *storeMock) CreateLoopIn(hash lntypes.Hash, +func (s *storeMock) CreateLoopIn(ctx context.Context, hash lntypes.Hash, swap *loopdb.LoopInContract) error { _, ok := s.loopInSwaps[hash] @@ -167,8 +170,8 @@ func (s *storeMock) CreateLoopIn(hash lntypes.Hash, // its lifetime. // // NOTE: Part of the loopdb.SwapStore interface. -func (s *storeMock) UpdateLoopOut(hash lntypes.Hash, time time.Time, - state loopdb.SwapStateData) error { +func (s *storeMock) UpdateLoopOut(ctx context.Context, hash lntypes.Hash, + time time.Time, state loopdb.SwapStateData) error { updates, ok := s.loopOutUpdates[hash] if !ok { @@ -187,8 +190,8 @@ func (s *storeMock) UpdateLoopOut(hash lntypes.Hash, time time.Time, // its lifetime. // // NOTE: Part of the loopdb.SwapStore interface. -func (s *storeMock) UpdateLoopIn(hash lntypes.Hash, time time.Time, - state loopdb.SwapStateData) error { +func (s *storeMock) UpdateLoopIn(ctx context.Context, hash lntypes.Hash, + time time.Time, state loopdb.SwapStateData) error { updates, ok := s.loopInUpdates[hash] if !ok { @@ -206,7 +209,9 @@ func (s *storeMock) UpdateLoopIn(hash lntypes.Hash, time time.Time, // bucket. // // NOTE: Part of the loopdb.SwapStore interface. -func (s *storeMock) PutLiquidityParams(params []byte) error { +func (s *storeMock) PutLiquidityParams(ctx context.Context, + params []byte) error { + return nil } @@ -214,7 +219,7 @@ func (s *storeMock) PutLiquidityParams(params []byte) error { // the bucket. // // NOTE: Part of the loopdb.SwapStore interface. -func (s *storeMock) FetchLiquidityParams() ([]byte, error) { +func (s *storeMock) FetchLiquidityParams(ctx context.Context) ([]byte, error) { return nil, nil }