From 5399e605545376be4bc6e79fa75a725a981b71f4 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 8 Mar 2021 10:52:57 +0200 Subject: [PATCH] loopd: verify that dest addr is for correct chain This commit adds verification to the loop out request to ensure that the formatting of the specified destination address matches the network that lnd is running on. --- loopd/swapclient_server.go | 25 ++++++++++++++---- loopd/swapclient_server_test.go | 47 ++++++++++++++++++++++++++++++++- 2 files changed, 66 insertions(+), 6 deletions(-) diff --git a/loopd/swapclient_server.go b/loopd/swapclient_server.go index 1e3f925..550504a 100644 --- a/loopd/swapclient_server.go +++ b/loopd/swapclient_server.go @@ -8,6 +8,7 @@ import ( "sync" "time" + "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcutil" "github.com/lightninglabs/lndclient" "github.com/lightninglabs/loop" @@ -35,6 +36,11 @@ const ( ) var ( + // errIncorrectChain is returned when the format of the + // destination address provided does not match the active chain. + errIncorrectChain = errors.New("invalid address format for the " + + "active chain") + // errConfTargetTooLow is returned when the chosen confirmation target // is below the allowed minimum. errConfTargetTooLow = errors.New("confirmation target too low") @@ -82,8 +88,9 @@ func (s *swapClientServer) LoopOut(ctx context.Context, } } - sweepConfTarget, err := validateLoopOutRequest(in.SweepConfTarget, - in.Label) + sweepConfTarget, err := validateLoopOutRequest( + s.lnd.ChainParams, in.SweepConfTarget, sweepAddr, in.Label, + ) if err != nil { return nil, err } @@ -922,9 +929,17 @@ func validateLoopInRequest(htlcConfTarget int32, external bool) (int32, error) { return validateConfTarget(htlcConfTarget, loop.DefaultHtlcConfTarget) } -// validateLoopOutRequest validates the confirmation target and label of the -// loop out request. -func validateLoopOutRequest(confTarget int32, label string) (int32, error) { +// validateLoopOutRequest validates the confirmation target, destination +// address and label of the loop out request. +func validateLoopOutRequest(chainParams *chaincfg.Params, confTarget int32, + sweepAddr btcutil.Address, label string) (int32, error) { + // Check that the provided destination address has the correct format + // for the active network. + if !sweepAddr.IsForNet(chainParams) { + return 0, fmt.Errorf("%w: Current active network is %s", + errIncorrectChain, chainParams.Name) + } + // Check that the label is valid. if err := labels.Validate(label); err != nil { return 0, err diff --git a/loopd/swapclient_server_test.go b/loopd/swapclient_server_test.go index 418eb36..47bf141 100644 --- a/loopd/swapclient_server_test.go +++ b/loopd/swapclient_server_test.go @@ -163,13 +163,53 @@ func TestValidateLoopInRequest(t *testing.T) { func TestValidateLoopOutRequest(t *testing.T) { tests := []struct { name string + chain chaincfg.Params confTarget int32 + destAddr btcutil.Address label string err error expectedTarget int32 }{ + { + name: "mainnet address with mainnet backend", + chain: chaincfg.MainNetParams, + destAddr: mainnetAddr, + label: "label ok", + confTarget: 2, + err: nil, + expectedTarget: 2, + }, + { + name: "mainnet address with testnet backend", + chain: chaincfg.TestNet3Params, + destAddr: mainnetAddr, + label: "label ok", + confTarget: 2, + err: errIncorrectChain, + expectedTarget: 0, + }, + { + name: "testnet address with testnet backend", + chain: chaincfg.TestNet3Params, + destAddr: testnetAddr, + label: "label ok", + confTarget: 2, + err: nil, + expectedTarget: 2, + }, + { + name: "testnet address with mainnet backend", + chain: chaincfg.MainNetParams, + destAddr: testnetAddr, + label: "label ok", + confTarget: 2, + err: errIncorrectChain, + expectedTarget: 0, + }, { name: "invalid label", + chain: chaincfg.MainNetParams, + destAddr: mainnetAddr, label: labels.Reserved, confTarget: 2, err: labels.ErrReservedPrefix, @@ -177,6 +217,8 @@ func TestValidateLoopOutRequest(t *testing.T) { }, { name: "invalid conf target", + chain: chaincfg.MainNetParams, + destAddr: mainnetAddr, label: "label ok", confTarget: 1, err: errConfTargetTooLow, @@ -184,6 +226,8 @@ func TestValidateLoopOutRequest(t *testing.T) { }, { name: "default conf target", + chain: chaincfg.MainNetParams, + destAddr: mainnetAddr, label: "label ok", confTarget: 0, err: nil, @@ -198,7 +242,8 @@ func TestValidateLoopOutRequest(t *testing.T) { t.Parallel() conf, err := validateLoopOutRequest( - test.confTarget, test.label, + &test.chain, test.confTarget, test.destAddr, + test.label, ) require.True(t, errors.Is(err, test.err)) require.Equal(t, test.expectedTarget, conf)