Merge pull request #548 from GeorgeTsagk/autoloop-amount-backoff

Autoloop amount backoff
pull/557/head
Olaoluwa Osuntokun 1 year ago committed by GitHub
commit 55845ff8ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -199,16 +199,30 @@ func TestAutoLoopEnabled(t *testing.T) {
}, },
}, },
} }
singleLoopOut = &loopdb.LoopOut{
Loop: loopdb.Loop{
Events: []*loopdb.LoopEvent{
{
SwapStateData: loopdb.SwapStateData{
State: loopdb.StateInitiated,
},
},
},
},
}
) )
// Tick our autolooper with no existing swaps, we expect a loop out // Tick our autolooper with no existing swaps, we expect a loop out
// swap to be dispatched for each channel. // swap to be dispatched for each channel.
step := &autoloopStep{ step := &autoloopStep{
minAmt: 1, minAmt: 1,
maxAmt: amt + 1, maxAmt: amt + 1,
quotesOut: quotes, quotesOut: quotes,
expectedOut: loopOuts, expectedOut: loopOuts,
existingOutSingle: singleLoopOut,
} }
c.autoloop(step) c.autoloop(step)
// Tick again with both of our swaps in progress. We haven't shifted our // Tick again with both of our swaps in progress. We haven't shifted our
@ -220,9 +234,10 @@ func TestAutoLoopEnabled(t *testing.T) {
} }
step = &autoloopStep{ step = &autoloopStep{
minAmt: 1, minAmt: 1,
maxAmt: amt + 1, maxAmt: amt + 1,
existingOut: existing, existingOut: existing,
existingOutSingle: singleLoopOut,
} }
c.autoloop(step) c.autoloop(step)
@ -278,11 +293,12 @@ func TestAutoLoopEnabled(t *testing.T) {
// still has balances which reflect that we need to swap), but nothing // still has balances which reflect that we need to swap), but nothing
// for channel 2, since it has had a failure. // for channel 2, since it has had a failure.
step = &autoloopStep{ step = &autoloopStep{
minAmt: 1, minAmt: 1,
maxAmt: amt + 1, maxAmt: amt + 1,
existingOut: existing, existingOut: existing,
quotesOut: quotes, quotesOut: quotes,
expectedOut: loopOuts, expectedOut: loopOuts,
existingOutSingle: singleLoopOut,
} }
c.autoloop(step) c.autoloop(step)
@ -299,10 +315,11 @@ func TestAutoLoopEnabled(t *testing.T) {
} }
step = &autoloopStep{ step = &autoloopStep{
minAmt: 1, minAmt: 1,
maxAmt: amt + 1, maxAmt: amt + 1,
existingOut: existing, existingOut: existing,
quotesOut: quotes, quotesOut: quotes,
existingOutSingle: singleLoopOut,
} }
c.autoloop(step) c.autoloop(step)
@ -446,13 +463,27 @@ func TestAutoloopAddress(t *testing.T) {
}, },
}, },
} }
singleLoopOut = &loopdb.LoopOut{
Loop: loopdb.Loop{
Events: []*loopdb.LoopEvent{
{
SwapStateData: loopdb.SwapStateData{
State: loopdb.StateHtlcPublished,
},
},
},
},
}
) )
step := &autoloopStep{ step := &autoloopStep{
minAmt: 1, minAmt: 1,
maxAmt: amt + 1, maxAmt: amt + 1,
quotesOut: quotes, quotesOut: quotes,
expectedOut: loopOuts, expectedOut: loopOuts,
existingOutSingle: singleLoopOut,
keepDestAddr: true,
} }
c.autoloop(step) c.autoloop(step)
@ -606,6 +637,18 @@ func TestCompositeRules(t *testing.T) {
}, },
}, },
} }
singleLoopOut = &loopdb.LoopOut{
Loop: loopdb.Loop{
Events: []*loopdb.LoopEvent{
{
SwapStateData: loopdb.SwapStateData{
State: loopdb.StateHtlcPublished,
},
},
},
},
}
) )
// Tick our autolooper with no existing swaps, we expect a loop out // Tick our autolooper with no existing swaps, we expect a loop out
@ -613,10 +656,11 @@ func TestCompositeRules(t *testing.T) {
// maximum to be greater than the swap amount for our peer swap (which // maximum to be greater than the swap amount for our peer swap (which
// is the larger of the two swaps). // is the larger of the two swaps).
step := &autoloopStep{ step := &autoloopStep{
minAmt: 1, minAmt: 1,
maxAmt: peerAmount + 1, maxAmt: peerAmount + 1,
quotesOut: quotes, quotesOut: quotes,
expectedOut: loopOuts, expectedOut: loopOuts,
existingOutSingle: singleLoopOut,
} }
c.autoloop(step) c.autoloop(step)
@ -928,6 +972,18 @@ func TestAutoloopBothTypes(t *testing.T) {
Label: labels.AutoloopLabel(swap.TypeIn), Label: labels.AutoloopLabel(swap.TypeIn),
Initiator: autoloopSwapInitiator, Initiator: autoloopSwapInitiator,
} }
singleLoopOut = &loopdb.LoopOut{
Loop: loopdb.Loop{
Events: []*loopdb.LoopEvent{
{
SwapStateData: loopdb.SwapStateData{
State: loopdb.StateHtlcPublished,
},
},
},
},
}
) )
step := &autoloopStep{ step := &autoloopStep{
@ -961,6 +1017,7 @@ func TestAutoloopBothTypes(t *testing.T) {
}, },
}, },
}, },
existingOutSingle: singleLoopOut,
} }
c.autoloop(step) c.autoloop(step)
c.stop() c.stop()

