loopdb: store outgoing channel set

Upgrade the database schema to allow for multiple outgoing channels.
This is implemented as an on-the-fly migration leaving the old key in
place.
pull/205/head
Joost Jager 4 years ago
parent 044c1c12dd
commit 8c544bf2ba
No known key found for this signature in database
GPG Key ID: A61B9D4C393C59C7

@ -359,9 +359,8 @@ func (s *Client) resumeSwaps(ctx context.Context,
func (s *Client) LoopOut(globalCtx context.Context, func (s *Client) LoopOut(globalCtx context.Context,
request *OutRequest) (*lntypes.Hash, btcutil.Address, error) { request *OutRequest) (*lntypes.Hash, btcutil.Address, error) {
log.Infof("LoopOut %v to %v (channel: %v)", log.Infof("LoopOut %v to %v (channels: %v)",
request.Amount, request.DestAddr, request.Amount, request.DestAddr, request.OutgoingChanSet,
request.LoopOutChannel,
) )
if err := s.waitForInitialized(globalCtx); err != nil { if err := s.waitForInitialized(globalCtx); err != nil {

@ -64,9 +64,9 @@ type OutRequest struct {
// client sweep tx. // client sweep tx.
SweepConfTarget int32 SweepConfTarget int32
// LoopOutChannel optionally specifies the short channel id of the // OutgoingChanSet optionally specifies the short channel ids of the
// channel to loop out. // channels that may be used to loop out.
LoopOutChannel *uint64 OutgoingChanSet loopdb.ChannelSet
// SwapPublicationDeadline can be set by the client to allow the server // SwapPublicationDeadline can be set by the client to allow the server
// delaying publication of the swap HTLC to save on chain fees. // delaying publication of the swap HTLC to save on chain fees.

@ -90,7 +90,7 @@ func (s *swapClientServer) LoopOut(ctx context.Context,
), ),
} }
if in.LoopOutChannel != 0 { if in.LoopOutChannel != 0 {
req.LoopOutChannel = &in.LoopOutChannel req.OutgoingChanSet = loopdb.ChannelSet{in.LoopOutChannel}
} }
hash, htlc, err := s.impl.LoopOut(ctx, req) hash, htlc, err := s.impl.LoopOut(ctx, req)
if err != nil { if err != nil {

@ -2,7 +2,6 @@ package loopd
import ( import (
"fmt" "fmt"
"strconv"
"github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg"
"github.com/lightninglabs/loop" "github.com/lightninglabs/loop"
@ -64,13 +63,8 @@ func viewOut(swapClient *loop.Client, chainParams *chaincfg.Params) error {
fmt.Printf(" Preimage: %v\n", s.Contract.Preimage) fmt.Printf(" Preimage: %v\n", s.Contract.Preimage)
fmt.Printf(" Htlc address: %v\n", htlc.Address) fmt.Printf(" Htlc address: %v\n", htlc.Address)
unchargeChannel := "any" fmt.Printf(" Uncharge channels: %v\n",
if s.Contract.UnchargeChannel != nil { s.Contract.OutgoingChanSet)
unchargeChannel = strconv.FormatUint(
*s.Contract.UnchargeChannel, 10,
)
}
fmt.Printf(" Uncharge channel: %v\n", unchargeChannel)
fmt.Printf(" Dest: %v\n", s.Contract.DestAddr) fmt.Printf(" Dest: %v\n", s.Contract.DestAddr)
fmt.Printf(" Amt: %v, Expiry: %v\n", fmt.Printf(" Amt: %v, Expiry: %v\n",
s.Contract.AmountRequested, s.Contract.CltvExpiry, s.Contract.AmountRequested, s.Contract.CltvExpiry,

@ -4,6 +4,8 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"strconv"
"strings"
"time" "time"
"github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg"
@ -34,9 +36,9 @@ type LoopOutContract struct {
// client sweep tx. // client sweep tx.
SweepConfTarget int32 SweepConfTarget int32
// TargetChannel is the channel to loop out. If zero, any channel may // OutgoingChanSet is the set of short ids of channels that may be used.
// be used. // If empty, any channel may be used.
UnchargeChannel *uint64 OutgoingChanSet ChannelSet
// PrepayInvoice is the invoice that the client should pay to the // PrepayInvoice is the invoice that the client should pay to the
// server that will be returned if the swap is complete. // server that will be returned if the swap is complete.
@ -53,6 +55,34 @@ type LoopOutContract struct {
SwapPublicationDeadline time.Time SwapPublicationDeadline time.Time
} }
// ChannelSet stores a set of channels.
type ChannelSet []uint64
// String returns the human-readable representation of a channel set.
func (c ChannelSet) String() string {
channelStrings := make([]string, len(c))
for i, chanID := range c {
channelStrings[i] = strconv.FormatUint(chanID, 10)
}
return strings.Join(channelStrings, ",")
}
// NewChannelSet instantiates a new channel set and verifies that there are no
// duplicates present.
func NewChannelSet(set []uint64) (ChannelSet, error) {
// Check channel set for duplicates.
chanSet := make(map[uint64]struct{})
for _, chanID := range set {
if _, exists := chanSet[chanID]; exists {
return nil, fmt.Errorf("duplicate chan in set: id=%v",
chanID)
}
chanSet[chanID] = struct{}{}
}
return ChannelSet(set), nil
}
// LoopOut is a combination of the contract and the updates. // LoopOut is a combination of the contract and the updates.
type LoopOut struct { type LoopOut struct {
Loop Loop
@ -161,7 +191,7 @@ func deserializeLoopOutContract(value []byte, chainParams *chaincfg.Params) (
return nil, err return nil, err
} }
if unchargeChannel != 0 { if unchargeChannel != 0 {
contract.UnchargeChannel = &unchargeChannel contract.OutgoingChanSet = ChannelSet{unchargeChannel}
} }
var deadlineNano int64 var deadlineNano int64
@ -248,10 +278,9 @@ func serializeLoopOutContract(swap *LoopOutContract) (
return nil, err return nil, err
} }
var unchargeChannel uint64 // Always write no outgoing channel. This field is replaced by an
if swap.UnchargeChannel != nil { // outgoing channel set.
unchargeChannel = *swap.UnchargeChannel unchargeChannel := uint64(0)
}
if err := binary.Write(&b, byteOrder, unchargeChannel); err != nil { if err := binary.Write(&b, byteOrder, unchargeChannel); err != nil {
return nil, err return nil, err
} }

@ -1,9 +1,11 @@
package loopdb package loopdb
import ( import (
"bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"io"
"os" "os"
"path/filepath" "path/filepath"
"time" "time"
@ -51,6 +53,14 @@ var (
// value: time || rawSwapState // value: time || rawSwapState
contractKey = []byte("contract") contractKey = []byte("contract")
// outgoingChanSetKey is the key that stores a list of channel ids that
// restrict the loop out swap payment.
//
// path: loopOutBucket -> swapBucket[hash] -> outgoingChanSetKey
//
// value: concatenation of uint64 channel ids
outgoingChanSetKey = []byte("outgoing-chan-set")
byteOrder = binary.BigEndian byteOrder = binary.BigEndian
keyLength = 33 keyLength = 33
@ -190,6 +200,29 @@ func (s *boltSwapStore) FetchLoopOutSwaps() ([]*LoopOut, error) {
return err return err
} }
// Read the list of concatenated outgoing channel ids
// that form the outgoing set.
setBytes := swapBucket.Get(outgoingChanSetKey)
if outgoingChanSetKey != nil {
r := bytes.NewReader(setBytes)
readLoop:
for {
var chanID uint64
err := binary.Read(r, byteOrder, &chanID)
switch {
case err == io.EOF:
break readLoop
case err != nil:
return err
}
contract.OutgoingChanSet = append(
contract.OutgoingChanSet,
chanID,
)
}
}
updates, err := deserializeUpdates(swapBucket) updates, err := deserializeUpdates(swapBucket)
if err != nil { if err != nil {
return err return err
@ -374,6 +407,19 @@ func (s *boltSwapStore) CreateLoopOut(hash lntypes.Hash,
return err return err
} }
// Write the outgoing channel set.
var b bytes.Buffer
for _, chanID := range swap.OutgoingChanSet {
err := binary.Write(&b, byteOrder, chanID)
if err != nil {
return err
}
}
err = swapBucket.Put(outgoingChanSetKey, b.Bytes())
if err != nil {
return err
}
// Finally, we'll create an empty updates bucket for this swap // Finally, we'll create an empty updates bucket for this swap
// to track any future updates to the swap itself. // to track any future updates to the swap itself.
_, err = swapBucket.CreateBucket(updatesBucketKey) _, err = swapBucket.CreateBucket(updatesBucketKey)

@ -45,7 +45,7 @@ func TestLoopOutStore(t *testing.T) {
// Next, we'll make a new pending swap that we'll insert into the // Next, we'll make a new pending swap that we'll insert into the
// database shortly. // database shortly.
pendingSwap := LoopOutContract{ unrestrictedSwap := LoopOutContract{
SwapContract: SwapContract{ SwapContract: SwapContract{
AmountRequested: 100, AmountRequested: 100,
Preimage: testPreimage, Preimage: testPreimage,
@ -71,7 +71,16 @@ func TestLoopOutStore(t *testing.T) {
SwapPublicationDeadline: time.Unix(0, initiationTime.UnixNano()), SwapPublicationDeadline: time.Unix(0, initiationTime.UnixNano()),
} }
testLoopOutStore(t, &pendingSwap) t.Run("no outgoing set", func(t *testing.T) {
testLoopOutStore(t, &unrestrictedSwap)
})
restrictedSwap := unrestrictedSwap
restrictedSwap.OutgoingChanSet = ChannelSet{1, 2}
t.Run("two channel outgoing set", func(t *testing.T) {
testLoopOutStore(t, &restrictedSwap)
})
} }
// testLoopOutStore tests the basic functionality of the current bbolt // testLoopOutStore tests the basic functionality of the current bbolt
@ -373,3 +382,65 @@ func createVersionZeroDb(t *testing.T, dbPath string) {
t.Fatal(err) t.Fatal(err)
} }
} }
// TestLegacyOutgoingChannel asserts that a legacy channel restriction is
// properly mapped onto the newer channel set.
func TestLegacyOutgoingChannel(t *testing.T) {
var (
legacyDbVersion = Hex("00000003")
legacyOutgoingChannel = Hex("0000000000000005")
)
legacyDb := map[string]interface{}{
"loop-in": map[string]interface{}{},
"metadata": map[string]interface{}{
"dbp": legacyDbVersion,
},
"uncharge-swaps": map[string]interface{}{
Hex("2a595d79a55168970532805ae20c9b5fac98f04db79ba4c6ae9b9ac0f206359e"): map[string]interface{}{
"contract": Hex("1562d6fbec140000010101010202020203030303040404040101010102020202030303030404040400000000000000640d707265706179696e766f69636501010101010101010101010101010101010101010101010101010101010101010201010101010101010101010101010101010101010101010101010101010101010300000090000000000000000a0000000000000014000000000000002800000063223347454e556d6e4552745766516374344e65676f6d557171745a757a5947507742530b73776170696e766f69636500000002000000000000001e") + legacyOutgoingChannel + Hex("1562d6fbec140000"),
"updates": map[string]interface{}{
Hex("0000000000000001"): Hex("1508290a92d4c00001000000000000000000000000000000000000000000000000"),
Hex("0000000000000002"): Hex("1508290a92d4c00006000000000000000000000000000000000000000000000000"),
},
},
},
}
// Restore a legacy database.
tempDirName, err := ioutil.TempDir("", "clientstore")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tempDirName)
tempPath := filepath.Join(tempDirName, dbFileName)
db, err := bbolt.Open(tempPath, 0600, nil)
if err != nil {
t.Fatal(err)
}
err = db.Update(func(tx *bbolt.Tx) error {
return RestoreDB(tx, legacyDb)
})
if err != nil {
t.Fatal(err)
}
db.Close()
// Fetch the legacy swap.
store, err := NewBoltSwapStore(tempDirName, &chaincfg.MainNetParams)
if err != nil {
t.Fatal(err)
}
swaps, err := store.FetchLoopOutSwaps()
if err != nil {
t.Fatal(err)
}
// Assert that the outgoing channel is read properly.
expectedChannelSet := ChannelSet{5}
if !reflect.DeepEqual(swaps[0].Contract.OutgoingChanSet, expectedChannelSet) {
t.Fatal("invalid outgoing channel")
}
}

@ -112,6 +112,12 @@ func newLoopOutSwap(globalCtx context.Context, cfg *swapConfig,
return nil, err return nil, err
} }
// Check channel set for duplicates.
chanSet, err := loopdb.NewChannelSet(request.OutgoingChanSet)
if err != nil {
return nil, err
}
// Instantiate a struct that contains all required data to start the // Instantiate a struct that contains all required data to start the
// swap. // swap.
initiationTime := time.Now() initiationTime := time.Now()
@ -121,7 +127,6 @@ func newLoopOutSwap(globalCtx context.Context, cfg *swapConfig,
DestAddr: request.DestAddr, DestAddr: request.DestAddr,
MaxSwapRoutingFee: request.MaxSwapRoutingFee, MaxSwapRoutingFee: request.MaxSwapRoutingFee,
SweepConfTarget: request.SweepConfTarget, SweepConfTarget: request.SweepConfTarget,
UnchargeChannel: request.LoopOutChannel,
PrepayInvoice: swapResp.prepayInvoice, PrepayInvoice: swapResp.prepayInvoice,
MaxPrepayRoutingFee: request.MaxPrepayRoutingFee, MaxPrepayRoutingFee: request.MaxPrepayRoutingFee,
SwapPublicationDeadline: request.SwapPublicationDeadline, SwapPublicationDeadline: request.SwapPublicationDeadline,
@ -136,6 +141,7 @@ func newLoopOutSwap(globalCtx context.Context, cfg *swapConfig,
MaxMinerFee: request.MaxMinerFee, MaxMinerFee: request.MaxMinerFee,
MaxSwapFee: request.MaxSwapFee, MaxSwapFee: request.MaxSwapFee,
}, },
OutgoingChanSet: chanSet,
} }
swapKit := newSwapKit( swapKit := newSwapKit(
@ -430,15 +436,9 @@ func (s *loopOutSwap) payInvoices(ctx context.Context) {
// Pay the swap invoice. // Pay the swap invoice.
s.log.Infof("Sending swap payment %v", s.SwapInvoice) s.log.Infof("Sending swap payment %v", s.SwapInvoice)
var outgoingChanIds []uint64
if s.LoopOutContract.UnchargeChannel != nil {
outgoingChanIds = append(
outgoingChanIds, *s.LoopOutContract.UnchargeChannel,
)
}
s.swapPaymentChan = s.payInvoice( s.swapPaymentChan = s.payInvoice(
ctx, s.SwapInvoice, s.MaxSwapRoutingFee, outgoingChanIds, ctx, s.SwapInvoice, s.MaxSwapRoutingFee,
s.LoopOutContract.OutgoingChanSet,
) )
// Pay the prepay invoice. // Pay the prepay invoice.
@ -452,7 +452,7 @@ func (s *loopOutSwap) payInvoices(ctx context.Context) {
// payInvoice pays a single invoice. // payInvoice pays a single invoice.
func (s *loopOutSwap) payInvoice(ctx context.Context, invoice string, func (s *loopOutSwap) payInvoice(ctx context.Context, invoice string,
maxFee btcutil.Amount, maxFee btcutil.Amount,
outgoingChanIds []uint64) chan lndclient.PaymentResult { outgoingChanIds loopdb.ChannelSet) chan lndclient.PaymentResult {
resultChan := make(chan lndclient.PaymentResult) resultChan := make(chan lndclient.PaymentResult)
@ -481,8 +481,8 @@ func (s *loopOutSwap) payInvoice(ctx context.Context, invoice string,
// payInvoiceAsync is the asynchronously executed part of paying an invoice. // payInvoiceAsync is the asynchronously executed part of paying an invoice.
func (s *loopOutSwap) payInvoiceAsync(ctx context.Context, func (s *loopOutSwap) payInvoiceAsync(ctx context.Context,
invoice string, maxFee btcutil.Amount, outgoingChanIds []uint64) ( invoice string, maxFee btcutil.Amount,
*lndclient.PaymentStatus, error) { outgoingChanIds loopdb.ChannelSet) (*lndclient.PaymentStatus, error) {
// Extract hash from payment request. Unfortunately the request // Extract hash from payment request. Unfortunately the request
// components aren't available directly. // components aren't available directly.

@ -3,6 +3,7 @@ package loop
import ( import (
"context" "context"
"errors" "errors"
"reflect"
"testing" "testing"
"time" "time"
@ -47,8 +48,11 @@ func TestLoopOutPaymentParameters(t *testing.T) {
const maxParts = 5 const maxParts = 5
// Initiate the swap. // Initiate the swap.
req := *testRequest
req.OutgoingChanSet = loopdb.ChannelSet{2, 3}
swap, err := newLoopOutSwap( swap, err := newLoopOutSwap(
context.Background(), cfg, height, testRequest, context.Background(), cfg, height, &req,
) )
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -99,6 +103,13 @@ func TestLoopOutPaymentParameters(t *testing.T) {
maxParts, swapPayment.MaxParts) maxParts, swapPayment.MaxParts)
} }
// Verify the outgoing channel set restriction.
if !reflect.DeepEqual(
[]uint64(req.OutgoingChanSet), swapPayment.OutgoingChanIds,
) {
t.Fatalf("Unexpected outgoing channel set")
}
// Swap is expected to register for confirmation of the htlc. Assert // Swap is expected to register for confirmation of the htlc. Assert
// this to prevent a blocked channel in the mock. // this to prevent a blocked channel in the mock.
ctx.AssertRegisterConf() ctx.AssertRegisterConf()

Loading…
Cancel
Save