diff --git a/client.go b/client.go index cf54815..24562dc 100644 --- a/client.go +++ b/client.go @@ -17,6 +17,7 @@ import ( "github.com/lightninglabs/loop/loopdb" "github.com/lightninglabs/loop/swap" "github.com/lightninglabs/loop/sweep" + "github.com/lightninglabs/loop/utils" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/routing/route" "google.golang.org/grpc" @@ -232,7 +233,7 @@ func (s *Client) FetchSwaps(ctx context.Context) ([]*SwapInfo, error) { LastUpdate: swp.LastUpdateTime(), } - htlc, err := GetHtlc( + htlc, err := utils.GetHtlc( swp.Hash, &swp.Contract.SwapContract, s.lndServices.ChainParams, ) @@ -265,7 +266,7 @@ func (s *Client) FetchSwaps(ctx context.Context) ([]*SwapInfo, error) { LastUpdate: swp.LastUpdateTime(), } - htlc, err := GetHtlc( + htlc, err := utils.GetHtlc( swp.Hash, &swp.Contract.SwapContract, s.lndServices.ChainParams, ) @@ -540,7 +541,7 @@ func (s *Client) getLoopOutSweepFee(ctx context.Context, confTarget int32) ( return 0, err } - scriptVersion := GetHtlcScriptVersion( + scriptVersion := utils.GetHtlcScriptVersion( loopdb.CurrentProtocolVersion(), ) @@ -731,7 +732,7 @@ func (s *Client) estimateFee(ctx context.Context, amt btcutil.Amount, // Generate a dummy address for fee estimation. witnessProg := [32]byte{} - scriptVersion := GetHtlcScriptVersion( + scriptVersion := utils.GetHtlcScriptVersion( loopdb.CurrentProtocolVersion(), ) diff --git a/loopd/view.go b/loopd/view.go index a6eaf13..aedd39c 100644 --- a/loopd/view.go +++ b/loopd/view.go @@ -8,6 +8,7 @@ import ( "github.com/lightninglabs/lndclient" "github.com/lightninglabs/loop" "github.com/lightninglabs/loop/loopdb" + "github.com/lightninglabs/loop/utils" ) // view prints all swaps currently in the database. @@ -56,7 +57,7 @@ func viewOut(swapClient *loop.Client, chainParams *chaincfg.Params) error { for _, s := range swaps { s := s - htlc, err := loop.GetHtlc( + htlc, err := utils.GetHtlc( s.Hash, &s.Contract.SwapContract, chainParams, ) if err != nil { @@ -107,7 +108,7 @@ func viewIn(swapClient *loop.Client, chainParams *chaincfg.Params) error { for _, s := range swaps { s := s - htlc, err := loop.GetHtlc( + htlc, err := utils.GetHtlc( s.Hash, &s.Contract.SwapContract, chainParams, ) if err != nil { diff --git a/loopin.go b/loopin.go index 5fb7908..f2e495e 100644 --- a/loopin.go +++ b/loopin.go @@ -18,6 +18,7 @@ import ( "github.com/lightninglabs/loop/labels" "github.com/lightninglabs/loop/loopdb" "github.com/lightninglabs/loop/swap" + "github.com/lightninglabs/loop/utils" "github.com/lightningnetwork/lnd/chainntnfs" invpkg "github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/keychain" @@ -442,7 +443,7 @@ func validateLoopInContract(height int32, response *newLoopInResponse) error { // initHtlcs creates and updates the native and nested segwit htlcs of the // loopInSwap. func (s *loopInSwap) initHtlcs() error { - htlc, err := GetHtlc( + htlc, err := utils.GetHtlc( s.hash, &s.SwapContract, s.swapKit.lnd.ChainParams, ) if err != nil { diff --git a/loopout.go b/loopout.go index f96265f..3bb5f5b 100644 --- a/loopout.go +++ b/loopout.go @@ -22,12 +22,12 @@ import ( "github.com/lightninglabs/loop/loopdb" "github.com/lightninglabs/loop/swap" "github.com/lightninglabs/loop/sweep" + "github.com/lightninglabs/loop/utils" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lntypes" - "github.com/lightningnetwork/lnd/zpay32" ) const ( @@ -207,7 +207,7 @@ func newLoopOutSwap(globalCtx context.Context, cfg *swapConfig, swapKit.lastUpdateTime = initiationTime // Create the htlc. - htlc, err := GetHtlc( + htlc, err := utils.GetHtlc( swapKit.hash, swapKit.contract, swapKit.lnd.ChainParams, ) if err != nil { @@ -220,7 +220,9 @@ func newLoopOutSwap(globalCtx context.Context, cfg *swapConfig, // Obtain the payment addr since we'll need it later for routing plugin // recommendation and possibly for cancel. - paymentAddr, err := obtainSwapPaymentAddr(contract.SwapInvoice, cfg) + paymentAddr, err := utils.ObtainSwapPaymentAddr( + contract.SwapInvoice, cfg.lnd.ChainParams, + ) if err != nil { return nil, err } @@ -263,7 +265,7 @@ func resumeLoopOutSwap(cfg *swapConfig, pend *loopdb.LoopOut, ) // Create the htlc. - htlc, err := GetHtlc( + htlc, err := utils.GetHtlc( swapKit.hash, swapKit.contract, swapKit.lnd.ChainParams, ) if err != nil { @@ -275,8 +277,8 @@ func resumeLoopOutSwap(cfg *swapConfig, pend *loopdb.LoopOut, // Obtain the payment addr since we'll need it later for routing plugin // recommendation and possibly for cancel. - paymentAddr, err := obtainSwapPaymentAddr( - pend.Contract.SwapInvoice, cfg, + paymentAddr, err := utils.ObtainSwapPaymentAddr( + pend.Contract.SwapInvoice, cfg.lnd.ChainParams, ) if err != nil { return nil, err @@ -302,24 +304,6 @@ func resumeLoopOutSwap(cfg *swapConfig, pend *loopdb.LoopOut, return swap, nil } -// obtainSwapPaymentAddr will retrieve the payment addr from the passed invoice. -func obtainSwapPaymentAddr(swapInvoice string, cfg *swapConfig) ( - *[32]byte, error) { - - swapPayReq, err := zpay32.Decode( - swapInvoice, cfg.lnd.ChainParams, - ) - if err != nil { - return nil, err - } - - if swapPayReq.PaymentAddr == nil { - return nil, fmt.Errorf("expected payment address for invoice") - } - - return swapPayReq.PaymentAddr, nil -} - // sendUpdate reports an update to the swap state. func (s *loopOutSwap) sendUpdate(ctx context.Context) error { info := s.swapInfo() diff --git a/swap.go b/swap.go index 61db839..bd5bea9 100644 --- a/swap.go +++ b/swap.go @@ -4,11 +4,10 @@ import ( "context" "time" - "github.com/btcsuite/btcd/chaincfg" "github.com/lightninglabs/lndclient" "github.com/lightninglabs/loop/loopdb" "github.com/lightninglabs/loop/swap" - "github.com/lightningnetwork/lnd/input" + "github.com/lightninglabs/loop/utils" "github.com/lightningnetwork/lnd/lntypes" ) @@ -50,59 +49,10 @@ func newSwapKit(hash lntypes.Hash, swapType swap.Type, cfg *swapConfig, } } -// GetHtlcScriptVersion returns the correct HTLC script version for the passed -// protocol version. -func GetHtlcScriptVersion( - protocolVersion loopdb.ProtocolVersion) swap.ScriptVersion { - - // If the swap was initiated before we had our v3 script, use v2. - if protocolVersion < loopdb.ProtocolVersionHtlcV3 || - protocolVersion == loopdb.ProtocolVersionUnrecorded { - - return swap.HtlcV2 - } - - return swap.HtlcV3 -} - // IsTaproot returns true if the swap referenced by the passed swap contract // uses the v3 (taproot) htlc. func IsTaprootSwap(swapContract *loopdb.SwapContract) bool { - return GetHtlcScriptVersion(swapContract.ProtocolVersion) == swap.HtlcV3 -} - -// GetHtlc composes and returns the on-chain swap script. -func GetHtlc(hash lntypes.Hash, contract *loopdb.SwapContract, - chainParams *chaincfg.Params) (*swap.Htlc, error) { - - switch GetHtlcScriptVersion(contract.ProtocolVersion) { - case swap.HtlcV2: - return swap.NewHtlcV2( - contract.CltvExpiry, contract.HtlcKeys.SenderScriptKey, - contract.HtlcKeys.ReceiverScriptKey, hash, - chainParams, - ) - - case swap.HtlcV3: - // Swaps that implement the new MuSig2 protocol will be expected - // to use the 1.0RC2 MuSig2 key derivation scheme. - muSig2Version := input.MuSig2Version040 - if contract.ProtocolVersion >= loopdb.ProtocolVersionMuSig2 { - muSig2Version = input.MuSig2Version100RC2 - } - - return swap.NewHtlcV3( - muSig2Version, - contract.CltvExpiry, - contract.HtlcKeys.SenderInternalPubKey, - contract.HtlcKeys.ReceiverInternalPubKey, - contract.HtlcKeys.SenderScriptKey, - contract.HtlcKeys.ReceiverScriptKey, - hash, chainParams, - ) - } - - return nil, swap.ErrInvalidScriptVersion + return utils.GetHtlcScriptVersion(swapContract.ProtocolVersion) == swap.HtlcV3 } // swapInfo constructs and returns a filled SwapInfo from diff --git a/utils/htlc_utils.go b/utils/htlc_utils.go new file mode 100644 index 0000000..3d5bd24 --- /dev/null +++ b/utils/htlc_utils.go @@ -0,0 +1,77 @@ +package utils + +import ( + "fmt" + + "github.com/btcsuite/btcd/chaincfg" + "github.com/lightninglabs/loop/loopdb" + "github.com/lightninglabs/loop/swap" + "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/zpay32" +) + +// GetHtlc composes and returns the on-chain swap script. +func GetHtlc(hash lntypes.Hash, contract *loopdb.SwapContract, + chainParams *chaincfg.Params) (*swap.Htlc, error) { + + switch GetHtlcScriptVersion(contract.ProtocolVersion) { + case swap.HtlcV2: + return swap.NewHtlcV2( + contract.CltvExpiry, contract.HtlcKeys.SenderScriptKey, + contract.HtlcKeys.ReceiverScriptKey, hash, + chainParams, + ) + + case swap.HtlcV3: + // Swaps that implement the new MuSig2 protocol will be expected + // to use the 1.0RC2 MuSig2 key derivation scheme. + muSig2Version := input.MuSig2Version040 + if contract.ProtocolVersion >= loopdb.ProtocolVersionMuSig2 { + muSig2Version = input.MuSig2Version100RC2 + } + + return swap.NewHtlcV3( + muSig2Version, + contract.CltvExpiry, + contract.HtlcKeys.SenderInternalPubKey, + contract.HtlcKeys.ReceiverInternalPubKey, + contract.HtlcKeys.SenderScriptKey, + contract.HtlcKeys.ReceiverScriptKey, + hash, chainParams, + ) + } + + return nil, swap.ErrInvalidScriptVersion +} + +// GetHtlcScriptVersion returns the correct HTLC script version for the passed +// protocol version. +func GetHtlcScriptVersion( + protocolVersion loopdb.ProtocolVersion) swap.ScriptVersion { + + // If the swap was initiated before we had our v3 script, use v2. + if protocolVersion < loopdb.ProtocolVersionHtlcV3 || + protocolVersion == loopdb.ProtocolVersionUnrecorded { + + return swap.HtlcV2 + } + + return swap.HtlcV3 +} + +// ObtainSwapPaymentAddr will retrieve the payment addr from the passed invoice. +func ObtainSwapPaymentAddr(swapInvoice string, chainParams *chaincfg.Params) ( + *[32]byte, error) { + + swapPayReq, err := zpay32.Decode(swapInvoice, chainParams) + if err != nil { + return nil, err + } + + if swapPayReq.PaymentAddr == nil { + return nil, fmt.Errorf("expected payment address for invoice") + } + + return swapPayReq.PaymentAddr, nil +}