@ -2,7 +2,9 @@ package liquidity
import ( import (
"context" "context"
"reflect"
"testing" "testing"
"time"
"github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/btcutil"
"github.com/lightninglabs/lndclient" "github.com/lightninglabs/lndclient"
@ -11,8 +13,10 @@ import (
"github.com/lightninglabs/loop/swap" "github.com/lightninglabs/loop/swap"
"github.com/lightninglabs/loop/test" "github.com/lightninglabs/loop/test"
"github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/ticker" "github.com/lightningnetwork/lnd/ticker"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
type autoloopTestCtx struct { type autoloopTestCtx struct {
@ -45,9 +49,17 @@ type autoloopTestCtx struct {
// loopOuts is a channel that we get existing loop out swaps on. // loopOuts is a channel that we get existing loop out swaps on.
loopOuts chan []*loopdb.LoopOut loopOuts chan []*loopdb.LoopOut
// loopOutSingle is the single loop out returned from fetching a single
// swap from store.
loopOutSingle *loopdb.LoopOut
// loopIns is a channel that we get existing loop in swaps on. // loopIns is a channel that we get existing loop in swaps on.
loopIns chan []*loopdb.LoopIn loopIns chan []*loopdb.LoopIn
// loopInSingle is the single loop in returned from fetching a single
// swap from store.
loopInSingle *loopdb.LoopIn
// restrictions is a channel that we get swap restrictions on. // restrictions is a channel that we get swap restrictions on.
restrictions chan *Restrictions restrictions chan *Restrictions
@ -131,6 +143,9 @@ func newAutoloopTestCtx(t *testing.T, parameters Parameters,
ListLoopOut: func() ([]*loopdb.LoopOut, error) { ListLoopOut: func() ([]*loopdb.LoopOut, error) {
return <-testCtx.loopOuts, nil return <-testCtx.loopOuts, nil
}, },
GetLoopOut: func(hash lntypes.Hash) (*loopdb.LoopOut, error) {
return testCtx.loopOutSingle, nil
},
ListLoopIn: func() ([]*loopdb.LoopIn, error) { ListLoopIn: func() ([]*loopdb.LoopIn, error) {
return <-testCtx.loopIns, nil return <-testCtx.loopIns, nil
}, },
@ -188,6 +203,10 @@ func newAutoloopTestCtx(t *testing.T, parameters Parameters,
testCtx.manager = NewManager(cfg) testCtx.manager = NewManager(cfg)
err := testCtx.manager.setParameters(context.Background(), parameters) err := testCtx.manager.setParameters(context.Background(), parameters)
assert.NoError(t, err) assert.NoError(t, err)
// Override the payments check interval for the tests in order to not
// timeout.
testCtx.manager.params.CustomPaymentCheckInterval =
150 * time.Millisecond
<-done <-done
return testCtx return testCtx
} }
@ -241,14 +260,17 @@ type loopInRequestResp struct {
// autoloopStep contains all of the information to required to step // autoloopStep contains all of the information to required to step
// through an autoloop tick. // through an autoloop tick.
type autoloopStep struct { type autoloopStep struct {
minAmt btcutil.Amount minAmt btcutil.Amount
maxAmt btcutil.Amount maxAmt btcutil.Amount
existingOut []*loopdb.LoopOut existingOut []*loopdb.LoopOut
existingIn []*loopdb.LoopIn existingOutSingle *loopdb.LoopOut
quotesOut []quoteRequestResp existingIn []*loopdb.LoopIn
quotesIn []quoteInRequestResp existingInSingle *loopdb.LoopIn
expectedOut []loopOutRequestResp quotesOut []quoteRequestResp
expectedIn []loopInRequestResp quotesIn []quoteInRequestResp
expectedOut []loopOutRequestResp
expectedIn []loopInRequestResp
keepDestAddr bool
} }
// autoloop walks our test context through the process of triggering our // autoloop walks our test context through the process of triggering our
@ -269,6 +291,9 @@ func (c *autoloopTestCtx) autoloop(step *autoloopStep) {
c.loopOuts <- step.existingOut c.loopOuts <- step.existingOut
c.loopIns <- step.existingIn c.loopIns <- step.existingIn
c.loopOutSingle = step.existingOutSingle
c.loopInSingle = step.existingInSingle
// Assert that we query the server for a quote for each of our // Assert that we query the server for a quote for each of our
// recommended swaps. Note that this differs from our set of expected // recommended swaps. Note that this differs from our set of expected
// swaps because we may get quotes for suggested swaps but then just // swaps because we may get quotes for suggested swaps but then just
@ -299,25 +324,77 @@ func (c *autoloopTestCtx) autoloop(step *autoloopStep) {
c.quotes <- expected.quote c.quotes <- expected.quote
} }
// Assert that we dispatch the expected set of swaps. require.True(c.t, c.matchLoopOuts(step.expectedOut, step.keepDestAddr))
for _, expected := range step.expectedOut { require.True(c.t, c.matchLoopIns(step.expectedIn))
}
// matchLoopOuts checks that the actual loop out requests we got match the
// expected ones. The argument keepDestAddr is used to indicate whether we keep
// the actual loops destination address for the comparison. This is useful
// because we don't want to compare the destination address generated by the
// wallet mock. We want to compare the destination address when testing the
// autoloop DestAddr parameter for loop outs.
func (c *autoloopTestCtx) matchLoopOuts(swaps []loopOutRequestResp,
keepDestAddr bool) bool {
swapsCopy := make([]loopOutRequestResp, len(swaps))
copy(swapsCopy, swaps)
length := len(swapsCopy)
for i := 0; i < length; i++ {
actual := <-c.outRequest actual := <-c.outRequest
// Set our destination address to nil so that we do not need to if !keepDestAddr {
// provide the address that is obtained by the mock wallet kit.
if expected.request.DestAddr == nil {
actual.DestAddr = nil actual.DestAddr = nil
} }
assert.Equal(c.t, expected.request, actual) inner:
c.loopOut <- expected.response for index, swap := range swapsCopy {
equal := reflect.DeepEqual(swap.request, actual)
if equal {
c.loopOut <- swap.response
swapsCopy = append(
swapsCopy[:index],
swapsCopy[index+1:]...,
)
break inner
}
}
} }
for _, expected := range step.expectedIn { return len(swapsCopy) == 0
}
// matchLoopIns checks that the actual loop in requests we got match the
// expected ones.
func (c *autoloopTestCtx) matchLoopIns(
swaps []loopInRequestResp) bool {
swapsCopy := make([]loopInRequestResp, len(swaps))
copy(swapsCopy, swaps)
for i := 0; i < len(swapsCopy); i++ {
actual := <-c.inRequest actual := <-c.inRequest
assert.Equal(c.t, expected.request, actual) inner:
for i, swap := range swapsCopy {
equal := reflect.DeepEqual(swap.request, actual)
c.loopIn <- expected.response if equal {
c.loopIn <- swap.response
swapsCopy = append(
swapsCopy[:i], swapsCopy[i+1:]...,
)
break inner
}
}
} }
return len(swapsCopy) == 0
} }

@ -48,6 +48,7 @@ import (
"github.com/lightninglabs/loop/swap" "github.com/lightninglabs/loop/swap"
"github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/funding" "github.com/lightningnetwork/lnd/funding"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/route"
@ -62,6 +63,22 @@ const (
// a channel is part of a temporarily failed swap. // a channel is part of a temporarily failed swap.
defaultFailureBackoff = time.Hour * 24 defaultFailureBackoff = time.Hour * 24
// defaultAmountBackoff is the default backoff we apply to the amount
// of a loop out swap that failed the off-chain payments.
defaultAmountBackoff = float64(0.25)
// defaultAmountBackoffRetry is the default number of times we will
// perform an amount backoff to a loop out swap before we give up.
defaultAmountBackoffRetry = 5
// defaultSwapWaitTimeout is the default maximum amount of time we
// wait for a swap to reach a terminal state.
defaultSwapWaitTimeout = time.Hour * 24
// defaultPaymentCheckInterval is the default time that passes between
// checks for loop out payments status.
defaultPaymentCheckInterval = time.Second * 2
// defaultConfTarget is the default sweep target we use for loop outs. // defaultConfTarget is the default sweep target we use for loop outs.
// We get our inbound liquidity quickly using preimage push, so we can // We get our inbound liquidity quickly using preimage push, so we can
// use a long conf target without worrying about ux impact. // use a long conf target without worrying about ux impact.
@ -78,7 +95,7 @@ const (
// DefaultAutoloopTicker is the default amount of time between automated // DefaultAutoloopTicker is the default amount of time between automated
// swap checks. // swap checks.
DefaultAutoloopTicker = time.Minute * 10 DefaultAutoloopTicker = time.Minute * 20
// autoloopSwapInitiator is the value we send in the initiator field of // autoloopSwapInitiator is the value we send in the initiator field of
// a swap request when issuing an automatic swap. // a swap request when issuing an automatic swap.
@ -164,6 +181,10 @@ type Config struct {
// ListLoopOut returns all of the loop our swaps stored on disk. // ListLoopOut returns all of the loop our swaps stored on disk.
ListLoopOut func() ([]*loopdb.LoopOut, error) ListLoopOut func() ([]*loopdb.LoopOut, error)
// GetLoopOut returns a single loop out swap based on the provided swap
// hash.
GetLoopOut func(hash lntypes.Hash) (*loopdb.LoopOut, error)
// ListLoopIn returns all of the loop in swaps stored on disk. // ListLoopIn returns all of the loop in swaps stored on disk.
ListLoopIn func() ([]*loopdb.LoopIn, error) ListLoopIn func() ([]*loopdb.LoopIn, error)
@ -399,13 +420,10 @@ func (m *Manager) autoloop(ctx context.Context) error {
swap.DestAddr = m.params.DestAddr swap.DestAddr = m.params.DestAddr
} }
loopOut, err := m.cfg.LoopOut(ctx, &swap) go m.dispatchStickyLoopOut(
if err != nil { ctx, swap, defaultAmountBackoffRetry,
return err defaultAmountBackoff,
} )
log.Infof("loop out automatically dispatched: hash: %v, "+
"address: %v", loopOut.SwapHash, loopOut.HtlcAddress)
} }
for _, in := range suggestion.InSwaps { for _, in := range suggestion.InSwaps {
@ -1044,6 +1062,143 @@ func (m *Manager) refreshAutoloopBudget(ctx context.Context) {
} }
} }
// dispatchStickyLoopOut attempts to dispatch a loop out swap that will
// automatically retry its execution with an amount based backoff.
func (m *Manager) dispatchStickyLoopOut(ctx context.Context,
out loop.OutRequest, retryCount uint16, amountBackoff float64) {
for i := 0; i < int(retryCount); i++ {
// Dispatch the swap.
swap, err := m.cfg.LoopOut(ctx, &out)
if err != nil {
log.Errorf("unable to dispatch loop out, hash: %v, "+
"err: %v", swap.SwapHash, err)
}
log.Infof("loop out automatically dispatched: hash: %v, "+
"address: %v, amount %v", swap.SwapHash,
swap.HtlcAddress, out.Amount)
updates := make(chan *loopdb.SwapState, 1)
// Monitor the swap state and write the desired update to the
// update channel. We do not want to read all of the swap state
// updates, just the one that will help us assume the state of
// the off-chain payment.
go m.waitForSwapPayment(
ctx, swap.SwapHash, updates, defaultSwapWaitTimeout,
)
select {
case <-ctx.Done():
return
case update := <-updates:
if update == nil {
// If update is nil then no update occurred
// within the defined timeout period. It's
// better to return and not attempt a retry.
log.Debug(
"No payment update received for swap "+
"%v, skipping amount backoff",
swap.SwapHash,
)
return
}
if *update == loopdb.StateFailOffchainPayments {
// Save the old amount so we can log it.
oldAmt := out.Amount
// If we failed to pay the server, we will
// decrease the amount of the swap and try
// again.
out.Amount -= btcutil.Amount(
float64(out.Amount) * amountBackoff,
)
log.Infof("swap %v: amount backoff old amount="+
"%v, new amount=%v", swap.SwapHash,
oldAmt, out.Amount)
continue
} else {
// If the update channel did not return an
// off-chain payment failure we won't retry.
return
}
}
}
}
// waitForSwapPayment waits for a swap to progress beyond the stage of
// forwarding the payment to the server through the network. It returns the
// final update on the outcome through a channel.
func (m *Manager) waitForSwapPayment(ctx context.Context, swapHash lntypes.Hash,
updateChan chan *loopdb.SwapState, timeout time.Duration) {
startTime := time.Now()
var (
swap *loopdb.LoopOut
err error
interval time.Duration
)
if m.params.CustomPaymentCheckInterval != 0 {
interval = m.params.CustomPaymentCheckInterval
} else {
interval = defaultPaymentCheckInterval
}
for time.Since(startTime) < timeout {
select {
case <-ctx.Done():
return
case <-time.After(interval):
}
swap, err = m.cfg.GetLoopOut(swapHash)
if err != nil {
log.Errorf(
"Error getting swap with hash %x: %v", swapHash,
err,
)
continue
}
// If no update has occurred yet, continue in order to wait.
update := swap.LastUpdate()
if update == nil {
continue
}
// Write the update if the swap has reached a state the helps
// us determine whether the off-chain payment successfully
// reached the destination.
switch update.State {
case loopdb.StateFailInsufficientValue:
fallthrough
case loopdb.StateSuccess:
fallthrough
case loopdb.StateFailSweepTimeout:
fallthrough
case loopdb.StateFailTimeout:
fallthrough
case loopdb.StatePreimageRevealed:
fallthrough
case loopdb.StateFailOffchainPayments:
updateChan <- &update.State
return
}
}
// If no update occurred within the defined timeout we return an empty
// update to the channel, causing the sticky loop out to not retry
// anymore.
updateChan <- nil
}
// swapTraffic contains a summary of our current and previously failed swaps. // swapTraffic contains a summary of our current and previously failed swaps.
type swapTraffic struct { type swapTraffic struct {
ongoingLoopOut map[lnwire.ShortChannelID]bool ongoingLoopOut map[lnwire.ShortChannelID]bool

@ -87,6 +87,10 @@ type Parameters struct {
// ChannelRules are exclusively set to prevent overlap between peer // ChannelRules are exclusively set to prevent overlap between peer
// and channel rules map to avoid ambiguity. // and channel rules map to avoid ambiguity.
PeerRules map[route.Vertex]*SwapRule PeerRules map[route.Vertex]*SwapRule
// CustomPaymentCheckInterval is an optional custom interval to use when
// checking an autoloop loop out payments' payment status.
CustomPaymentCheckInterval time.Duration
} }
// String returns the string representation of our parameters. // String returns the string representation of our parameters.

@ -72,6 +72,7 @@ func getLiquidityManager(client *loop.Client) *liquidity.Manager {
LoopOutQuote: client.LoopOutQuote, LoopOutQuote: client.LoopOutQuote,
LoopInQuote: client.LoopInQuote, LoopInQuote: client.LoopInQuote,
ListLoopOut: client.Store.FetchLoopOutSwaps, ListLoopOut: client.Store.FetchLoopOutSwaps,
GetLoopOut: client.Store.FetchLoopOutSwap,
ListLoopIn: client.Store.FetchLoopInSwaps, ListLoopIn: client.Store.FetchLoopInSwaps,
MinimumConfirmations: minConfTarget, MinimumConfirmations: minConfTarget,
PutLiquidityParams: client.Store.PutLiquidityParams, PutLiquidityParams: client.Store.PutLiquidityParams,

@ -12,6 +12,9 @@ type SwapStore interface {
// FetchLoopOutSwaps returns all swaps currently in the store. // FetchLoopOutSwaps returns all swaps currently in the store.
FetchLoopOutSwaps() ([]*LoopOut, error) FetchLoopOutSwaps() ([]*LoopOut, error)
// FetchLoopOutSwap returns the loop out swap with the given hash.
FetchLoopOutSwap(hash lntypes.Hash) (*LoopOut, error)
// CreateLoopOut adds an initiated swap to the store. // CreateLoopOut adds an initiated swap to the store.
CreateLoopOut(hash lntypes.Hash, swap *LoopOutContract) error CreateLoopOut(hash lntypes.Hash, swap *LoopOutContract) error

@ -255,111 +255,12 @@ func (s *boltSwapStore) FetchLoopOutSwaps() ([]*LoopOut, error) {
return nil return nil
} }
// From the root bucket, we'll grab the next swap loop, err := s.fetchLoopOutSwap(rootBucket, swapHash)
// bucket for this swap from its swaphash.
swapBucket := rootBucket.Bucket(swapHash)
if swapBucket == nil {
return fmt.Errorf("swap bucket %x not found",
swapHash)
}
// With the main swap bucket obtained, we'll grab the
// raw swap contract bytes and decode it.
contractBytes := swapBucket.Get(contractKey)
if contractBytes == nil {
return errors.New("contract not found")
}
contract, err := deserializeLoopOutContract(
contractBytes, s.chainParams,
)
if err != nil { if err != nil {
return err return err
} }
// Get our label for this swap, if it is present. swaps = append(swaps, loop)
contract.Label = getLabel(swapBucket)
// Read the list of concatenated outgoing channel ids
// that form the outgoing set.
setBytes := swapBucket.Get(outgoingChanSetKey)
if outgoingChanSetKey != nil {
r := bytes.NewReader(setBytes)
readLoop:
for {
var chanID uint64
err := binary.Read(r, byteOrder, &chanID)
switch {
case err == io.EOF:
break readLoop
case err != nil:
return err
}
contract.OutgoingChanSet = append(
contract.OutgoingChanSet,
chanID,
)
}
}
// Set our default number of confirmations for the swap.
contract.HtlcConfirmations = DefaultLoopOutHtlcConfirmations
// If we have the number of confirmations stored for
// this swap, we overwrite our default with the stored
// value.
confBytes := swapBucket.Get(confirmationsKey)
if confBytes != nil {
r := bytes.NewReader(confBytes)
err := binary.Read(
r, byteOrder, &contract.HtlcConfirmations,
)
if err != nil {
return err
}
}
updates, err := deserializeUpdates(swapBucket)
if err != nil {
return err
}
// Try to unmarshal the protocol version for the swap.
// If the protocol version is not stored (which is
// the case for old clients), we'll assume the
// ProtocolVersionUnrecorded instead.
contract.ProtocolVersion, err =
UnmarshalProtocolVersion(
swapBucket.Get(protocolVersionKey),
)
if err != nil {
return err
}
// Try to unmarshal the key locator.
if contract.ProtocolVersion >= ProtocolVersionHtlcV3 {
contract.ClientKeyLocator, err = UnmarshalKeyLocator(
swapBucket.Get(keyLocatorKey),
)
if err != nil {
return err
}
}
loop := LoopOut{
Loop: Loop{
Events: updates,
},
Contract: contract,
}
loop.Hash, err = lntypes.MakeHash(swapHash)
if err != nil {
return err
}
swaps = append(swaps, &loop)
return nil return nil
}) })
@ -371,53 +272,33 @@ func (s *boltSwapStore) FetchLoopOutSwaps() ([]*LoopOut, error) {
return swaps, nil return swaps, nil
} }
// deserializeUpdates deserializes the list of swap updates that are stored as a // FetchLoopOutSwap returns the loop out swap with the given hash.
// key of the given bucket. //
func deserializeUpdates(swapBucket *bbolt.Bucket) ([]*LoopEvent, error) { // NOTE: Part of the loopdb.SwapStore interface.
// Once we have the raw swap, we'll also need to decode func (s *boltSwapStore) FetchLoopOutSwap(hash lntypes.Hash) (*LoopOut, error) {
// each of the past updates to the swap itself. var swap *LoopOut
stateBucket := swapBucket.Bucket(updatesBucketKey)
if stateBucket == nil {
return nil, errors.New("updates bucket not found")
}
// Deserialize and collect each swap update into our slice of swap
// events.
var updates []*LoopEvent
err := stateBucket.ForEach(func(k, v []byte) error {
updateBucket := stateBucket.Bucket(k)
if updateBucket == nil {
return fmt.Errorf("expected state sub-bucket for %x", k)
}
basicState := updateBucket.Get(basicStateKey) err := s.db.View(func(tx *bbolt.Tx) error {
if basicState == nil { // First, we'll grab our main loop out bucket key.
return errors.New("no basic state for update") rootBucket := tx.Bucket(loopOutBucketKey)
if rootBucket == nil {
return errors.New("bucket does not exist")
} }
event, err := deserializeLoopEvent(basicState) loop, err := s.fetchLoopOutSwap(rootBucket, hash[:])
if err != nil { if err != nil {
return err return err
} }
// Deserialize htlc tx hash if this updates contains one. swap = loop
htlcTxHashBytes := updateBucket.Get(htlcTxHashKey)
if htlcTxHashBytes != nil {
htlcTxHash, err := chainhash.NewHash(htlcTxHashBytes)
if err != nil {
return err
}
event.HtlcTxHash = htlcTxHash
}
updates = append(updates, event)
return nil return nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
return updates, nil return swap, nil
} }
// FetchLoopInSwaps returns all loop in swaps currently in the store. // FetchLoopInSwaps returns all loop in swaps currently in the store.
@ -442,71 +323,12 @@ func (s *boltSwapStore) FetchLoopInSwaps() ([]*LoopIn, error) {
return nil return nil
} }
// From the root bucket, we'll grab the next swap loop, err := s.fetchLoopInSwap(rootBucket, swapHash)
// bucket for this swap from its swaphash.
swapBucket := rootBucket.Bucket(swapHash)
if swapBucket == nil {
return fmt.Errorf("swap bucket %x not found",
swapHash)
}
// With the main swap bucket obtained, we'll grab the
// raw swap contract bytes and decode it.
contractBytes := swapBucket.Get(contractKey)
if contractBytes == nil {
return errors.New("contract not found")
}
contract, err := deserializeLoopInContract(
contractBytes,
)
if err != nil { if err != nil {
return err return err
} }
// Get our label for this swap, if it is present. swaps = append(swaps, loop)
contract.Label = getLabel(swapBucket)
updates, err := deserializeUpdates(swapBucket)
if err != nil {
return err
}
// Try to unmarshal the protocol version for the swap.
// If the protocol version is not stored (which is
// the case for old clients), we'll assume the
// ProtocolVersionUnrecorded instead.
contract.ProtocolVersion, err =
UnmarshalProtocolVersion(
swapBucket.Get(protocolVersionKey),
)
if err != nil {
return err
}
// Try to unmarshal the key locator.
if contract.ProtocolVersion >= ProtocolVersionHtlcV3 {
contract.ClientKeyLocator, err = UnmarshalKeyLocator(
swapBucket.Get(keyLocatorKey),
)
if err != nil {
return err
}
}
loop := LoopIn{
Loop: Loop{
Events: updates,
},
Contract: contract,
}
loop.Hash, err = lntypes.MakeHash(swapHash)
if err != nil {
return err
}
swaps = append(swaps, &loop)
return nil return nil
}) })
@ -824,3 +646,243 @@ func (s *boltSwapStore) FetchLiquidityParams() ([]byte, error) {
return params, err return params, err
} }
// fetchUpdates deserializes the list of swap updates that are stored as a
// key of the given bucket.
func fetchUpdates(swapBucket *bbolt.Bucket) ([]*LoopEvent, error) {
// Once we have the raw swap, we'll also need to decode
// each of the past updates to the swap itself.
stateBucket := swapBucket.Bucket(updatesBucketKey)
if stateBucket == nil {
return nil, errors.New("updates bucket not found")
}
// Deserialize and collect each swap update into our slice of swap
// events.
var updates []*LoopEvent
err := stateBucket.ForEach(func(k, v []byte) error {
updateBucket := stateBucket.Bucket(k)
if updateBucket == nil {
return fmt.Errorf("expected state sub-bucket for %x", k)
}
basicState := updateBucket.Get(basicStateKey)
if basicState == nil {
return errors.New("no basic state for update")
}
event, err := deserializeLoopEvent(basicState)
if err != nil {
return err
}
// Deserialize htlc tx hash if this updates contains one.
htlcTxHashBytes := updateBucket.Get(htlcTxHashKey)
if htlcTxHashBytes != nil {
htlcTxHash, err := chainhash.NewHash(htlcTxHashBytes)
if err != nil {
return err
}
event.HtlcTxHash = htlcTxHash
}
updates = append(updates, event)
return nil
})
if err != nil {
return nil, err
}
return updates, nil
}
// fetchLoopOutSwap fetches and deserializes the raw swap bytes into a LoopOut
// struct.
func (s *boltSwapStore) fetchLoopOutSwap(rootBucket *bbolt.Bucket,
swapHash []byte) (*LoopOut, error) {
// From the root bucket, we'll grab the next swap
// bucket for this swap from its swaphash.
swapBucket := rootBucket.Bucket(swapHash)
if swapBucket == nil {
return nil, fmt.Errorf("swap bucket %x not found",
swapHash)
}
hash, err := lntypes.MakeHash(swapHash)
if err != nil {
return nil, err
}
// With the main swap bucket obtained, we'll grab the
// raw swap contract bytes and decode it.
contractBytes := swapBucket.Get(contractKey)
if contractBytes == nil {
return nil, errors.New("contract not found")
}
contract, err := deserializeLoopOutContract(
contractBytes, s.chainParams,
)
if err != nil {
return nil, err
}
// Get our label for this swap, if it is present.
contract.Label = getLabel(swapBucket)
// Read the list of concatenated outgoing channel ids
// that form the outgoing set.
setBytes := swapBucket.Get(outgoingChanSetKey)
if outgoingChanSetKey != nil {
r := bytes.NewReader(setBytes)
readLoop:
for {
var chanID uint64
err := binary.Read(r, byteOrder, &chanID)
switch {
case err == io.EOF:
break readLoop
case err != nil:
return nil, err
}
contract.OutgoingChanSet = append(
contract.OutgoingChanSet,
chanID,
)
}
}
// Set our default number of confirmations for the swap.
contract.HtlcConfirmations = DefaultLoopOutHtlcConfirmations
// If we have the number of confirmations stored for
// this swap, we overwrite our default with the stored
// value.
confBytes := swapBucket.Get(confirmationsKey)
if confBytes != nil {
r := bytes.NewReader(confBytes)
err := binary.Read(
r, byteOrder, &contract.HtlcConfirmations,
)
if err != nil {
return nil, err
}
}
updates, err := fetchUpdates(swapBucket)
if err != nil {
return nil, err
}
// Try to unmarshal the protocol version for the swap.
// If the protocol version is not stored (which is
// the case for old clients), we'll assume the
// ProtocolVersionUnrecorded instead.
contract.ProtocolVersion, err =
UnmarshalProtocolVersion(
swapBucket.Get(protocolVersionKey),
)
if err != nil {
return nil, err
}
// Try to unmarshal the key locator.
if contract.ProtocolVersion >= ProtocolVersionHtlcV3 {
contract.ClientKeyLocator, err = UnmarshalKeyLocator(
swapBucket.Get(keyLocatorKey),
)
if err != nil {
return nil, err
}
}
loop := LoopOut{
Loop: Loop{
Events: updates,
},
Contract: contract,
}
loop.Hash, err = lntypes.MakeHash(hash[:])
if err != nil {
return nil, err
}
return &loop, nil
}
// fetchLoopInSwap fetches and deserializes the raw swap bytes into a LoopIn
// struct.
func (s *boltSwapStore) fetchLoopInSwap(rootBucket *bbolt.Bucket,
swapHash []byte) (*LoopIn, error) {
// From the root bucket, we'll grab the next swap
// bucket for this swap from its swaphash.
swapBucket := rootBucket.Bucket(swapHash)
if swapBucket == nil {
return nil, fmt.Errorf("swap bucket %x not found",
swapHash)
}
hash, err := lntypes.MakeHash(swapHash)
if err != nil {
return nil, err
}
// With the main swap bucket obtained, we'll grab the
// raw swap contract bytes and decode it.
contractBytes := swapBucket.Get(contractKey)
if contractBytes == nil {
return nil, errors.New("contract not found")
}
contract, err := deserializeLoopInContract(
contractBytes,
)
if err != nil {
return nil, err
}
// Get our label for this swap, if it is present.
contract.Label = getLabel(swapBucket)
updates, err := fetchUpdates(swapBucket)
if err != nil {
return nil, err
}
// Try to unmarshal the protocol version for the swap.
// If the protocol version is not stored (which is
// the case for old clients), we'll assume the
// ProtocolVersionUnrecorded instead.
contract.ProtocolVersion, err =
UnmarshalProtocolVersion(
swapBucket.Get(protocolVersionKey),
)
if err != nil {
return nil, err
}
// Try to unmarshal the key locator.
if contract.ProtocolVersion >= ProtocolVersionHtlcV3 {
contract.ClientKeyLocator, err = UnmarshalKeyLocator(
swapBucket.Get(keyLocatorKey),
)
if err != nil {
return nil, err
}
}
loop := LoopIn{
Loop: Loop{
Events: updates,
},
Contract: contract,
}
loop.Hash = hash
return &loop, nil
}

