package loop import ( "context" "fmt" "testing" "time" "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/wire" "github.com/lightninglabs/lndclient" "github.com/lightninglabs/loop/loopdb" "github.com/lightninglabs/loop/swap" "github.com/lightninglabs/loop/sweep" "github.com/lightninglabs/loop/sweepbatcher" "github.com/lightninglabs/loop/test" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lntypes" "github.com/stretchr/testify/require" ) var ( testPreimage = lntypes.Preimage([32]byte{ 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, }) ) // testContext contains functionality to support client unit tests. type testContext struct { test.Context serverMock *serverMock swapClient *Client statusChan chan SwapInfo store *loopdb.StoreMock expiryChan chan time.Time runErr chan error stop func() } // mockVerifySchnorrSigFail is used to simulate failed taproot keyspend // signature verification. If passed to the executeConfig we'll test an // uncooperative server and will fall back to scriptspend sweep. func mockVerifySchnorrSigFail(pubKey *btcec.PublicKey, hash, sig []byte) error { return fmt.Errorf("invalid sig") } // mockVerifySchnorrSigSuccess is used to simulate successful taproot keyspend // signature verification. If passed to the executeConfig we'll test an // uncooperative server and will fall back to scriptspend sweep. func mockVerifySchnorrSigSuccess(pubKey *btcec.PublicKey, hash, sig []byte) error { return fmt.Errorf("invalid sig") } func mockMuSig2SignSweep(ctx context.Context, protocolVersion loopdb.ProtocolVersion, swapHash lntypes.Hash, paymentAddr [32]byte, nonce []byte, sweepTxPsbt []byte, prevoutMap map[wire.OutPoint]*wire.TxOut) ( []byte, []byte, error) { return nil, nil, nil } func newSwapClient(config *clientConfig) *Client { sweeper := &sweep.Sweeper{ Lnd: config.LndServices, } lndServices := config.LndServices batcherStore := sweepbatcher.NewStoreMock() batcher := sweepbatcher.NewBatcher( config.LndServices.WalletKit, config.LndServices.ChainNotifier, config.LndServices.Signer, mockMuSig2SignSweep, mockVerifySchnorrSigSuccess, config.LndServices.ChainParams, batcherStore, config.Store, ) executor := newExecutor(&executorConfig{ lnd: lndServices, store: config.Store, sweeper: sweeper, batcher: batcher, createExpiryTimer: config.CreateExpiryTimer, cancelSwap: config.Server.CancelLoopOutSwap, verifySchnorrSig: mockVerifySchnorrSigFail, }) return &Client{ errChan: make(chan error), clientConfig: *config, lndServices: lndServices, sweeper: sweeper, executor: executor, resumeReady: make(chan struct{}), } } func createClientTestContext(t *testing.T, pendingSwaps []*loopdb.LoopOut) *testContext { clientLnd := test.NewMockLnd() serverMock := newServerMock(clientLnd) store := loopdb.NewStoreMock(t) for _, s := range pendingSwaps { store.LoopOutSwaps[s.Hash] = s.Contract updates := []loopdb.SwapStateData{} for _, e := range s.Events { updates = append(updates, e.SwapStateData) } store.LoopOutUpdates[s.Hash] = updates } expiryChan := make(chan time.Time) timerFactory := func(expiry time.Duration) <-chan time.Time { return expiryChan } swapClient := newSwapClient(&clientConfig{ LndServices: &clientLnd.LndServices, Server: serverMock, Store: store, CreateExpiryTimer: timerFactory, }) statusChan := make(chan SwapInfo) ctx := &testContext{ Context: test.NewContext( t, clientLnd, ), swapClient: swapClient, statusChan: statusChan, expiryChan: expiryChan, store: store, serverMock: serverMock, } ctx.runErr = make(chan error) runCtx, stop := context.WithCancel(context.Background()) ctx.stop = stop go func() { err := swapClient.Run(runCtx, statusChan) log.Errorf("client run: %v", err) ctx.runErr <- err }() return ctx } func (ctx *testContext) finish() { ctx.stop() select { case err := <-ctx.runErr: require.NoError(ctx.Context.T, err) case <-time.After(test.Timeout): ctx.Context.T.Fatal("client not stopping") } ctx.assertIsDone() } func (ctx *testContext) assertIsDone() { require.NoError(ctx.Context.T, ctx.Context.Lnd.IsDone()) require.NoError(ctx.Context.T, ctx.store.IsDone()) select { case <-ctx.statusChan: ctx.Context.T.Fatalf("not all status updates read") default: } } func (ctx *testContext) assertStored() { ctx.Context.T.Helper() ctx.store.AssertLoopOutStored() } func (ctx *testContext) assertStorePreimageReveal() { ctx.Context.T.Helper() ctx.store.AssertStorePreimageReveal() } func (ctx *testContext) assertStoreFinished(expectedResult loopdb.SwapState) { ctx.Context.T.Helper() ctx.store.AssertStoreFinished(expectedResult) } func (ctx *testContext) assertStatus(expectedState loopdb.SwapState) { ctx.Context.T.Helper() for { select { case update := <-ctx.statusChan: if update.SwapType != swap.TypeOut { continue } if update.State == expectedState { return } case <-time.After(test.Timeout): ctx.Context.T.Fatalf("expected status %v not "+ "received in time", expectedState) } } } func (ctx *testContext) publishHtlc(script []byte, amt btcutil.Amount) wire.OutPoint { // Create the htlc tx. htlcTx := wire.MsgTx{} htlcTx.AddTxIn(&wire.TxIn{ PreviousOutPoint: wire.OutPoint{}, }) htlcTx.AddTxOut(&wire.TxOut{ PkScript: script, Value: int64(amt), }) htlcTxHash := htlcTx.TxHash() // Signal client that script has been published. select { case ctx.Lnd.ConfChannel <- &chainntnfs.TxConfirmation{ Tx: &htlcTx, }: case <-time.After(test.Timeout): ctx.Context.T.Fatalf("htlc confirmed not consumed") } return wire.OutPoint{ Hash: htlcTxHash, Index: 0, } } // trackPayment asserts that a call to track payment was sent and sends the // status provided into the updates channel. func (ctx *testContext) trackPayment(status lnrpc.Payment_PaymentStatus) { trackPayment := ctx.Context.AssertTrackPayment() select { case trackPayment.Updates <- lndclient.PaymentStatus{ State: status, }: case <-time.After(test.Timeout): ctx.Context.T.Fatalf("could not send payment update") } } // assertPreimagePush asserts that we made an attempt to push our preimage to // the server. func (ctx *testContext) assertPreimagePush(preimage lntypes.Preimage) { select { case pushedPreimage := <-ctx.serverMock.preimagePush: require.Equal(ctx.Context.T, preimage, pushedPreimage) case <-time.After(test.Timeout): ctx.Context.T.Fatalf("preimage not pushed") } } func (ctx *testContext) AssertEpochListeners(numListeners int32) { ctx.Context.T.Helper() require.Eventually(ctx.Context.T, func() bool { return ctx.Lnd.EpochSubscribers() == numListeners }, test.Timeout, time.Millisecond*250) }