Merge pull request #477 from bhandras/taproot-htlc

multi: changes to the taproot HTLC  required for both client and server
pull/506/head
András Bánki-Horváth 2 years ago committed by GitHub
commit 8f23c6789b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -195,7 +195,7 @@ func (s *Client) FetchSwaps() ([]*SwapInfo, error) {
htlc, err := swap.NewHtlc( htlc, err := swap.NewHtlc(
GetHtlcScriptVersion(swp.Contract.ProtocolVersion), GetHtlcScriptVersion(swp.Contract.ProtocolVersion),
swp.Contract.CltvExpiry, swp.Contract.SenderKey, swp.Contract.CltvExpiry, swp.Contract.SenderKey,
swp.Contract.ReceiverKey, nil, swp.Hash, swap.HtlcP2WSH, swp.Contract.ReceiverKey, swp.Hash, swap.HtlcP2WSH,
s.lndServices.ChainParams, s.lndServices.ChainParams,
) )
if err != nil { if err != nil {
@ -216,7 +216,7 @@ func (s *Client) FetchSwaps() ([]*SwapInfo, error) {
htlcNP2WSH, err := swap.NewHtlc( htlcNP2WSH, err := swap.NewHtlc(
GetHtlcScriptVersion(swp.Contract.ProtocolVersion), GetHtlcScriptVersion(swp.Contract.ProtocolVersion),
swp.Contract.CltvExpiry, swp.Contract.SenderKey, swp.Contract.CltvExpiry, swp.Contract.SenderKey,
swp.Contract.ReceiverKey, nil, swp.Hash, swap.HtlcNP2WSH, swp.Contract.ReceiverKey, swp.Hash, swap.HtlcNP2WSH,
s.lndServices.ChainParams, s.lndServices.ChainParams,
) )
if err != nil { if err != nil {
@ -226,7 +226,7 @@ func (s *Client) FetchSwaps() ([]*SwapInfo, error) {
htlcP2WSH, err := swap.NewHtlc( htlcP2WSH, err := swap.NewHtlc(
GetHtlcScriptVersion(swp.Contract.ProtocolVersion), GetHtlcScriptVersion(swp.Contract.ProtocolVersion),
swp.Contract.CltvExpiry, swp.Contract.SenderKey, swp.Contract.CltvExpiry, swp.Contract.SenderKey,
swp.Contract.ReceiverKey, nil, swp.Hash, swap.HtlcP2WSH, swp.Contract.ReceiverKey, swp.Hash, swap.HtlcP2WSH,
s.lndServices.ChainParams, s.lndServices.ChainParams,
) )
if err != nil { if err != nil {

@ -284,7 +284,7 @@ func testResume(t *testing.T, confs uint32, expired, preimageRevealed,
scriptVersion := GetHtlcScriptVersion(protocolVersion) scriptVersion := GetHtlcScriptVersion(protocolVersion)
htlc, err := swap.NewHtlc( htlc, err := swap.NewHtlc(
scriptVersion, pendingSwap.Contract.CltvExpiry, senderKey, scriptVersion, pendingSwap.Contract.CltvExpiry, senderKey,
receiverKey, nil, hash, swap.HtlcP2WSH, &chaincfg.TestNet3Params, receiverKey, hash, swap.HtlcP2WSH, &chaincfg.TestNet3Params,
) )
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, htlc.PkScript, confIntent.PkScript) require.Equal(t, htlc.PkScript, confIntent.PkScript)

@ -54,7 +54,7 @@ func viewOut(swapClient *loop.Client, chainParams *chaincfg.Params) error {
s.Contract.CltvExpiry, s.Contract.CltvExpiry,
s.Contract.SenderKey, s.Contract.SenderKey,
s.Contract.ReceiverKey, s.Contract.ReceiverKey,
nil, s.Hash, swap.HtlcP2WSH, chainParams, s.Hash, swap.HtlcP2WSH, chainParams,
) )
if err != nil { if err != nil {
return err return err
@ -106,7 +106,7 @@ func viewIn(swapClient *loop.Client, chainParams *chaincfg.Params) error {
s.Contract.CltvExpiry, s.Contract.CltvExpiry,
s.Contract.SenderKey, s.Contract.SenderKey,
s.Contract.ReceiverKey, s.Contract.ReceiverKey,
nil, s.Hash, swap.HtlcNP2WSH, chainParams, s.Hash, swap.HtlcNP2WSH, chainParams,
) )
if err != nil { if err != nil {
return err return err

@ -945,14 +945,18 @@ func (s *loopInSwap) publishTimeoutTx(ctx context.Context,
return 0, err return 0, err
} }
// Create a function that will assemble our timeout witness.
witnessFunc := func(sig []byte) (wire.TxWitness, error) { witnessFunc := func(sig []byte) (wire.TxWitness, error) {
return s.htlc.GenTimeoutWitness(sig) return s.htlc.GenTimeoutWitness(sig)
} }
// Retrieve the full script required to unlock the output.
redeemScript := s.htlc.TimeoutScript()
sequence := uint32(0) sequence := uint32(0)
timeoutTx, err := s.sweeper.CreateSweepTx( timeoutTx, err := s.sweeper.CreateSweepTx(
ctx, s.height, sequence, s.htlc, *htlcOutpoint, s.SenderKey, ctx, s.height, sequence, s.htlc, *htlcOutpoint, s.SenderKey,
witnessFunc, htlcValue, fee, s.timeoutAddr, redeemScript, witnessFunc, htlcValue, fee, s.timeoutAddr,
) )
if err != nil { if err != nil {
return 0, err return 0, err

@ -399,7 +399,7 @@ func testLoopInResume(t *testing.T, state loopdb.SwapState, expired bool,
htlc, err := swap.NewHtlc( htlc, err := swap.NewHtlc(
scriptVersion, contract.CltvExpiry, contract.SenderKey, scriptVersion, contract.CltvExpiry, contract.SenderKey,
contract.ReceiverKey, nil, testPreimage.Hash(), swap.HtlcNP2WSH, contract.ReceiverKey, testPreimage.Hash(), swap.HtlcNP2WSH,
cfg.lnd.ChainParams, cfg.lnd.ChainParams,
) )
if err != nil { if err != nil {

@ -1240,6 +1240,9 @@ func (s *loopOutSwap) sweep(ctx context.Context,
return s.htlc.GenSuccessWitness(sig, s.Preimage) return s.htlc.GenSuccessWitness(sig, s.Preimage)
} }
// Retrieve the full script required to unlock the output.
redeemScript := s.htlc.SuccessScript()
remainingBlocks := s.CltvExpiry - s.height remainingBlocks := s.CltvExpiry - s.height
blocksToLastReveal := remainingBlocks - MinLoopOutPreimageRevealDelta blocksToLastReveal := remainingBlocks - MinLoopOutPreimageRevealDelta
preimageRevealed := s.state == loopdb.StatePreimageRevealed preimageRevealed := s.state == loopdb.StatePreimageRevealed
@ -1296,7 +1299,8 @@ func (s *loopOutSwap) sweep(ctx context.Context,
// Create sweep tx. // Create sweep tx.
sweepTx, err := s.sweeper.CreateSweepTx( sweepTx, err := s.sweeper.CreateSweepTx(
ctx, s.height, s.htlc.SuccessSequence(), s.htlc, htlcOutpoint, ctx, s.height, s.htlc.SuccessSequence(), s.htlc, htlcOutpoint,
s.ReceiverKey, witnessFunc, htlcValue, fee, s.DestAddr, s.ReceiverKey, redeemScript, witnessFunc, htlcValue, fee,
s.DestAddr,
) )
if err != nil { if err != nil {
return err return err

@ -72,7 +72,7 @@ func (s *swapKit) getHtlc(outputType swap.HtlcOutputType) (*swap.Htlc, error) {
return swap.NewHtlc( return swap.NewHtlc(
GetHtlcScriptVersion(s.contract.ProtocolVersion), GetHtlcScriptVersion(s.contract.ProtocolVersion),
s.contract.CltvExpiry, s.contract.SenderKey, s.contract.CltvExpiry, s.contract.SenderKey,
s.contract.ReceiverKey, nil, s.hash, outputType, s.contract.ReceiverKey, s.hash, outputType,
s.swapConfig.lnd.ChainParams, s.swapConfig.lnd.ChainParams,
) )
} }

@ -6,10 +6,12 @@ import (
"errors" "errors"
"fmt" "fmt"
btcec "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/schnorr" "github.com/btcsuite/btcd/btcec/v2/schnorr"
"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
"github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
secp "github.com/decred/dcrd/dcrec/secp256k1/v4" secp "github.com/decred/dcrd/dcrec/secp256k1/v4"
@ -62,8 +64,10 @@ type HtlcScript interface {
// redeeming the htlc. // redeeming the htlc.
IsSuccessWitness(witness wire.TxWitness) bool IsSuccessWitness(witness wire.TxWitness) bool
// Script returns the htlc script. // lockingConditions return the address, pkScript and sigScript (if
Script() []byte // required) for a htlc script.
lockingConditions(HtlcOutputType, *chaincfg.Params) (btcutil.Address,
[]byte, []byte, error)
// MaxSuccessWitnessSize returns the maximum witness size for the // MaxSuccessWitnessSize returns the maximum witness size for the
// success case witness. // success case witness.
@ -73,9 +77,21 @@ type HtlcScript interface {
// timeout case witness. // timeout case witness.
MaxTimeoutWitnessSize() int MaxTimeoutWitnessSize() int
// TimeoutScript returns the redeem script required to unlock the htlc
// after timeout.
TimeoutScript() []byte
// SuccessScript returns the redeem script required to unlock the htlc
// using the preimage.
SuccessScript() []byte
// SuccessSequence returns the sequence to spend this htlc in the // SuccessSequence returns the sequence to spend this htlc in the
// success case. // success case.
SuccessSequence() uint32 SuccessSequence() uint32
// SigHash is the signature hash to use for transactions spending from
// the htlc.
SigHash() txscript.SigHashType
} }
// Htlc contains relevant htlc information from the receiver perspective. // Htlc contains relevant htlc information from the receiver perspective.
@ -101,7 +117,7 @@ var (
// script size. // script size.
QuoteHtlc, _ = NewHtlc( QuoteHtlc, _ = NewHtlc(
HtlcV2, HtlcV2,
^int32(0), quoteKey, quoteKey, nil, quoteHash, HtlcP2WSH, ^int32(0), quoteKey, quoteKey, quoteHash, HtlcP2WSH,
&chaincfg.MainNetParams, &chaincfg.MainNetParams,
) )
@ -114,17 +130,6 @@ var (
// selected for a v1 or v2 script. // selected for a v1 or v2 script.
ErrInvalidOutputSelected = fmt.Errorf("taproot output selected for " + ErrInvalidOutputSelected = fmt.Errorf("taproot output selected for " +
"non taproot htlc") "non taproot htlc")
// ErrSharedKeyNotNeeded is returned when a shared key is provided for
// either the v1 or v2 script. Shared key is only necessary for the v3
// script.
ErrSharedKeyNotNeeded = fmt.Errorf("shared key not supported for " +
"script version")
// ErrSharedKeyRequired is returned when a script version requires a
// shared key.
ErrSharedKeyRequired = fmt.Errorf("shared key required for script " +
"version")
) )
// String returns the string value of HtlcOutputType. // String returns the string value of HtlcOutputType.
@ -147,9 +152,8 @@ func (h HtlcOutputType) String() string {
// NewHtlc returns a new instance. For v3 scripts, an internal pubkey generated // NewHtlc returns a new instance. For v3 scripts, an internal pubkey generated
// by both participants must be provided. // by both participants must be provided.
func NewHtlc(version ScriptVersion, cltvExpiry int32, func NewHtlc(version ScriptVersion, cltvExpiry int32,
senderKey, receiverKey [33]byte, sharedKey *btcec.PublicKey, senderKey, receiverKey [33]byte, hash lntypes.Hash,
hash lntypes.Hash, outputType HtlcOutputType, outputType HtlcOutputType, chainParams *chaincfg.Params) (*Htlc, error) {
chainParams *chaincfg.Params) (*Htlc, error) {
var ( var (
err error err error
@ -158,28 +162,18 @@ func NewHtlc(version ScriptVersion, cltvExpiry int32,
switch version { switch version {
case HtlcV1: case HtlcV1:
if sharedKey != nil {
return nil, ErrSharedKeyNotNeeded
}
htlc, err = newHTLCScriptV1( htlc, err = newHTLCScriptV1(
cltvExpiry, senderKey, receiverKey, hash, cltvExpiry, senderKey, receiverKey, hash,
) )
case HtlcV2: case HtlcV2:
if sharedKey != nil {
return nil, ErrSharedKeyNotNeeded
}
htlc, err = newHTLCScriptV2( htlc, err = newHTLCScriptV2(
cltvExpiry, senderKey, receiverKey, hash, cltvExpiry, senderKey, receiverKey, hash,
) )
case HtlcV3: case HtlcV3:
if sharedKey == nil {
return nil, ErrSharedKeyRequired
}
htlc, err = newHTLCScriptV3( htlc, err = newHTLCScriptV3(
cltvExpiry, senderKey, receiverKey, cltvExpiry, senderKey, receiverKey, hash,
sharedKey, hash,
) )
default: default:
@ -190,14 +184,36 @@ func NewHtlc(version ScriptVersion, cltvExpiry int32,
return nil, err return nil, err
} }
var pkScript, sigScript []byte address, pkScript, sigScript, err := htlc.lockingConditions(
var address btcutil.Address outputType, chainParams,
)
if err != nil {
return nil, fmt.Errorf("could not get address: %w", err)
}
return &Htlc{
HtlcScript: htlc,
Hash: hash,
Version: version,
PkScript: pkScript,
OutputType: outputType,
ChainParams: chainParams,
Address: address,
SigScript: sigScript,
}, nil
}
// segwitV0LockingConditions provides the address, pkScript and sigScript (if
// required) for the segwit v0 script and output type provided.
func segwitV0LockingConditions(outputType HtlcOutputType,
chainParams *chaincfg.Params, script []byte) (btcutil.Address,
[]byte, []byte, error) {
switch outputType { switch outputType {
case HtlcNP2WSH: case HtlcNP2WSH:
p2wshPkScript, err := input.WitnessScriptHash(htlc.Script()) p2wshPkScript, err := input.WitnessScriptHash(script)
if err != nil { if err != nil {
return nil, err return nil, nil, nil, err
} }
// Generate p2sh script for p2wsh (nested). // Generate p2sh script for p2wsh (nested).
@ -210,78 +226,54 @@ func NewHtlc(version ScriptVersion, cltvExpiry int32,
builder.AddData(hash160) builder.AddData(hash160)
builder.AddOp(txscript.OP_EQUAL) builder.AddOp(txscript.OP_EQUAL)
pkScript, err = builder.Script() pkScript, err := builder.Script()
if err != nil { if err != nil {
return nil, err return nil, nil, nil, err
} }
// Generate a valid sigScript that will allow us to spend the // Generate a valid sigScript that will allow us to spend the
// p2sh output. The sigScript will contain only a single push of // p2sh output. The sigScript will contain only a single push of
// the p2wsh witness program corresponding to the matching // the p2wsh witness program corresponding to the matching
// public key of this address. // public key of this address.
sigScript, err = txscript.NewScriptBuilder(). sigScript, err := txscript.NewScriptBuilder().
AddData(p2wshPkScript). AddData(p2wshPkScript).
Script() Script()
if err != nil { if err != nil {
return nil, err return nil, nil, nil, err
} }
address, err = btcutil.NewAddressScriptHash( address, err := btcutil.NewAddressScriptHash(
p2wshPkScript, chainParams, p2wshPkScript, chainParams,
) )
if err != nil { if err != nil {
return nil, err return nil, nil, nil, err
} }
return address, pkScript, sigScript, nil
case HtlcP2WSH: case HtlcP2WSH:
pkScript, err = input.WitnessScriptHash(htlc.Script()) pkScript, err := input.WitnessScriptHash(script)
if err != nil { if err != nil {
return nil, err return nil, nil, nil, err
} }
address, err = btcutil.NewAddressWitnessScriptHash( address, err := btcutil.NewAddressWitnessScriptHash(
pkScript[2:], pkScript[2:],
chainParams, chainParams,
) )
if err != nil { if err != nil {
return nil, err return nil, nil, nil, err
} }
case HtlcP2TR: // Pay to witness script hash (segwit v0) does not need a
// Confirm we have a v3 htlc. // sigScript (we provide it in the witness instead), so we
trHtlc, ok := htlc.(*HtlcScriptV3) // return nil for our sigScript.
if !ok { return address, pkScript, nil, nil
return nil, ErrInvalidOutputSelected
}
// Generate a tapscript address from our HTLC's taptree.
address, err = btcutil.NewAddressTaproot(
schnorr.SerializePubKey(trHtlc.TaprootKey), chainParams,
)
if err != nil {
return nil, err
}
// Generate locking script.
pkScript, err = txscript.PayToAddrScript(address)
if err != nil {
return nil, err
}
default: default:
return nil, errors.New("unknown output type") return nil, nil, nil, fmt.Errorf("unexpected output type: %d",
outputType)
} }
return &Htlc{
HtlcScript: htlc,
Hash: hash,
Version: version,
PkScript: pkScript,
OutputType: outputType,
ChainParams: chainParams,
Address: address,
SigScript: sigScript,
}, nil
} }
// GenSuccessWitness returns the success script to spend this htlc with // GenSuccessWitness returns the success script to spend this htlc with
@ -307,8 +299,8 @@ func (h *Htlc) AddSuccessToEstimator(estimator *input.TxWeightEstimator) error {
if !ok { if !ok {
return ErrInvalidOutputSelected return ErrInvalidOutputSelected
} }
successLeaf := txscript.NewBaseTapLeaf(trHtlc.SuccessScript) successLeaf := txscript.NewBaseTapLeaf(trHtlc.SuccessScript())
timeoutLeaf := txscript.NewBaseTapLeaf(trHtlc.TimeoutScript) timeoutLeaf := txscript.NewBaseTapLeaf(trHtlc.TimeoutScript())
timeoutLeafHash := timeoutLeaf.TapHash() timeoutLeafHash := timeoutLeaf.TapHash()
tapscript := input.TapscriptPartialReveal( tapscript := input.TapscriptPartialReveal(
@ -338,8 +330,8 @@ func (h *Htlc) AddTimeoutToEstimator(estimator *input.TxWeightEstimator) error {
if !ok { if !ok {
return ErrInvalidOutputSelected return ErrInvalidOutputSelected
} }
successLeaf := txscript.NewBaseTapLeaf(trHtlc.SuccessScript) successLeaf := txscript.NewBaseTapLeaf(trHtlc.SuccessScript())
timeoutLeaf := txscript.NewBaseTapLeaf(trHtlc.TimeoutScript) timeoutLeaf := txscript.NewBaseTapLeaf(trHtlc.TimeoutScript())
successLeafHash := successLeaf.TapHash() successLeafHash := successLeaf.TapHash()
tapscript := input.TapscriptPartialReveal( tapscript := input.TapscriptPartialReveal(
@ -453,8 +445,19 @@ func (h *HtlcScriptV1) IsSuccessWitness(witness wire.TxWitness) bool {
return !isTimeoutTx return !isTimeoutTx
} }
// Script returns the htlc script. // TimeoutScript returns the redeem script required to unlock the htlc after
func (h *HtlcScriptV1) Script() []byte { // timeout.
//
// In the case of HtlcScriptV1, this is the full segwit v0 script.
func (h *HtlcScriptV1) TimeoutScript() []byte {
return h.script
}
// SuccessScript returns the redeem script required to unlock the htlc using
// the preimage.
//
// In the case of HtlcScriptV1, this is the full segwit v0 script.
func (h *HtlcScriptV1) SuccessScript() []byte {
return h.script return h.script
} }
@ -491,6 +494,19 @@ func (h *HtlcScriptV1) SuccessSequence() uint32 {
return 0 return 0
} }
// Sighash is the signature hash to use for transactions spending from the htlc.
func (h *HtlcScriptV1) SigHash() txscript.SigHashType {
return txscript.SigHashAll
}
// lockingConditions return the address, pkScript and sigScript (if
// required) for a htlc script.
func (h *HtlcScriptV1) lockingConditions(htlcOutputType HtlcOutputType,
params *chaincfg.Params) (btcutil.Address, []byte, []byte, error) {
return segwitV0LockingConditions(htlcOutputType, params, h.script)
}
// HtlcScriptV2 encapsulates the htlc v2 script. // HtlcScriptV2 encapsulates the htlc v2 script.
type HtlcScriptV2 struct { type HtlcScriptV2 struct {
script []byte script []byte
@ -586,8 +602,19 @@ func (h *HtlcScriptV2) GenTimeoutWitness(
return witnessStack, nil return witnessStack, nil
} }
// Script returns the htlc script. // TimeoutScript returns the redeem script required to unlock the htlc after
func (h *HtlcScriptV2) Script() []byte { // timeout.
//
// In the case of HtlcScriptV2, this is the full segwit v0 script.
func (h *HtlcScriptV2) TimeoutScript() []byte {
return h.script
}
// SuccessScript returns the redeem script required to unlock the htlc using
// the preimage.
//
// In the case of HtlcScriptV2, this is the full segwit v0 script.
func (h *HtlcScriptV2) SuccessScript() []byte {
return h.script return h.script
} }
@ -625,51 +652,66 @@ func (h *HtlcScriptV2) SuccessSequence() uint32 {
return 1 return 1
} }
// Sighash is the signature hash to use for transactions spending from the htlc.
func (h *HtlcScriptV2) SigHash() txscript.SigHashType {
return txscript.SigHashAll
}
// lockingConditions return the address, pkScript and sigScript (if
// required) for a htlc script.
func (h *HtlcScriptV2) lockingConditions(htlcOutputType HtlcOutputType,
params *chaincfg.Params) (btcutil.Address, []byte, []byte, error) {
return segwitV0LockingConditions(htlcOutputType, params, h.script)
}
// HtlcScriptV3 encapsulates the htlc v3 script. // HtlcScriptV3 encapsulates the htlc v3 script.
type HtlcScriptV3 struct { type HtlcScriptV3 struct {
// The final locking script for the timeout path which is available to // timeoutScript is the final locking script for the timeout path which
// the sender after the set blockheight. // is available to the sender after the set blockheight.
TimeoutScript []byte timeoutScript []byte
// The final locking script for the success path in which the receiver // successScript is the final locking script for the success path in
// reveals the preimage. // which the receiver reveals the preimage.
SuccessScript []byte successScript []byte
// The public key for the keyspend path which bypasses the above two // InternalPubKey is the public key for the keyspend path which bypasses
// locking scripts. // the above two locking scripts.
InternalPubKey *btcec.PublicKey InternalPubKey *btcec.PublicKey
// The taproot public key which is created with the above 3 inputs. // TaprootKey is the taproot public key which is created with the above
// 3 inputs.
TaprootKey *btcec.PublicKey TaprootKey *btcec.PublicKey
// RootHash is the root hash of the taptree.
RootHash chainhash.Hash
} }
// newHTLCScriptV3 constructs a HtlcScipt with the HTLC V3 taproot script. // newHTLCScriptV3 constructs a HtlcScipt with the HTLC V3 taproot script.
func newHTLCScriptV3(cltvExpiry int32, senderHtlcKey, func newHTLCScriptV3(cltvExpiry int32, senderHtlcKey, receiverHtlcKey [33]byte,
receiverHtlcKey [33]byte, sharedKey *btcec.PublicKey,
swapHash lntypes.Hash) (*HtlcScriptV3, error) { swapHash lntypes.Hash) (*HtlcScriptV3, error) {
receiverPubKey, err := btcec.ParsePubKey( senderPubKey, err := schnorr.ParsePubKey(senderHtlcKey[1:])
receiverHtlcKey[:],
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
senderPubKey, err := btcec.ParsePubKey( receiverPubKey, err := schnorr.ParsePubKey(receiverHtlcKey[1:])
senderHtlcKey[:],
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var schnorrSenderKey, schnorrReceiverKey [32]byte aggregateKey, _, _, err := musig2.AggregateKeys(
copy(schnorrSenderKey[:], schnorr.SerializePubKey(senderPubKey)) []*btcec.PublicKey{senderPubKey, receiverPubKey}, true,
copy(schnorrReceiverKey[:], schnorr.SerializePubKey(receiverPubKey)) )
if err != nil {
return nil, err
}
// Create our success path script, we'll use this separately // Create our success path script, we'll use this separately
// to generate the success path leaf. // to generate the success path leaf.
successPathScript, err := GenSuccessPathScript( successPathScript, err := GenSuccessPathScript(
schnorrReceiverKey, swapHash, receiverPubKey, swapHash,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -678,7 +720,7 @@ func newHTLCScriptV3(cltvExpiry int32, senderHtlcKey,
// Create our timeout path leaf, we'll use this separately // Create our timeout path leaf, we'll use this separately
// to generate the timeout path leaf. // to generate the timeout path leaf.
timeoutPathScript, err := GenTimeoutPathScript( timeoutPathScript, err := GenTimeoutPathScript(
schnorrSenderKey, int64(cltvExpiry), senderPubKey, int64(cltvExpiry),
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -694,14 +736,15 @@ func newHTLCScriptV3(cltvExpiry int32, senderHtlcKey,
// Calculate top level taproot key. // Calculate top level taproot key.
taprootKey := txscript.ComputeTaprootOutputKey( taprootKey := txscript.ComputeTaprootOutputKey(
sharedKey, rootHash[:], aggregateKey.PreTweakedKey, rootHash[:],
) )
return &HtlcScriptV3{ return &HtlcScriptV3{
TimeoutScript: timeoutPathScript, timeoutScript: timeoutPathScript,
SuccessScript: successPathScript, successScript: successPathScript,
InternalPubKey: sharedKey, InternalPubKey: aggregateKey.PreTweakedKey,
TaprootKey: taprootKey, TaprootKey: taprootKey,
RootHash: rootHash,
}, nil }, nil
} }
@ -709,11 +752,11 @@ func newHTLCScriptV3(cltvExpiry int32, senderHtlcKey,
// Largest possible bytesize of the script is 32 + 1 + 2 + 1 = 36. // Largest possible bytesize of the script is 32 + 1 + 2 + 1 = 36.
// //
// <senderHtlcKey> OP_CHECKSIGVERIFY <cltvExpiry> OP_CHECKLOCKTIMEVERIFY // <senderHtlcKey> OP_CHECKSIGVERIFY <cltvExpiry> OP_CHECKLOCKTIMEVERIFY
func GenTimeoutPathScript( func GenTimeoutPathScript(senderHtlcKey *btcec.PublicKey, cltvExpiry int64) (
senderHtlcKey [32]byte, cltvExpiry int64) ([]byte, error) { []byte, error) {
builder := txscript.NewScriptBuilder() builder := txscript.NewScriptBuilder()
builder.AddData(senderHtlcKey[:]) builder.AddData(schnorr.SerializePubKey(senderHtlcKey))
builder.AddOp(txscript.OP_CHECKSIGVERIFY) builder.AddOp(txscript.OP_CHECKSIGVERIFY)
builder.AddInt64(cltvExpiry) builder.AddInt64(cltvExpiry)
builder.AddOp(txscript.OP_CHECKLOCKTIMEVERIFY) builder.AddOp(txscript.OP_CHECKLOCKTIMEVERIFY)
@ -727,12 +770,12 @@ func GenTimeoutPathScript(
// OP_SIZE 32 OP_EQUALVERIFY // OP_SIZE 32 OP_EQUALVERIFY
// OP_HASH160 <ripemd160h(swapHash)> OP_EQUALVERIFY // OP_HASH160 <ripemd160h(swapHash)> OP_EQUALVERIFY
// 1 OP_CHECKSEQUENCEVERIFY // 1 OP_CHECKSEQUENCEVERIFY
func GenSuccessPathScript(receiverHtlcKey [32]byte, func GenSuccessPathScript(receiverHtlcKey *btcec.PublicKey,
swapHash lntypes.Hash) ([]byte, error) { swapHash lntypes.Hash) ([]byte, error) {
builder := txscript.NewScriptBuilder() builder := txscript.NewScriptBuilder()
builder.AddData(receiverHtlcKey[:]) builder.AddData(schnorr.SerializePubKey(receiverHtlcKey))
builder.AddOp(txscript.OP_CHECKSIGVERIFY) builder.AddOp(txscript.OP_CHECKSIGVERIFY)
builder.AddOp(txscript.OP_SIZE) builder.AddOp(txscript.OP_SIZE)
builder.AddInt64(32) builder.AddInt64(32)
@ -777,7 +820,7 @@ func (h *HtlcScriptV3) genControlBlock(leafScript []byte) ([]byte, error) {
func (h *HtlcScriptV3) genSuccessWitness( func (h *HtlcScriptV3) genSuccessWitness(
receiverSig []byte, preimage lntypes.Preimage) (wire.TxWitness, error) { receiverSig []byte, preimage lntypes.Preimage) (wire.TxWitness, error) {
controlBlockBytes, err := h.genControlBlock(h.TimeoutScript) controlBlockBytes, err := h.genControlBlock(h.timeoutScript)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -785,7 +828,7 @@ func (h *HtlcScriptV3) genSuccessWitness(
return wire.TxWitness{ return wire.TxWitness{
preimage[:], preimage[:],
receiverSig, receiverSig,
h.SuccessScript, h.successScript,
controlBlockBytes, controlBlockBytes,
}, nil }, nil
} }
@ -795,14 +838,14 @@ func (h *HtlcScriptV3) genSuccessWitness(
func (h *HtlcScriptV3) GenTimeoutWitness( func (h *HtlcScriptV3) GenTimeoutWitness(
senderSig []byte) (wire.TxWitness, error) { senderSig []byte) (wire.TxWitness, error) {
controlBlockBytes, err := h.genControlBlock(h.SuccessScript) controlBlockBytes, err := h.genControlBlock(h.successScript)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return wire.TxWitness{ return wire.TxWitness{
senderSig, senderSig,
h.TimeoutScript, h.timeoutScript,
controlBlockBytes, controlBlockBytes,
}, nil }, nil
} }
@ -813,9 +856,20 @@ func (h *HtlcScriptV3) IsSuccessWitness(witness wire.TxWitness) bool {
return len(witness) == 4 return len(witness) == 4
} }
// Script is not implemented, but necessary to conform to interface. // TimeoutScript returns the redeem script required to unlock the htlc after
func (h *HtlcScriptV3) Script() []byte { // timeout.
return nil //
// In the case of HtlcScriptV3, this is the timeout tapleaf.
func (h *HtlcScriptV3) TimeoutScript() []byte {
return h.timeoutScript
}
// SuccessScript returns the redeem script required to unlock the htlc using
// the preimage.
//
// In the case of HtlcScriptV3, this is the claim tapleaf.
func (h *HtlcScriptV3) SuccessScript() []byte {
return h.successScript
} }
// MaxSuccessWitnessSize returns the maximum witness size for the // MaxSuccessWitnessSize returns the maximum witness size for the
@ -861,3 +915,39 @@ func (h *HtlcScriptV3) MaxTimeoutWitnessSize() int {
func (h *HtlcScriptV3) SuccessSequence() uint32 { func (h *HtlcScriptV3) SuccessSequence() uint32 {
return 1 return 1
} }
// Sighash is the signature hash to use for transactions spending from the htlc.
func (h *HtlcScriptV3) SigHash() txscript.SigHashType {
return txscript.SigHashDefault
}
// lockingConditions return the address, pkScript and sigScript (if required)
// for a htlc script.
func (h *HtlcScriptV3) lockingConditions(outputType HtlcOutputType,
chainParams *chaincfg.Params) (btcutil.Address, []byte, []byte, error) {
// HtlcV3 can only have taproot output type, because we utilize
// tapscript claim paths.
if outputType != HtlcP2TR {
return nil, nil, nil, fmt.Errorf("htlc v3 only supports P2TR "+
"outputs, got: %v", outputType)
}
// Generate a tapscript address from our tree.
address, err := btcutil.NewAddressTaproot(
schnorr.SerializePubKey(h.TaprootKey), chainParams,
)
if err != nil {
return nil, nil, nil, err
}
// Generate locking script.
pkScript, err := txscript.PayToAddrScript(address)
if err != nil {
return nil, nil, nil, err
}
// Taproot (segwit v1) does not need a sigScript (we provide it in the
// witness instead), so we return nil for our sigScript.
return address, pkScript, nil, nil
}

@ -3,12 +3,10 @@ package swap
import ( import (
"bytes" "bytes"
"crypto/sha256" "crypto/sha256"
"encoding/hex"
"fmt" "fmt"
"testing" "testing"
"github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcec/v2/schnorr"
"github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/chaincfg/chainhash"
@ -137,7 +135,7 @@ func TestHtlcV2(t *testing.T) {
// Create the htlc. // Create the htlc.
htlc, err := NewHtlc( htlc, err := NewHtlc(
HtlcV2, testCltvExpiry, senderKey, receiverKey, nil, hash, HtlcV2, testCltvExpiry, senderKey, receiverKey, hash,
HtlcP2WSH, &chaincfg.MainNetParams, HtlcP2WSH, &chaincfg.MainNetParams,
) )
require.NoError(t, err) require.NoError(t, err)
@ -160,16 +158,17 @@ func TestHtlcV2(t *testing.T) {
) )
signTx := func(tx *wire.MsgTx, pubkey *btcec.PublicKey, signTx := func(tx *wire.MsgTx, pubkey *btcec.PublicKey,
signer *input.MockSigner) (input.Signature, error) { signer *input.MockSigner, witnessScript []byte) (
input.Signature, error) {
signDesc := &input.SignDescriptor{ signDesc := &input.SignDescriptor{
KeyDesc: keychain.KeyDescriptor{ KeyDesc: keychain.KeyDescriptor{
PubKey: pubkey, PubKey: pubkey,
}, },
WitnessScript: htlc.Script(), WitnessScript: witnessScript,
Output: htlcOutput, Output: htlcOutput,
HashType: txscript.SigHashAll, HashType: htlc.SigHash(),
SigHashes: txscript.NewTxSigHashes( SigHashes: txscript.NewTxSigHashes(
tx, prevOutFetcher, tx, prevOutFetcher,
), ),
@ -191,6 +190,7 @@ func TestHtlcV2(t *testing.T) {
sweepTx.TxIn[0].Sequence = htlc.SuccessSequence() sweepTx.TxIn[0].Sequence = htlc.SuccessSequence()
sweepSig, err := signTx( sweepSig, err := signTx(
sweepTx, receiverPubKey, receiverSigner, sweepTx, receiverPubKey, receiverSigner,
htlc.SuccessScript(),
) )
require.NoError(t, err) require.NoError(t, err)
@ -210,6 +210,7 @@ func TestHtlcV2(t *testing.T) {
sweepTx.TxIn[0].Sequence = 0 sweepTx.TxIn[0].Sequence = 0
sweepSig, err := signTx( sweepSig, err := signTx(
sweepTx, receiverPubKey, receiverSigner, sweepTx, receiverPubKey, receiverSigner,
htlc.SuccessScript(),
) )
require.NoError(t, err) require.NoError(t, err)
@ -228,6 +229,7 @@ func TestHtlcV2(t *testing.T) {
sweepTx.LockTime = testCltvExpiry - 1 sweepTx.LockTime = testCltvExpiry - 1
sweepSig, err := signTx( sweepSig, err := signTx(
sweepTx, senderPubKey, senderSigner, sweepTx, senderPubKey, senderSigner,
htlc.TimeoutScript(),
) )
require.NoError(t, err) require.NoError(t, err)
@ -246,6 +248,7 @@ func TestHtlcV2(t *testing.T) {
sweepTx.LockTime = testCltvExpiry sweepTx.LockTime = testCltvExpiry
sweepSig, err := signTx( sweepSig, err := signTx(
sweepTx, senderPubKey, senderSigner, sweepTx, senderPubKey, senderSigner,
htlc.TimeoutScript(),
) )
require.NoError(t, err) require.NoError(t, err)
@ -264,6 +267,7 @@ func TestHtlcV2(t *testing.T) {
sweepTx.LockTime = testCltvExpiry sweepTx.LockTime = testCltvExpiry
sweepSig, err := signTx( sweepSig, err := signTx(
sweepTx, receiverPubKey, receiverSigner, sweepTx, receiverPubKey, receiverSigner,
htlc.TimeoutScript(),
) )
require.NoError(t, err) require.NoError(t, err)
@ -285,7 +289,7 @@ func TestHtlcV2(t *testing.T) {
// Create the htlc with the bogus key. // Create the htlc with the bogus key.
htlc, err = NewHtlc( htlc, err = NewHtlc(
HtlcV2, testCltvExpiry, HtlcV2, testCltvExpiry,
bogusKey, receiverKey, nil, hash, bogusKey, receiverKey, hash,
HtlcP2WSH, &chaincfg.MainNetParams, HtlcP2WSH, &chaincfg.MainNetParams,
) )
require.NoError(t, err) require.NoError(t, err)
@ -299,6 +303,7 @@ func TestHtlcV2(t *testing.T) {
sweepTx.LockTime = testCltvExpiry sweepTx.LockTime = testCltvExpiry
sweepSig, err := signTx( sweepSig, err := signTx(
sweepTx, senderPubKey, senderSigner, sweepTx, senderPubKey, senderSigner,
htlc.TimeoutScript(),
) )
require.NoError(t, err) require.NoError(t, err)
@ -352,17 +357,8 @@ func TestHtlcV3(t *testing.T) {
copy(receiverKey[:], receiverPubKey.SerializeCompressed()) copy(receiverKey[:], receiverPubKey.SerializeCompressed())
copy(senderKey[:], senderPubKey.SerializeCompressed()) copy(senderKey[:], senderPubKey.SerializeCompressed())
randomSharedKey, err := hex.DecodeString(
"03fcb7d1b502bd59f4dbc6cf503e5c280189e0e6dd2d10c4c14d97ed8611" +
"a99178",
)
require.NoError(t, err)
randomSharedPubKey, err := btcec.ParsePubKey(randomSharedKey)
require.NoError(t, err)
htlc, err := NewHtlc( htlc, err := NewHtlc(
HtlcV3, cltvExpiry, senderKey, receiverKey, randomSharedPubKey, HtlcV3, cltvExpiry, senderKey, receiverKey,
hashedPreimage, HtlcP2TR, &chaincfg.MainNetParams, hashedPreimage, HtlcP2TR, &chaincfg.MainNetParams,
) )
require.NoError(t, err) require.NoError(t, err)
@ -401,7 +397,7 @@ func TestHtlcV3(t *testing.T) {
sig, err := txscript.RawTxInTapscriptSignature( sig, err := txscript.RawTxInTapscriptSignature(
tx, hashCache, 0, value, p2trPkScript, leaf, tx, hashCache, 0, value, p2trPkScript, leaf,
txscript.SigHashDefault, privateKey, htlc.SigHash(), privateKey,
) )
require.NoError(t, err) require.NoError(t, err)
@ -426,7 +422,7 @@ func TestHtlcV3(t *testing.T) {
sig := signTx( sig := signTx(
tx, receiverPrivKey, tx, receiverPrivKey,
txscript.NewBaseTapLeaf( txscript.NewBaseTapLeaf(
trHtlc.SuccessScript, trHtlc.SuccessScript(),
), ),
) )
witness, err := htlc.genSuccessWitness( witness, err := htlc.genSuccessWitness(
@ -450,7 +446,7 @@ func TestHtlcV3(t *testing.T) {
sig := signTx( sig := signTx(
tx, receiverPrivKey, tx, receiverPrivKey,
txscript.NewBaseTapLeaf( txscript.NewBaseTapLeaf(
trHtlc.SuccessScript, trHtlc.SuccessScript(),
), ),
) )
witness, err := htlc.genSuccessWitness( witness, err := htlc.genSuccessWitness(
@ -474,7 +470,7 @@ func TestHtlcV3(t *testing.T) {
sig := signTx( sig := signTx(
tx, senderPrivKey, tx, senderPrivKey,
txscript.NewBaseTapLeaf( txscript.NewBaseTapLeaf(
trHtlc.TimeoutScript, trHtlc.TimeoutScript(),
), ),
) )
@ -497,7 +493,7 @@ func TestHtlcV3(t *testing.T) {
sig := signTx( sig := signTx(
tx, senderPrivKey, tx, senderPrivKey,
txscript.NewBaseTapLeaf( txscript.NewBaseTapLeaf(
trHtlc.TimeoutScript, trHtlc.TimeoutScript(),
), ),
) )
@ -520,7 +516,7 @@ func TestHtlcV3(t *testing.T) {
sig := signTx( sig := signTx(
tx, receiverPrivKey, tx, receiverPrivKey,
txscript.NewBaseTapLeaf( txscript.NewBaseTapLeaf(
trHtlc.TimeoutScript, trHtlc.TimeoutScript(),
), ),
) )
@ -544,18 +540,9 @@ func TestHtlcV3(t *testing.T) {
bogusKey.SerializeCompressed(), bogusKey.SerializeCompressed(),
) )
var shnorrSenderKey [32]byte
copy(
shnorrSenderKey[:],
schnorr.SerializePubKey(
senderPubKey,
),
)
htlc, err := NewHtlc( htlc, err := NewHtlc(
HtlcV3, cltvExpiry, bogusKeyBytes, HtlcV3, cltvExpiry, bogusKeyBytes,
receiverKey, randomSharedPubKey, receiverKey, hashedPreimage, HtlcP2TR,
hashedPreimage, HtlcP2TR,
&chaincfg.MainNetParams, &chaincfg.MainNetParams,
) )
require.NoError(t, err) require.NoError(t, err)
@ -576,7 +563,7 @@ func TestHtlcV3(t *testing.T) {
) )
timeoutScript, err := GenTimeoutPathScript( timeoutScript, err := GenTimeoutPathScript(
shnorrSenderKey, int64(cltvExpiry), senderPubKey, int64(cltvExpiry),
) )
require.NoError(t, err) require.NoError(t, err)

@ -23,7 +23,7 @@ type Sweeper struct {
func (s *Sweeper) CreateSweepTx( func (s *Sweeper) CreateSweepTx(
globalCtx context.Context, height int32, sequence uint32, globalCtx context.Context, height int32, sequence uint32,
htlc *swap.Htlc, htlcOutpoint wire.OutPoint, htlc *swap.Htlc, htlcOutpoint wire.OutPoint,
keyBytes [33]byte, keyBytes [33]byte, witnessScript []byte,
witnessFunc func(sig []byte) (wire.TxWitness, error), witnessFunc func(sig []byte) (wire.TxWitness, error),
amount, fee btcutil.Amount, amount, fee btcutil.Amount,
destAddr btcutil.Address) (*wire.MsgTx, error) { destAddr btcutil.Address) (*wire.MsgTx, error) {
@ -59,20 +59,30 @@ func (s *Sweeper) CreateSweepTx(
} }
signDesc := lndclient.SignDescriptor{ signDesc := lndclient.SignDescriptor{
WitnessScript: htlc.Script(), WitnessScript: witnessScript,
Output: &wire.TxOut{ Output: &wire.TxOut{
Value: int64(amount), Value: int64(amount),
PkScript: htlc.PkScript, PkScript: htlc.PkScript,
}, },
HashType: txscript.SigHashAll, HashType: htlc.SigHash(),
InputIndex: 0, InputIndex: 0,
KeyDesc: keychain.KeyDescriptor{ KeyDesc: keychain.KeyDescriptor{
PubKey: key, PubKey: key,
}, },
} }
// We need our previous outputs for taproot spends, and there's no
// harm including them for segwit v0, so we always include our prevOut.
prevOut := []*wire.TxOut{
{
Value: int64(amount),
PkScript: htlc.PkScript,
},
}
rawSigs, err := s.Lnd.Signer.SignOutputRaw( rawSigs, err := s.Lnd.Signer.SignOutputRaw(
globalCtx, sweepTx, []*lndclient.SignDescriptor{&signDesc}, nil, globalCtx, sweepTx, []*lndclient.SignDescriptor{&signDesc},
prevOut,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("signing: %v", err) return nil, fmt.Errorf("signing: %v", err)

Loading…
Cancel
Save