diff --git a/liquidity/autoloop_test.go b/liquidity/autoloop_test.go index bcc07a5..253b12b 100644 --- a/liquidity/autoloop_test.go +++ b/liquidity/autoloop_test.go @@ -199,16 +199,30 @@ func TestAutoLoopEnabled(t *testing.T) { }, }, } + + singleLoopOut = &loopdb.LoopOut{ + Loop: loopdb.Loop{ + Events: []*loopdb.LoopEvent{ + { + SwapStateData: loopdb.SwapStateData{ + State: loopdb.StateInitiated, + }, + }, + }, + }, + } ) // Tick our autolooper with no existing swaps, we expect a loop out // swap to be dispatched for each channel. step := &autoloopStep{ - minAmt: 1, - maxAmt: amt + 1, - quotesOut: quotes, - expectedOut: loopOuts, + minAmt: 1, + maxAmt: amt + 1, + quotesOut: quotes, + expectedOut: loopOuts, + existingOutSingle: singleLoopOut, } + c.autoloop(step) // Tick again with both of our swaps in progress. We haven't shifted our @@ -220,9 +234,10 @@ func TestAutoLoopEnabled(t *testing.T) { } step = &autoloopStep{ - minAmt: 1, - maxAmt: amt + 1, - existingOut: existing, + minAmt: 1, + maxAmt: amt + 1, + existingOut: existing, + existingOutSingle: singleLoopOut, } c.autoloop(step) @@ -278,11 +293,12 @@ func TestAutoLoopEnabled(t *testing.T) { // still has balances which reflect that we need to swap), but nothing // for channel 2, since it has had a failure. step = &autoloopStep{ - minAmt: 1, - maxAmt: amt + 1, - existingOut: existing, - quotesOut: quotes, - expectedOut: loopOuts, + minAmt: 1, + maxAmt: amt + 1, + existingOut: existing, + quotesOut: quotes, + expectedOut: loopOuts, + existingOutSingle: singleLoopOut, } c.autoloop(step) @@ -299,10 +315,11 @@ func TestAutoLoopEnabled(t *testing.T) { } step = &autoloopStep{ - minAmt: 1, - maxAmt: amt + 1, - existingOut: existing, - quotesOut: quotes, + minAmt: 1, + maxAmt: amt + 1, + existingOut: existing, + quotesOut: quotes, + existingOutSingle: singleLoopOut, } c.autoloop(step) @@ -446,13 +463,27 @@ func TestAutoloopAddress(t *testing.T) { }, }, } + + singleLoopOut = &loopdb.LoopOut{ + Loop: loopdb.Loop{ + Events: []*loopdb.LoopEvent{ + { + SwapStateData: loopdb.SwapStateData{ + State: loopdb.StateHtlcPublished, + }, + }, + }, + }, + } ) step := &autoloopStep{ - minAmt: 1, - maxAmt: amt + 1, - quotesOut: quotes, - expectedOut: loopOuts, + minAmt: 1, + maxAmt: amt + 1, + quotesOut: quotes, + expectedOut: loopOuts, + existingOutSingle: singleLoopOut, + keepDestAddr: true, } c.autoloop(step) @@ -606,6 +637,18 @@ func TestCompositeRules(t *testing.T) { }, }, } + + singleLoopOut = &loopdb.LoopOut{ + Loop: loopdb.Loop{ + Events: []*loopdb.LoopEvent{ + { + SwapStateData: loopdb.SwapStateData{ + State: loopdb.StateHtlcPublished, + }, + }, + }, + }, + } ) // Tick our autolooper with no existing swaps, we expect a loop out @@ -613,10 +656,11 @@ func TestCompositeRules(t *testing.T) { // maximum to be greater than the swap amount for our peer swap (which // is the larger of the two swaps). step := &autoloopStep{ - minAmt: 1, - maxAmt: peerAmount + 1, - quotesOut: quotes, - expectedOut: loopOuts, + minAmt: 1, + maxAmt: peerAmount + 1, + quotesOut: quotes, + expectedOut: loopOuts, + existingOutSingle: singleLoopOut, } c.autoloop(step) @@ -928,6 +972,18 @@ func TestAutoloopBothTypes(t *testing.T) { Label: labels.AutoloopLabel(swap.TypeIn), Initiator: autoloopSwapInitiator, } + + singleLoopOut = &loopdb.LoopOut{ + Loop: loopdb.Loop{ + Events: []*loopdb.LoopEvent{ + { + SwapStateData: loopdb.SwapStateData{ + State: loopdb.StateHtlcPublished, + }, + }, + }, + }, + } ) step := &autoloopStep{ @@ -961,6 +1017,7 @@ func TestAutoloopBothTypes(t *testing.T) { }, }, }, + existingOutSingle: singleLoopOut, } c.autoloop(step) c.stop() diff --git a/liquidity/autoloop_testcontext_test.go b/liquidity/autoloop_testcontext_test.go index bf25ab1..8379fa9 100644 --- a/liquidity/autoloop_testcontext_test.go +++ b/liquidity/autoloop_testcontext_test.go @@ -2,7 +2,9 @@ package liquidity import ( "context" + "reflect" "testing" + "time" "github.com/btcsuite/btcd/btcutil" "github.com/lightninglabs/lndclient" @@ -11,8 +13,10 @@ import ( "github.com/lightninglabs/loop/swap" "github.com/lightninglabs/loop/test" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/ticker" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type autoloopTestCtx struct { @@ -45,9 +49,17 @@ type autoloopTestCtx struct { // loopOuts is a channel that we get existing loop out swaps on. loopOuts chan []*loopdb.LoopOut + // loopOutSingle is the single loop out returned from fetching a single + // swap from store. + loopOutSingle *loopdb.LoopOut + // loopIns is a channel that we get existing loop in swaps on. loopIns chan []*loopdb.LoopIn + // loopInSingle is the single loop in returned from fetching a single + // swap from store. + loopInSingle *loopdb.LoopIn + // restrictions is a channel that we get swap restrictions on. restrictions chan *Restrictions @@ -131,6 +143,9 @@ func newAutoloopTestCtx(t *testing.T, parameters Parameters, ListLoopOut: func() ([]*loopdb.LoopOut, error) { return <-testCtx.loopOuts, nil }, + GetLoopOut: func(hash lntypes.Hash) (*loopdb.LoopOut, error) { + return testCtx.loopOutSingle, nil + }, ListLoopIn: func() ([]*loopdb.LoopIn, error) { return <-testCtx.loopIns, nil }, @@ -188,6 +203,10 @@ func newAutoloopTestCtx(t *testing.T, parameters Parameters, testCtx.manager = NewManager(cfg) err := testCtx.manager.setParameters(context.Background(), parameters) assert.NoError(t, err) + // Override the payments check interval for the tests in order to not + // timeout. + testCtx.manager.params.CustomPaymentCheckInterval = + 150 * time.Millisecond <-done return testCtx } @@ -241,14 +260,17 @@ type loopInRequestResp struct { // autoloopStep contains all of the information to required to step // through an autoloop tick. type autoloopStep struct { - minAmt btcutil.Amount - maxAmt btcutil.Amount - existingOut []*loopdb.LoopOut - existingIn []*loopdb.LoopIn - quotesOut []quoteRequestResp - quotesIn []quoteInRequestResp - expectedOut []loopOutRequestResp - expectedIn []loopInRequestResp + minAmt btcutil.Amount + maxAmt btcutil.Amount + existingOut []*loopdb.LoopOut + existingOutSingle *loopdb.LoopOut + existingIn []*loopdb.LoopIn + existingInSingle *loopdb.LoopIn + quotesOut []quoteRequestResp + quotesIn []quoteInRequestResp + expectedOut []loopOutRequestResp + expectedIn []loopInRequestResp + keepDestAddr bool } // autoloop walks our test context through the process of triggering our @@ -269,6 +291,9 @@ func (c *autoloopTestCtx) autoloop(step *autoloopStep) { c.loopOuts <- step.existingOut c.loopIns <- step.existingIn + c.loopOutSingle = step.existingOutSingle + c.loopInSingle = step.existingInSingle + // Assert that we query the server for a quote for each of our // recommended swaps. Note that this differs from our set of expected // swaps because we may get quotes for suggested swaps but then just @@ -299,25 +324,77 @@ func (c *autoloopTestCtx) autoloop(step *autoloopStep) { c.quotes <- expected.quote } - // Assert that we dispatch the expected set of swaps. - for _, expected := range step.expectedOut { + require.True(c.t, c.matchLoopOuts(step.expectedOut, step.keepDestAddr)) + require.True(c.t, c.matchLoopIns(step.expectedIn)) +} + +// matchLoopOuts checks that the actual loop out requests we got match the +// expected ones. The argument keepDestAddr is used to indicate whether we keep +// the actual loops destination address for the comparison. This is useful +// because we don't want to compare the destination address generated by the +// wallet mock. We want to compare the destination address when testing the +// autoloop DestAddr parameter for loop outs. +func (c *autoloopTestCtx) matchLoopOuts(swaps []loopOutRequestResp, + keepDestAddr bool) bool { + + swapsCopy := make([]loopOutRequestResp, len(swaps)) + copy(swapsCopy, swaps) + + length := len(swapsCopy) + + for i := 0; i < length; i++ { actual := <-c.outRequest - // Set our destination address to nil so that we do not need to - // provide the address that is obtained by the mock wallet kit. - if expected.request.DestAddr == nil { + if !keepDestAddr { actual.DestAddr = nil } - assert.Equal(c.t, expected.request, actual) - c.loopOut <- expected.response + inner: + for index, swap := range swapsCopy { + equal := reflect.DeepEqual(swap.request, actual) + + if equal { + c.loopOut <- swap.response + + swapsCopy = append( + swapsCopy[:index], + swapsCopy[index+1:]..., + ) + + break inner + } + } } - for _, expected := range step.expectedIn { + return len(swapsCopy) == 0 +} + +// matchLoopIns checks that the actual loop in requests we got match the +// expected ones. +func (c *autoloopTestCtx) matchLoopIns( + swaps []loopInRequestResp) bool { + + swapsCopy := make([]loopInRequestResp, len(swaps)) + copy(swapsCopy, swaps) + + for i := 0; i < len(swapsCopy); i++ { actual := <-c.inRequest - assert.Equal(c.t, expected.request, actual) + inner: + for i, swap := range swapsCopy { + equal := reflect.DeepEqual(swap.request, actual) - c.loopIn <- expected.response + if equal { + c.loopIn <- swap.response + + swapsCopy = append( + swapsCopy[:i], swapsCopy[i+1:]..., + ) + + break inner + } + } } + + return len(swapsCopy) == 0 } diff --git a/liquidity/liquidity.go b/liquidity/liquidity.go index 9b8179f..4f6a45d 100644 --- a/liquidity/liquidity.go +++ b/liquidity/liquidity.go @@ -48,6 +48,7 @@ import ( "github.com/lightninglabs/loop/swap" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/funding" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" @@ -62,6 +63,22 @@ const ( // a channel is part of a temporarily failed swap. defaultFailureBackoff = time.Hour * 24 + // defaultAmountBackoff is the default backoff we apply to the amount + // of a loop out swap that failed the off-chain payments. + defaultAmountBackoff = float64(0.25) + + // defaultAmountBackoffRetry is the default number of times we will + // perform an amount backoff to a loop out swap before we give up. + defaultAmountBackoffRetry = 5 + + // defaultSwapWaitTimeout is the default maximum amount of time we + // wait for a swap to reach a terminal state. + defaultSwapWaitTimeout = time.Hour * 24 + + // defaultPaymentCheckInterval is the default time that passes between + // checks for loop out payments status. + defaultPaymentCheckInterval = time.Second * 2 + // defaultConfTarget is the default sweep target we use for loop outs. // We get our inbound liquidity quickly using preimage push, so we can // use a long conf target without worrying about ux impact. @@ -78,7 +95,7 @@ const ( // DefaultAutoloopTicker is the default amount of time between automated // swap checks. - DefaultAutoloopTicker = time.Minute * 10 + DefaultAutoloopTicker = time.Minute * 20 // autoloopSwapInitiator is the value we send in the initiator field of // a swap request when issuing an automatic swap. @@ -164,6 +181,10 @@ type Config struct { // ListLoopOut returns all of the loop our swaps stored on disk. ListLoopOut func() ([]*loopdb.LoopOut, error) + // GetLoopOut returns a single loop out swap based on the provided swap + // hash. + GetLoopOut func(hash lntypes.Hash) (*loopdb.LoopOut, error) + // ListLoopIn returns all of the loop in swaps stored on disk. ListLoopIn func() ([]*loopdb.LoopIn, error) @@ -399,13 +420,10 @@ func (m *Manager) autoloop(ctx context.Context) error { swap.DestAddr = m.params.DestAddr } - loopOut, err := m.cfg.LoopOut(ctx, &swap) - if err != nil { - return err - } - - log.Infof("loop out automatically dispatched: hash: %v, "+ - "address: %v", loopOut.SwapHash, loopOut.HtlcAddress) + go m.dispatchStickyLoopOut( + ctx, swap, defaultAmountBackoffRetry, + defaultAmountBackoff, + ) } for _, in := range suggestion.InSwaps { @@ -1044,6 +1062,143 @@ func (m *Manager) refreshAutoloopBudget(ctx context.Context) { } } +// dispatchStickyLoopOut attempts to dispatch a loop out swap that will +// automatically retry its execution with an amount based backoff. +func (m *Manager) dispatchStickyLoopOut(ctx context.Context, + out loop.OutRequest, retryCount uint16, amountBackoff float64) { + + for i := 0; i < int(retryCount); i++ { + // Dispatch the swap. + swap, err := m.cfg.LoopOut(ctx, &out) + if err != nil { + log.Errorf("unable to dispatch loop out, hash: %v, "+ + "err: %v", swap.SwapHash, err) + } + + log.Infof("loop out automatically dispatched: hash: %v, "+ + "address: %v, amount %v", swap.SwapHash, + swap.HtlcAddress, out.Amount) + + updates := make(chan *loopdb.SwapState, 1) + + // Monitor the swap state and write the desired update to the + // update channel. We do not want to read all of the swap state + // updates, just the one that will help us assume the state of + // the off-chain payment. + go m.waitForSwapPayment( + ctx, swap.SwapHash, updates, defaultSwapWaitTimeout, + ) + + select { + case <-ctx.Done(): + return + + case update := <-updates: + if update == nil { + // If update is nil then no update occurred + // within the defined timeout period. It's + // better to return and not attempt a retry. + log.Debug( + "No payment update received for swap "+ + "%v, skipping amount backoff", + swap.SwapHash, + ) + + return + } + + if *update == loopdb.StateFailOffchainPayments { + // Save the old amount so we can log it. + oldAmt := out.Amount + + // If we failed to pay the server, we will + // decrease the amount of the swap and try + // again. + out.Amount -= btcutil.Amount( + float64(out.Amount) * amountBackoff, + ) + + log.Infof("swap %v: amount backoff old amount="+ + "%v, new amount=%v", swap.SwapHash, + oldAmt, out.Amount) + + continue + } else { + // If the update channel did not return an + // off-chain payment failure we won't retry. + return + } + } + } +} + +// waitForSwapPayment waits for a swap to progress beyond the stage of +// forwarding the payment to the server through the network. It returns the +// final update on the outcome through a channel. +func (m *Manager) waitForSwapPayment(ctx context.Context, swapHash lntypes.Hash, + updateChan chan *loopdb.SwapState, timeout time.Duration) { + + startTime := time.Now() + var ( + swap *loopdb.LoopOut + err error + interval time.Duration + ) + + if m.params.CustomPaymentCheckInterval != 0 { + interval = m.params.CustomPaymentCheckInterval + } else { + interval = defaultPaymentCheckInterval + } + + for time.Since(startTime) < timeout { + select { + case <-ctx.Done(): + return + case <-time.After(interval): + } + + swap, err = m.cfg.GetLoopOut(swapHash) + if err != nil { + log.Errorf( + "Error getting swap with hash %x: %v", swapHash, + err, + ) + continue + } + + // If no update has occurred yet, continue in order to wait. + update := swap.LastUpdate() + if update == nil { + continue + } + + // Write the update if the swap has reached a state the helps + // us determine whether the off-chain payment successfully + // reached the destination. + switch update.State { + case loopdb.StateFailInsufficientValue: + fallthrough + case loopdb.StateSuccess: + fallthrough + case loopdb.StateFailSweepTimeout: + fallthrough + case loopdb.StateFailTimeout: + fallthrough + case loopdb.StatePreimageRevealed: + fallthrough + case loopdb.StateFailOffchainPayments: + updateChan <- &update.State + return + } + } + + // If no update occurred within the defined timeout we return an empty + // update to the channel, causing the sticky loop out to not retry + // anymore. + updateChan <- nil +} + // swapTraffic contains a summary of our current and previously failed swaps. type swapTraffic struct { ongoingLoopOut map[lnwire.ShortChannelID]bool diff --git a/liquidity/parameters.go b/liquidity/parameters.go index 12bfa5a..2b74c5c 100644 --- a/liquidity/parameters.go +++ b/liquidity/parameters.go @@ -87,6 +87,10 @@ type Parameters struct { // ChannelRules are exclusively set to prevent overlap between peer // and channel rules map to avoid ambiguity. PeerRules map[route.Vertex]*SwapRule + + // CustomPaymentCheckInterval is an optional custom interval to use when + // checking an autoloop loop out payments' payment status. + CustomPaymentCheckInterval time.Duration } // String returns the string representation of our parameters. diff --git a/loopd/utils.go b/loopd/utils.go index 611e370..c21b781 100644 --- a/loopd/utils.go +++ b/loopd/utils.go @@ -72,6 +72,7 @@ func getLiquidityManager(client *loop.Client) *liquidity.Manager { LoopOutQuote: client.LoopOutQuote, LoopInQuote: client.LoopInQuote, ListLoopOut: client.Store.FetchLoopOutSwaps, + GetLoopOut: client.Store.FetchLoopOutSwap, ListLoopIn: client.Store.FetchLoopInSwaps, MinimumConfirmations: minConfTarget, PutLiquidityParams: client.Store.PutLiquidityParams, diff --git a/loopdb/interface.go b/loopdb/interface.go index 41172da..6f85067 100644 --- a/loopdb/interface.go +++ b/loopdb/interface.go @@ -12,6 +12,9 @@ type SwapStore interface { // FetchLoopOutSwaps returns all swaps currently in the store. FetchLoopOutSwaps() ([]*LoopOut, error) + // FetchLoopOutSwap returns the loop out swap with the given hash. + FetchLoopOutSwap(hash lntypes.Hash) (*LoopOut, error) + // CreateLoopOut adds an initiated swap to the store. CreateLoopOut(hash lntypes.Hash, swap *LoopOutContract) error diff --git a/loopdb/store.go b/loopdb/store.go index 1343974..b86008d 100644 --- a/loopdb/store.go +++ b/loopdb/store.go @@ -255,111 +255,12 @@ func (s *boltSwapStore) FetchLoopOutSwaps() ([]*LoopOut, error) { return nil } - // From the root bucket, we'll grab the next swap - // bucket for this swap from its swaphash. - swapBucket := rootBucket.Bucket(swapHash) - if swapBucket == nil { - return fmt.Errorf("swap bucket %x not found", - swapHash) - } - - // With the main swap bucket obtained, we'll grab the - // raw swap contract bytes and decode it. - contractBytes := swapBucket.Get(contractKey) - if contractBytes == nil { - return errors.New("contract not found") - } - - contract, err := deserializeLoopOutContract( - contractBytes, s.chainParams, - ) + loop, err := s.fetchLoopOutSwap(rootBucket, swapHash) if err != nil { return err } - // Get our label for this swap, if it is present. - contract.Label = getLabel(swapBucket) - - // Read the list of concatenated outgoing channel ids - // that form the outgoing set. - setBytes := swapBucket.Get(outgoingChanSetKey) - if outgoingChanSetKey != nil { - r := bytes.NewReader(setBytes) - readLoop: - for { - var chanID uint64 - err := binary.Read(r, byteOrder, &chanID) - switch { - case err == io.EOF: - break readLoop - case err != nil: - return err - } - - contract.OutgoingChanSet = append( - contract.OutgoingChanSet, - chanID, - ) - } - } - - // Set our default number of confirmations for the swap. - contract.HtlcConfirmations = DefaultLoopOutHtlcConfirmations - - // If we have the number of confirmations stored for - // this swap, we overwrite our default with the stored - // value. - confBytes := swapBucket.Get(confirmationsKey) - if confBytes != nil { - r := bytes.NewReader(confBytes) - err := binary.Read( - r, byteOrder, &contract.HtlcConfirmations, - ) - if err != nil { - return err - } - } - - updates, err := deserializeUpdates(swapBucket) - if err != nil { - return err - } - - // Try to unmarshal the protocol version for the swap. - // If the protocol version is not stored (which is - // the case for old clients), we'll assume the - // ProtocolVersionUnrecorded instead. - contract.ProtocolVersion, err = - UnmarshalProtocolVersion( - swapBucket.Get(protocolVersionKey), - ) - if err != nil { - return err - } - - // Try to unmarshal the key locator. - if contract.ProtocolVersion >= ProtocolVersionHtlcV3 { - contract.ClientKeyLocator, err = UnmarshalKeyLocator( - swapBucket.Get(keyLocatorKey), - ) - if err != nil { - return err - } - } - - loop := LoopOut{ - Loop: Loop{ - Events: updates, - }, - Contract: contract, - } - - loop.Hash, err = lntypes.MakeHash(swapHash) - if err != nil { - return err - } - - swaps = append(swaps, &loop) + swaps = append(swaps, loop) return nil }) @@ -371,53 +272,33 @@ func (s *boltSwapStore) FetchLoopOutSwaps() ([]*LoopOut, error) { return swaps, nil } -// deserializeUpdates deserializes the list of swap updates that are stored as a -// key of the given bucket. -func deserializeUpdates(swapBucket *bbolt.Bucket) ([]*LoopEvent, error) { - // Once we have the raw swap, we'll also need to decode - // each of the past updates to the swap itself. - stateBucket := swapBucket.Bucket(updatesBucketKey) - if stateBucket == nil { - return nil, errors.New("updates bucket not found") - } - - // Deserialize and collect each swap update into our slice of swap - // events. - var updates []*LoopEvent - err := stateBucket.ForEach(func(k, v []byte) error { - updateBucket := stateBucket.Bucket(k) - if updateBucket == nil { - return fmt.Errorf("expected state sub-bucket for %x", k) - } +// FetchLoopOutSwap returns the loop out swap with the given hash. +// +// NOTE: Part of the loopdb.SwapStore interface. +func (s *boltSwapStore) FetchLoopOutSwap(hash lntypes.Hash) (*LoopOut, error) { + var swap *LoopOut - basicState := updateBucket.Get(basicStateKey) - if basicState == nil { - return errors.New("no basic state for update") + err := s.db.View(func(tx *bbolt.Tx) error { + // First, we'll grab our main loop out bucket key. + rootBucket := tx.Bucket(loopOutBucketKey) + if rootBucket == nil { + return errors.New("bucket does not exist") } - event, err := deserializeLoopEvent(basicState) + loop, err := s.fetchLoopOutSwap(rootBucket, hash[:]) if err != nil { return err } - // Deserialize htlc tx hash if this updates contains one. - htlcTxHashBytes := updateBucket.Get(htlcTxHashKey) - if htlcTxHashBytes != nil { - htlcTxHash, err := chainhash.NewHash(htlcTxHashBytes) - if err != nil { - return err - } - event.HtlcTxHash = htlcTxHash - } + swap = loop - updates = append(updates, event) return nil }) if err != nil { return nil, err } - return updates, nil + return swap, nil } // FetchLoopInSwaps returns all loop in swaps currently in the store. @@ -442,71 +323,12 @@ func (s *boltSwapStore) FetchLoopInSwaps() ([]*LoopIn, error) { return nil } - // From the root bucket, we'll grab the next swap - // bucket for this swap from its swaphash. - swapBucket := rootBucket.Bucket(swapHash) - if swapBucket == nil { - return fmt.Errorf("swap bucket %x not found", - swapHash) - } - - // With the main swap bucket obtained, we'll grab the - // raw swap contract bytes and decode it. - contractBytes := swapBucket.Get(contractKey) - if contractBytes == nil { - return errors.New("contract not found") - } - - contract, err := deserializeLoopInContract( - contractBytes, - ) + loop, err := s.fetchLoopInSwap(rootBucket, swapHash) if err != nil { return err } - // Get our label for this swap, if it is present. - contract.Label = getLabel(swapBucket) - - updates, err := deserializeUpdates(swapBucket) - if err != nil { - return err - } - - // Try to unmarshal the protocol version for the swap. - // If the protocol version is not stored (which is - // the case for old clients), we'll assume the - // ProtocolVersionUnrecorded instead. - contract.ProtocolVersion, err = - UnmarshalProtocolVersion( - swapBucket.Get(protocolVersionKey), - ) - if err != nil { - return err - } - - // Try to unmarshal the key locator. - if contract.ProtocolVersion >= ProtocolVersionHtlcV3 { - contract.ClientKeyLocator, err = UnmarshalKeyLocator( - swapBucket.Get(keyLocatorKey), - ) - if err != nil { - return err - } - } - - loop := LoopIn{ - Loop: Loop{ - Events: updates, - }, - Contract: contract, - } - - loop.Hash, err = lntypes.MakeHash(swapHash) - if err != nil { - return err - } - - swaps = append(swaps, &loop) + swaps = append(swaps, loop) return nil }) @@ -824,3 +646,243 @@ func (s *boltSwapStore) FetchLiquidityParams() ([]byte, error) { return params, err } + +// fetchUpdates deserializes the list of swap updates that are stored as a +// key of the given bucket. +func fetchUpdates(swapBucket *bbolt.Bucket) ([]*LoopEvent, error) { + // Once we have the raw swap, we'll also need to decode + // each of the past updates to the swap itself. + stateBucket := swapBucket.Bucket(updatesBucketKey) + if stateBucket == nil { + return nil, errors.New("updates bucket not found") + } + + // Deserialize and collect each swap update into our slice of swap + // events. + var updates []*LoopEvent + err := stateBucket.ForEach(func(k, v []byte) error { + updateBucket := stateBucket.Bucket(k) + if updateBucket == nil { + return fmt.Errorf("expected state sub-bucket for %x", k) + } + + basicState := updateBucket.Get(basicStateKey) + if basicState == nil { + return errors.New("no basic state for update") + } + + event, err := deserializeLoopEvent(basicState) + if err != nil { + return err + } + + // Deserialize htlc tx hash if this updates contains one. + htlcTxHashBytes := updateBucket.Get(htlcTxHashKey) + if htlcTxHashBytes != nil { + htlcTxHash, err := chainhash.NewHash(htlcTxHashBytes) + if err != nil { + return err + } + event.HtlcTxHash = htlcTxHash + } + + updates = append(updates, event) + return nil + }) + if err != nil { + return nil, err + } + + return updates, nil +} + +// fetchLoopOutSwap fetches and deserializes the raw swap bytes into a LoopOut +// struct. +func (s *boltSwapStore) fetchLoopOutSwap(rootBucket *bbolt.Bucket, + swapHash []byte) (*LoopOut, error) { + + // From the root bucket, we'll grab the next swap + // bucket for this swap from its swaphash. + swapBucket := rootBucket.Bucket(swapHash) + if swapBucket == nil { + return nil, fmt.Errorf("swap bucket %x not found", + swapHash) + } + + hash, err := lntypes.MakeHash(swapHash) + if err != nil { + return nil, err + } + + // With the main swap bucket obtained, we'll grab the + // raw swap contract bytes and decode it. + contractBytes := swapBucket.Get(contractKey) + if contractBytes == nil { + return nil, errors.New("contract not found") + } + + contract, err := deserializeLoopOutContract( + contractBytes, s.chainParams, + ) + if err != nil { + return nil, err + } + + // Get our label for this swap, if it is present. + contract.Label = getLabel(swapBucket) + + // Read the list of concatenated outgoing channel ids + // that form the outgoing set. + setBytes := swapBucket.Get(outgoingChanSetKey) + if outgoingChanSetKey != nil { + r := bytes.NewReader(setBytes) + readLoop: + for { + var chanID uint64 + err := binary.Read(r, byteOrder, &chanID) + switch { + case err == io.EOF: + break readLoop + case err != nil: + return nil, err + } + + contract.OutgoingChanSet = append( + contract.OutgoingChanSet, + chanID, + ) + } + } + + // Set our default number of confirmations for the swap. + contract.HtlcConfirmations = DefaultLoopOutHtlcConfirmations + + // If we have the number of confirmations stored for + // this swap, we overwrite our default with the stored + // value. + confBytes := swapBucket.Get(confirmationsKey) + if confBytes != nil { + r := bytes.NewReader(confBytes) + err := binary.Read( + r, byteOrder, &contract.HtlcConfirmations, + ) + if err != nil { + return nil, err + } + } + + updates, err := fetchUpdates(swapBucket) + if err != nil { + return nil, err + } + + // Try to unmarshal the protocol version for the swap. + // If the protocol version is not stored (which is + // the case for old clients), we'll assume the + // ProtocolVersionUnrecorded instead. + contract.ProtocolVersion, err = + UnmarshalProtocolVersion( + swapBucket.Get(protocolVersionKey), + ) + if err != nil { + return nil, err + } + + // Try to unmarshal the key locator. + if contract.ProtocolVersion >= ProtocolVersionHtlcV3 { + contract.ClientKeyLocator, err = UnmarshalKeyLocator( + swapBucket.Get(keyLocatorKey), + ) + if err != nil { + return nil, err + } + } + + loop := LoopOut{ + Loop: Loop{ + Events: updates, + }, + Contract: contract, + } + + loop.Hash, err = lntypes.MakeHash(hash[:]) + if err != nil { + return nil, err + } + + return &loop, nil +} + +// fetchLoopInSwap fetches and deserializes the raw swap bytes into a LoopIn +// struct. +func (s *boltSwapStore) fetchLoopInSwap(rootBucket *bbolt.Bucket, + swapHash []byte) (*LoopIn, error) { + + // From the root bucket, we'll grab the next swap + // bucket for this swap from its swaphash. + swapBucket := rootBucket.Bucket(swapHash) + if swapBucket == nil { + return nil, fmt.Errorf("swap bucket %x not found", + swapHash) + } + + hash, err := lntypes.MakeHash(swapHash) + if err != nil { + return nil, err + } + + // With the main swap bucket obtained, we'll grab the + // raw swap contract bytes and decode it. + contractBytes := swapBucket.Get(contractKey) + if contractBytes == nil { + return nil, errors.New("contract not found") + } + + contract, err := deserializeLoopInContract( + contractBytes, + ) + if err != nil { + return nil, err + } + + // Get our label for this swap, if it is present. + contract.Label = getLabel(swapBucket) + + updates, err := fetchUpdates(swapBucket) + if err != nil { + return nil, err + } + + // Try to unmarshal the protocol version for the swap. + // If the protocol version is not stored (which is + // the case for old clients), we'll assume the + // ProtocolVersionUnrecorded instead. + contract.ProtocolVersion, err = + UnmarshalProtocolVersion( + swapBucket.Get(protocolVersionKey), + ) + if err != nil { + return nil, err + } + + // Try to unmarshal the key locator. + if contract.ProtocolVersion >= ProtocolVersionHtlcV3 { + contract.ClientKeyLocator, err = UnmarshalKeyLocator( + swapBucket.Get(keyLocatorKey), + ) + if err != nil { + return nil, err + } + } + + loop := LoopIn{ + Loop: Loop{ + Events: updates, + }, + Contract: contract, + } + + loop.Hash = hash + + return &loop, nil +} diff --git a/loopdb/store_test.go b/loopdb/store_test.go index ea349d6..8d2602c 100644 --- a/loopdb/store_test.go +++ b/loopdb/store_test.go @@ -5,7 +5,6 @@ import ( "io/ioutil" "os" "path/filepath" - "reflect" "testing" "time" @@ -96,24 +95,20 @@ func TestLoopOutStore(t *testing.T) { // swap store for specific swap parameters. func testLoopOutStore(t *testing.T, pendingSwap *LoopOutContract) { tempDirName, err := ioutil.TempDir("", "clientstore") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) + defer os.RemoveAll(tempDirName) store, err := NewBoltSwapStore(tempDirName, &chaincfg.MainNetParams) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // First, verify that an empty database has no active swaps. swaps, err := store.FetchLoopOutSwaps() - if err != nil { - t.Fatal(err) - } - if len(swaps) != 0 { - t.Fatal("expected empty store") - } + + require.NoError(t, err) + require.Empty(t, swaps) + + hash := pendingSwap.Preimage.Hash() // checkSwap is a test helper function that'll assert the state of a // swap. @@ -121,43 +116,37 @@ func testLoopOutStore(t *testing.T, pendingSwap *LoopOutContract) { t.Helper() swaps, err := store.FetchLoopOutSwaps() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) - if len(swaps) != 1 { - t.Fatal("expected pending swap in store") - } + require.Len(t, swaps, 1) - swap := swaps[0].Contract - if !reflect.DeepEqual(swap, pendingSwap) { - t.Fatal("invalid pending swap data") - } + swap, err := store.FetchLoopOutSwap(hash) + require.NoError(t, err) - if swaps[0].State().State != expectedState { - t.Fatalf("expected state %v, but got %v", - expectedState, swaps[0].State(), - ) - } + require.Equal(t, hash, swap.Hash) + require.Equal(t, hash, swaps[0].Hash) + + swapContract := swap.Contract + + require.Equal(t, swapContract, pendingSwap) + + require.Equal(t, expectedState, swap.State().State) if expectedState == StatePreimageRevealed { - require.NotNil(t, swaps[0].State().HtlcTxHash) + require.NotNil(t, swap.State().HtlcTxHash) } } - hash := pendingSwap.Preimage.Hash() - // If we create a new swap, then it should show up as being initialized // right after. - if err := store.CreateLoopOut(hash, pendingSwap); err != nil { - t.Fatal(err) - } + err = store.CreateLoopOut(hash, pendingSwap) + require.NoError(t, err) + checkSwap(StateInitiated) // Trying to make the same swap again should result in an error. - if err := store.CreateLoopOut(hash, pendingSwap); err == nil { - t.Fatal("expected error on storing duplicate") - } + err = store.CreateLoopOut(hash, pendingSwap) + require.Error(t, err) checkSwap(StateInitiated) // Next, we'll update to the next state of the pre-image being @@ -169,9 +158,8 @@ func testLoopOutStore(t *testing.T, pendingSwap *LoopOutContract) { HtlcTxHash: &chainhash.Hash{1, 6, 2}, }, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) + checkSwap(StatePreimageRevealed) // Next, we'll update to the final state to ensure that the state is @@ -182,21 +170,17 @@ func testLoopOutStore(t *testing.T, pendingSwap *LoopOutContract) { State: StateFailInsufficientValue, }, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) checkSwap(StateFailInsufficientValue) - if err := store.Close(); err != nil { - t.Fatal(err) - } + err = store.Close() + require.NoError(t, err) // If we re-open the same store, then the state of the current swap // should be the same. store, err = NewBoltSwapStore(tempDirName, &chaincfg.MainNetParams) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) + checkSwap(StateFailInsufficientValue) } @@ -242,24 +226,18 @@ func TestLoopInStore(t *testing.T) { func testLoopInStore(t *testing.T, pendingSwap LoopInContract) { tempDirName, err := ioutil.TempDir("", "clientstore") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer os.RemoveAll(tempDirName) store, err := NewBoltSwapStore(tempDirName, &chaincfg.MainNetParams) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // First, verify that an empty database has no active swaps. swaps, err := store.FetchLoopInSwaps() - if err != nil { - t.Fatal(err) - } - if len(swaps) != 0 { - t.Fatal("expected empty store") - } + require.NoError(t, err) + require.Empty(t, swaps) + + hash := sha256.Sum256(testPreimage[:]) // checkSwap is a test helper function that'll assert the state of a // swap. @@ -267,39 +245,27 @@ func testLoopInStore(t *testing.T, pendingSwap LoopInContract) { t.Helper() swaps, err := store.FetchLoopInSwaps() - if err != nil { - t.Fatal(err) - } - - if len(swaps) != 1 { - t.Fatal("expected pending swap in store") - } + require.NoError(t, err) + require.Len(t, swaps, 1) swap := swaps[0].Contract - if !reflect.DeepEqual(swap, &pendingSwap) { - t.Fatal("invalid pending swap data") - } - if swaps[0].State().State != expectedState { - t.Fatalf("expected state %v, but got %v", - expectedState, swaps[0].State(), - ) - } - } + require.Equal(t, swap, &pendingSwap) - hash := sha256.Sum256(testPreimage[:]) + require.Equal(t, swaps[0].State().State, expectedState) + } // If we create a new swap, then it should show up as being initialized // right after. - if err := store.CreateLoopIn(hash, &pendingSwap); err != nil { - t.Fatal(err) - } + err = store.CreateLoopIn(hash, &pendingSwap) + require.NoError(t, err) + checkSwap(StateInitiated) // Trying to make the same swap again should result in an error. - if err := store.CreateLoopIn(hash, &pendingSwap); err == nil { - t.Fatal("expected error on storing duplicate") - } + err = store.CreateLoopIn(hash, &pendingSwap) + require.Error(t, err) + checkSwap(StateInitiated) // Next, we'll update to the next state of the pre-image being @@ -310,9 +276,8 @@ func testLoopInStore(t *testing.T, pendingSwap LoopInContract) { State: StatePreimageRevealed, }, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) + checkSwap(StatePreimageRevealed) // Next, we'll update to the final state to ensure that the state is @@ -323,21 +288,17 @@ func testLoopInStore(t *testing.T, pendingSwap LoopInContract) { State: StateFailInsufficientValue, }, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) checkSwap(StateFailInsufficientValue) - if err := store.Close(); err != nil { - t.Fatal(err) - } + err = store.Close() + require.NoError(t, err) // If we re-open the same store, then the state of the current swap // should be the same. store, err = NewBoltSwapStore(tempDirName, &chaincfg.MainNetParams) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) + checkSwap(StateFailInsufficientValue) } @@ -467,9 +428,8 @@ func TestLegacyOutgoingChannel(t *testing.T) { // Assert that the outgoing channel is read properly. expectedChannelSet := ChannelSet{5} - if !reflect.DeepEqual(swaps[0].Contract.OutgoingChanSet, expectedChannelSet) { - t.Fatal("invalid outgoing channel") - } + + require.Equal(t, expectedChannelSet, swaps[0].Contract.OutgoingChanSet) } // TestLiquidityParams checks that reading and writing to liquidty bucket are diff --git a/store_mock_test.go b/store_mock_test.go index 4e8a0ad..ccf6305 100644 --- a/store_mock_test.go +++ b/store_mock_test.go @@ -70,6 +70,36 @@ func (s *storeMock) FetchLoopOutSwaps() ([]*loopdb.LoopOut, error) { return result, nil } +// FetchLoopOutSwaps returns all swaps currently in the store. +// +// NOTE: Part of the loopdb.SwapStore interface. +func (s *storeMock) FetchLoopOutSwap( + hash lntypes.Hash) (*loopdb.LoopOut, error) { + + contract, ok := s.loopOutSwaps[hash] + if !ok { + return nil, errors.New("swap not found") + } + + updates := s.loopOutUpdates[hash] + events := make([]*loopdb.LoopEvent, len(updates)) + for i, u := range updates { + events[i] = &loopdb.LoopEvent{ + SwapStateData: u, + } + } + + swap := &loopdb.LoopOut{ + Loop: loopdb.Loop{ + Hash: hash, + Events: events, + }, + Contract: contract, + } + + return swap, nil +} + // CreateLoopOut adds an initiated swap to the store. // // NOTE: Part of the loopdb.SwapStore interface.