diff --git a/loopdb/interface.go b/loopdb/interface.go index 41172da..6f85067 100644 --- a/loopdb/interface.go +++ b/loopdb/interface.go @@ -12,6 +12,9 @@ type SwapStore interface { // FetchLoopOutSwaps returns all swaps currently in the store. FetchLoopOutSwaps() ([]*LoopOut, error) + // FetchLoopOutSwap returns the loop out swap with the given hash. + FetchLoopOutSwap(hash lntypes.Hash) (*LoopOut, error) + // CreateLoopOut adds an initiated swap to the store. CreateLoopOut(hash lntypes.Hash, swap *LoopOutContract) error diff --git a/loopdb/store.go b/loopdb/store.go index 1343974..b86008d 100644 --- a/loopdb/store.go +++ b/loopdb/store.go @@ -255,111 +255,12 @@ func (s *boltSwapStore) FetchLoopOutSwaps() ([]*LoopOut, error) { return nil } - // From the root bucket, we'll grab the next swap - // bucket for this swap from its swaphash. - swapBucket := rootBucket.Bucket(swapHash) - if swapBucket == nil { - return fmt.Errorf("swap bucket %x not found", - swapHash) - } - - // With the main swap bucket obtained, we'll grab the - // raw swap contract bytes and decode it. - contractBytes := swapBucket.Get(contractKey) - if contractBytes == nil { - return errors.New("contract not found") - } - - contract, err := deserializeLoopOutContract( - contractBytes, s.chainParams, - ) + loop, err := s.fetchLoopOutSwap(rootBucket, swapHash) if err != nil { return err } - // Get our label for this swap, if it is present. - contract.Label = getLabel(swapBucket) - - // Read the list of concatenated outgoing channel ids - // that form the outgoing set. - setBytes := swapBucket.Get(outgoingChanSetKey) - if outgoingChanSetKey != nil { - r := bytes.NewReader(setBytes) - readLoop: - for { - var chanID uint64 - err := binary.Read(r, byteOrder, &chanID) - switch { - case err == io.EOF: - break readLoop - case err != nil: - return err - } - - contract.OutgoingChanSet = append( - contract.OutgoingChanSet, - chanID, - ) - } - } - - // Set our default number of confirmations for the swap. - contract.HtlcConfirmations = DefaultLoopOutHtlcConfirmations - - // If we have the number of confirmations stored for - // this swap, we overwrite our default with the stored - // value. - confBytes := swapBucket.Get(confirmationsKey) - if confBytes != nil { - r := bytes.NewReader(confBytes) - err := binary.Read( - r, byteOrder, &contract.HtlcConfirmations, - ) - if err != nil { - return err - } - } - - updates, err := deserializeUpdates(swapBucket) - if err != nil { - return err - } - - // Try to unmarshal the protocol version for the swap. - // If the protocol version is not stored (which is - // the case for old clients), we'll assume the - // ProtocolVersionUnrecorded instead. - contract.ProtocolVersion, err = - UnmarshalProtocolVersion( - swapBucket.Get(protocolVersionKey), - ) - if err != nil { - return err - } - - // Try to unmarshal the key locator. - if contract.ProtocolVersion >= ProtocolVersionHtlcV3 { - contract.ClientKeyLocator, err = UnmarshalKeyLocator( - swapBucket.Get(keyLocatorKey), - ) - if err != nil { - return err - } - } - - loop := LoopOut{ - Loop: Loop{ - Events: updates, - }, - Contract: contract, - } - - loop.Hash, err = lntypes.MakeHash(swapHash) - if err != nil { - return err - } - - swaps = append(swaps, &loop) + swaps = append(swaps, loop) return nil }) @@ -371,53 +272,33 @@ func (s *boltSwapStore) FetchLoopOutSwaps() ([]*LoopOut, error) { return swaps, nil } -// deserializeUpdates deserializes the list of swap updates that are stored as a -// key of the given bucket. -func deserializeUpdates(swapBucket *bbolt.Bucket) ([]*LoopEvent, error) { - // Once we have the raw swap, we'll also need to decode - // each of the past updates to the swap itself. - stateBucket := swapBucket.Bucket(updatesBucketKey) - if stateBucket == nil { - return nil, errors.New("updates bucket not found") - } - - // Deserialize and collect each swap update into our slice of swap - // events. - var updates []*LoopEvent - err := stateBucket.ForEach(func(k, v []byte) error { - updateBucket := stateBucket.Bucket(k) - if updateBucket == nil { - return fmt.Errorf("expected state sub-bucket for %x", k) - } +// FetchLoopOutSwap returns the loop out swap with the given hash. +// +// NOTE: Part of the loopdb.SwapStore interface. +func (s *boltSwapStore) FetchLoopOutSwap(hash lntypes.Hash) (*LoopOut, error) { + var swap *LoopOut - basicState := updateBucket.Get(basicStateKey) - if basicState == nil { - return errors.New("no basic state for update") + err := s.db.View(func(tx *bbolt.Tx) error { + // First, we'll grab our main loop out bucket key. + rootBucket := tx.Bucket(loopOutBucketKey) + if rootBucket == nil { + return errors.New("bucket does not exist") } - event, err := deserializeLoopEvent(basicState) + loop, err := s.fetchLoopOutSwap(rootBucket, hash[:]) if err != nil { return err } - // Deserialize htlc tx hash if this updates contains one. - htlcTxHashBytes := updateBucket.Get(htlcTxHashKey) - if htlcTxHashBytes != nil { - htlcTxHash, err := chainhash.NewHash(htlcTxHashBytes) - if err != nil { - return err - } - event.HtlcTxHash = htlcTxHash - } + swap = loop - updates = append(updates, event) return nil }) if err != nil { return nil, err } - return updates, nil + return swap, nil } // FetchLoopInSwaps returns all loop in swaps currently in the store. @@ -442,71 +323,12 @@ func (s *boltSwapStore) FetchLoopInSwaps() ([]*LoopIn, error) { return nil } - // From the root bucket, we'll grab the next swap - // bucket for this swap from its swaphash. - swapBucket := rootBucket.Bucket(swapHash) - if swapBucket == nil { - return fmt.Errorf("swap bucket %x not found", - swapHash) - } - - // With the main swap bucket obtained, we'll grab the - // raw swap contract bytes and decode it. - contractBytes := swapBucket.Get(contractKey) - if contractBytes == nil { - return errors.New("contract not found") - } - - contract, err := deserializeLoopInContract( - contractBytes, - ) + loop, err := s.fetchLoopInSwap(rootBucket, swapHash) if err != nil { return err } - // Get our label for this swap, if it is present. - contract.Label = getLabel(swapBucket) - - updates, err := deserializeUpdates(swapBucket) - if err != nil { - return err - } - - // Try to unmarshal the protocol version for the swap. - // If the protocol version is not stored (which is - // the case for old clients), we'll assume the - // ProtocolVersionUnrecorded instead. - contract.ProtocolVersion, err = - UnmarshalProtocolVersion( - swapBucket.Get(protocolVersionKey), - ) - if err != nil { - return err - } - - // Try to unmarshal the key locator. - if contract.ProtocolVersion >= ProtocolVersionHtlcV3 { - contract.ClientKeyLocator, err = UnmarshalKeyLocator( - swapBucket.Get(keyLocatorKey), - ) - if err != nil { - return err - } - } - - loop := LoopIn{ - Loop: Loop{ - Events: updates, - }, - Contract: contract, - } - - loop.Hash, err = lntypes.MakeHash(swapHash) - if err != nil { - return err - } - - swaps = append(swaps, &loop) + swaps = append(swaps, loop) return nil }) @@ -824,3 +646,243 @@ func (s *boltSwapStore) FetchLiquidityParams() ([]byte, error) { return params, err } + +// fetchUpdates deserializes the list of swap updates that are stored as a +// key of the given bucket. +func fetchUpdates(swapBucket *bbolt.Bucket) ([]*LoopEvent, error) { + // Once we have the raw swap, we'll also need to decode + // each of the past updates to the swap itself. + stateBucket := swapBucket.Bucket(updatesBucketKey) + if stateBucket == nil { + return nil, errors.New("updates bucket not found") + } + + // Deserialize and collect each swap update into our slice of swap + // events. + var updates []*LoopEvent + err := stateBucket.ForEach(func(k, v []byte) error { + updateBucket := stateBucket.Bucket(k) + if updateBucket == nil { + return fmt.Errorf("expected state sub-bucket for %x", k) + } + + basicState := updateBucket.Get(basicStateKey) + if basicState == nil { + return errors.New("no basic state for update") + } + + event, err := deserializeLoopEvent(basicState) + if err != nil { + return err + } + + // Deserialize htlc tx hash if this updates contains one. + htlcTxHashBytes := updateBucket.Get(htlcTxHashKey) + if htlcTxHashBytes != nil { + htlcTxHash, err := chainhash.NewHash(htlcTxHashBytes) + if err != nil { + return err + } + event.HtlcTxHash = htlcTxHash + } + + updates = append(updates, event) + return nil + }) + if err != nil { + return nil, err + } + + return updates, nil +} + +// fetchLoopOutSwap fetches and deserializes the raw swap bytes into a LoopOut +// struct. +func (s *boltSwapStore) fetchLoopOutSwap(rootBucket *bbolt.Bucket, + swapHash []byte) (*LoopOut, error) { + + // From the root bucket, we'll grab the next swap + // bucket for this swap from its swaphash. + swapBucket := rootBucket.Bucket(swapHash) + if swapBucket == nil { + return nil, fmt.Errorf("swap bucket %x not found", + swapHash) + } + + hash, err := lntypes.MakeHash(swapHash) + if err != nil { + return nil, err + } + + // With the main swap bucket obtained, we'll grab the + // raw swap contract bytes and decode it. + contractBytes := swapBucket.Get(contractKey) + if contractBytes == nil { + return nil, errors.New("contract not found") + } + + contract, err := deserializeLoopOutContract( + contractBytes, s.chainParams, + ) + if err != nil { + return nil, err + } + + // Get our label for this swap, if it is present. + contract.Label = getLabel(swapBucket) + + // Read the list of concatenated outgoing channel ids + // that form the outgoing set. + setBytes := swapBucket.Get(outgoingChanSetKey) + if outgoingChanSetKey != nil { + r := bytes.NewReader(setBytes) + readLoop: + for { + var chanID uint64 + err := binary.Read(r, byteOrder, &chanID) + switch { + case err == io.EOF: + break readLoop + case err != nil: + return nil, err + } + + contract.OutgoingChanSet = append( + contract.OutgoingChanSet, + chanID, + ) + } + } + + // Set our default number of confirmations for the swap. + contract.HtlcConfirmations = DefaultLoopOutHtlcConfirmations + + // If we have the number of confirmations stored for + // this swap, we overwrite our default with the stored + // value. + confBytes := swapBucket.Get(confirmationsKey) + if confBytes != nil { + r := bytes.NewReader(confBytes) + err := binary.Read( + r, byteOrder, &contract.HtlcConfirmations, + ) + if err != nil { + return nil, err + } + } + + updates, err := fetchUpdates(swapBucket) + if err != nil { + return nil, err + } + + // Try to unmarshal the protocol version for the swap. + // If the protocol version is not stored (which is + // the case for old clients), we'll assume the + // ProtocolVersionUnrecorded instead. + contract.ProtocolVersion, err = + UnmarshalProtocolVersion( + swapBucket.Get(protocolVersionKey), + ) + if err != nil { + return nil, err + } + + // Try to unmarshal the key locator. + if contract.ProtocolVersion >= ProtocolVersionHtlcV3 { + contract.ClientKeyLocator, err = UnmarshalKeyLocator( + swapBucket.Get(keyLocatorKey), + ) + if err != nil { + return nil, err + } + } + + loop := LoopOut{ + Loop: Loop{ + Events: updates, + }, + Contract: contract, + } + + loop.Hash, err = lntypes.MakeHash(hash[:]) + if err != nil { + return nil, err + } + + return &loop, nil +} + +// fetchLoopInSwap fetches and deserializes the raw swap bytes into a LoopIn +// struct. +func (s *boltSwapStore) fetchLoopInSwap(rootBucket *bbolt.Bucket, + swapHash []byte) (*LoopIn, error) { + + // From the root bucket, we'll grab the next swap + // bucket for this swap from its swaphash. + swapBucket := rootBucket.Bucket(swapHash) + if swapBucket == nil { + return nil, fmt.Errorf("swap bucket %x not found", + swapHash) + } + + hash, err := lntypes.MakeHash(swapHash) + if err != nil { + return nil, err + } + + // With the main swap bucket obtained, we'll grab the + // raw swap contract bytes and decode it. + contractBytes := swapBucket.Get(contractKey) + if contractBytes == nil { + return nil, errors.New("contract not found") + } + + contract, err := deserializeLoopInContract( + contractBytes, + ) + if err != nil { + return nil, err + } + + // Get our label for this swap, if it is present. + contract.Label = getLabel(swapBucket) + + updates, err := fetchUpdates(swapBucket) + if err != nil { + return nil, err + } + + // Try to unmarshal the protocol version for the swap. + // If the protocol version is not stored (which is + // the case for old clients), we'll assume the + // ProtocolVersionUnrecorded instead. + contract.ProtocolVersion, err = + UnmarshalProtocolVersion( + swapBucket.Get(protocolVersionKey), + ) + if err != nil { + return nil, err + } + + // Try to unmarshal the key locator. + if contract.ProtocolVersion >= ProtocolVersionHtlcV3 { + contract.ClientKeyLocator, err = UnmarshalKeyLocator( + swapBucket.Get(keyLocatorKey), + ) + if err != nil { + return nil, err + } + } + + loop := LoopIn{ + Loop: Loop{ + Events: updates, + }, + Contract: contract, + } + + loop.Hash = hash + + return &loop, nil +}