From 503c83c29f4e6a605a83ecfadb153a11021724f9 Mon Sep 17 00:00:00 2001 From: Joost Jager Date: Tue, 19 May 2020 09:29:39 +0200 Subject: [PATCH] loopdb: unroll shared fetch logic Split the fetch logic so that it is easier to add loop type-specific serialization. --- loopdb/store.go | 126 ++++++++++++++++++++++++++++++------------------ 1 file changed, 80 insertions(+), 46 deletions(-) diff --git a/loopdb/store.go b/loopdb/store.go index 9514236..168f03e 100644 --- a/loopdb/store.go +++ b/loopdb/store.go @@ -146,12 +146,15 @@ func NewBoltSwapStore(dbPath string, chainParams *chaincfg.Params) ( }, nil } -func (s *boltSwapStore) fetchSwaps(bucketKey []byte, - callback func([]byte, Loop) error) error { +// FetchLoopOutSwaps returns all loop out swaps currently in the store. +// +// NOTE: Part of the loopdb.SwapStore interface. +func (s *boltSwapStore) FetchLoopOutSwaps() ([]*LoopOut, error) { + var swaps []*LoopOut - return s.db.View(func(tx *bbolt.Tx) error { + err := s.db.View(func(tx *bbolt.Tx) error { // First, we'll grab our main loop in bucket key. - rootBucket := tx.Bucket(bucketKey) + rootBucket := tx.Bucket(loopOutBucketKey) if rootBucket == nil { return errors.New("bucket does not exist") } @@ -180,22 +183,40 @@ func (s *boltSwapStore) fetchSwaps(bucketKey []byte, return errors.New("contract not found") } + contract, err := deserializeLoopOutContract( + contractBytes, s.chainParams, + ) + if err != nil { + return err + } + updates, err := deserializeUpdates(swapBucket) if err != nil { return err } - var hash lntypes.Hash - copy(hash[:], swapHash) + loop := LoopOut{ + Loop: Loop{ + Events: updates, + }, + Contract: contract, + } - loop := Loop{ - Hash: hash, - Events: updates, + loop.Hash, err = lntypes.MakeHash(swapHash) + if err != nil { + return err } - return callback(contractBytes, loop) + swaps = append(swaps, &loop) + + return nil }) }) + if err != nil { + return nil, err + } + + return swaps, nil } // deserializeUpdates deserializes the list of swap updates that are stored as a @@ -227,44 +248,43 @@ func deserializeUpdates(swapBucket *bbolt.Bucket) ([]*LoopEvent, error) { return updates, nil } -// FetchLoopOutSwaps returns all loop out swaps currently in the store. +// FetchLoopInSwaps returns all loop in swaps currently in the store. // // NOTE: Part of the loopdb.SwapStore interface. -func (s *boltSwapStore) FetchLoopOutSwaps() ([]*LoopOut, error) { - var swaps []*LoopOut - - err := s.fetchSwaps(loopOutBucketKey, - func(contractBytes []byte, loop Loop) error { - contract, err := deserializeLoopOutContract( - contractBytes, s.chainParams, - ) - if err != nil { - return err - } +func (s *boltSwapStore) FetchLoopInSwaps() ([]*LoopIn, error) { + var swaps []*LoopIn - swaps = append(swaps, &LoopOut{ - Contract: contract, - Loop: loop, - }) + err := s.db.View(func(tx *bbolt.Tx) error { + // First, we'll grab our main loop in bucket key. + rootBucket := tx.Bucket(loopInBucketKey) + if rootBucket == nil { + return errors.New("bucket does not exist") + } - return nil - }, - ) - if err != nil { - return nil, err - } + // We'll now traverse the root bucket for all active swaps. The + // primary key is the swap hash itself. + return rootBucket.ForEach(func(swapHash, v []byte) error { + // Only go into things that we know are sub-bucket + // keys. + if v != nil { + return nil + } - return swaps, nil -} + // From the root bucket, we'll grab the next swap + // bucket for this swap from its swaphash. + swapBucket := rootBucket.Bucket(swapHash) + if swapBucket == nil { + return fmt.Errorf("swap bucket %x not found", + swapHash) + } -// FetchLoopInSwaps returns all loop in swaps currently in the store. -// -// NOTE: Part of the loopdb.SwapStore interface. -func (s *boltSwapStore) FetchLoopInSwaps() ([]*LoopIn, error) { - var swaps []*LoopIn + // With the main swap bucket obtained, we'll grab the + // raw swap contract bytes and decode it. + contractBytes := swapBucket.Get(contractKey) + if contractBytes == nil { + return errors.New("contract not found") + } - err := s.fetchSwaps(loopInBucketKey, - func(contractBytes []byte, loop Loop) error { contract, err := deserializeLoopInContract( contractBytes, ) @@ -272,14 +292,28 @@ func (s *boltSwapStore) FetchLoopInSwaps() ([]*LoopIn, error) { return err } - swaps = append(swaps, &LoopIn{ + updates, err := deserializeUpdates(swapBucket) + if err != nil { + return err + } + + loop := LoopIn{ + Loop: Loop{ + Events: updates, + }, Contract: contract, - Loop: loop, - }) + } + + loop.Hash, err = lntypes.MakeHash(swapHash) + if err != nil { + return err + } + + swaps = append(swaps, &loop) return nil - }, - ) + }) + }) if err != nil { return nil, err }