From 0795a894f5a5f7ea3955f400c8fe1e21bf3b3078 Mon Sep 17 00:00:00 2001 From: Joost Jager Date: Tue, 12 Mar 2019 16:09:57 +0100 Subject: [PATCH] loopdb: add loop in This commit adds the required code to persist loop in swaps. It also introduces the file loop.go to which shared code is moved. --- loopdb/interface.go | 11 ++ loopdb/loop.go | 241 +++++++++++++++++++++++++++++++++++++++++++ loopdb/loopin.go | 89 ++++++++++++++++ loopdb/loopout.go | 231 +---------------------------------------- loopdb/store.go | 190 +++++++++++++++++++++++++--------- loopdb/store_test.go | 124 +++++++++++++++++++++- loopdb/swapstate.go | 3 + 7 files changed, 608 insertions(+), 281 deletions(-) create mode 100644 loopdb/loop.go create mode 100644 loopdb/loopin.go diff --git a/loopdb/interface.go b/loopdb/interface.go index 5c3f3d4..72b92d0 100644 --- a/loopdb/interface.go +++ b/loopdb/interface.go @@ -20,6 +20,17 @@ type SwapStore interface { // the various stages in its lifetime. UpdateLoopOut(hash lntypes.Hash, time time.Time, state SwapState) error + // FetchLoopInSwaps returns all swaps currently in the store. + FetchLoopInSwaps() ([]*LoopIn, error) + + // CreateLoopIn adds an initiated swap to the store. + CreateLoopIn(hash lntypes.Hash, swap *LoopInContract) error + + // UpdateLoopIn stores a new event for a target loop in swap. This + // appends to the event log for a particular swap as it goes through + // the various stages in its lifetime. + UpdateLoopIn(hash lntypes.Hash, time time.Time, state SwapState) error + // Close closes the underlying database. Close() error } diff --git a/loopdb/loop.go b/loopdb/loop.go new file mode 100644 index 0000000..3afa7ea --- /dev/null +++ b/loopdb/loop.go @@ -0,0 +1,241 @@ +package loopdb + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "time" + + "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcutil" + "github.com/lightningnetwork/lnd/lntypes" +) + +// SwapContract contains the base data that is serialized to persistent storage +// for pending swaps. +type SwapContract struct { + // Preimage is the preimage for the swap. + Preimage lntypes.Preimage + + // AmountRequested is the total amount of the swap. + AmountRequested btcutil.Amount + + // PrepayInvoice is the invoice that the client should pay to the + // server that will be returned if the swap is complete. + PrepayInvoice string + + // SenderKey is the key of the sender that will be used in the on-chain + // HTLC. + SenderKey [33]byte + + // ReceiverKey is the of the receiver that will be used in the on-chain + // HTLC. + ReceiverKey [33]byte + + // CltvExpiry is the total absolute CLTV expiry of the swap. + CltvExpiry int32 + + // MaxPrepayRoutingFee is the maximum off-chain fee in msat that may be + // paid for the prepayment to the server. + MaxPrepayRoutingFee btcutil.Amount + + // MaxSwapFee is the maximum we are willing to pay the server for the + // swap. + MaxSwapFee btcutil.Amount + + // MaxMinerFee is the maximum in on-chain fees that we are willing to + // spend. + MaxMinerFee btcutil.Amount + + // InitiationHeight is the block height at which the swap was + // initiated. + InitiationHeight int32 + + // InitiationTime is the time at which the swap was initiated. + InitiationTime time.Time +} + +// Loop contains fields shared between LoopIn and LoopOut +type Loop struct { + Hash lntypes.Hash + Events []*LoopEvent +} + +// LoopEvent contains the dynamic data of a swap. +type LoopEvent struct { + // State is the new state for this swap as a result of this event. + State SwapState + + // Time is the time that this swap had its state changed. + Time time.Time +} + +// State returns the most recent state of this swap. +func (s *Loop) State() SwapState { + lastUpdate := s.LastUpdate() + if lastUpdate == nil { + return StateInitiated + } + + return lastUpdate.State +} + +// LastUpdate returns the most recent update of this swap. +func (s *Loop) LastUpdate() *LoopEvent { + eventCount := len(s.Events) + + if eventCount == 0 { + return nil + } + + lastEvent := s.Events[eventCount-1] + return lastEvent +} + +func deserializeContract(r io.Reader) (*SwapContract, error) { + swap := SwapContract{} + var err error + var unixNano int64 + if err := binary.Read(r, byteOrder, &unixNano); err != nil { + return nil, err + } + swap.InitiationTime = time.Unix(0, unixNano) + + if err := binary.Read(r, byteOrder, &swap.Preimage); err != nil { + return nil, err + } + + binary.Read(r, byteOrder, &swap.AmountRequested) + + swap.PrepayInvoice, err = wire.ReadVarString(r, 0) + if err != nil { + return nil, err + } + + n, err := r.Read(swap.SenderKey[:]) + if err != nil { + return nil, err + } + if n != keyLength { + return nil, fmt.Errorf("sender key has invalid length") + } + + n, err = r.Read(swap.ReceiverKey[:]) + if err != nil { + return nil, err + } + if n != keyLength { + return nil, fmt.Errorf("receiver key has invalid length") + } + + if err := binary.Read(r, byteOrder, &swap.CltvExpiry); err != nil { + return nil, err + } + if err := binary.Read(r, byteOrder, &swap.MaxMinerFee); err != nil { + return nil, err + } + + if err := binary.Read(r, byteOrder, &swap.MaxSwapFee); err != nil { + return nil, err + } + + if err := binary.Read(r, byteOrder, &swap.MaxPrepayRoutingFee); err != nil { + return nil, err + } + if err := binary.Read(r, byteOrder, &swap.InitiationHeight); err != nil { + return nil, err + } + + return &swap, nil +} + +func serializeContract(swap *SwapContract, b *bytes.Buffer) error { + if err := binary.Write(b, byteOrder, swap.InitiationTime.UnixNano()); err != nil { + return err + } + + if err := binary.Write(b, byteOrder, swap.Preimage); err != nil { + return err + } + + if err := binary.Write(b, byteOrder, swap.AmountRequested); err != nil { + return err + } + + if err := wire.WriteVarString(b, 0, swap.PrepayInvoice); err != nil { + return err + } + + n, err := b.Write(swap.SenderKey[:]) + if err != nil { + return err + } + if n != keyLength { + return fmt.Errorf("sender key has invalid length") + } + + n, err = b.Write(swap.ReceiverKey[:]) + if err != nil { + return err + } + if n != keyLength { + return fmt.Errorf("receiver key has invalid length") + } + + if err := binary.Write(b, byteOrder, swap.CltvExpiry); err != nil { + return err + } + + if err := binary.Write(b, byteOrder, swap.MaxMinerFee); err != nil { + return err + } + + if err := binary.Write(b, byteOrder, swap.MaxSwapFee); err != nil { + return err + } + + if err := binary.Write(b, byteOrder, swap.MaxPrepayRoutingFee); err != nil { + return err + } + + if err := binary.Write(b, byteOrder, swap.InitiationHeight); err != nil { + return err + } + + return nil +} + +func serializeLoopEvent(time time.Time, state SwapState) ( + []byte, error) { + + var b bytes.Buffer + + if err := binary.Write(&b, byteOrder, time.UnixNano()); err != nil { + return nil, err + } + + if err := binary.Write(&b, byteOrder, state); err != nil { + return nil, err + } + + return b.Bytes(), nil +} + +func deserializeLoopEvent(value []byte) (*LoopEvent, error) { + update := &LoopEvent{} + + r := bytes.NewReader(value) + + var unixNano int64 + if err := binary.Read(r, byteOrder, &unixNano); err != nil { + return nil, err + } + update.Time = time.Unix(0, unixNano) + + if err := binary.Read(r, byteOrder, &update.State); err != nil { + return nil, err + } + + return update, nil +} diff --git a/loopdb/loopin.go b/loopdb/loopin.go new file mode 100644 index 0000000..b63bf70 --- /dev/null +++ b/loopdb/loopin.go @@ -0,0 +1,89 @@ +package loopdb + +import ( + "bytes" + "encoding/binary" + "time" +) + +// LoopInContract contains the data that is serialized to persistent storage for +// pending loop in swaps. +type LoopInContract struct { + SwapContract + + // SweepConfTarget specifies the targeted confirmation target for the + // client sweep tx. + HtlcConfTarget int32 + + // LoopInChannel is the channel to charge. If zero, any channel may + // be used. + LoopInChannel *uint64 +} + +// LoopIn is a combination of the contract and the updates. +type LoopIn struct { + Loop + + Contract *LoopInContract +} + +// LastUpdateTime returns the last update time of this swap. +func (s *LoopIn) LastUpdateTime() time.Time { + lastUpdate := s.LastUpdate() + if lastUpdate == nil { + return s.Contract.InitiationTime + } + + return lastUpdate.Time +} + +// serializeLoopInContract serialize the loop in contract into a byte slice. +func serializeLoopInContract(swap *LoopInContract) ( + []byte, error) { + + var b bytes.Buffer + + serializeContract(&swap.SwapContract, &b) + + if err := binary.Write(&b, byteOrder, swap.HtlcConfTarget); err != nil { + return nil, err + } + + var chargeChannel uint64 + if swap.LoopInChannel != nil { + chargeChannel = *swap.LoopInChannel + } + if err := binary.Write(&b, byteOrder, chargeChannel); err != nil { + return nil, err + } + + return b.Bytes(), nil +} + +// deserializeLoopInContract deserializes the loop in contract from a byte slice. +func deserializeLoopInContract(value []byte) (*LoopInContract, error) { + r := bytes.NewReader(value) + + contract, err := deserializeContract(r) + if err != nil { + return nil, err + } + + swap := LoopInContract{ + SwapContract: *contract, + } + + if err := binary.Read(r, byteOrder, &swap.HtlcConfTarget); err != nil { + return nil, err + } + + var loopInChannel uint64 + if err := binary.Read(r, byteOrder, &loopInChannel); err != nil { + return nil, err + } + if loopInChannel != 0 { + swap.LoopInChannel = &loopInChannel + } + + return &swap, nil +} diff --git a/loopdb/loopout.go b/loopdb/loopout.go index acbf667..0ade89b 100644 --- a/loopdb/loopout.go +++ b/loopdb/loopout.go @@ -3,59 +3,12 @@ package loopdb import ( "bytes" "encoding/binary" - "fmt" - "io" "time" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" - "github.com/lightningnetwork/lnd/lntypes" ) -// SwapContract contains the base data that is serialized to persistent storage -// for pending swaps. -type SwapContract struct { - // Preimage is the preimage for the swap. - Preimage lntypes.Preimage - - // AmountRequested is the total amount of the swap. - AmountRequested btcutil.Amount - - // PrepayInvoice is the invoice that the client should pay to the - // server that will be returned if the swap is complete. - PrepayInvoice string - - // SenderKey is the key of the sender that will be used in the on-chain - // HTLC. - SenderKey [33]byte - - // ReceiverKey is the of the receiver that will be used in the on-chain - // HTLC. - ReceiverKey [33]byte - - // CltvExpiry is the total absolute CLTV expiry of the swap. - CltvExpiry int32 - - // MaxPrepayRoutingFee is the maximum off-chain fee in msat that may be - // paid for the prepayment to the server. - MaxPrepayRoutingFee btcutil.Amount - - // MaxSwapFee is the maximum we are willing to pay the server for the - // swap. - MaxSwapFee btcutil.Amount - - // MaxMinerFee is the maximum in on-chain fees that we are willing to - // spend. - MaxMinerFee btcutil.Amount - - // InitiationHeight is the block height at which the swap was - // initiated. - InitiationHeight int32 - - // InitiationTime is the time at which the swap was initiated. - InitiationTime time.Time -} - // LoopOutContract contains the data that is serialized to persistent storage // for pending swaps. type LoopOutContract struct { @@ -84,49 +37,14 @@ type LoopOutContract struct { UnchargeChannel *uint64 } -// LoopOutEvent contains the dynamic data of a swap. -type LoopOutEvent struct { - // State is the new state for this swap as a result of this event. - State SwapState - - // Time is the time that this swap had its state changed. - Time time.Time -} - // LoopOut is a combination of the contract and the updates. type LoopOut struct { - // Hash is the hash that uniquely identifies this swap. - Hash lntypes.Hash + Loop // Contract is the active contract for this swap. It describes the // precise details of the swap including the final fee, CLTV value, // etc. Contract *LoopOutContract - - // Events are each of the state transitions that this swap underwent. - Events []*LoopOutEvent -} - -// State returns the most recent state of this swap. -func (s *LoopOut) State() SwapState { - lastUpdate := s.LastUpdate() - if lastUpdate == nil { - return StateInitiated - } - - return lastUpdate.State -} - -// LastUpdate returns the most recent update of this swap. -func (s *LoopOut) LastUpdate() *LoopOutEvent { - eventCount := len(s.Events) - - if eventCount == 0 { - return nil - } - - lastEvent := s.Events[eventCount-1] - return lastEvent } // LastUpdateTime returns the last update time of this swap. @@ -218,150 +136,3 @@ func serializeLoopOutContract(swap *LoopOutContract) ( return b.Bytes(), nil } - -func deserializeContract(r io.Reader) (*SwapContract, error) { - swap := SwapContract{} - var err error - var unixNano int64 - if err := binary.Read(r, byteOrder, &unixNano); err != nil { - return nil, err - } - swap.InitiationTime = time.Unix(0, unixNano) - - if err := binary.Read(r, byteOrder, &swap.Preimage); err != nil { - return nil, err - } - - binary.Read(r, byteOrder, &swap.AmountRequested) - - swap.PrepayInvoice, err = wire.ReadVarString(r, 0) - if err != nil { - return nil, err - } - - n, err := r.Read(swap.SenderKey[:]) - if err != nil { - return nil, err - } - if n != keyLength { - return nil, fmt.Errorf("sender key has invalid length") - } - - n, err = r.Read(swap.ReceiverKey[:]) - if err != nil { - return nil, err - } - if n != keyLength { - return nil, fmt.Errorf("receiver key has invalid length") - } - - if err := binary.Read(r, byteOrder, &swap.CltvExpiry); err != nil { - return nil, err - } - if err := binary.Read(r, byteOrder, &swap.MaxMinerFee); err != nil { - return nil, err - } - - if err := binary.Read(r, byteOrder, &swap.MaxSwapFee); err != nil { - return nil, err - } - - if err := binary.Read(r, byteOrder, &swap.MaxPrepayRoutingFee); err != nil { - return nil, err - } - if err := binary.Read(r, byteOrder, &swap.InitiationHeight); err != nil { - return nil, err - } - - return &swap, nil -} - -func serializeContract(swap *SwapContract, b *bytes.Buffer) error { - if err := binary.Write(b, byteOrder, swap.InitiationTime.UnixNano()); err != nil { - return err - } - - if err := binary.Write(b, byteOrder, swap.Preimage); err != nil { - return err - } - - if err := binary.Write(b, byteOrder, swap.AmountRequested); err != nil { - return err - } - - if err := wire.WriteVarString(b, 0, swap.PrepayInvoice); err != nil { - return err - } - - n, err := b.Write(swap.SenderKey[:]) - if err != nil { - return err - } - if n != keyLength { - return fmt.Errorf("sender key has invalid length") - } - - n, err = b.Write(swap.ReceiverKey[:]) - if err != nil { - return err - } - if n != keyLength { - return fmt.Errorf("receiver key has invalid length") - } - - if err := binary.Write(b, byteOrder, swap.CltvExpiry); err != nil { - return err - } - - if err := binary.Write(b, byteOrder, swap.MaxMinerFee); err != nil { - return err - } - - if err := binary.Write(b, byteOrder, swap.MaxSwapFee); err != nil { - return err - } - - if err := binary.Write(b, byteOrder, swap.MaxPrepayRoutingFee); err != nil { - return err - } - - if err := binary.Write(b, byteOrder, swap.InitiationHeight); err != nil { - return err - } - - return nil -} - -func serializeLoopOutEvent(time time.Time, state SwapState) ( - []byte, error) { - - var b bytes.Buffer - - if err := binary.Write(&b, byteOrder, time.UnixNano()); err != nil { - return nil, err - } - - if err := binary.Write(&b, byteOrder, state); err != nil { - return nil, err - } - - return b.Bytes(), nil -} - -func deserializeLoopOutEvent(value []byte) (*LoopOutEvent, error) { - update := &LoopOutEvent{} - - r := bytes.NewReader(value) - - var unixNano int64 - if err := binary.Read(r, byteOrder, &unixNano); err != nil { - return nil, err - } - update.Time = time.Unix(0, unixNano) - - if err := binary.Read(r, byteOrder, &update.State); err != nil { - return nil, err - } - - return update, nil -} diff --git a/loopdb/store.go b/loopdb/store.go index 7feb588..01b33d3 100644 --- a/loopdb/store.go +++ b/loopdb/store.go @@ -17,13 +17,19 @@ var ( // database. dbFileName = "loop.db" - // unchargeSwapsBucketKey is a bucket that contains all swaps that are + // loopOutBucketKey is a bucket that contains all swaps that are // currently pending or completed. This bucket is keyed by the // swaphash, and leads to a nested sub-bucket that houses information // for that swap. // // maps: swapHash -> swapBucket - unchargeSwapsBucketKey = []byte("uncharge-swaps") + loopOutBucketKey = []byte("uncharge-swaps") + + // chargeSwapsBucketKey is a bucket that contains all swaps that are + // currently pending or completed. + // + // maps: swap_hash -> chargeContract + loopInBucketKey = []byte("loop-in") // unchargeUpdatesBucketKey is a bucket that contains all updates // pertaining to a swap. This is a sub-bucket of the swap bucket for a @@ -88,11 +94,11 @@ func NewBoltSwapStore(dbPath string) (*boltSwapStore, error) { // We'll create all the buckets we need if this is the first time we're // starting up. If they already exist, then these calls will be noops. err = bdb.Update(func(tx *bbolt.Tx) error { - _, err := tx.CreateBucketIfNotExists(unchargeSwapsBucketKey) + _, err := tx.CreateBucketIfNotExists(loopOutBucketKey) if err != nil { return err } - _, err = tx.CreateBucketIfNotExists(updatesBucketKey) + _, err = tx.CreateBucketIfNotExists(loopInBucketKey) if err != nil { return err } @@ -118,15 +124,12 @@ func NewBoltSwapStore(dbPath string) (*boltSwapStore, error) { }, nil } -// FetchLoopOutSwaps returns all swaps currently in the store. -// -// NOTE: Part of the loopdb.SwapStore interface. -func (s *boltSwapStore) FetchLoopOutSwaps() ([]*LoopOut, error) { - var swaps []*LoopOut +func (s *boltSwapStore) fetchSwaps(bucketKey []byte, + callback func([]byte, Loop) error) error { - err := s.db.View(func(tx *bbolt.Tx) error { - // First, we'll grab our main loop out swap bucket key. - rootBucket := tx.Bucket(unchargeSwapsBucketKey) + return s.db.View(func(tx *bbolt.Tx) error { + // First, we'll grab our main loop in bucket key. + rootBucket := tx.Bucket(bucketKey) if rootBucket == nil { return errors.New("bucket does not exist") } @@ -154,12 +157,6 @@ func (s *boltSwapStore) FetchLoopOutSwaps() ([]*LoopOut, error) { if contractBytes == nil { return errors.New("contract not found") } - contract, err := deserializeLoopOutContract( - contractBytes, - ) - if err != nil { - return err - } // Once we have the raw swap, we'll also need to decode // each of the past updates to the swap itself. @@ -170,9 +167,9 @@ func (s *boltSwapStore) FetchLoopOutSwaps() ([]*LoopOut, error) { // De serialize and collect each swap update into our // slice of swap events. - var updates []*LoopOutEvent - err = stateBucket.ForEach(func(k, v []byte) error { - event, err := deserializeLoopOutEvent(v) + var updates []*LoopEvent + err := stateBucket.ForEach(func(k, v []byte) error { + event, err := deserializeLoopEvent(v) if err != nil { return err } @@ -187,16 +184,39 @@ func (s *boltSwapStore) FetchLoopOutSwaps() ([]*LoopOut, error) { var hash lntypes.Hash copy(hash[:], swapHash) - swap := LoopOut{ - Contract: contract, - Hash: hash, - Events: updates, + loop := Loop{ + Hash: hash, + Events: updates, } - swaps = append(swaps, &swap) - return nil + return callback(contractBytes, loop) }) }) +} + +// FetchLoopOutSwaps returns all loop out swaps currently in the store. +// +// NOTE: Part of the loopdb.SwapStore interface. +func (s *boltSwapStore) FetchLoopOutSwaps() ([]*LoopOut, error) { + var swaps []*LoopOut + + err := s.fetchSwaps(loopOutBucketKey, + func(contractBytes []byte, loop Loop) error { + contract, err := deserializeLoopOutContract( + contractBytes, + ) + if err != nil { + return err + } + + swaps = append(swaps, &LoopOut{ + Contract: contract, + Loop: loop, + }) + + return nil + }, + ) if err != nil { return nil, err } @@ -204,24 +224,45 @@ func (s *boltSwapStore) FetchLoopOutSwaps() ([]*LoopOut, error) { return swaps, nil } -// CreateLoopOut adds an initiated swap to the store. +// FetchLoopInSwaps returns all loop in swaps currently in the store. // // NOTE: Part of the loopdb.SwapStore interface. -func (s *boltSwapStore) CreateLoopOut(hash lntypes.Hash, - swap *LoopOutContract) error { +func (s *boltSwapStore) FetchLoopInSwaps() ([]*LoopIn, error) { + var swaps []*LoopIn - // If the hash doesn't match the pre-image, then this is an invalid - // swap so we'll bail out early. - if hash != swap.Preimage.Hash() { - return errors.New("hash and preimage do not match") + err := s.fetchSwaps(loopInBucketKey, + func(contractBytes []byte, loop Loop) error { + contract, err := deserializeLoopInContract( + contractBytes, + ) + if err != nil { + return err + } + + swaps = append(swaps, &LoopIn{ + Contract: contract, + Loop: loop, + }) + + return nil + }, + ) + if err != nil { + return nil, err } + return swaps, nil +} + +func (s *boltSwapStore) createLoop(bucketKey []byte, hash lntypes.Hash, + contractBytes []byte) error { + // Otherwise, we'll create a new swap within the database. return s.db.Update(func(tx *bbolt.Tx) error { // First, we'll grab the root bucket that houses all of our // main swaps. rootBucket, err := tx.CreateBucketIfNotExists( - unchargeSwapsBucketKey, + bucketKey, ) if err != nil { return err @@ -230,8 +271,7 @@ func (s *boltSwapStore) CreateLoopOut(hash lntypes.Hash, // If the swap already exists, then we'll exit as we don't want // to override a swap. if rootBucket.Get(hash[:]) != nil { - return fmt.Errorf("swap %v already exists", - swap.Preimage) + return fmt.Errorf("swap %v already exists", hash) } // From the root bucket, we'll make a new sub swap bucket using @@ -241,15 +281,11 @@ func (s *boltSwapStore) CreateLoopOut(hash lntypes.Hash, return err } - // With out swap bucket created, we'll serialize and store the - // swap itself. - contract, err := serializeLoopOutContract(swap) + // With the swap bucket created, we'll store the swap itself. + err = swapBucket.Put(contractKey, contractBytes) if err != nil { return err } - if err := swapBucket.Put(contractKey, contract); err != nil { - return err - } // Finally, we'll create an empty updates bucket for this swap // to track any future updates to the swap itself. @@ -258,18 +294,54 @@ func (s *boltSwapStore) CreateLoopOut(hash lntypes.Hash, }) } -// UpdateLoopOut stores a swap updateLoopOut. This appends to the event log for -// a particular swap as it goes through the various stages in its lifetime. +// CreateLoopOut adds an initiated swap to the store. // // NOTE: Part of the loopdb.SwapStore interface. -func (s *boltSwapStore) UpdateLoopOut(hash lntypes.Hash, time time.Time, - state SwapState) error { +func (s *boltSwapStore) CreateLoopOut(hash lntypes.Hash, + swap *LoopOutContract) error { + + // If the hash doesn't match the pre-image, then this is an invalid + // swap so we'll bail out early. + if hash != swap.Preimage.Hash() { + return errors.New("hash and preimage do not match") + } + + contractBytes, err := serializeLoopOutContract(swap) + if err != nil { + return err + } + + return s.createLoop(loopOutBucketKey, hash, contractBytes) +} + +// CreateLoopIn adds an initiated swap to the store. +// +// NOTE: Part of the loopdb.SwapStore interface. +func (s *boltSwapStore) CreateLoopIn(hash lntypes.Hash, + swap *LoopInContract) error { + + // If the hash doesn't match the pre-image, then this is an invalid + // swap so we'll bail out early. + if hash != swap.Preimage.Hash() { + return errors.New("hash and preimage do not match") + } + + contractBytes, err := serializeLoopInContract(swap) + if err != nil { + return err + } + + return s.createLoop(loopInBucketKey, hash, contractBytes) +} + +func (s *boltSwapStore) updateLoop(bucketKey []byte, hash lntypes.Hash, + time time.Time, state SwapState) error { return s.db.Update(func(tx *bbolt.Tx) error { // Starting from the root bucket, we'll traverse the bucket // hierarchy all the way down to the swap bucket, and the // update sub-bucket within that. - rootBucket := tx.Bucket(unchargeSwapsBucketKey) + rootBucket := tx.Bucket(bucketKey) if rootBucket == nil { return errors.New("bucket does not exist") } @@ -290,7 +362,7 @@ func (s *boltSwapStore) UpdateLoopOut(hash lntypes.Hash, time time.Time, } // With the ID obtained, we'll write out this new update value. - updateValue, err := serializeLoopOutEvent(time, state) + updateValue, err := serializeLoopEvent(time, state) if err != nil { return err } @@ -298,6 +370,26 @@ func (s *boltSwapStore) UpdateLoopOut(hash lntypes.Hash, time time.Time, }) } +// UpdateLoopOut stores a swap update. This appends to the event log for +// a particular swap as it goes through the various stages in its lifetime. +// +// NOTE: Part of the loopdb.SwapStore interface. +func (s *boltSwapStore) UpdateLoopOut(hash lntypes.Hash, time time.Time, + state SwapState) error { + + return s.updateLoop(loopOutBucketKey, hash, time, state) +} + +// UpdateLoopIn stores a swap update. This appends to the event log for +// a particular swap as it goes through the various stages in its lifetime. +// +// NOTE: Part of the loopdb.SwapStore interface. +func (s *boltSwapStore) UpdateLoopIn(hash lntypes.Hash, time time.Time, + state SwapState) error { + + return s.updateLoop(loopInBucketKey, hash, time, state) +} + // Close closes the underlying database. // // NOTE: Part of the loopdb.SwapStore interface. diff --git a/loopdb/store_test.go b/loopdb/store_test.go index ca11fa9..fab5795 100644 --- a/loopdb/store_test.go +++ b/loopdb/store_test.go @@ -33,9 +33,9 @@ var ( testTime = time.Date(2018, time.January, 9, 14, 00, 00, 0, time.UTC) ) -// TestBoltSwapStore tests all the basic functionality of the current bbolt +// TestLoopOutStore tests all the basic functionality of the current bbolt // swap store. -func TestBoltSwapStore(t *testing.T) { +func TestLoopOutStore(t *testing.T) { tempDirName, err := ioutil.TempDir("", "clientstore") if err != nil { t.Fatal(err) @@ -156,3 +156,123 @@ func TestBoltSwapStore(t *testing.T) { } checkSwap(StateFailInsufficientValue) } + +// TestLoopInStore tests all the basic functionality of the current bbolt +// swap store. +func TestLoopInStore(t *testing.T) { + tempDirName, err := ioutil.TempDir("", "clientstore") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempDirName) + + store, err := NewBoltSwapStore(tempDirName) + if err != nil { + t.Fatal(err) + } + + // First, verify that an empty database has no active swaps. + swaps, err := store.FetchLoopInSwaps() + if err != nil { + t.Fatal(err) + } + if len(swaps) != 0 { + t.Fatal("expected empty store") + } + + hash := sha256.Sum256(testPreimage[:]) + initiationTime := time.Date(2018, 11, 1, 0, 0, 0, 0, time.UTC) + + // Next, we'll make a new pending swap that we'll insert into the + // database shortly. + pendingSwap := LoopInContract{ + SwapContract: SwapContract{ + AmountRequested: 100, + Preimage: testPreimage, + CltvExpiry: 144, + SenderKey: senderKey, + PrepayInvoice: "prepayinvoice", + ReceiverKey: receiverKey, + MaxMinerFee: 10, + MaxSwapFee: 20, + MaxPrepayRoutingFee: 40, + InitiationHeight: 99, + + // Convert to/from unix to remove timezone, so that it + // doesn't interfere with DeepEqual. + InitiationTime: time.Unix(0, initiationTime.UnixNano()), + }, + HtlcConfTarget: 2, + } + + // checkSwap is a test helper function that'll assert the state of a + // swap. + checkSwap := func(expectedState SwapState) { + t.Helper() + + swaps, err := store.FetchLoopInSwaps() + if err != nil { + t.Fatal(err) + } + + if len(swaps) != 1 { + t.Fatal("expected pending swap in store") + } + + swap := swaps[0].Contract + if !reflect.DeepEqual(swap, &pendingSwap) { + t.Fatal("invalid pending swap data") + } + + if swaps[0].State() != expectedState { + t.Fatalf("expected state %v, but got %v", + expectedState, swaps[0].State(), + ) + } + } + + // If we create a new swap, then it should show up as being initialized + // right after. + if err := store.CreateLoopIn(hash, &pendingSwap); err != nil { + t.Fatal(err) + } + checkSwap(StateInitiated) + + // Trying to make the same swap again should result in an error. + if err := store.CreateLoopIn(hash, &pendingSwap); err == nil { + t.Fatal("expected error on storing duplicate") + } + checkSwap(StateInitiated) + + // Next, we'll update to the next state of the pre-image being + // revealed. The state should be reflected here again. + err = store.UpdateLoopIn( + hash, testTime, StatePreimageRevealed, + ) + if err != nil { + t.Fatal(err) + } + checkSwap(StatePreimageRevealed) + + // Next, we'll update to the final state to ensure that the state is + // properly updated. + err = store.UpdateLoopIn( + hash, testTime, StateFailInsufficientValue, + ) + if err != nil { + t.Fatal(err) + } + checkSwap(StateFailInsufficientValue) + + if err := store.Close(); err != nil { + t.Fatal(err) + } + + // If we re-open the same store, then the state of the current swap + // should be the same. + store, err = NewBoltSwapStore(tempDirName) + if err != nil { + t.Fatal(err) + } + checkSwap(StateFailInsufficientValue) +} diff --git a/loopdb/swapstate.go b/loopdb/swapstate.go index 5b27870..8a76009 100644 --- a/loopdb/swapstate.go +++ b/loopdb/swapstate.go @@ -90,6 +90,9 @@ func (s SwapState) String() string { case StatePreimageRevealed: return "PreimageRevealed" + case StateHtlcPublished: + return "HtlcPublished" + case StateSuccess: return "Success"