@ -5,7 +5,6 @@ import (
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"reflect"
"testing" "testing"
"time" "time"
@ -96,24 +95,20 @@ func TestLoopOutStore(t *testing.T) {
// swap store for specific swap parameters. // swap store for specific swap parameters.
func testLoopOutStore(t *testing.T, pendingSwap *LoopOutContract) { func testLoopOutStore(t *testing.T, pendingSwap *LoopOutContract) {
tempDirName, err := ioutil.TempDir("", "clientstore") tempDirName, err := ioutil.TempDir("", "clientstore")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
defer os.RemoveAll(tempDirName) defer os.RemoveAll(tempDirName)
store, err := NewBoltSwapStore(tempDirName, &chaincfg.MainNetParams) store, err := NewBoltSwapStore(tempDirName, &chaincfg.MainNetParams)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
// First, verify that an empty database has no active swaps. // First, verify that an empty database has no active swaps.
swaps, err := store.FetchLoopOutSwaps() swaps, err := store.FetchLoopOutSwaps()
if err != nil {
t.Fatal(err) require.NoError(t, err)
} require.Empty(t, swaps)
if len(swaps) != 0 {
t.Fatal("expected empty store") hash := pendingSwap.Preimage.Hash()
}
// checkSwap is a test helper function that'll assert the state of a // checkSwap is a test helper function that'll assert the state of a
// swap. // swap.
@ -121,43 +116,37 @@ func testLoopOutStore(t *testing.T, pendingSwap *LoopOutContract) {
t.Helper() t.Helper()
swaps, err := store.FetchLoopOutSwaps() swaps, err := store.FetchLoopOutSwaps()
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
if len(swaps) != 1 { require.Len(t, swaps, 1)
t.Fatal("expected pending swap in store")
}
swap := swaps[0].Contract swap, err := store.FetchLoopOutSwap(hash)
if !reflect.DeepEqual(swap, pendingSwap) { require.NoError(t, err)
t.Fatal("invalid pending swap data")
}
if swaps[0].State().State != expectedState { require.Equal(t, hash, swap.Hash)
t.Fatalf("expected state %v, but got %v", require.Equal(t, hash, swaps[0].Hash)
expectedState, swaps[0].State(),
) swapContract := swap.Contract
}
require.Equal(t, swapContract, pendingSwap)
require.Equal(t, expectedState, swap.State().State)
if expectedState == StatePreimageRevealed { if expectedState == StatePreimageRevealed {
require.NotNil(t, swaps[0].State().HtlcTxHash) require.NotNil(t, swap.State().HtlcTxHash)
} }
} }
hash := pendingSwap.Preimage.Hash()
// If we create a new swap, then it should show up as being initialized // If we create a new swap, then it should show up as being initialized
// right after. // right after.
if err := store.CreateLoopOut(hash, pendingSwap); err != nil { err = store.CreateLoopOut(hash, pendingSwap)
t.Fatal(err) require.NoError(t, err)
}
checkSwap(StateInitiated) checkSwap(StateInitiated)
// Trying to make the same swap again should result in an error. // Trying to make the same swap again should result in an error.
if err := store.CreateLoopOut(hash, pendingSwap); err == nil { err = store.CreateLoopOut(hash, pendingSwap)
t.Fatal("expected error on storing duplicate") require.Error(t, err)
}
checkSwap(StateInitiated) checkSwap(StateInitiated)
// Next, we'll update to the next state of the pre-image being // Next, we'll update to the next state of the pre-image being
@ -169,9 +158,8 @@ func testLoopOutStore(t *testing.T, pendingSwap *LoopOutContract) {
HtlcTxHash: &chainhash.Hash{1, 6, 2}, HtlcTxHash: &chainhash.Hash{1, 6, 2},
}, },
) )
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
checkSwap(StatePreimageRevealed) checkSwap(StatePreimageRevealed)
// Next, we'll update to the final state to ensure that the state is // Next, we'll update to the final state to ensure that the state is
@ -182,21 +170,17 @@ func testLoopOutStore(t *testing.T, pendingSwap *LoopOutContract) {
State: StateFailInsufficientValue, State: StateFailInsufficientValue,
}, },
) )
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
checkSwap(StateFailInsufficientValue) checkSwap(StateFailInsufficientValue)
if err := store.Close(); err != nil { err = store.Close()
t.Fatal(err) require.NoError(t, err)
}
// If we re-open the same store, then the state of the current swap // If we re-open the same store, then the state of the current swap
// should be the same. // should be the same.
store, err = NewBoltSwapStore(tempDirName, &chaincfg.MainNetParams) store, err = NewBoltSwapStore(tempDirName, &chaincfg.MainNetParams)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
checkSwap(StateFailInsufficientValue) checkSwap(StateFailInsufficientValue)
} }
@ -242,24 +226,18 @@ func TestLoopInStore(t *testing.T) {
func testLoopInStore(t *testing.T, pendingSwap LoopInContract) { func testLoopInStore(t *testing.T, pendingSwap LoopInContract) {
tempDirName, err := ioutil.TempDir("", "clientstore") tempDirName, err := ioutil.TempDir("", "clientstore")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
defer os.RemoveAll(tempDirName) defer os.RemoveAll(tempDirName)
store, err := NewBoltSwapStore(tempDirName, &chaincfg.MainNetParams) store, err := NewBoltSwapStore(tempDirName, &chaincfg.MainNetParams)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
// First, verify that an empty database has no active swaps. // First, verify that an empty database has no active swaps.
swaps, err := store.FetchLoopInSwaps() swaps, err := store.FetchLoopInSwaps()
if err != nil { require.NoError(t, err)
t.Fatal(err) require.Empty(t, swaps)
}
if len(swaps) != 0 { hash := sha256.Sum256(testPreimage[:])
t.Fatal("expected empty store")
}
// checkSwap is a test helper function that'll assert the state of a // checkSwap is a test helper function that'll assert the state of a
// swap. // swap.
@ -267,39 +245,27 @@ func testLoopInStore(t *testing.T, pendingSwap LoopInContract) {
t.Helper() t.Helper()
swaps, err := store.FetchLoopInSwaps() swaps, err := store.FetchLoopInSwaps()
if err != nil { require.NoError(t, err)
t.Fatal(err) require.Len(t, swaps, 1)
}
if len(swaps) != 1 {
t.Fatal("expected pending swap in store")
}
swap := swaps[0].Contract swap := swaps[0].Contract
if !reflect.DeepEqual(swap, &pendingSwap) {
t.Fatal("invalid pending swap data")
}
if swaps[0].State().State != expectedState { require.Equal(t, swap, &pendingSwap)
t.Fatalf("expected state %v, but got %v",
expectedState, swaps[0].State(),
)
}
}
hash := sha256.Sum256(testPreimage[:]) require.Equal(t, swaps[0].State().State, expectedState)
}
// If we create a new swap, then it should show up as being initialized // If we create a new swap, then it should show up as being initialized
// right after. // right after.
if err := store.CreateLoopIn(hash, &pendingSwap); err != nil { err = store.CreateLoopIn(hash, &pendingSwap)
t.Fatal(err) require.NoError(t, err)
}
checkSwap(StateInitiated) checkSwap(StateInitiated)
// Trying to make the same swap again should result in an error. // Trying to make the same swap again should result in an error.
if err := store.CreateLoopIn(hash, &pendingSwap); err == nil { err = store.CreateLoopIn(hash, &pendingSwap)
t.Fatal("expected error on storing duplicate") require.Error(t, err)
}
checkSwap(StateInitiated) checkSwap(StateInitiated)
// Next, we'll update to the next state of the pre-image being // Next, we'll update to the next state of the pre-image being
@ -310,9 +276,8 @@ func testLoopInStore(t *testing.T, pendingSwap LoopInContract) {
State: StatePreimageRevealed, State: StatePreimageRevealed,
}, },
) )
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
checkSwap(StatePreimageRevealed) checkSwap(StatePreimageRevealed)
// Next, we'll update to the final state to ensure that the state is // Next, we'll update to the final state to ensure that the state is
@ -323,21 +288,17 @@ func testLoopInStore(t *testing.T, pendingSwap LoopInContract) {
State: StateFailInsufficientValue, State: StateFailInsufficientValue,
}, },
) )
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
checkSwap(StateFailInsufficientValue) checkSwap(StateFailInsufficientValue)
if err := store.Close(); err != nil { err = store.Close()
t.Fatal(err) require.NoError(t, err)
}
// If we re-open the same store, then the state of the current swap // If we re-open the same store, then the state of the current swap
// should be the same. // should be the same.
store, err = NewBoltSwapStore(tempDirName, &chaincfg.MainNetParams) store, err = NewBoltSwapStore(tempDirName, &chaincfg.MainNetParams)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
checkSwap(StateFailInsufficientValue) checkSwap(StateFailInsufficientValue)
} }
@ -467,9 +428,8 @@ func TestLegacyOutgoingChannel(t *testing.T) {
// Assert that the outgoing channel is read properly. // Assert that the outgoing channel is read properly.
expectedChannelSet := ChannelSet{5} expectedChannelSet := ChannelSet{5}
if !reflect.DeepEqual(swaps[0].Contract.OutgoingChanSet, expectedChannelSet) {
t.Fatal("invalid outgoing channel") require.Equal(t, expectedChannelSet, swaps[0].Contract.OutgoingChanSet)
}
} }
// TestLiquidityParams checks that reading and writing to liquidty bucket are // TestLiquidityParams checks that reading and writing to liquidty bucket are

@ -70,6 +70,36 @@ func (s *storeMock) FetchLoopOutSwaps() ([]*loopdb.LoopOut, error) {
return result, nil return result, nil
} }
// FetchLoopOutSwaps returns all swaps currently in the store.
//
// NOTE: Part of the loopdb.SwapStore interface.
func (s *storeMock) FetchLoopOutSwap(
hash lntypes.Hash) (*loopdb.LoopOut, error) {
contract, ok := s.loopOutSwaps[hash]
if !ok {
return nil, errors.New("swap not found")
}
updates := s.loopOutUpdates[hash]
events := make([]*loopdb.LoopEvent, len(updates))
for i, u := range updates {
events[i] = &loopdb.LoopEvent{
SwapStateData: u,
}
}
swap := &loopdb.LoopOut{
Loop: loopdb.Loop{
Hash: hash,
Events: events,
},
Contract: contract,
}
return swap, nil
}
// CreateLoopOut adds an initiated swap to the store. // CreateLoopOut adds an initiated swap to the store.
// //
// NOTE: Part of the loopdb.SwapStore interface. // NOTE: Part of the loopdb.SwapStore interface.

Loading…
Cancel
Save