diff --git a/client.go b/client.go index 9b8161f..fcf0d48 100644 --- a/client.go +++ b/client.go @@ -359,9 +359,8 @@ func (s *Client) resumeSwaps(ctx context.Context, func (s *Client) LoopOut(globalCtx context.Context, request *OutRequest) (*lntypes.Hash, btcutil.Address, error) { - log.Infof("LoopOut %v to %v (channel: %v)", - request.Amount, request.DestAddr, - request.LoopOutChannel, + log.Infof("LoopOut %v to %v (channels: %v)", + request.Amount, request.DestAddr, request.OutgoingChanSet, ) if err := s.waitForInitialized(globalCtx); err != nil { diff --git a/interface.go b/interface.go index f702a45..287832d 100644 --- a/interface.go +++ b/interface.go @@ -64,9 +64,9 @@ type OutRequest struct { // client sweep tx. SweepConfTarget int32 - // LoopOutChannel optionally specifies the short channel id of the - // channel to loop out. - LoopOutChannel *uint64 + // OutgoingChanSet optionally specifies the short channel ids of the + // channels that may be used to loop out. + OutgoingChanSet loopdb.ChannelSet // SwapPublicationDeadline can be set by the client to allow the server // delaying publication of the swap HTLC to save on chain fees. diff --git a/loopd/swapclient_server.go b/loopd/swapclient_server.go index 4f9d9a1..c59e714 100644 --- a/loopd/swapclient_server.go +++ b/loopd/swapclient_server.go @@ -90,7 +90,7 @@ func (s *swapClientServer) LoopOut(ctx context.Context, ), } if in.LoopOutChannel != 0 { - req.LoopOutChannel = &in.LoopOutChannel + req.OutgoingChanSet = loopdb.ChannelSet{in.LoopOutChannel} } hash, htlc, err := s.impl.LoopOut(ctx, req) if err != nil { diff --git a/loopd/view.go b/loopd/view.go index 433158c..aa2d250 100644 --- a/loopd/view.go +++ b/loopd/view.go @@ -2,7 +2,6 @@ package loopd import ( "fmt" - "strconv" "github.com/btcsuite/btcd/chaincfg" "github.com/lightninglabs/loop" @@ -64,13 +63,8 @@ func viewOut(swapClient *loop.Client, chainParams *chaincfg.Params) error { fmt.Printf(" Preimage: %v\n", s.Contract.Preimage) fmt.Printf(" Htlc address: %v\n", htlc.Address) - unchargeChannel := "any" - if s.Contract.UnchargeChannel != nil { - unchargeChannel = strconv.FormatUint( - *s.Contract.UnchargeChannel, 10, - ) - } - fmt.Printf(" Uncharge channel: %v\n", unchargeChannel) + fmt.Printf(" Uncharge channels: %v\n", + s.Contract.OutgoingChanSet) fmt.Printf(" Dest: %v\n", s.Contract.DestAddr) fmt.Printf(" Amt: %v, Expiry: %v\n", s.Contract.AmountRequested, s.Contract.CltvExpiry, diff --git a/loopdb/loopout.go b/loopdb/loopout.go index 3df0654..d24a57e 100644 --- a/loopdb/loopout.go +++ b/loopdb/loopout.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "fmt" + "strconv" + "strings" "time" "github.com/btcsuite/btcd/chaincfg" @@ -34,9 +36,9 @@ type LoopOutContract struct { // client sweep tx. SweepConfTarget int32 - // TargetChannel is the channel to loop out. If zero, any channel may - // be used. - UnchargeChannel *uint64 + // OutgoingChanSet is the set of short ids of channels that may be used. + // If empty, any channel may be used. + OutgoingChanSet ChannelSet // PrepayInvoice is the invoice that the client should pay to the // server that will be returned if the swap is complete. @@ -53,6 +55,34 @@ type LoopOutContract struct { SwapPublicationDeadline time.Time } +// ChannelSet stores a set of channels. +type ChannelSet []uint64 + +// String returns the human-readable representation of a channel set. +func (c ChannelSet) String() string { + channelStrings := make([]string, len(c)) + for i, chanID := range c { + channelStrings[i] = strconv.FormatUint(chanID, 10) + } + return strings.Join(channelStrings, ",") +} + +// NewChannelSet instantiates a new channel set and verifies that there are no +// duplicates present. +func NewChannelSet(set []uint64) (ChannelSet, error) { + // Check channel set for duplicates. + chanSet := make(map[uint64]struct{}) + for _, chanID := range set { + if _, exists := chanSet[chanID]; exists { + return nil, fmt.Errorf("duplicate chan in set: id=%v", + chanID) + } + chanSet[chanID] = struct{}{} + } + + return ChannelSet(set), nil +} + // LoopOut is a combination of the contract and the updates. type LoopOut struct { Loop @@ -161,7 +191,7 @@ func deserializeLoopOutContract(value []byte, chainParams *chaincfg.Params) ( return nil, err } if unchargeChannel != 0 { - contract.UnchargeChannel = &unchargeChannel + contract.OutgoingChanSet = ChannelSet{unchargeChannel} } var deadlineNano int64 @@ -248,10 +278,9 @@ func serializeLoopOutContract(swap *LoopOutContract) ( return nil, err } - var unchargeChannel uint64 - if swap.UnchargeChannel != nil { - unchargeChannel = *swap.UnchargeChannel - } + // Always write no outgoing channel. This field is replaced by an + // outgoing channel set. + unchargeChannel := uint64(0) if err := binary.Write(&b, byteOrder, unchargeChannel); err != nil { return nil, err } diff --git a/loopdb/store.go b/loopdb/store.go index 168f03e..505809d 100644 --- a/loopdb/store.go +++ b/loopdb/store.go @@ -1,9 +1,11 @@ package loopdb import ( + "bytes" "encoding/binary" "errors" "fmt" + "io" "os" "path/filepath" "time" @@ -51,6 +53,14 @@ var ( // value: time || rawSwapState contractKey = []byte("contract") + // outgoingChanSetKey is the key that stores a list of channel ids that + // restrict the loop out swap payment. + // + // path: loopOutBucket -> swapBucket[hash] -> outgoingChanSetKey + // + // value: concatenation of uint64 channel ids + outgoingChanSetKey = []byte("outgoing-chan-set") + byteOrder = binary.BigEndian keyLength = 33 @@ -190,6 +200,29 @@ func (s *boltSwapStore) FetchLoopOutSwaps() ([]*LoopOut, error) { return err } + // Read the list of concatenated outgoing channel ids + // that form the outgoing set. + setBytes := swapBucket.Get(outgoingChanSetKey) + if outgoingChanSetKey != nil { + r := bytes.NewReader(setBytes) + readLoop: + for { + var chanID uint64 + err := binary.Read(r, byteOrder, &chanID) + switch { + case err == io.EOF: + break readLoop + case err != nil: + return err + } + + contract.OutgoingChanSet = append( + contract.OutgoingChanSet, + chanID, + ) + } + } + updates, err := deserializeUpdates(swapBucket) if err != nil { return err @@ -374,6 +407,19 @@ func (s *boltSwapStore) CreateLoopOut(hash lntypes.Hash, return err } + // Write the outgoing channel set. + var b bytes.Buffer + for _, chanID := range swap.OutgoingChanSet { + err := binary.Write(&b, byteOrder, chanID) + if err != nil { + return err + } + } + err = swapBucket.Put(outgoingChanSetKey, b.Bytes()) + if err != nil { + return err + } + // Finally, we'll create an empty updates bucket for this swap // to track any future updates to the swap itself. _, err = swapBucket.CreateBucket(updatesBucketKey) diff --git a/loopdb/store_test.go b/loopdb/store_test.go index 9f15adf..799a1d4 100644 --- a/loopdb/store_test.go +++ b/loopdb/store_test.go @@ -45,7 +45,7 @@ func TestLoopOutStore(t *testing.T) { // Next, we'll make a new pending swap that we'll insert into the // database shortly. - pendingSwap := LoopOutContract{ + unrestrictedSwap := LoopOutContract{ SwapContract: SwapContract{ AmountRequested: 100, Preimage: testPreimage, @@ -71,7 +71,16 @@ func TestLoopOutStore(t *testing.T) { SwapPublicationDeadline: time.Unix(0, initiationTime.UnixNano()), } - testLoopOutStore(t, &pendingSwap) + t.Run("no outgoing set", func(t *testing.T) { + testLoopOutStore(t, &unrestrictedSwap) + }) + + restrictedSwap := unrestrictedSwap + restrictedSwap.OutgoingChanSet = ChannelSet{1, 2} + + t.Run("two channel outgoing set", func(t *testing.T) { + testLoopOutStore(t, &restrictedSwap) + }) } // testLoopOutStore tests the basic functionality of the current bbolt @@ -373,3 +382,65 @@ func createVersionZeroDb(t *testing.T, dbPath string) { t.Fatal(err) } } + +// TestLegacyOutgoingChannel asserts that a legacy channel restriction is +// properly mapped onto the newer channel set. +func TestLegacyOutgoingChannel(t *testing.T) { + var ( + legacyDbVersion = Hex("00000003") + legacyOutgoingChannel = Hex("0000000000000005") + ) + + legacyDb := map[string]interface{}{ + "loop-in": map[string]interface{}{}, + "metadata": map[string]interface{}{ + "dbp": legacyDbVersion, + }, + "uncharge-swaps": map[string]interface{}{ + Hex("2a595d79a55168970532805ae20c9b5fac98f04db79ba4c6ae9b9ac0f206359e"): map[string]interface{}{ + "contract": Hex("1562d6fbec140000010101010202020203030303040404040101010102020202030303030404040400000000000000640d707265706179696e766f69636501010101010101010101010101010101010101010101010101010101010101010201010101010101010101010101010101010101010101010101010101010101010300000090000000000000000a0000000000000014000000000000002800000063223347454e556d6e4552745766516374344e65676f6d557171745a757a5947507742530b73776170696e766f69636500000002000000000000001e") + legacyOutgoingChannel + Hex("1562d6fbec140000"), + "updates": map[string]interface{}{ + Hex("0000000000000001"): Hex("1508290a92d4c00001000000000000000000000000000000000000000000000000"), + Hex("0000000000000002"): Hex("1508290a92d4c00006000000000000000000000000000000000000000000000000"), + }, + }, + }, + } + + // Restore a legacy database. + tempDirName, err := ioutil.TempDir("", "clientstore") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDirName) + + tempPath := filepath.Join(tempDirName, dbFileName) + db, err := bbolt.Open(tempPath, 0600, nil) + if err != nil { + t.Fatal(err) + } + err = db.Update(func(tx *bbolt.Tx) error { + return RestoreDB(tx, legacyDb) + }) + if err != nil { + t.Fatal(err) + } + db.Close() + + // Fetch the legacy swap. + store, err := NewBoltSwapStore(tempDirName, &chaincfg.MainNetParams) + if err != nil { + t.Fatal(err) + } + + swaps, err := store.FetchLoopOutSwaps() + if err != nil { + t.Fatal(err) + } + + // Assert that the outgoing channel is read properly. + expectedChannelSet := ChannelSet{5} + if !reflect.DeepEqual(swaps[0].Contract.OutgoingChanSet, expectedChannelSet) { + t.Fatal("invalid outgoing channel") + } +} diff --git a/loopout.go b/loopout.go index f81f3d5..b233c0b 100644 --- a/loopout.go +++ b/loopout.go @@ -112,6 +112,12 @@ func newLoopOutSwap(globalCtx context.Context, cfg *swapConfig, return nil, err } + // Check channel set for duplicates. + chanSet, err := loopdb.NewChannelSet(request.OutgoingChanSet) + if err != nil { + return nil, err + } + // Instantiate a struct that contains all required data to start the // swap. initiationTime := time.Now() @@ -121,7 +127,6 @@ func newLoopOutSwap(globalCtx context.Context, cfg *swapConfig, DestAddr: request.DestAddr, MaxSwapRoutingFee: request.MaxSwapRoutingFee, SweepConfTarget: request.SweepConfTarget, - UnchargeChannel: request.LoopOutChannel, PrepayInvoice: swapResp.prepayInvoice, MaxPrepayRoutingFee: request.MaxPrepayRoutingFee, SwapPublicationDeadline: request.SwapPublicationDeadline, @@ -136,6 +141,7 @@ func newLoopOutSwap(globalCtx context.Context, cfg *swapConfig, MaxMinerFee: request.MaxMinerFee, MaxSwapFee: request.MaxSwapFee, }, + OutgoingChanSet: chanSet, } swapKit := newSwapKit( @@ -430,15 +436,9 @@ func (s *loopOutSwap) payInvoices(ctx context.Context) { // Pay the swap invoice. s.log.Infof("Sending swap payment %v", s.SwapInvoice) - var outgoingChanIds []uint64 - if s.LoopOutContract.UnchargeChannel != nil { - outgoingChanIds = append( - outgoingChanIds, *s.LoopOutContract.UnchargeChannel, - ) - } - s.swapPaymentChan = s.payInvoice( - ctx, s.SwapInvoice, s.MaxSwapRoutingFee, outgoingChanIds, + ctx, s.SwapInvoice, s.MaxSwapRoutingFee, + s.LoopOutContract.OutgoingChanSet, ) // Pay the prepay invoice. @@ -452,7 +452,7 @@ func (s *loopOutSwap) payInvoices(ctx context.Context) { // payInvoice pays a single invoice. func (s *loopOutSwap) payInvoice(ctx context.Context, invoice string, maxFee btcutil.Amount, - outgoingChanIds []uint64) chan lndclient.PaymentResult { + outgoingChanIds loopdb.ChannelSet) chan lndclient.PaymentResult { resultChan := make(chan lndclient.PaymentResult) @@ -481,8 +481,8 @@ func (s *loopOutSwap) payInvoice(ctx context.Context, invoice string, // payInvoiceAsync is the asynchronously executed part of paying an invoice. func (s *loopOutSwap) payInvoiceAsync(ctx context.Context, - invoice string, maxFee btcutil.Amount, outgoingChanIds []uint64) ( - *lndclient.PaymentStatus, error) { + invoice string, maxFee btcutil.Amount, + outgoingChanIds loopdb.ChannelSet) (*lndclient.PaymentStatus, error) { // Extract hash from payment request. Unfortunately the request // components aren't available directly. diff --git a/loopout_test.go b/loopout_test.go index 1b233ea..d5dc9de 100644 --- a/loopout_test.go +++ b/loopout_test.go @@ -3,6 +3,7 @@ package loop import ( "context" "errors" + "reflect" "testing" "time" @@ -47,8 +48,11 @@ func TestLoopOutPaymentParameters(t *testing.T) { const maxParts = 5 // Initiate the swap. + req := *testRequest + req.OutgoingChanSet = loopdb.ChannelSet{2, 3} + swap, err := newLoopOutSwap( - context.Background(), cfg, height, testRequest, + context.Background(), cfg, height, &req, ) if err != nil { t.Fatal(err) @@ -99,6 +103,13 @@ func TestLoopOutPaymentParameters(t *testing.T) { maxParts, swapPayment.MaxParts) } + // Verify the outgoing channel set restriction. + if !reflect.DeepEqual( + []uint64(req.OutgoingChanSet), swapPayment.OutgoingChanIds, + ) { + t.Fatalf("Unexpected outgoing channel set") + } + // Swap is expected to register for confirmation of the htlc. Assert // this to prevent a blocked channel in the mock. ctx.AssertRegisterConf()