diff --git a/client_test.go b/client_test.go index 68822ac..fc840e0 100644 --- a/client_test.go +++ b/client_test.go @@ -12,6 +12,7 @@ import ( "github.com/btcsuite/btcutil" "github.com/lightninglabs/lndclient" "github.com/lightninglabs/loop/loopdb" + "github.com/lightninglabs/loop/swap" "github.com/lightninglabs/loop/test" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lntypes" @@ -68,7 +69,7 @@ func TestSuccess(t *testing.T) { testSuccess(ctx, testRequest.Amount, info.SwapHash, signalPrepaymentResult, signalSwapPaymentResult, false, - confIntent, + confIntent, swap.HtlcV2, ) } @@ -150,22 +151,46 @@ func TestResume(t *testing.T) { defaultConfs := loopdb.DefaultLoopOutHtlcConfirmations - t.Run("not expired", func(t *testing.T) { - testResume(t, defaultConfs, false, false, true) - }) - t.Run("not expired, custom confirmations", func(t *testing.T) { - testResume(t, 3, false, false, true) - }) - t.Run("expired not revealed", func(t *testing.T) { - testResume(t, defaultConfs, true, false, false) - }) - t.Run("expired revealed", func(t *testing.T) { - testResume(t, defaultConfs, true, true, true) - }) + storedVersion := []loopdb.ProtocolVersion{ + loopdb.ProtocolVersionUnrecorded, + loopdb.ProtocolVersionHtlcV2, + } + + for _, version := range storedVersion { + version := version + + t.Run(version.String(), func(t *testing.T) { + t.Run("not expired", func(t *testing.T) { + testResume( + t, defaultConfs, false, false, true, + version, + ) + }) + t.Run("not expired, custom confirmations", + func(t *testing.T) { + testResume( + t, 3, false, false, true, + version, + ) + }) + t.Run("expired not revealed", func(t *testing.T) { + testResume( + t, defaultConfs, true, false, false, + version, + ) + }) + t.Run("expired revealed", func(t *testing.T) { + testResume( + t, defaultConfs, true, true, true, + version, + ) + }) + }) + } } func testResume(t *testing.T, confs uint32, expired, preimageRevealed, - expectSuccess bool) { + expectSuccess bool, protocolVersion loopdb.ProtocolVersion) { defer test.Guard(t)() @@ -222,6 +247,7 @@ func testResume(t *testing.T, confs uint32, expired, preimageRevealed, SenderKey: senderKey, MaxSwapFee: 60000, MaxMinerFee: 50000, + ProtocolVersion: protocolVersion, }, }, Loop: loopdb.Loop{ @@ -250,6 +276,15 @@ func testResume(t *testing.T, confs uint32, expired, preimageRevealed, // Expect client to register for our expected number of confirmations. confIntent := ctx.AssertRegisterConf(preimageRevealed, int32(confs)) + // Assert that the loopout htlc equals to the expected one. + scriptVersion := GetHtlcScriptVersion(protocolVersion) + htlc, err := swap.NewHtlc( + scriptVersion, pendingSwap.Contract.CltvExpiry, senderKey, + receiverKey, hash, swap.HtlcP2WSH, &chaincfg.TestNet3Params, + ) + require.NoError(t, err) + require.Equal(t, htlc.PkScript, confIntent.PkScript) + signalSwapPaymentResult(nil) signalPrepaymentResult(nil) @@ -267,13 +302,14 @@ func testResume(t *testing.T, confs uint32, expired, preimageRevealed, func(r error) {}, func(r error) {}, preimageRevealed, - confIntent, + confIntent, scriptVersion, ) } func testSuccess(ctx *testContext, amt btcutil.Amount, hash lntypes.Hash, signalPrepaymentResult, signalSwapPaymentResult func(error), - preimageRevealed bool, confIntent *test.ConfRegistration) { + preimageRevealed bool, confIntent *test.ConfRegistration, + scriptVersion swap.ScriptVersion) { htlcOutpoint := ctx.publishHtlc(confIntent.PkScript, amt) @@ -304,8 +340,13 @@ func testSuccess(ctx *testContext, amt btcutil.Amount, hash lntypes.Hash, ctx.T.Fatalf("client not sweeping from htlc tx") } + preImageIndex := 1 + if scriptVersion == swap.HtlcV2 { + preImageIndex = 0 + } + // Check preimage. - clientPreImage := sweepTx.TxIn[0].Witness[1] + clientPreImage := sweepTx.TxIn[0].Witness[preImageIndex] clientPreImageHash := sha256.Sum256(clientPreImage) if clientPreImageHash != hash { ctx.T.Fatalf("incorrect preimage") diff --git a/loopin_test.go b/loopin_test.go index 7b08bcc..f457842 100644 --- a/loopin_test.go +++ b/loopin_test.go @@ -282,20 +282,48 @@ func testLoopInTimeout(t *testing.T, // TestLoopInResume tests resuming swaps in various states. func TestLoopInResume(t *testing.T) { - t.Run("initiated", func(t *testing.T) { - testLoopInResume(t, loopdb.StateInitiated, false) - }) + storedVersion := []loopdb.ProtocolVersion{ + loopdb.ProtocolVersionUnrecorded, + loopdb.ProtocolVersionHtlcV2, + } - t.Run("initiated expired", func(t *testing.T) { - testLoopInResume(t, loopdb.StateInitiated, true) - }) + htlcVersion := []swap.ScriptVersion{ + swap.HtlcV1, + swap.HtlcV2, + } - t.Run("htlc published", func(t *testing.T) { - testLoopInResume(t, loopdb.StateHtlcPublished, false) - }) + for i, version := range storedVersion { + version := version + scriptVersion := htlcVersion[i] + + t.Run(version.String(), func(t *testing.T) { + t.Run("initiated", func(t *testing.T) { + testLoopInResume( + t, loopdb.StateInitiated, false, + version, scriptVersion, + ) + }) + + t.Run("initiated expired", func(t *testing.T) { + testLoopInResume( + t, loopdb.StateInitiated, true, + version, scriptVersion, + ) + }) + + t.Run("htlc published", func(t *testing.T) { + testLoopInResume( + t, loopdb.StateHtlcPublished, false, + version, scriptVersion, + ) + }) + }) + } } -func testLoopInResume(t *testing.T, state loopdb.SwapState, expired bool) { +func testLoopInResume(t *testing.T, state loopdb.SwapState, expired bool, + storedVersion loopdb.ProtocolVersion, scriptVersion swap.ScriptVersion) { + defer test.Guard(t)() ctx := newLoopInTestContext(t) @@ -314,6 +342,7 @@ func testLoopInResume(t *testing.T, state loopdb.SwapState, expired bool) { SenderKey: senderKey, MaxSwapFee: 60000, MaxMinerFee: 50000, + ProtocolVersion: storedVersion, }, } pendSwap := &loopdb.LoopIn{ @@ -331,7 +360,7 @@ func testLoopInResume(t *testing.T, state loopdb.SwapState, expired bool) { } htlc, err := swap.NewHtlc( - swap.HtlcV1, contract.CltvExpiry, contract.SenderKey, + scriptVersion, contract.CltvExpiry, contract.SenderKey, contract.ReceiverKey, testPreimage.Hash(), swap.HtlcNP2WSH, cfg.lnd.ChainParams, )