diff --git a/client.go b/client.go index a58a2d9..7f7db0c 100644 --- a/client.go +++ b/client.go @@ -129,6 +129,7 @@ func NewClient(dbDir string, cfg *ClientConfig) (*Client, func(), error) { CreateExpiryTimer: func(d time.Duration) <-chan time.Time { return time.NewTimer(d).C }, + LoopOutMaxParts: cfg.LoopOutMaxParts, } sweeper := &sweep.Sweeper{ diff --git a/config.go b/config.go index cc69839..198b591 100644 --- a/config.go +++ b/config.go @@ -15,4 +15,5 @@ type clientConfig struct { Store loopdb.SwapStore LsatStore lsat.Store CreateExpiryTimer func(expiry time.Duration) <-chan time.Time + LoopOutMaxParts uint32 } diff --git a/loopd/swapclient_server.go b/loopd/swapclient_server.go index 4a6b9ea..3933845 100644 --- a/loopd/swapclient_server.go +++ b/loopd/swapclient_server.go @@ -44,6 +44,13 @@ var ( // errConfTargetTooLow is returned when the chosen confirmation target // is below the allowed minimum. errConfTargetTooLow = errors.New("confirmation target too low") + + // errBalanceTooLow is returned when the loop out amount can't be + // satisfied given total balance of the selection of channels to loop + // out on. + errBalanceTooLow = errors.New( + "channel balance too low for loop out amount", + ) ) // swapClientServer implements the grpc service exposed by loopd. @@ -89,7 +96,8 @@ func (s *swapClientServer) LoopOut(ctx context.Context, } sweepConfTarget, err := validateLoopOutRequest( - s.lnd.ChainParams, in.SweepConfTarget, sweepAddr, in.Label, + ctx, s.lnd.Client, s.lnd.ChainParams, in, sweepAddr, + s.impl.LoopOutMaxParts, ) if err != nil { return nil, err @@ -981,9 +989,12 @@ func validateLoopInRequest(htlcConfTarget int32, external bool) (int32, error) { } // validateLoopOutRequest validates the confirmation target, destination -// address and label of the loop out request. -func validateLoopOutRequest(chainParams *chaincfg.Params, confTarget int32, - sweepAddr btcutil.Address, label string) (int32, error) { +// address and label of the loop out request. It also checks that the requested +// loop amount is valid given the available balance. +func validateLoopOutRequest(ctx context.Context, lnd lndclient.LightningClient, + chainParams *chaincfg.Params, req *looprpc.LoopOutRequest, + sweepAddr btcutil.Address, maxParts uint32) (int32, error) { + // Check that the provided destination address has the correct format // for the active network. if !sweepAddr.IsForNet(chainParams) { @@ -992,9 +1003,101 @@ func validateLoopOutRequest(chainParams *chaincfg.Params, confTarget int32, } // Check that the label is valid. - if err := labels.Validate(label); err != nil { + if err := labels.Validate(req.Label); err != nil { + return 0, err + } + + channels, err := lnd.ListChannels(ctx) + if err != nil { return 0, err } - return validateConfTarget(confTarget, loop.DefaultSweepConfTarget) + unlimitedChannels := len(req.OutgoingChanSet) == 0 + outgoingChanSetMap := make(map[uint64]bool) + for _, chanID := range req.OutgoingChanSet { + outgoingChanSetMap[chanID] = true + } + + var activeChannelSet []lndclient.ChannelInfo + for _, c := range channels { + // Don't bother looking at inactive channels. + if !c.Active { + continue + } + + // If no outgoing channel set was specified then all active + // channels are considered. However, if a channel set was + // specified then only the specified channels are considered. + if unlimitedChannels || outgoingChanSetMap[c.ChannelID] { + activeChannelSet = append(activeChannelSet, c) + } + } + + // Determine if the loop out request is theoretically possible given + // the amount requested, the maximum possible routing fees, + // the available channel set and the fact that equal splitting is + // used for MPP. + requiredBalance := btcutil.Amount(req.Amt + req.MaxSwapRoutingFee) + isRoutable, _ := hasBandwidth(activeChannelSet, requiredBalance, + int(maxParts)) + if !isRoutable { + return 0, fmt.Errorf("%w: Requested swap amount of %d "+ + "sats along with the maximum routing fee of %d sats "+ + "is more than what can be routed given current state "+ + "of the channel set", errBalanceTooLow, req.Amt, + req.MaxSwapRoutingFee) + } + + return validateConfTarget( + req.SweepConfTarget, loop.DefaultSweepConfTarget, + ) +} + +// hasBandwidth simulates the MPP splitting logic that will be used by LND when +// attempting to route the payment. This function is used to evaluate if a +// payment will be routable given the splitting logic used by LND. +// It returns true if the amount is routable given the channel set and the +// maximum number of shards allowed. If the amount is routable then the number +// of shards used is also returned. This function makes an assumption that the +// minimum loop amount divided by max parts will not be less than the minimum +// shard amount. If the MPP logic changes, then this function should be updated. +func hasBandwidth(channels []lndclient.ChannelInfo, amt btcutil.Amount, + maxParts int) (bool, int) { + + scratch := make([]btcutil.Amount, len(channels)) + var totalBandwidth btcutil.Amount + for i, channel := range channels { + scratch[i] = channel.LocalBalance + totalBandwidth += channel.LocalBalance + } + + if totalBandwidth < amt { + return false, 0 + } + + split := amt + for shard := 0; shard <= maxParts; { + paid := false + for i := 0; i < len(scratch); i++ { + if scratch[i] >= split { + scratch[i] -= split + amt -= split + paid = true + shard++ + break + } + } + + if amt == 0 { + return true, shard + } + + if !paid { + split /= 2 + } else { + split = amt + } + } + + return false, 0 } diff --git a/loopd/swapclient_server_test.go b/loopd/swapclient_server_test.go index 47bf141..ea5a89d 100644 --- a/loopd/swapclient_server_test.go +++ b/loopd/swapclient_server_test.go @@ -1,13 +1,19 @@ package loopd import ( + "context" "errors" "testing" "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcutil" + "github.com/lightninglabs/lndclient" "github.com/lightninglabs/loop" "github.com/lightninglabs/loop/labels" + "github.com/lightninglabs/loop/looprpc" + mock_lnd "github.com/lightninglabs/loop/test" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" "github.com/stretchr/testify/require" ) @@ -19,6 +25,50 @@ var ( mainnetAddr, _ = btcutil.NewAddressScriptHash( []byte{123}, &chaincfg.MainNetParams, ) + + chanID1 = lnwire.NewShortChanIDFromInt(1) + chanID2 = lnwire.NewShortChanIDFromInt(2) + chanID3 = lnwire.NewShortChanIDFromInt(3) + chanID4 = lnwire.NewShortChanIDFromInt(4) + + peer1 = route.Vertex{1} + peer2 = route.Vertex{2} + + channel1 = lndclient.ChannelInfo{ + Active: false, + ChannelID: chanID1.ToUint64(), + PubKeyBytes: peer1, + LocalBalance: 10000, + RemoteBalance: 0, + Capacity: 10000, + } + + channel2 = lndclient.ChannelInfo{ + Active: true, + ChannelID: chanID2.ToUint64(), + PubKeyBytes: peer2, + LocalBalance: 10000, + RemoteBalance: 0, + Capacity: 10000, + } + + channel3 = lndclient.ChannelInfo{ + Active: true, + ChannelID: chanID3.ToUint64(), + PubKeyBytes: peer2, + LocalBalance: 10000, + RemoteBalance: 0, + Capacity: 10000, + } + + channel4 = lndclient.ChannelInfo{ + Active: true, + ChannelID: chanID4.ToUint64(), + PubKeyBytes: peer2, + LocalBalance: 1000, + RemoteBalance: 0, + Capacity: 1000, + } ) // TestValidateConfTarget tests all failure and success cases for our conf @@ -162,77 +212,239 @@ func TestValidateLoopInRequest(t *testing.T) { // TestValidateLoopOutRequest tests validation of loop out requests. func TestValidateLoopOutRequest(t *testing.T) { tests := []struct { - name string - chain chaincfg.Params - confTarget int32 - destAddr btcutil.Address - label string - err error - expectedTarget int32 + name string + chain chaincfg.Params + confTarget int32 + destAddr btcutil.Address + label string + channels []lndclient.ChannelInfo + outgoingChanSet []uint64 + amount int64 + maxRoutingFee int64 + maxParts uint32 + err error + expectedTarget int32 }{ { - name: "mainnet address with mainnet backend", - chain: chaincfg.MainNetParams, - destAddr: mainnetAddr, - label: "label ok", - confTarget: 2, + name: "mainnet address with mainnet backend", + chain: chaincfg.MainNetParams, + destAddr: mainnetAddr, + label: "label ok", + confTarget: 2, + channels: []lndclient.ChannelInfo{ + channel2, + }, + amount: 10000, + maxParts: 5, err: nil, expectedTarget: 2, }, { - name: "mainnet address with testnet backend", - chain: chaincfg.TestNet3Params, - destAddr: mainnetAddr, - label: "label ok", - confTarget: 2, + name: "mainnet address with testnet backend", + chain: chaincfg.TestNet3Params, + destAddr: mainnetAddr, + label: "label ok", + confTarget: 2, + channels: []lndclient.ChannelInfo{ + channel2, + }, + amount: 10000, + maxParts: 5, err: errIncorrectChain, expectedTarget: 0, }, { - name: "testnet address with testnet backend", - chain: chaincfg.TestNet3Params, - destAddr: testnetAddr, - label: "label ok", - confTarget: 2, + name: "testnet address with testnet backend", + chain: chaincfg.TestNet3Params, + destAddr: testnetAddr, + label: "label ok", + confTarget: 2, + channels: []lndclient.ChannelInfo{ + channel2, + }, + amount: 10000, + maxParts: 5, err: nil, expectedTarget: 2, }, { - name: "testnet address with mainnet backend", - chain: chaincfg.MainNetParams, - destAddr: testnetAddr, - label: "label ok", - confTarget: 2, + name: "testnet address with mainnet backend", + chain: chaincfg.MainNetParams, + destAddr: testnetAddr, + label: "label ok", + confTarget: 2, + channels: []lndclient.ChannelInfo{ + channel2, + }, + amount: 10000, + maxParts: 5, err: errIncorrectChain, expectedTarget: 0, }, { - name: "invalid label", - chain: chaincfg.MainNetParams, - destAddr: mainnetAddr, - label: labels.Reserved, - confTarget: 2, + name: "invalid label", + chain: chaincfg.MainNetParams, + destAddr: mainnetAddr, + label: labels.Reserved, + confTarget: 2, + channels: []lndclient.ChannelInfo{ + channel2, + }, + amount: 10000, + maxParts: 5, err: labels.ErrReservedPrefix, expectedTarget: 0, }, { - name: "invalid conf target", - chain: chaincfg.MainNetParams, - destAddr: mainnetAddr, - label: "label ok", - confTarget: 1, + name: "invalid conf target", + chain: chaincfg.MainNetParams, + destAddr: mainnetAddr, + label: "label ok", + confTarget: 1, + channels: []lndclient.ChannelInfo{ + channel2, + }, + amount: 10000, + maxParts: 5, err: errConfTargetTooLow, expectedTarget: 0, }, { - name: "default conf target", - chain: chaincfg.MainNetParams, - destAddr: mainnetAddr, - label: "label ok", - confTarget: 0, + name: "default conf target", + chain: chaincfg.MainNetParams, + destAddr: mainnetAddr, + label: "label ok", + confTarget: 0, + channels: []lndclient.ChannelInfo{ + channel2, + }, + amount: 10000, + maxParts: 5, err: nil, expectedTarget: 9, }, + { + name: "valid amount for default channel set", + chain: chaincfg.MainNetParams, + destAddr: mainnetAddr, + label: "label ok", + confTarget: 2, + channels: []lndclient.ChannelInfo{ + channel1, channel2, channel3, + }, + amount: 20000, + maxParts: 5, + err: nil, + expectedTarget: 2, + }, + { + name: "invalid amount for default channel set", + chain: chaincfg.MainNetParams, + destAddr: mainnetAddr, + label: "label ok", + confTarget: 2, + channels: []lndclient.ChannelInfo{ + channel1, channel2, channel3, + }, + amount: 25000, + maxParts: 5, + err: errBalanceTooLow, + expectedTarget: 0, + }, + { + name: "inactive channel in outgoing channel set", + chain: chaincfg.MainNetParams, + destAddr: mainnetAddr, + label: "label ok", + confTarget: 2, + channels: []lndclient.ChannelInfo{ + channel1, channel2, channel3, + }, + outgoingChanSet: []uint64{ + chanID1.ToUint64(), + }, + amount: 1000, + maxParts: 5, + err: errBalanceTooLow, + expectedTarget: 0, + }, + { + name: "outgoing channel set balance is enough", + chain: chaincfg.MainNetParams, + destAddr: mainnetAddr, + label: "label ok", + confTarget: 2, + channels: []lndclient.ChannelInfo{ + channel1, channel2, channel3, + }, + outgoingChanSet: []uint64{ + chanID2.ToUint64(), + }, + amount: 1000, + maxParts: 5, + err: nil, + expectedTarget: 2, + }, + { + name: "outgoing channel set balance not sufficient", + chain: chaincfg.MainNetParams, + destAddr: mainnetAddr, + label: "label ok", + confTarget: 2, + channels: []lndclient.ChannelInfo{ + channel1, channel2, channel3, + }, + outgoingChanSet: []uint64{ + chanID2.ToUint64(), + }, + amount: 20000, + maxParts: 5, + err: errBalanceTooLow, + expectedTarget: 0, + }, + { + name: "amount with routing fee too high", + chain: chaincfg.MainNetParams, + destAddr: mainnetAddr, + label: "label ok", + confTarget: 2, + channels: []lndclient.ChannelInfo{ + channel2, + }, + amount: 10000, + maxRoutingFee: 100, + maxParts: 5, + err: errBalanceTooLow, + expectedTarget: 0, + }, + { + name: "can split between channels", + chain: chaincfg.MainNetParams, + destAddr: mainnetAddr, + label: "label ok", + confTarget: 2, + channels: []lndclient.ChannelInfo{ + channel2, channel4, + }, + amount: 11000, + maxParts: 16, + err: nil, + expectedTarget: 2, + }, + { + name: "can't split between channels", + chain: chaincfg.MainNetParams, + destAddr: mainnetAddr, + label: "label ok", + confTarget: 2, + channels: []lndclient.ChannelInfo{ + channel2, channel4, + }, + amount: 11000, + maxParts: 5, + err: errBalanceTooLow, + expectedTarget: 0, + }, } for _, test := range tests { @@ -240,13 +452,119 @@ func TestValidateLoopOutRequest(t *testing.T) { t.Run(test.name, func(t *testing.T) { t.Parallel() + ctx := context.Background() + + lnd := mock_lnd.NewMockLnd() + lnd.Channels = test.channels + + req := &looprpc.LoopOutRequest{ + Amt: test.amount, + MaxSwapRoutingFee: test.maxRoutingFee, + OutgoingChanSet: test.outgoingChanSet, + Label: test.label, + SweepConfTarget: test.confTarget, + } conf, err := validateLoopOutRequest( - &test.chain, test.confTarget, test.destAddr, - test.label, + ctx, lnd.Client, &test.chain, req, + test.destAddr, test.maxParts, ) require.True(t, errors.Is(err, test.err)) require.Equal(t, test.expectedTarget, conf) }) } } + +// TestHasBandwidth tests that the hasBandwidth function correctly simulates +// the MPP logic used by LND. +func TestHasBandwidth(t *testing.T) { + tests := []struct { + name string + channels []lndclient.ChannelInfo + maxParts int + amt btcutil.Amount + expectedRes bool + expectedShards int + }{ + { + name: "can route due to high number of parts", + channels: []lndclient.ChannelInfo{ + { + LocalBalance: 100, + }, + { + LocalBalance: 10, + }, + }, + maxParts: 11, + amt: 110, + expectedRes: true, + expectedShards: 8, + }, + { + name: "can't route due to low number of parts", + channels: []lndclient.ChannelInfo{ + { + LocalBalance: 100, + }, + { + LocalBalance: 10, + }, + }, + maxParts: 5, + amt: 110, + expectedRes: false, + }, + { + name: "can route", + channels: []lndclient.ChannelInfo{ + { + LocalBalance: 1000, + }, + { + LocalBalance: 1000, + }, + }, + maxParts: 5, + amt: 2000, + expectedRes: true, + expectedShards: 2, + }, + { + name: "can route", + channels: []lndclient.ChannelInfo{ + { + LocalBalance: 100, + }, + { + LocalBalance: 100, + }, + { + LocalBalance: 100, + }, + }, + maxParts: 10, + amt: 300, + expectedRes: true, + expectedShards: 10, + }, + { + name: "can't route due to empty channel set", + maxParts: 10, + amt: 300, + expectedRes: false, + expectedShards: 0, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + res, shards := hasBandwidth(test.channels, test.amt, + test.maxParts) + require.Equal(t, test.expectedRes, res) + require.Equal(t, test.expectedShards, shards) + }) + } +}