diff --git a/client_test.go b/client_test.go index 911732a..f4a9382 100644 --- a/client_test.go +++ b/client_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcutil" "github.com/lightninglabs/lndclient" "github.com/lightninglabs/loop/loopdb" @@ -56,7 +57,7 @@ func TestSuccess(t *testing.T) { signalPrepaymentResult := ctx.AssertPaid(prepayInvoiceDesc) // Expect client to register for conf. - confIntent := ctx.AssertRegisterConf() + confIntent := ctx.AssertRegisterConf(false) testSuccess(ctx, testRequest.Amount, *hash, signalPrepaymentResult, signalSwapPaymentResult, false, @@ -82,7 +83,7 @@ func TestFailOffchain(t *testing.T) { signalSwapPaymentResult := ctx.AssertPaid(swapInvoiceDesc) signalPrepaymentResult := ctx.AssertPaid(prepayInvoiceDesc) - ctx.AssertRegisterConf() + ctx.AssertRegisterConf(false) signalSwapPaymentResult( errors.New(lndclient.PaymentResultUnknownPaymentHash), @@ -179,10 +180,17 @@ func testResume(t *testing.T, expired, preimageRevealed, expectSuccess bool) { var receiverKey [33]byte copy(receiverKey[:], receiverPubKey.SerializeCompressed()) - state := loopdb.StateInitiated + update := loopdb.LoopEvent{ + SwapStateData: loopdb.SwapStateData{ + State: loopdb.StateInitiated, + }, + } + if preimageRevealed { - state = loopdb.StatePreimageRevealed + update.State = loopdb.StatePreimageRevealed + update.HtlcTxHash = &chainhash.Hash{1, 2, 6} } + pendingSwap := &loopdb.LoopOut{ Contract: &loopdb.LoopOutContract{ DestAddr: dest, @@ -201,14 +209,8 @@ func testResume(t *testing.T, expired, preimageRevealed, expectSuccess bool) { }, }, Loop: loopdb.Loop{ - Events: []*loopdb.LoopEvent{ - { - SwapStateData: loopdb.SwapStateData{ - State: state, - }, - }, - }, - Hash: hash, + Events: []*loopdb.LoopEvent{&update}, + Hash: hash, }, } @@ -230,7 +232,7 @@ func testResume(t *testing.T, expired, preimageRevealed, expectSuccess bool) { signalPrepaymentResult := ctx.AssertPaid(prepayInvoiceDesc) // Expect client to register for conf. - confIntent := ctx.AssertRegisterConf() + confIntent := ctx.AssertRegisterConf(preimageRevealed) signalSwapPaymentResult(nil) signalPrepaymentResult(nil) diff --git a/go.sum b/go.sum index fa8cd9f..f2fe2fa 100644 --- a/go.sum +++ b/go.sum @@ -237,6 +237,7 @@ github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4k github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.2.0 h1:Hbg2NidpLE8veEBkEZTL3CvlkUIVzuU9jDplZO54c48= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= diff --git a/loopdb/meta.go b/loopdb/meta.go index 1d81a06..88fdf9a 100644 --- a/loopdb/meta.go +++ b/loopdb/meta.go @@ -36,6 +36,7 @@ var ( migrateCosts, migrateSwapPublicationDeadline, migrateLastHop, + migrateUpdates, } latestDBVersion = uint32(len(migrations)) diff --git a/loopdb/migration_04_updates.go b/loopdb/migration_04_updates.go new file mode 100644 index 0000000..9ad3a29 --- /dev/null +++ b/loopdb/migration_04_updates.go @@ -0,0 +1,112 @@ +package loopdb + +import ( + "errors" + "fmt" + + "github.com/btcsuite/btcd/chaincfg" + "github.com/coreos/bbolt" +) + +// migrateUpdates migrates the swap updates to add an additional level of +// nesting, allowing for optional keys to be added. +func migrateUpdates(tx *bbolt.Tx, chainParams *chaincfg.Params) error { + for _, key := range [][]byte{loopInBucketKey, loopOutBucketKey} { + rootBucket := tx.Bucket(key) + if rootBucket == nil { + return fmt.Errorf("bucket %v does not exist", key) + } + + err := migrateSwapTypeUpdates(rootBucket) + if err != nil { + return err + } + } + + return nil +} + +// migrateSwapTypeUpdates migrates updates for swaps in the specified bucket. +func migrateSwapTypeUpdates(rootBucket *bbolt.Bucket) error { + var swaps [][]byte + + // Do not modify inside the for each. + err := rootBucket.ForEach(func(swapHash, v []byte) error { + // Only go into things that we know are sub-bucket + // keys. + if rootBucket.Bucket(swapHash) != nil { + swaps = append(swaps, swapHash) + } + + return nil + }) + if err != nil { + return err + } + + // With the swaps listed, migrate them one by one. + for _, swapHash := range swaps { + swapBucket := rootBucket.Bucket(swapHash) + if swapBucket == nil { + return fmt.Errorf("swap bucket %x not found", + swapHash) + } + + err := migrateSwapUpdates(swapBucket) + if err != nil { + return err + } + } + + return nil +} + +// migrateSwapUpdates migrates updates for the swap stored in the specified +// bucket. +func migrateSwapUpdates(swapBucket *bbolt.Bucket) error { + // With the main swap bucket obtained, we'll grab the + // raw swap contract bytes. + updatesBucket := swapBucket.Bucket(updatesBucketKey) + if updatesBucket == nil { + return errors.New("updates bucket not found") + } + + type state struct { + id, state []byte + } + + var existingStates []state + + // Do not modify inside the for each. + err := updatesBucket.ForEach(func(k, v []byte) error { + existingStates = append(existingStates, state{id: k, state: v}) + return nil + }) + if err != nil { + return err + } + + for _, existingState := range existingStates { + // Delete the existing state key. + err := updatesBucket.Delete(existingState.id) + if err != nil { + return err + } + + // Re-create as a bucket. + updateBucket, err := updatesBucket.CreateBucket( + existingState.id, + ) + if err != nil { + return err + } + + // Write back the basic state as a sub-key. + err = updateBucket.Put(basicStateKey, existingState.state) + if err != nil { + return err + } + } + + return nil +} diff --git a/loopdb/migration_04_updates_test.go b/loopdb/migration_04_updates_test.go new file mode 100644 index 0000000..80c1156 --- /dev/null +++ b/loopdb/migration_04_updates_test.go @@ -0,0 +1,87 @@ +package loopdb + +import ( + "io/ioutil" + "os" + "path/filepath" + "testing" + + "github.com/btcsuite/btcd/chaincfg" + "github.com/coreos/bbolt" + "github.com/stretchr/testify/require" +) + +// TestMigrationUpdates asserts that the swap updates migration is carried out +// correctly. +func TestMigrationUpdates(t *testing.T) { + var ( + legacyDbVersion = Hex("00000003") + ) + + legacyDb := map[string]interface{}{ + "metadata": map[string]interface{}{ + "dbp": legacyDbVersion, + }, + "loop-in": map[string]interface{}{ + Hex("acae09fec9020b7996042613eede68a9eaf29eb28c21ea9943b19e344365a4bb"): map[string]interface{}{ + "contract": Hex("161b25277262bdb5c7c2827b975b2cbc7eb13e222b30cf88ea6daef4bcf22bdac4116c23071472cb000000000000ea6003f2f513a8fd7958b6a229dfb8835f6ab2c9c63cc3e138784d3e8c0e0ebbdd4e61033f26c40666977ed497eea4694d6dd3f07dbcf037089234ff665cd0a07fea329400007b8a00000000000059a600000000000009ca000077a20000000600000000000000000000000000000000000000000000000000000000000000000000"), + "updates": map[string]interface{}{ + Hex("0000000000000001"): Hex("161b252772cb524508000000000000000000000000000000000000000000000000"), + Hex("0000000000000002"): Hex("161b252837115e9b09ffffffffffff1f6a00000000000000000000000000000000"), + Hex("0000000000000003"): Hex("161b252ab670360d0200000000000009ca00000000000000000000000000000000"), + }, + }, + }, + "uncharge-swaps": map[string]interface{}{ + Hex("c3b3d7a145dbd2bab5aa1f505305f31ee432fe23b0801f065fac453dd9b1f923"): map[string]interface{}{ + "contract": Hex("161b2526643767387ca76e58c964a8f2b6c0a13392b2dea93bde260226a263fb836954054ed1756b000000000000c350fd11016c6e6263727431333337306e3170303072343775707035366c7671663836753565766135647868686c706c78303733756a70676e3979767977376130766a37746d307678793276683576716471327770657832757270307963717a7279787139377a76757173703570373232733970686a6e6e6e706c3778716e796a78353373706863346c396735306b396e347836703761793577707539306b6673397179397173717a353766676a7a67676838343439377375716b383436787a3333336a713036736c6b38637a323872657466363672796b7876396a746e6a3072683979666a6170777065617265713071396679797a666664676d6874687973617370757565746e6b72306b32376370326173366a750269d66fd2cea620dc06f1f7de7838f0c8b145b82c7033080c398862f3421a23230382cb637badbb07f9926a06ecd88b6150513ea0060dc8d6dc1c1fb623926b0a0f000077d400000000000b458c00000000000005f10000000000000024000077a22c6263727431713271756332666777737971376463617a73666e3332636a7874667671647671366a6c70706574fd0f016c6e626372743530313834306e317030307234377570703563776561306732396d30667434646432726167397870306e726d6a72396c33726b7a717037706a6c34337a6e6d6b64336c79337364713877646d6b7a757163717a7279787139377a767571737035616478717538766168643730743776747165777578366d6d64337977636639767835736476717567753833327230676e373466733971793971737168746773636638386e377664767136716e71307a657775366d7471616e326c7a306e7534737a72376c6b36646d343673336c78726572656e333972616b7a6c777378346c613538733966773630356d6767766b766879716e743339713976737367777879367571707236713273780000000600000000000003f20000000000000000161b25262710ce00"), + "outgoing-chan-set": nil, + "updates": map[string]interface{}{ + Hex("0000000000000001"): Hex("161b252a770e649b01000000000000053900000000000000000000000000000001"), + Hex("0000000000000002"): Hex("161b252ab671bdd90200000000000005f10000000000001a9c0000000000000003"), + }, + }, + }, + } + + // Restore a legacy database. + tempDirName, err := ioutil.TempDir("", "clientstore") + require.NoError(t, err) + defer os.RemoveAll(tempDirName) + + tempPath := filepath.Join(tempDirName, dbFileName) + db, err := bbolt.Open(tempPath, 0600, nil) + require.NoError(t, err) + + err = db.Update(func(tx *bbolt.Tx) error { + return RestoreDB(tx, legacyDb) + }) + + // Close database regardless of update result. + db.Close() + + // Assert update was successful. + require.NoError(t, err) + + // Open db and migrate to the latest version. + store, err := NewBoltSwapStore(tempDirName, &chaincfg.MainNetParams) + require.NoError(t, err) + + // Fetch the legacy loop out swap and assert that the updates are still + // there. + outSwaps, err := store.FetchLoopOutSwaps() + require.NoError(t, err) + + outSwap := outSwaps[0] + require.Len(t, outSwap.Events, 2) + require.Equal(t, StateSuccess, outSwap.Events[1].State) + + // Fetch the legacy loop in swap and assert that the updates are still + // there. + inSwaps, err := store.FetchLoopInSwaps() + require.NoError(t, err) + + inSwap := inSwaps[0] + require.Len(t, inSwap.Events, 3) + require.Equal(t, StateSuccess, outSwap.Events[1].State) +} diff --git a/loopdb/raw_db_test.go b/loopdb/raw_db_test.go index 358cedb..52df920 100644 --- a/loopdb/raw_db_test.go +++ b/loopdb/raw_db_test.go @@ -2,7 +2,6 @@ package loopdb import ( "encoding/hex" - "errors" "fmt" "strings" @@ -88,6 +87,15 @@ func restoreDB(bucket *bbolt.Bucket, data map[string]interface{}) error { for k, v := range data { key := []byte(k) + // Store nil values. + if v == nil { + err := bucket.Put(key, nil) + if err != nil { + return err + } + continue + } + switch value := v.(type) { // Key contains value. @@ -109,7 +117,7 @@ func restoreDB(bucket *bbolt.Bucket, data map[string]interface{}) error { } default: - return errors.New("invalid type") + return fmt.Errorf("invalid type %T", value) } } diff --git a/loopdb/store.go b/loopdb/store.go index 505809d..1d48904 100644 --- a/loopdb/store.go +++ b/loopdb/store.go @@ -11,6 +11,7 @@ import ( "time" "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/coreos/bbolt" "github.com/lightningnetwork/lnd/lntypes" ) @@ -45,6 +46,12 @@ var ( // maps: updateNumber -> time || state updatesBucketKey = []byte("updates") + // basicStateKey contains the serialized basic swap state. + basicStateKey = []byte{0} + + // htlcTxHashKey contains the confirmed htlc tx id. + htlcTxHashKey = []byte{1} + // contractKey is the key that stores the serialized swap contract. It // is nested within the sub-bucket for each active swap. // @@ -265,12 +272,32 @@ func deserializeUpdates(swapBucket *bbolt.Bucket) ([]*LoopEvent, error) { // Deserialize and collect each swap update into our slice of swap // events. var updates []*LoopEvent - err := stateBucket.ForEach(func(_, v []byte) error { - event, err := deserializeLoopEvent(v) + err := stateBucket.ForEach(func(k, v []byte) error { + updateBucket := stateBucket.Bucket(k) + if updateBucket == nil { + return fmt.Errorf("expected state sub-bucket for %x", k) + } + + basicState := updateBucket.Get(basicStateKey) + if basicState == nil { + return errors.New("no basic state for update") + } + + event, err := deserializeLoopEvent(basicState) if err != nil { return err } + // Deserialize htlc tx hash if this updates contains one. + htlcTxHashBytes := updateBucket.Get(htlcTxHashKey) + if htlcTxHashBytes != nil { + htlcTxHash, err := chainhash.NewHash(htlcTxHashBytes) + if err != nil { + return err + } + event.HtlcTxHash = htlcTxHash + } + updates = append(updates, event) return nil }) @@ -482,24 +509,45 @@ func (s *boltSwapStore) updateLoop(bucketKey []byte, hash lntypes.Hash, if swapBucket == nil { return errors.New("swap not found") } - updateBucket := swapBucket.Bucket(updatesBucketKey) - if updateBucket == nil { + updatesBucket := swapBucket.Bucket(updatesBucketKey) + if updatesBucket == nil { return errors.New("udpate bucket not found") } // Each update for this swap will get a new monotonically // increasing ID number that we'll obtain now. - id, err := updateBucket.NextSequence() + id, err := updatesBucket.NextSequence() if err != nil { return err } + nextUpdateBucket, err := updatesBucket.CreateBucket(itob(id)) + if err != nil { + return fmt.Errorf("cannot create update bucket") + } + // With the ID obtained, we'll write out this new update value. updateValue, err := serializeLoopEvent(time, state) if err != nil { return err } - return updateBucket.Put(itob(id), updateValue) + + err = nextUpdateBucket.Put(basicStateKey, updateValue) + if err != nil { + return err + } + + // Write the htlc tx hash if available. + if state.HtlcTxHash != nil { + err := nextUpdateBucket.Put( + htlcTxHashKey, state.HtlcTxHash[:], + ) + if err != nil { + return err + } + } + + return nil }) } diff --git a/loopdb/store_test.go b/loopdb/store_test.go index 799a1d4..d373876 100644 --- a/loopdb/store_test.go +++ b/loopdb/store_test.go @@ -10,10 +10,12 @@ import ( "time" "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/coreos/bbolt" "github.com/lightninglabs/loop/test" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/routing/route" + "github.com/stretchr/testify/require" ) var ( @@ -130,6 +132,10 @@ func testLoopOutStore(t *testing.T, pendingSwap *LoopOutContract) { expectedState, swaps[0].State(), ) } + + if expectedState == StatePreimageRevealed { + require.NotNil(t, swaps[0].State().HtlcTxHash) + } } hash := pendingSwap.Preimage.Hash() @@ -152,7 +158,8 @@ func testLoopOutStore(t *testing.T, pendingSwap *LoopOutContract) { err = store.UpdateLoopOut( hash, testTime, SwapStateData{ - State: StatePreimageRevealed, + State: StatePreimageRevealed, + HtlcTxHash: &chainhash.Hash{1, 6, 2}, }, ) if err != nil { diff --git a/loopdb/swapstate.go b/loopdb/swapstate.go index eb10f3f..1227f16 100644 --- a/loopdb/swapstate.go +++ b/loopdb/swapstate.go @@ -1,6 +1,9 @@ package loopdb -import "github.com/btcsuite/btcutil" +import ( + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcutil" +) // SwapState indicates the current state of a swap. This enumeration is the // union of loop in and loop out states. A single type is used for both swap @@ -19,44 +22,44 @@ const ( // confirmed. This state will mostly coalesce with StateHtlcConfirmed, // except in the case where we wait for fees to come down before we // sweep. - StatePreimageRevealed = 1 + StatePreimageRevealed SwapState = 1 // StateSuccess is the final swap state that is reached when the sweep // tx has the required confirmation depth (SweepConfDepth) and the // server pulled the off-chain htlc. - StateSuccess = 2 + StateSuccess SwapState = 2 // StateFailOffchainPayments indicates that it wasn't possible to find // routes for one or both of the off-chain payments to the server that // satisfied the payment restrictions (fee and timelock limits). - StateFailOffchainPayments = 3 + StateFailOffchainPayments SwapState = 3 // StateFailTimeout indicates that the on-chain htlc wasn't confirmed // before its expiry or confirmed too late (MinPreimageRevealDelta // violated). - StateFailTimeout = 4 + StateFailTimeout SwapState = 4 // StateFailSweepTimeout indicates that the on-chain htlc wasn't swept // before the server revoked the htlc. The server didn't pull the // off-chain htlc (even though it could have) and we timed out the // off-chain htlc ourselves. No funds lost. - StateFailSweepTimeout = 5 + StateFailSweepTimeout SwapState = 5 // StateFailInsufficientValue indicates that the published on-chain htlc // had a value lower than the requested amount. - StateFailInsufficientValue = 6 + StateFailInsufficientValue SwapState = 6 // StateFailTemporary indicates that the swap cannot progress because // of an internal error. This is not a final state. Manual intervention // (like a restart) is required to solve this problem. - StateFailTemporary = 7 + StateFailTemporary SwapState = 7 // StateHtlcPublished means that the client published the on-chain htlc. - StateHtlcPublished = 8 + StateHtlcPublished SwapState = 8 // StateInvoiceSettled means that the swap invoice has been paid by the // server. - StateInvoiceSettled = 9 + StateInvoiceSettled SwapState = 9 ) // SwapStateType defines the types of swap states that exist. Every swap state @@ -147,4 +150,7 @@ type SwapStateData struct { // Cost are the accrued (final) costs so far. Cost SwapCost + + // HtlcTxHash is the tx id of the confirmed htlc. + HtlcTxHash *chainhash.Hash } diff --git a/loopout.go b/loopout.go index 4418f28..12af47e 100644 --- a/loopout.go +++ b/loopout.go @@ -8,6 +8,7 @@ import ( "fmt" "time" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" "github.com/lightninglabs/lndclient" @@ -56,6 +57,9 @@ type loopOutSwap struct { htlc *swap.Htlc + // htlcTxHash is the confirmed htlc tx id. + htlcTxHash *chainhash.Hash + swapPaymentChan chan lndclient.PaymentResult prePaymentChan chan lndclient.PaymentResult } @@ -210,6 +214,7 @@ func resumeLoopOutSwap(reqContext context.Context, cfg *swapConfig, } else { swap.state = lastUpdate.State swap.lastUpdateTime = lastUpdate.Time + swap.htlcTxHash = lastUpdate.HtlcTxHash } return swap, nil @@ -376,6 +381,7 @@ func (s *loopOutSwap) executeSwap(globalCtx context.Context) error { // Try to spend htlc and continue (rbf) until a spend has confirmed. spendDetails, err := s.waitForHtlcSpendConfirmed(globalCtx, + *htlcOutpoint, func() error { return s.sweep(globalCtx, *htlcOutpoint, htlcValue) }, @@ -419,8 +425,9 @@ func (s *loopOutSwap) persistState(ctx context.Context) error { err := s.store.UpdateLoopOut( s.hash, updateTime, loopdb.SwapStateData{ - State: s.state, - Cost: s.cost, + State: s.state, + Cost: s.cost, + HtlcTxHash: s.htlcTxHash, }, ) if err != nil { @@ -563,11 +570,21 @@ func (s *loopOutSwap) waitForConfirmedHtlc(globalCtx context.Context) ( s.InitiationHeight, ) + // If we've revealed the preimage in a previous run, we expect to have + // recorded the htlc tx hash. We use this to re-register for + // confirmation, to be sure that we'll keep tracking the same htlc. For + // older swaps, this field may not be populated even though the preimage + // has already been revealed. + if s.state == loopdb.StatePreimageRevealed && s.htlcTxHash == nil { + s.log.Warnf("No htlc tx hash available, registering with " + + "just the pkscript") + } + ctx, cancel := context.WithCancel(globalCtx) defer cancel() htlcConfChan, htlcErrChan, err := s.lnd.ChainNotifier.RegisterConfirmationsNtfn( - ctx, nil, s.htlc.PkScript, 1, + ctx, s.htlcTxHash, s.htlc.PkScript, 1, s.InitiationHeight, ) if err != nil { @@ -680,8 +697,10 @@ func (s *loopOutSwap) waitForConfirmedHtlc(globalCtx context.Context) ( } } - s.log.Infof("Htlc tx %v at height %v", txConf.Tx.TxHash(), - txConf.BlockHeight) + htlcTxHash := txConf.Tx.TxHash() + s.log.Infof("Htlc tx %v at height %v", htlcTxHash, txConf.BlockHeight) + + s.htlcTxHash = &htlcTxHash return txConf, nil } @@ -694,13 +713,14 @@ func (s *loopOutSwap) waitForConfirmedHtlc(globalCtx context.Context) ( // sweep offchain. So we must make sure we sweep successfully before on-chain // timeout. func (s *loopOutSwap) waitForHtlcSpendConfirmed(globalCtx context.Context, - spendFunc func() error) (*chainntnfs.SpendDetail, error) { + htlc wire.OutPoint, spendFunc func() error) (*chainntnfs.SpendDetail, + error) { // Register the htlc spend notification. ctx, cancel := context.WithCancel(globalCtx) defer cancel() spendChan, spendErr, err := s.lnd.ChainNotifier.RegisterSpendNtfn( - ctx, nil, s.htlc.PkScript, s.InitiationHeight, + ctx, &htlc, s.htlc.PkScript, s.InitiationHeight, ) if err != nil { return nil, fmt.Errorf("register spend ntfn: %v", err) diff --git a/loopout_test.go b/loopout_test.go index fbbddfe..2263ff0 100644 --- a/loopout_test.go +++ b/loopout_test.go @@ -115,7 +115,7 @@ func TestLoopOutPaymentParameters(t *testing.T) { // Swap is expected to register for confirmation of the htlc. Assert // this to prevent a blocked channel in the mock. - ctx.AssertRegisterConf() + ctx.AssertRegisterConf(false) // Cancel the swap. There is nothing else we need to assert. The payment // parameters don't play a role in the remainder of the swap process. @@ -187,7 +187,7 @@ func TestLateHtlcPublish(t *testing.T) { signalPrepaymentResult := ctx.AssertPaid(prepayInvoiceDesc) // Expect client to register for conf - ctx.AssertRegisterConf() + ctx.AssertRegisterConf(false) // // Wait too long before publishing htlc. blockEpochChan <- int32(swap.CltvExpiry - 10) @@ -283,7 +283,7 @@ func TestCustomSweepConfTarget(t *testing.T) { signalPrepaymentResult(nil) // Notify the confirmation notification for the HTLC. - ctx.AssertRegisterConf() + ctx.AssertRegisterConf(false) blockEpochChan <- ctx.Lnd.Height + 1 @@ -484,7 +484,7 @@ func TestPreimagePush(t *testing.T) { signalPrepaymentResult(nil) // Notify the confirmation notification for the HTLC. - ctx.AssertRegisterConf() + ctx.AssertRegisterConf(false) blockEpochChan <- ctx.Lnd.Height + 1 @@ -529,7 +529,7 @@ func TestPreimagePush(t *testing.T) { cfg.store.(*storeMock).assertLoopOutState(loopdb.StatePreimageRevealed) status := <-statusChan require.Equal( - t, status.State, loopdb.SwapState(loopdb.StatePreimageRevealed), + t, status.State, loopdb.StatePreimageRevealed, ) // We expect the sweep tx to have been published. @@ -578,7 +578,7 @@ func TestPreimagePush(t *testing.T) { cfg.store.(*storeMock).assertLoopOutState(loopdb.StateSuccess) status = <-statusChan require.Equal( - t, status.State, loopdb.SwapState(loopdb.StateSuccess), + t, status.State, loopdb.StateSuccess, ) require.NoError(t, <-errChan) diff --git a/test/context.go b/test/context.go index 81740c4..146573c 100644 --- a/test/context.go +++ b/test/context.go @@ -113,14 +113,18 @@ func (ctx *Context) AssertTrackPayment() TrackPaymentMessage { } // AssertRegisterConf asserts that a register for conf has been received. -func (ctx *Context) AssertRegisterConf() *ConfRegistration { +func (ctx *Context) AssertRegisterConf(expectTxHash bool) *ConfRegistration { ctx.T.Helper() // Expect client to register for conf var confIntent *ConfRegistration select { case confIntent = <-ctx.Lnd.RegisterConfChannel: - if confIntent.TxID != nil { + switch { + case expectTxHash && confIntent.TxID == nil: + ctx.T.Fatalf("expected tx id for registration") + + case !expectTxHash && confIntent.TxID != nil: ctx.T.Fatalf("expected script only registration") } case <-time.After(Timeout):