From 6a0a9556a006bceadfceb2d87b29eaa9953fcd88 Mon Sep 17 00:00:00 2001 From: Joost Jager Date: Tue, 12 Mar 2019 16:09:57 +0100 Subject: [PATCH] loopdb: add loop in This commit adds the required code to persist loop in swaps. It also introduces the file loop.go to which shared code is moved. Sharing of contract serialization/deserialization code has been reverted. The prepay fields do not apply to loop in, but were part of the shared contract struct. Without also adding a migration, it wouldn't be possible to keep the shared code. In general it is probably more flexible to keep the contract serialization code separated between in and out swaps. --- client_test.go | 12 ++- loopdb/interface.go | 11 +++ loopdb/loop.go | 32 ++++--- loopdb/loopin.go | 168 ++++++++++++++++++++++++++++++++++ loopdb/loopout.go | 189 ++++++++++++++++++-------------------- loopdb/store.go | 212 +++++++++++++++++++++++++++++++------------ loopdb/store_test.go | 155 +++++++++++++++++++++++++++---- loopdb/swapstate.go | 28 +++--- loopout.go | 32 +++---- store_mock_test.go | 100 +++++++++++++++++++- 10 files changed, 711 insertions(+), 228 deletions(-) create mode 100644 loopdb/loopin.go diff --git a/client_test.go b/client_test.go index ca0c61d..5e8f08d 100644 --- a/client_test.go +++ b/client_test.go @@ -188,6 +188,7 @@ func testResume(t *testing.T, expired, preimageRevealed, expectSuccess bool) { SwapInvoice: swapPayReq, SweepConfTarget: 2, MaxSwapRoutingFee: 70000, + PrepayInvoice: prePayReq, SwapContract: loopdb.SwapContract{ Preimage: preimage, AmountRequested: amt, @@ -195,16 +196,17 @@ func testResume(t *testing.T, expired, preimageRevealed, expectSuccess bool) { ReceiverKey: receiverKey, SenderKey: senderKey, MaxSwapFee: 60000, - PrepayInvoice: prePayReq, MaxMinerFee: 50000, }, }, - Events: []*loopdb.LoopOutEvent{ - { - State: state, + Loop: loopdb.Loop{ + Events: []*loopdb.LoopEvent{ + { + State: state, + }, }, + Hash: hash, }, - Hash: hash, } if expired { diff --git a/loopdb/interface.go b/loopdb/interface.go index 5c3f3d4..72b92d0 100644 --- a/loopdb/interface.go +++ b/loopdb/interface.go @@ -20,6 +20,17 @@ type SwapStore interface { // the various stages in its lifetime. UpdateLoopOut(hash lntypes.Hash, time time.Time, state SwapState) error + // FetchLoopInSwaps returns all swaps currently in the store. + FetchLoopInSwaps() ([]*LoopIn, error) + + // CreateLoopIn adds an initiated swap to the store. + CreateLoopIn(hash lntypes.Hash, swap *LoopInContract) error + + // UpdateLoopIn stores a new event for a target loop in swap. This + // appends to the event log for a particular swap as it goes through + // the various stages in its lifetime. + UpdateLoopIn(hash lntypes.Hash, time time.Time, state SwapState) error + // Close closes the underlying database. Close() error } diff --git a/loopdb/loop.go b/loopdb/loop.go index 81d1959..80298e8 100644 --- a/loopdb/loop.go +++ b/loopdb/loop.go @@ -18,10 +18,6 @@ type SwapContract struct { // AmountRequested is the total amount of the swap. AmountRequested btcutil.Amount - // PrepayInvoice is the invoice that the client should pay to the - // server that will be returned if the swap is complete. - PrepayInvoice string - // SenderKey is the key of the sender that will be used in the on-chain // HTLC. SenderKey [33]byte @@ -33,10 +29,6 @@ type SwapContract struct { // CltvExpiry is the total absolute CLTV expiry of the swap. CltvExpiry int32 - // MaxPrepayRoutingFee is the maximum off-chain fee in msat that may be - // paid for the prepayment to the server. - MaxPrepayRoutingFee btcutil.Amount - // MaxSwapFee is the maximum we are willing to pay the server for the // swap. MaxSwapFee btcutil.Amount @@ -53,8 +45,14 @@ type SwapContract struct { InitiationTime time.Time } -// LoopOutEvent contains the dynamic data of a swap. -type LoopOutEvent struct { +// Loop contains fields shared between LoopIn and LoopOut +type Loop struct { + Hash lntypes.Hash + Events []*LoopEvent +} + +// LoopEvent contains the dynamic data of a swap. +type LoopEvent struct { // State is the new state for this swap as a result of this event. State SwapState @@ -63,7 +61,7 @@ type LoopOutEvent struct { } // State returns the most recent state of this swap. -func (s *LoopOut) State() SwapState { +func (s *Loop) State() SwapState { lastUpdate := s.LastUpdate() if lastUpdate == nil { return StateInitiated @@ -73,7 +71,7 @@ func (s *LoopOut) State() SwapState { } // LastUpdate returns the most recent update of this swap. -func (s *LoopOut) LastUpdate() *LoopOutEvent { +func (s *Loop) LastUpdate() *LoopEvent { eventCount := len(s.Events) if eventCount == 0 { @@ -84,7 +82,9 @@ func (s *LoopOut) LastUpdate() *LoopOutEvent { return lastEvent } -func serializeLoopOutEvent(time time.Time, state SwapState) ( +// serializeLoopEvent serializes a state update of a swap. This is used for both +// in and out swaps. +func serializeLoopEvent(time time.Time, state SwapState) ( []byte, error) { var b bytes.Buffer @@ -100,8 +100,10 @@ func serializeLoopOutEvent(time time.Time, state SwapState) ( return b.Bytes(), nil } -func deserializeLoopOutEvent(value []byte) (*LoopOutEvent, error) { - update := &LoopOutEvent{} +// deserializeLoopEvent deserializes a state update of a swap. This is used for +// both in and out swaps. +func deserializeLoopEvent(value []byte) (*LoopEvent, error) { + update := &LoopEvent{} r := bytes.NewReader(value) diff --git a/loopdb/loopin.go b/loopdb/loopin.go new file mode 100644 index 0000000..13d45d6 --- /dev/null +++ b/loopdb/loopin.go @@ -0,0 +1,168 @@ +package loopdb + +import ( + "bytes" + "encoding/binary" + "fmt" + "time" +) + +// LoopInContract contains the data that is serialized to persistent storage for +// pending loop in swaps. +type LoopInContract struct { + SwapContract + + // SweepConfTarget specifies the targeted confirmation target for the + // client sweep tx. + HtlcConfTarget int32 + + // LoopInChannel is the channel to charge. If zero, any channel may + // be used. + LoopInChannel *uint64 +} + +// LoopIn is a combination of the contract and the updates. +type LoopIn struct { + Loop + + Contract *LoopInContract +} + +// LastUpdateTime returns the last update time of this swap. +func (s *LoopIn) LastUpdateTime() time.Time { + lastUpdate := s.LastUpdate() + if lastUpdate == nil { + return s.Contract.InitiationTime + } + + return lastUpdate.Time +} + +// serializeLoopInContract serialize the loop in contract into a byte slice. +func serializeLoopInContract(swap *LoopInContract) ( + []byte, error) { + + var b bytes.Buffer + + if err := binary.Write(&b, byteOrder, swap.InitiationTime.UnixNano()); err != nil { + return nil, err + } + + if err := binary.Write(&b, byteOrder, swap.Preimage); err != nil { + return nil, err + } + + if err := binary.Write(&b, byteOrder, swap.AmountRequested); err != nil { + return nil, err + } + + n, err := b.Write(swap.SenderKey[:]) + if err != nil { + return nil, err + } + if n != keyLength { + return nil, fmt.Errorf("sender key has invalid length") + } + + n, err = b.Write(swap.ReceiverKey[:]) + if err != nil { + return nil, err + } + if n != keyLength { + return nil, fmt.Errorf("receiver key has invalid length") + } + + if err := binary.Write(&b, byteOrder, swap.CltvExpiry); err != nil { + return nil, err + } + + if err := binary.Write(&b, byteOrder, swap.MaxMinerFee); err != nil { + return nil, err + } + + if err := binary.Write(&b, byteOrder, swap.MaxSwapFee); err != nil { + return nil, err + } + + if err := binary.Write(&b, byteOrder, swap.InitiationHeight); err != nil { + return nil, err + } + + if err := binary.Write(&b, byteOrder, swap.HtlcConfTarget); err != nil { + return nil, err + } + + var chargeChannel uint64 + if swap.LoopInChannel != nil { + chargeChannel = *swap.LoopInChannel + } + if err := binary.Write(&b, byteOrder, chargeChannel); err != nil { + return nil, err + } + + return b.Bytes(), nil +} + +// deserializeLoopInContract deserializes the loop in contract from a byte slice. +func deserializeLoopInContract(value []byte) (*LoopInContract, error) { + r := bytes.NewReader(value) + + contract := LoopInContract{} + var err error + var unixNano int64 + if err := binary.Read(r, byteOrder, &unixNano); err != nil { + return nil, err + } + contract.InitiationTime = time.Unix(0, unixNano) + + if err := binary.Read(r, byteOrder, &contract.Preimage); err != nil { + return nil, err + } + + binary.Read(r, byteOrder, &contract.AmountRequested) + + n, err := r.Read(contract.SenderKey[:]) + if err != nil { + return nil, err + } + if n != keyLength { + return nil, fmt.Errorf("sender key has invalid length") + } + + n, err = r.Read(contract.ReceiverKey[:]) + if err != nil { + return nil, err + } + if n != keyLength { + return nil, fmt.Errorf("receiver key has invalid length") + } + + if err := binary.Read(r, byteOrder, &contract.CltvExpiry); err != nil { + return nil, err + } + if err := binary.Read(r, byteOrder, &contract.MaxMinerFee); err != nil { + return nil, err + } + + if err := binary.Read(r, byteOrder, &contract.MaxSwapFee); err != nil { + return nil, err + } + + if err := binary.Read(r, byteOrder, &contract.InitiationHeight); err != nil { + return nil, err + } + + if err := binary.Read(r, byteOrder, &contract.HtlcConfTarget); err != nil { + return nil, err + } + + var loopInChannel uint64 + if err := binary.Read(r, byteOrder, &loopInChannel); err != nil { + return nil, err + } + if loopInChannel != 0 { + contract.LoopInChannel = &loopInChannel + } + + return &contract, nil +} diff --git a/loopdb/loopout.go b/loopdb/loopout.go index 31aac63..0176ec9 100644 --- a/loopdb/loopout.go +++ b/loopdb/loopout.go @@ -4,13 +4,11 @@ import ( "bytes" "encoding/binary" "fmt" - "io" "time" "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" - "github.com/lightningnetwork/lnd/lntypes" ) // LoopOutContract contains the data that is serialized to persistent storage @@ -39,20 +37,24 @@ type LoopOutContract struct { // TargetChannel is the channel to loop out. If zero, any channel may // be used. UnchargeChannel *uint64 + + // PrepayInvoice is the invoice that the client should pay to the + // server that will be returned if the swap is complete. + PrepayInvoice string + + // MaxPrepayRoutingFee is the maximum off-chain fee in msat that may be + // paid for the prepayment to the server. + MaxPrepayRoutingFee btcutil.Amount } // LoopOut is a combination of the contract and the updates. type LoopOut struct { - // Hash is the hash that uniquely identifies this swap. - Hash lntypes.Hash + Loop // Contract is the active contract for this swap. It describes the // precise details of the swap including the final fee, CLTV value, // etc. Contract *LoopOutContract - - // Events are each of the state transitions that this swap underwent. - Events []*LoopOutEvent } // LastUpdateTime returns the last update time of this swap. @@ -70,104 +72,114 @@ func deserializeLoopOutContract(value []byte, chainParams *chaincfg.Params) ( r := bytes.NewReader(value) - contract, err := deserializeContract(r) - if err != nil { + contract := LoopOutContract{} + var err error + var unixNano int64 + if err := binary.Read(r, byteOrder, &unixNano); err != nil { return nil, err } + contract.InitiationTime = time.Unix(0, unixNano) - swap := LoopOutContract{ - SwapContract: *contract, + if err := binary.Read(r, byteOrder, &contract.Preimage); err != nil { + return nil, err } - addr, err := wire.ReadVarString(r, 0) + binary.Read(r, byteOrder, &contract.AmountRequested) + + contract.PrepayInvoice, err = wire.ReadVarString(r, 0) if err != nil { return nil, err } - swap.DestAddr, err = btcutil.DecodeAddress(addr, chainParams) + + n, err := r.Read(contract.SenderKey[:]) if err != nil { return nil, err } + if n != keyLength { + return nil, fmt.Errorf("sender key has invalid length") + } - swap.SwapInvoice, err = wire.ReadVarString(r, 0) + n, err = r.Read(contract.ReceiverKey[:]) if err != nil { return nil, err } + if n != keyLength { + return nil, fmt.Errorf("receiver key has invalid length") + } - if err := binary.Read(r, byteOrder, &swap.SweepConfTarget); err != nil { + if err := binary.Read(r, byteOrder, &contract.CltvExpiry); err != nil { + return nil, err + } + if err := binary.Read(r, byteOrder, &contract.MaxMinerFee); err != nil { return nil, err } - if err := binary.Read(r, byteOrder, &swap.MaxSwapRoutingFee); err != nil { + if err := binary.Read(r, byteOrder, &contract.MaxSwapFee); err != nil { return nil, err } - var unchargeChannel uint64 - if err := binary.Read(r, byteOrder, &unchargeChannel); err != nil { + if err := binary.Read(r, byteOrder, &contract.MaxPrepayRoutingFee); err != nil { return nil, err } - if unchargeChannel != 0 { - swap.UnchargeChannel = &unchargeChannel + if err := binary.Read(r, byteOrder, &contract.InitiationHeight); err != nil { + return nil, err } - return &swap, nil -} - -func serializeLoopOutContract(swap *LoopOutContract) ( - []byte, error) { - - var b bytes.Buffer - - serializeContract(&swap.SwapContract, &b) - - addr := swap.DestAddr.String() - if err := wire.WriteVarString(&b, 0, addr); err != nil { + addr, err := wire.ReadVarString(r, 0) + if err != nil { + return nil, err + } + contract.DestAddr, err = btcutil.DecodeAddress(addr, chainParams) + if err != nil { return nil, err } - if err := wire.WriteVarString(&b, 0, swap.SwapInvoice); err != nil { + contract.SwapInvoice, err = wire.ReadVarString(r, 0) + if err != nil { return nil, err } - if err := binary.Write(&b, byteOrder, swap.SweepConfTarget); err != nil { + if err := binary.Read(r, byteOrder, &contract.SweepConfTarget); err != nil { return nil, err } - if err := binary.Write(&b, byteOrder, swap.MaxSwapRoutingFee); err != nil { + if err := binary.Read(r, byteOrder, &contract.MaxSwapRoutingFee); err != nil { return nil, err } var unchargeChannel uint64 - if swap.UnchargeChannel != nil { - unchargeChannel = *swap.UnchargeChannel - } - if err := binary.Write(&b, byteOrder, unchargeChannel); err != nil { + if err := binary.Read(r, byteOrder, &unchargeChannel); err != nil { return nil, err } + if unchargeChannel != 0 { + contract.UnchargeChannel = &unchargeChannel + } - return b.Bytes(), nil + return &contract, nil } -func deserializeContract(r io.Reader) (*SwapContract, error) { - swap := SwapContract{} - var err error - var unixNano int64 - if err := binary.Read(r, byteOrder, &unixNano); err != nil { +func serializeLoopOutContract(swap *LoopOutContract) ( + []byte, error) { + + var b bytes.Buffer + + if err := binary.Write(&b, byteOrder, swap.InitiationTime.UnixNano()); err != nil { return nil, err } - swap.InitiationTime = time.Unix(0, unixNano) - if err := binary.Read(r, byteOrder, &swap.Preimage); err != nil { + if err := binary.Write(&b, byteOrder, swap.Preimage); err != nil { return nil, err } - binary.Read(r, byteOrder, &swap.AmountRequested) + if err := binary.Write(&b, byteOrder, swap.AmountRequested); err != nil { + return nil, err + } - swap.PrepayInvoice, err = wire.ReadVarString(r, 0) - if err != nil { + if err := wire.WriteVarString(&b, 0, swap.PrepayInvoice); err != nil { return nil, err } - n, err := r.Read(swap.SenderKey[:]) + n, err := b.Write(swap.SenderKey[:]) if err != nil { return nil, err } @@ -175,7 +187,7 @@ func deserializeContract(r io.Reader) (*SwapContract, error) { return nil, fmt.Errorf("sender key has invalid length") } - n, err = r.Read(swap.ReceiverKey[:]) + n, err = b.Write(swap.ReceiverKey[:]) if err != nil { return nil, err } @@ -183,79 +195,50 @@ func deserializeContract(r io.Reader) (*SwapContract, error) { return nil, fmt.Errorf("receiver key has invalid length") } - if err := binary.Read(r, byteOrder, &swap.CltvExpiry); err != nil { - return nil, err - } - if err := binary.Read(r, byteOrder, &swap.MaxMinerFee); err != nil { + if err := binary.Write(&b, byteOrder, swap.CltvExpiry); err != nil { return nil, err } - if err := binary.Read(r, byteOrder, &swap.MaxSwapFee); err != nil { + if err := binary.Write(&b, byteOrder, swap.MaxMinerFee); err != nil { return nil, err } - if err := binary.Read(r, byteOrder, &swap.MaxPrepayRoutingFee); err != nil { - return nil, err - } - if err := binary.Read(r, byteOrder, &swap.InitiationHeight); err != nil { + if err := binary.Write(&b, byteOrder, swap.MaxSwapFee); err != nil { return nil, err } - return &swap, nil -} - -func serializeContract(swap *SwapContract, b *bytes.Buffer) error { - if err := binary.Write(b, byteOrder, swap.InitiationTime.UnixNano()); err != nil { - return err - } - - if err := binary.Write(b, byteOrder, swap.Preimage); err != nil { - return err - } - - if err := binary.Write(b, byteOrder, swap.AmountRequested); err != nil { - return err - } - - if err := wire.WriteVarString(b, 0, swap.PrepayInvoice); err != nil { - return err + if err := binary.Write(&b, byteOrder, swap.MaxPrepayRoutingFee); err != nil { + return nil, err } - n, err := b.Write(swap.SenderKey[:]) - if err != nil { - return err - } - if n != keyLength { - return fmt.Errorf("sender key has invalid length") + if err := binary.Write(&b, byteOrder, swap.InitiationHeight); err != nil { + return nil, err } - n, err = b.Write(swap.ReceiverKey[:]) - if err != nil { - return err - } - if n != keyLength { - return fmt.Errorf("receiver key has invalid length") + addr := swap.DestAddr.String() + if err := wire.WriteVarString(&b, 0, addr); err != nil { + return nil, err } - if err := binary.Write(b, byteOrder, swap.CltvExpiry); err != nil { - return err + if err := wire.WriteVarString(&b, 0, swap.SwapInvoice); err != nil { + return nil, err } - if err := binary.Write(b, byteOrder, swap.MaxMinerFee); err != nil { - return err + if err := binary.Write(&b, byteOrder, swap.SweepConfTarget); err != nil { + return nil, err } - if err := binary.Write(b, byteOrder, swap.MaxSwapFee); err != nil { - return err + if err := binary.Write(&b, byteOrder, swap.MaxSwapRoutingFee); err != nil { + return nil, err } - if err := binary.Write(b, byteOrder, swap.MaxPrepayRoutingFee); err != nil { - return err + var unchargeChannel uint64 + if swap.UnchargeChannel != nil { + unchargeChannel = *swap.UnchargeChannel } - - if err := binary.Write(b, byteOrder, swap.InitiationHeight); err != nil { - return err + if err := binary.Write(&b, byteOrder, unchargeChannel); err != nil { + return nil, err } - return nil + return b.Bytes(), nil } diff --git a/loopdb/store.go b/loopdb/store.go index b87f6cb..463395f 100644 --- a/loopdb/store.go +++ b/loopdb/store.go @@ -18,19 +18,27 @@ var ( // database. dbFileName = "loop.db" - // unchargeSwapsBucketKey is a bucket that contains all swaps that are - // currently pending or completed. This bucket is keyed by the - // swaphash, and leads to a nested sub-bucket that houses information - // for that swap. + // loopOutBucketKey is a bucket that contains all out swaps that are + // currently pending or completed. This bucket is keyed by the swaphash, + // and leads to a nested sub-bucket that houses information for that + // swap. // // maps: swapHash -> swapBucket - unchargeSwapsBucketKey = []byte("uncharge-swaps") + loopOutBucketKey = []byte("uncharge-swaps") - // unchargeUpdatesBucketKey is a bucket that contains all updates - // pertaining to a swap. This is a sub-bucket of the swap bucket for a - // particular swap. This list only ever grows. + // loopInBucketKey is a bucket that contains all in swaps that are + // currently pending or completed. This bucket is keyed by the swaphash, + // and leads to a nested sub-bucket that houses information for that + // swap. // - // path: unchargeUpdatesBucket -> swapBucket[hash] -> updateBucket + // maps: swapHash -> swapBucket + loopInBucketKey = []byte("loop-in") + + // updatesBucketKey is a bucket that contains all updates pertaining to + // a swap. This is a sub-bucket of the swap bucket for a particular + // swap. This list only ever grows. + // + // path: loopInBucket/loopOutBucket -> swapBucket[hash] -> updatesBucket // // maps: updateNumber -> time || state updatesBucketKey = []byte("updates") @@ -38,7 +46,7 @@ var ( // contractKey is the key that stores the serialized swap contract. It // is nested within the sub-bucket for each active swap. // - // path: unchargeUpdatesBucket -> swapBucket[hash] + // path: loopInBucket/loopOutBucket -> swapBucket[hash] -> contractKey // // value: time || rawSwapState contractKey = []byte("contract") @@ -92,11 +100,11 @@ func NewBoltSwapStore(dbPath string, chainParams *chaincfg.Params) ( // We'll create all the buckets we need if this is the first time we're // starting up. If they already exist, then these calls will be noops. err = bdb.Update(func(tx *bbolt.Tx) error { - _, err := tx.CreateBucketIfNotExists(unchargeSwapsBucketKey) + _, err := tx.CreateBucketIfNotExists(loopOutBucketKey) if err != nil { return err } - _, err = tx.CreateBucketIfNotExists(updatesBucketKey) + _, err = tx.CreateBucketIfNotExists(loopInBucketKey) if err != nil { return err } @@ -123,15 +131,12 @@ func NewBoltSwapStore(dbPath string, chainParams *chaincfg.Params) ( }, nil } -// FetchLoopOutSwaps returns all swaps currently in the store. -// -// NOTE: Part of the loopdb.SwapStore interface. -func (s *boltSwapStore) FetchLoopOutSwaps() ([]*LoopOut, error) { - var swaps []*LoopOut +func (s *boltSwapStore) fetchSwaps(bucketKey []byte, + callback func([]byte, Loop) error) error { - err := s.db.View(func(tx *bbolt.Tx) error { - // First, we'll grab our main loop out swap bucket key. - rootBucket := tx.Bucket(unchargeSwapsBucketKey) + return s.db.View(func(tx *bbolt.Tx) error { + // First, we'll grab our main loop in bucket key. + rootBucket := tx.Bucket(bucketKey) if rootBucket == nil { return errors.New("bucket does not exist") } @@ -159,12 +164,6 @@ func (s *boltSwapStore) FetchLoopOutSwaps() ([]*LoopOut, error) { if contractBytes == nil { return errors.New("contract not found") } - contract, err := deserializeLoopOutContract( - contractBytes, s.chainParams, - ) - if err != nil { - return err - } // Once we have the raw swap, we'll also need to decode // each of the past updates to the swap itself. @@ -175,9 +174,9 @@ func (s *boltSwapStore) FetchLoopOutSwaps() ([]*LoopOut, error) { // De serialize and collect each swap update into our // slice of swap events. - var updates []*LoopOutEvent - err = stateBucket.ForEach(func(k, v []byte) error { - event, err := deserializeLoopOutEvent(v) + var updates []*LoopEvent + err := stateBucket.ForEach(func(k, v []byte) error { + event, err := deserializeLoopEvent(v) if err != nil { return err } @@ -192,16 +191,39 @@ func (s *boltSwapStore) FetchLoopOutSwaps() ([]*LoopOut, error) { var hash lntypes.Hash copy(hash[:], swapHash) - swap := LoopOut{ - Contract: contract, - Hash: hash, - Events: updates, + loop := Loop{ + Hash: hash, + Events: updates, } - swaps = append(swaps, &swap) - return nil + return callback(contractBytes, loop) }) }) +} + +// FetchLoopOutSwaps returns all loop out swaps currently in the store. +// +// NOTE: Part of the loopdb.SwapStore interface. +func (s *boltSwapStore) FetchLoopOutSwaps() ([]*LoopOut, error) { + var swaps []*LoopOut + + err := s.fetchSwaps(loopOutBucketKey, + func(contractBytes []byte, loop Loop) error { + contract, err := deserializeLoopOutContract( + contractBytes, s.chainParams, + ) + if err != nil { + return err + } + + swaps = append(swaps, &LoopOut{ + Contract: contract, + Loop: loop, + }) + + return nil + }, + ) if err != nil { return nil, err } @@ -209,24 +231,47 @@ func (s *boltSwapStore) FetchLoopOutSwaps() ([]*LoopOut, error) { return swaps, nil } -// CreateLoopOut adds an initiated swap to the store. +// FetchLoopInSwaps returns all loop in swaps currently in the store. // // NOTE: Part of the loopdb.SwapStore interface. -func (s *boltSwapStore) CreateLoopOut(hash lntypes.Hash, - swap *LoopOutContract) error { +func (s *boltSwapStore) FetchLoopInSwaps() ([]*LoopIn, error) { + var swaps []*LoopIn - // If the hash doesn't match the pre-image, then this is an invalid - // swap so we'll bail out early. - if hash != swap.Preimage.Hash() { - return errors.New("hash and preimage do not match") + err := s.fetchSwaps(loopInBucketKey, + func(contractBytes []byte, loop Loop) error { + contract, err := deserializeLoopInContract( + contractBytes, + ) + if err != nil { + return err + } + + swaps = append(swaps, &LoopIn{ + Contract: contract, + Loop: loop, + }) + + return nil + }, + ) + if err != nil { + return nil, err } + return swaps, nil +} + +// createLoop creates a swap in the store. It requires that the contract is +// already serialized to be able to use this function for both in and out swaps. +func (s *boltSwapStore) createLoop(bucketKey []byte, hash lntypes.Hash, + contractBytes []byte) error { + // Otherwise, we'll create a new swap within the database. return s.db.Update(func(tx *bbolt.Tx) error { // First, we'll grab the root bucket that houses all of our // main swaps. rootBucket, err := tx.CreateBucketIfNotExists( - unchargeSwapsBucketKey, + bucketKey, ) if err != nil { return err @@ -235,8 +280,7 @@ func (s *boltSwapStore) CreateLoopOut(hash lntypes.Hash, // If the swap already exists, then we'll exit as we don't want // to override a swap. if rootBucket.Get(hash[:]) != nil { - return fmt.Errorf("swap %v already exists", - swap.Preimage) + return fmt.Errorf("swap %v already exists", hash) } // From the root bucket, we'll make a new sub swap bucket using @@ -246,15 +290,11 @@ func (s *boltSwapStore) CreateLoopOut(hash lntypes.Hash, return err } - // With out swap bucket created, we'll serialize and store the - // swap itself. - contract, err := serializeLoopOutContract(swap) + // With the swap bucket created, we'll store the swap itself. + err = swapBucket.Put(contractKey, contractBytes) if err != nil { return err } - if err := swapBucket.Put(contractKey, contract); err != nil { - return err - } // Finally, we'll create an empty updates bucket for this swap // to track any future updates to the swap itself. @@ -263,18 +303,56 @@ func (s *boltSwapStore) CreateLoopOut(hash lntypes.Hash, }) } -// UpdateLoopOut stores a swap updateLoopOut. This appends to the event log for -// a particular swap as it goes through the various stages in its lifetime. +// CreateLoopOut adds an initiated swap to the store. // // NOTE: Part of the loopdb.SwapStore interface. -func (s *boltSwapStore) UpdateLoopOut(hash lntypes.Hash, time time.Time, - state SwapState) error { +func (s *boltSwapStore) CreateLoopOut(hash lntypes.Hash, + swap *LoopOutContract) error { + + // If the hash doesn't match the pre-image, then this is an invalid + // swap so we'll bail out early. + if hash != swap.Preimage.Hash() { + return errors.New("hash and preimage do not match") + } + + contractBytes, err := serializeLoopOutContract(swap) + if err != nil { + return err + } + + return s.createLoop(loopOutBucketKey, hash, contractBytes) +} + +// CreateLoopIn adds an initiated swap to the store. +// +// NOTE: Part of the loopdb.SwapStore interface. +func (s *boltSwapStore) CreateLoopIn(hash lntypes.Hash, + swap *LoopInContract) error { + + // If the hash doesn't match the pre-image, then this is an invalid + // swap so we'll bail out early. + if hash != swap.Preimage.Hash() { + return errors.New("hash and preimage do not match") + } + + contractBytes, err := serializeLoopInContract(swap) + if err != nil { + return err + } + + return s.createLoop(loopInBucketKey, hash, contractBytes) +} + +// updateLoop saves a new swap state transition to the store. It takes in a +// bucket key so that this function can be used for both in and out swaps. +func (s *boltSwapStore) updateLoop(bucketKey []byte, hash lntypes.Hash, + time time.Time, state SwapState) error { return s.db.Update(func(tx *bbolt.Tx) error { // Starting from the root bucket, we'll traverse the bucket // hierarchy all the way down to the swap bucket, and the // update sub-bucket within that. - rootBucket := tx.Bucket(unchargeSwapsBucketKey) + rootBucket := tx.Bucket(bucketKey) if rootBucket == nil { return errors.New("bucket does not exist") } @@ -295,7 +373,7 @@ func (s *boltSwapStore) UpdateLoopOut(hash lntypes.Hash, time time.Time, } // With the ID obtained, we'll write out this new update value. - updateValue, err := serializeLoopOutEvent(time, state) + updateValue, err := serializeLoopEvent(time, state) if err != nil { return err } @@ -303,6 +381,26 @@ func (s *boltSwapStore) UpdateLoopOut(hash lntypes.Hash, time time.Time, }) } +// UpdateLoopOut stores a swap update. This appends to the event log for +// a particular swap as it goes through the various stages in its lifetime. +// +// NOTE: Part of the loopdb.SwapStore interface. +func (s *boltSwapStore) UpdateLoopOut(hash lntypes.Hash, time time.Time, + state SwapState) error { + + return s.updateLoop(loopOutBucketKey, hash, time, state) +} + +// UpdateLoopIn stores a swap update. This appends to the event log for +// a particular swap as it goes through the various stages in its lifetime. +// +// NOTE: Part of the loopdb.SwapStore interface. +func (s *boltSwapStore) UpdateLoopIn(hash lntypes.Hash, time time.Time, + state SwapState) error { + + return s.updateLoop(loopInBucketKey, hash, time, state) +} + // Close closes the underlying database. // // NOTE: Part of the loopdb.SwapStore interface. diff --git a/loopdb/store_test.go b/loopdb/store_test.go index 63026bc..4d6bbe2 100644 --- a/loopdb/store_test.go +++ b/loopdb/store_test.go @@ -34,9 +34,9 @@ var ( testTime = time.Date(2018, time.January, 9, 14, 00, 00, 0, time.UTC) ) -// TestBoltSwapStore tests all the basic functionality of the current bbolt +// TestLoopOutStore tests all the basic functionality of the current bbolt // swap store. -func TestBoltSwapStore(t *testing.T) { +func TestLoopOutStore(t *testing.T) { tempDirName, err := ioutil.TempDir("", "clientstore") if err != nil { t.Fatal(err) @@ -65,25 +65,27 @@ func TestBoltSwapStore(t *testing.T) { // database shortly. pendingSwap := LoopOutContract{ SwapContract: SwapContract{ - AmountRequested: 100, - Preimage: testPreimage, - CltvExpiry: 144, - SenderKey: senderKey, - PrepayInvoice: "prepayinvoice", - ReceiverKey: receiverKey, - MaxMinerFee: 10, - MaxSwapFee: 20, - MaxPrepayRoutingFee: 40, - InitiationHeight: 99, + AmountRequested: 100, + Preimage: testPreimage, + CltvExpiry: 144, + SenderKey: senderKey, + + ReceiverKey: receiverKey, + MaxMinerFee: 10, + MaxSwapFee: 20, + + InitiationHeight: 99, // Convert to/from unix to remove timezone, so that it // doesn't interfere with DeepEqual. InitiationTime: time.Unix(0, initiationTime.UnixNano()), }, - DestAddr: destAddr, - SwapInvoice: "swapinvoice", - MaxSwapRoutingFee: 30, - SweepConfTarget: 2, + MaxPrepayRoutingFee: 40, + PrepayInvoice: "prepayinvoice", + DestAddr: destAddr, + SwapInvoice: "swapinvoice", + MaxSwapRoutingFee: 30, + SweepConfTarget: 2, } // checkSwap is a test helper function that'll assert the state of a @@ -157,3 +159,124 @@ func TestBoltSwapStore(t *testing.T) { } checkSwap(StateFailInsufficientValue) } + +// TestLoopInStore tests all the basic functionality of the current bbolt +// swap store. +func TestLoopInStore(t *testing.T) { + tempDirName, err := ioutil.TempDir("", "clientstore") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDirName) + + store, err := NewBoltSwapStore(tempDirName, &chaincfg.MainNetParams) + if err != nil { + t.Fatal(err) + } + + // First, verify that an empty database has no active swaps. + swaps, err := store.FetchLoopInSwaps() + if err != nil { + t.Fatal(err) + } + if len(swaps) != 0 { + t.Fatal("expected empty store") + } + + hash := sha256.Sum256(testPreimage[:]) + initiationTime := time.Date(2018, 11, 1, 0, 0, 0, 0, time.UTC) + + // Next, we'll make a new pending swap that we'll insert into the + // database shortly. + loopInChannel := uint64(123) + + pendingSwap := LoopInContract{ + SwapContract: SwapContract{ + AmountRequested: 100, + Preimage: testPreimage, + CltvExpiry: 144, + SenderKey: senderKey, + ReceiverKey: receiverKey, + MaxMinerFee: 10, + MaxSwapFee: 20, + InitiationHeight: 99, + + // Convert to/from unix to remove timezone, so that it + // doesn't interfere with DeepEqual. + InitiationTime: time.Unix(0, initiationTime.UnixNano()), + }, + HtlcConfTarget: 2, + LoopInChannel: &loopInChannel, + } + + // checkSwap is a test helper function that'll assert the state of a + // swap. + checkSwap := func(expectedState SwapState) { + t.Helper() + + swaps, err := store.FetchLoopInSwaps() + if err != nil { + t.Fatal(err) + } + + if len(swaps) != 1 { + t.Fatal("expected pending swap in store") + } + + swap := swaps[0].Contract + if !reflect.DeepEqual(swap, &pendingSwap) { + t.Fatal("invalid pending swap data") + } + + if swaps[0].State() != expectedState { + t.Fatalf("expected state %v, but got %v", + expectedState, swaps[0].State(), + ) + } + } + + // If we create a new swap, then it should show up as being initialized + // right after. + if err := store.CreateLoopIn(hash, &pendingSwap); err != nil { + t.Fatal(err) + } + checkSwap(StateInitiated) + + // Trying to make the same swap again should result in an error. + if err := store.CreateLoopIn(hash, &pendingSwap); err == nil { + t.Fatal("expected error on storing duplicate") + } + checkSwap(StateInitiated) + + // Next, we'll update to the next state of the pre-image being + // revealed. The state should be reflected here again. + err = store.UpdateLoopIn( + hash, testTime, StatePreimageRevealed, + ) + if err != nil { + t.Fatal(err) + } + checkSwap(StatePreimageRevealed) + + // Next, we'll update to the final state to ensure that the state is + // properly updated. + err = store.UpdateLoopIn( + hash, testTime, StateFailInsufficientValue, + ) + if err != nil { + t.Fatal(err) + } + checkSwap(StateFailInsufficientValue) + + if err := store.Close(); err != nil { + t.Fatal(err) + } + + // If we re-open the same store, then the state of the current swap + // should be the same. + store, err = NewBoltSwapStore(tempDirName, &chaincfg.MainNetParams) + if err != nil { + t.Fatal(err) + } + checkSwap(StateFailInsufficientValue) +} diff --git a/loopdb/swapstate.go b/loopdb/swapstate.go index 5b27870..707d16b 100644 --- a/loopdb/swapstate.go +++ b/loopdb/swapstate.go @@ -1,6 +1,8 @@ package loopdb -// SwapState indicates the current state of a swap. +// SwapState indicates the current state of a swap. This enumeration is the +// union of loop in and loop out states. A single type is used for both swap +// types to be able to reduce code duplication that would otherwise be required. type SwapState uint8 const ( @@ -22,23 +24,24 @@ const ( // server pulled the off-chain htlc. StateSuccess = 2 - // StateFailOffchainPayments indicates that it wasn't possible to find routes - // for one or both of the off-chain payments to the server that + // StateFailOffchainPayments indicates that it wasn't possible to find + // routes for one or both of the off-chain payments to the server that // satisfied the payment restrictions (fee and timelock limits). StateFailOffchainPayments = 3 - // StateFailTimeout indicates that the on-chain htlc wasn't confirmed before - // its expiry or confirmed too late (MinPreimageRevealDelta violated). + // StateFailTimeout indicates that the on-chain htlc wasn't confirmed + // before its expiry or confirmed too late (MinPreimageRevealDelta + // violated). StateFailTimeout = 4 - // StateFailSweepTimeout indicates that the on-chain htlc wasn't swept before - // the server revoked the htlc. The server didn't pull the off-chain - // htlc (even though it could have) and we timed out the off-chain htlc - // ourselves. No funds lost. + // StateFailSweepTimeout indicates that the on-chain htlc wasn't swept + // before the server revoked the htlc. The server didn't pull the + // off-chain htlc (even though it could have) and we timed out the + // off-chain htlc ourselves. No funds lost. StateFailSweepTimeout = 5 - // StateFailInsufficientValue indicates that the published on-chain htlc had - // a value lower than the requested amount. + // StateFailInsufficientValue indicates that the published on-chain htlc + // had a value lower than the requested amount. StateFailInsufficientValue = 6 // StateFailTemporary indicates that the swap cannot progress because @@ -90,6 +93,9 @@ func (s SwapState) String() string { case StatePreimageRevealed: return "PreimageRevealed" + case StateHtlcPublished: + return "HtlcPublished" + case StateSuccess: return "Success" diff --git a/loopout.go b/loopout.go index 065eaaf..7e52b5d 100644 --- a/loopout.go +++ b/loopout.go @@ -84,23 +84,23 @@ func newLoopOutSwap(globalCtx context.Context, cfg *swapConfig, initiationTime := time.Now() contract := loopdb.LoopOutContract{ - SwapInvoice: swapResp.swapInvoice, - DestAddr: request.DestAddr, - MaxSwapRoutingFee: request.MaxSwapRoutingFee, - SweepConfTarget: request.SweepConfTarget, - UnchargeChannel: request.LoopOutChannel, + SwapInvoice: swapResp.swapInvoice, + DestAddr: request.DestAddr, + MaxSwapRoutingFee: request.MaxSwapRoutingFee, + SweepConfTarget: request.SweepConfTarget, + UnchargeChannel: request.LoopOutChannel, + PrepayInvoice: swapResp.prepayInvoice, + MaxPrepayRoutingFee: request.MaxPrepayRoutingFee, SwapContract: loopdb.SwapContract{ - InitiationHeight: currentHeight, - InitiationTime: initiationTime, - PrepayInvoice: swapResp.prepayInvoice, - ReceiverKey: receiverKey, - SenderKey: swapResp.senderKey, - Preimage: swapPreimage, - AmountRequested: request.Amount, - CltvExpiry: swapResp.expiry, - MaxMinerFee: request.MaxMinerFee, - MaxSwapFee: request.MaxSwapFee, - MaxPrepayRoutingFee: request.MaxPrepayRoutingFee, + InitiationHeight: currentHeight, + InitiationTime: initiationTime, + ReceiverKey: receiverKey, + SenderKey: swapResp.senderKey, + Preimage: swapPreimage, + AmountRequested: request.Amount, + CltvExpiry: swapResp.expiry, + MaxMinerFee: request.MaxMinerFee, + MaxSwapFee: request.MaxSwapFee, }, } diff --git a/store_mock_test.go b/store_mock_test.go index fe1832d..9b33014 100644 --- a/store_mock_test.go +++ b/store_mock_test.go @@ -17,6 +17,11 @@ type storeMock struct { loopOutStoreChan chan loopdb.LoopOutContract loopOutUpdateChan chan loopdb.SwapState + loopInSwaps map[lntypes.Hash]*loopdb.LoopInContract + loopInUpdates map[lntypes.Hash][]loopdb.SwapState + loopInStoreChan chan loopdb.LoopInContract + loopInUpdateChan chan loopdb.SwapState + t *testing.T } @@ -33,7 +38,11 @@ func newStoreMock(t *testing.T) *storeMock { loopOutSwaps: make(map[lntypes.Hash]*loopdb.LoopOutContract), loopOutUpdates: make(map[lntypes.Hash][]loopdb.SwapState), - t: t, + loopInStoreChan: make(chan loopdb.LoopInContract, 1), + loopInUpdateChan: make(chan loopdb.SwapState, 1), + loopInSwaps: make(map[lntypes.Hash]*loopdb.LoopInContract), + loopInUpdates: make(map[lntypes.Hash][]loopdb.SwapState), + t: t, } } @@ -45,17 +54,19 @@ func (s *storeMock) FetchLoopOutSwaps() ([]*loopdb.LoopOut, error) { for hash, contract := range s.loopOutSwaps { updates := s.loopOutUpdates[hash] - events := make([]*loopdb.LoopOutEvent, len(updates)) + events := make([]*loopdb.LoopEvent, len(updates)) for i, u := range updates { - events[i] = &loopdb.LoopOutEvent{ + events[i] = &loopdb.LoopEvent{ State: u, } } swap := &loopdb.LoopOut{ - Hash: hash, + Loop: loopdb.Loop{ + Hash: hash, + Events: events, + }, Contract: contract, - Events: events, } result = append(result, swap) } @@ -81,6 +92,50 @@ func (s *storeMock) CreateLoopOut(hash lntypes.Hash, return nil } +// FetchLoopInSwaps returns all in swaps currently in the store. +func (s *storeMock) FetchLoopInSwaps() ([]*loopdb.LoopIn, error) { + result := []*loopdb.LoopIn{} + + for hash, contract := range s.loopInSwaps { + updates := s.loopInUpdates[hash] + events := make([]*loopdb.LoopEvent, len(updates)) + for i, u := range updates { + events[i] = &loopdb.LoopEvent{ + State: u, + } + } + + swap := &loopdb.LoopIn{ + Loop: loopdb.Loop{ + Hash: hash, + Events: events, + }, + Contract: contract, + } + result = append(result, swap) + } + + return result, nil +} + +// CreateLoopIn adds an initiated loop in swap to the store. +// +// NOTE: Part of the loopdb.SwapStore interface. +func (s *storeMock) CreateLoopIn(hash lntypes.Hash, + swap *loopdb.LoopInContract) error { + + _, ok := s.loopInSwaps[hash] + if ok { + return errors.New("swap already exists") + } + + s.loopInSwaps[hash] = swap + s.loopInUpdates[hash] = []loopdb.SwapState{} + s.loopInStoreChan <- *swap + + return nil +} + // UpdateLoopOut stores a new event for a target loop out swap. This appends to // the event log for a particular swap as it goes through the various stages in // its lifetime. @@ -101,6 +156,26 @@ func (s *storeMock) UpdateLoopOut(hash lntypes.Hash, time time.Time, return nil } +// UpdateLoopIn stores a new event for a target loop in swap. This appends to +// the event log for a particular swap as it goes through the various stages in +// its lifetime. +// +// NOTE: Part of the loopdb.SwapStore interface. +func (s *storeMock) UpdateLoopIn(hash lntypes.Hash, time time.Time, + state loopdb.SwapState) error { + + updates, ok := s.loopInUpdates[hash] + if !ok { + return errors.New("swap does not exists") + } + + updates = append(updates, state) + s.loopOutUpdates[hash] = updates + s.loopOutUpdateChan <- state + + return nil +} + func (s *storeMock) Close() error { return nil } @@ -130,6 +205,21 @@ func (s *storeMock) assertLoopOutStored() { } } +func (s *storeMock) assertLoopInStored() { + s.t.Helper() + + <-s.loopInStoreChan +} + +func (s *storeMock) assertLoopInState(expectedState loopdb.SwapState) { + s.t.Helper() + + state := <-s.loopOutUpdateChan + if state != expectedState { + s.t.Fatalf("unexpected state") + } +} + func (s *storeMock) assertStorePreimageReveal() { s.t.Helper()