From 9678c7817dbc13502e7850b7bb231f5c6521e0c2 Mon Sep 17 00:00:00 2001 From: carla Date: Mon, 3 Aug 2020 10:55:58 +0200 Subject: [PATCH] multi: add swap label to SwapContract and store under separate key This commits adds an optional label to our swaps, and writes it to disk under a separate key in our swap bucket. This approach is chosen rather than an on-the-fly addition to our existing swap contract field so that we do not need to deal with EOF checking in the future. To allow creation of unique internal labels, we add a reserved prefix which can be used by the daemon to set labels that are distinct from client set ones. --- interface.go | 6 +++++ labels/labels.go | 46 ++++++++++++++++++++++++++++++++ labels/labels_test.go | 53 ++++++++++++++++++++++++++++++++++++ loopdb/loop.go | 3 +++ loopdb/loopin.go | 31 ++++++++++++++++++++++ loopdb/store.go | 29 ++++++++++++++++++++ loopdb/store_test.go | 62 ++++++++++++++++++++++++++++--------------- loopin.go | 7 +++++ loopout.go | 7 +++++ 9 files changed, 223 insertions(+), 21 deletions(-) create mode 100644 labels/labels.go create mode 100644 labels/labels_test.go diff --git a/interface.go b/interface.go index a1be6d6..46d88f0 100644 --- a/interface.go +++ b/interface.go @@ -74,6 +74,9 @@ type OutRequest struct { // Expiry is the absolute expiry height of the on-chain htlc. Expiry int32 + + // Label contains an optional label for the swap. + Label string } // Out contains the full details of a loop out request. This includes things @@ -186,6 +189,9 @@ type LoopInRequest struct { // ExternalHtlc specifies whether the htlc is published by an external // source. ExternalHtlc bool + + // Label contains an optional label for the swap. + Label string } // LoopInTerms are the server terms on which it executes loop in swaps. diff --git a/labels/labels.go b/labels/labels.go new file mode 100644 index 0000000..b2137f4 --- /dev/null +++ b/labels/labels.go @@ -0,0 +1,46 @@ +package labels + +import ( + "errors" +) + +const ( + // MaxLength is the maximum length we allow for labels. + MaxLength = 500 + + // Reserved is used as a prefix to separate labels that are created by + // loopd from those created by users. + Reserved = "[reserved]" +) + +var ( + // ErrLabelTooLong is returned when a label exceeds our length limit. + ErrLabelTooLong = errors.New("label exceeds maximum length") + + // ErrReservedPrefix is returned when a label contains the prefix + // which is reserved for internally produced labels. + ErrReservedPrefix = errors.New("label contains reserved prefix") +) + +// Validate checks that a label is of appropriate length and is not in our list +// of reserved labels. +func Validate(label string) error { + if len(label) > MaxLength { + return ErrLabelTooLong + } + + // If the label is shorter than our reserved prefix, it cannot contain + // it. + if len(label) < len(Reserved) { + return nil + } + + // Check if our label begins with our reserved prefix. We don't mind if + // it has our reserved prefix in another case, we just need to be able + // to reserve a subset of labels with this prefix. + if label[0:len(Reserved)] == Reserved { + return ErrReservedPrefix + } + + return nil +} diff --git a/labels/labels_test.go b/labels/labels_test.go new file mode 100644 index 0000000..6010d55 --- /dev/null +++ b/labels/labels_test.go @@ -0,0 +1,53 @@ +package labels + +import ( + "fmt" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +// TestValidate tests validation of labels. +func TestValidate(t *testing.T) { + tests := []struct { + name string + label string + err error + }{ + { + name: "label ok", + label: "label", + err: nil, + }, + { + name: "exceeds limit", + label: strings.Repeat(" ", MaxLength+1), + err: ErrLabelTooLong, + }, + { + name: "exactly reserved prefix", + label: Reserved, + err: ErrReservedPrefix, + }, + { + name: "starts with reserved prefix", + label: fmt.Sprintf("%v test", Reserved), + err: ErrReservedPrefix, + }, + { + name: "ends with reserved prefix", + label: fmt.Sprintf("test %v", Reserved), + err: nil, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, test.err, Validate(test.label)) + }) + } +} diff --git a/loopdb/loop.go b/loopdb/loop.go index 870df9f..51ce917 100644 --- a/loopdb/loop.go +++ b/loopdb/loop.go @@ -43,6 +43,9 @@ type SwapContract struct { // InitiationTime is the time at which the swap was initiated. InitiationTime time.Time + + // Label contains an optional label for the swap. + Label string } // Loop contains fields shared between LoopIn and LoopOut diff --git a/loopdb/loopin.go b/loopdb/loopin.go index 409d679..756656c 100644 --- a/loopdb/loopin.go +++ b/loopdb/loopin.go @@ -6,6 +6,8 @@ import ( "fmt" "time" + "github.com/coreos/bbolt" + "github.com/lightninglabs/loop/labels" "github.com/lightningnetwork/lnd/routing/route" ) @@ -24,6 +26,10 @@ type LoopInContract struct { // ExternalHtlc specifies whether the htlc is published by an external // source. ExternalHtlc bool + + // Label contains an optional label for the swap. Note that this field + // is stored separately to the rest of the contract on disk. + Label string } // LoopIn is a combination of the contract and the updates. @@ -112,6 +118,31 @@ func serializeLoopInContract(swap *LoopInContract) ( return b.Bytes(), nil } +// putLabel performs validation of a label and writes it to the bucket provided +// under the label key if it is non-zero. +func putLabel(bucket *bbolt.Bucket, label string) error { + if len(label) == 0 { + return nil + } + + if err := labels.Validate(label); err != nil { + return err + } + + return bucket.Put(labelKey, []byte(label)) +} + +// getLabel attempts to get an optional label stored under the label key in a +// bucket. If it is not present, an empty label is returned. +func getLabel(bucket *bbolt.Bucket) string { + label := bucket.Get(labelKey) + if label == nil { + return "" + } + + return string(label) +} + // deserializeLoopInContract deserializes the loop in contract from a byte slice. func deserializeLoopInContract(value []byte) (*LoopInContract, error) { r := bytes.NewReader(value) diff --git a/loopdb/store.go b/loopdb/store.go index 1d48904..b4a6fad 100644 --- a/loopdb/store.go +++ b/loopdb/store.go @@ -60,6 +60,15 @@ var ( // value: time || rawSwapState contractKey = []byte("contract") + // labelKey is the key that stores an optional label for the swap. If + // a swap was created before we started adding labels, or was created + // without a label, this key will not be present. + // + // path: loopInBucket/loopOutBucket -> swapBucket[hash] -> labelKey + // + // value: string label + labelKey = []byte("label") + // outgoingChanSetKey is the key that stores a list of channel ids that // restrict the loop out swap payment. // @@ -207,6 +216,9 @@ func (s *boltSwapStore) FetchLoopOutSwaps() ([]*LoopOut, error) { return err } + // Get our label for this swap, if it is present. + contract.Label = getLabel(swapBucket) + // Read the list of concatenated outgoing channel ids // that form the outgoing set. setBytes := swapBucket.Get(outgoingChanSetKey) @@ -352,6 +364,9 @@ func (s *boltSwapStore) FetchLoopInSwaps() ([]*LoopIn, error) { return err } + // Get our label for this swap, if it is present. + contract.Label = getLabel(swapBucket) + updates, err := deserializeUpdates(swapBucket) if err != nil { return err @@ -434,6 +449,10 @@ func (s *boltSwapStore) CreateLoopOut(hash lntypes.Hash, return err } + if err := putLabel(swapBucket, swap.Label); err != nil { + return err + } + // Write the outgoing channel set. var b bytes.Buffer for _, chanID := range swap.OutgoingChanSet { @@ -447,6 +466,11 @@ func (s *boltSwapStore) CreateLoopOut(hash lntypes.Hash, return err } + // Write label to disk if we have one. + if err := putLabel(swapBucket, swap.Label); err != nil { + return err + } + // Finally, we'll create an empty updates bucket for this swap // to track any future updates to the swap itself. _, err = swapBucket.CreateBucket(updatesBucketKey) @@ -485,6 +509,11 @@ func (s *boltSwapStore) CreateLoopIn(hash lntypes.Hash, return err } + // Write label to disk if we have one. + if err := putLabel(swapBucket, swap.Label); err != nil { + return err + } + // Finally, we'll create an empty updates bucket for this swap // to track any future updates to the swap itself. _, err = swapBucket.CreateBucket(updatesBucketKey) diff --git a/loopdb/store_test.go b/loopdb/store_test.go index d373876..bde8c25 100644 --- a/loopdb/store_test.go +++ b/loopdb/store_test.go @@ -83,6 +83,13 @@ func TestLoopOutStore(t *testing.T) { t.Run("two channel outgoing set", func(t *testing.T) { testLoopOutStore(t, &restrictedSwap) }) + + labelledSwap := unrestrictedSwap + labelledSwap.Label = "test label" + t.Run("labelled swap", func(t *testing.T) { + testLoopOutStore(t, &labelledSwap) + }) + } // testLoopOutStore tests the basic functionality of the current bbolt @@ -196,27 +203,6 @@ func testLoopOutStore(t *testing.T, pendingSwap *LoopOutContract) { // TestLoopInStore tests all the basic functionality of the current bbolt // swap store. func TestLoopInStore(t *testing.T) { - tempDirName, err := ioutil.TempDir("", "clientstore") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(tempDirName) - - store, err := NewBoltSwapStore(tempDirName, &chaincfg.MainNetParams) - if err != nil { - t.Fatal(err) - } - - // First, verify that an empty database has no active swaps. - swaps, err := store.FetchLoopInSwaps() - if err != nil { - t.Fatal(err) - } - if len(swaps) != 0 { - t.Fatal("expected empty store") - } - - hash := sha256.Sum256(testPreimage[:]) initiationTime := time.Date(2018, 11, 1, 0, 0, 0, 0, time.UTC) // Next, we'll make a new pending swap that we'll insert into the @@ -243,6 +229,38 @@ func TestLoopInStore(t *testing.T) { ExternalHtlc: true, } + t.Run("loop in", func(t *testing.T) { + testLoopInStore(t, pendingSwap) + }) + + labelledSwap := pendingSwap + labelledSwap.Label = "test label" + t.Run("loop in with label", func(t *testing.T) { + testLoopInStore(t, labelledSwap) + }) +} + +func testLoopInStore(t *testing.T, pendingSwap LoopInContract) { + tempDirName, err := ioutil.TempDir("", "clientstore") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDirName) + + store, err := NewBoltSwapStore(tempDirName, &chaincfg.MainNetParams) + if err != nil { + t.Fatal(err) + } + + // First, verify that an empty database has no active swaps. + swaps, err := store.FetchLoopInSwaps() + if err != nil { + t.Fatal(err) + } + if len(swaps) != 0 { + t.Fatal("expected empty store") + } + // checkSwap is a test helper function that'll assert the state of a // swap. checkSwap := func(expectedState SwapState) { @@ -269,6 +287,8 @@ func TestLoopInStore(t *testing.T) { } } + hash := sha256.Sum256(testPreimage[:]) + // If we create a new swap, then it should show up as being initialized // right after. if err := store.CreateLoopIn(hash, &pendingSwap); err != nil { diff --git a/loopin.go b/loopin.go index 703974e..89777d9 100644 --- a/loopin.go +++ b/loopin.go @@ -13,6 +13,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/lightninglabs/lndclient" + "github.com/lightninglabs/loop/labels" "github.com/lightninglabs/loop/loopdb" "github.com/lightninglabs/loop/swap" "github.com/lightningnetwork/lnd/chainntnfs" @@ -77,6 +78,11 @@ func newLoopInSwap(globalCtx context.Context, cfg *swapConfig, currentHeight int32, request *LoopInRequest) (*loopInInitResult, error) { + // Before we start, check that the label is valid. + if err := labels.Validate(request.Label); err != nil { + return nil, err + } + // Request current server loop in terms and use these to calculate the // swap fee that we should subtract from the swap amount in the payment // request that we send to the server. @@ -165,6 +171,7 @@ func newLoopInSwap(globalCtx context.Context, cfg *swapConfig, CltvExpiry: swapResp.expiry, MaxMinerFee: request.MaxMinerFee, MaxSwapFee: request.MaxSwapFee, + Label: request.Label, }, } diff --git a/loopout.go b/loopout.go index 5e26bc4..4f92ff5 100644 --- a/loopout.go +++ b/loopout.go @@ -13,6 +13,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" "github.com/lightninglabs/lndclient" + "github.com/lightninglabs/loop/labels" "github.com/lightninglabs/loop/loopdb" "github.com/lightninglabs/loop/swap" "github.com/lightninglabs/loop/sweep" @@ -87,6 +88,11 @@ type loopOutInitResult struct { func newLoopOutSwap(globalCtx context.Context, cfg *swapConfig, currentHeight int32, request *OutRequest) (*loopOutInitResult, error) { + // Before we start, check that the label is valid. + if err := labels.Validate(request.Label); err != nil { + return nil, err + } + // Generate random preimage. var swapPreimage [32]byte if _, err := rand.Read(swapPreimage[:]); err != nil { @@ -154,6 +160,7 @@ func newLoopOutSwap(globalCtx context.Context, cfg *swapConfig, CltvExpiry: request.Expiry, MaxMinerFee: request.MaxMinerFee, MaxSwapFee: request.MaxSwapFee, + Label: request.Label, }, OutgoingChanSet: chanSet, }