reservation: update reservation state machine

This commit updates the reservation statemachine to
allow for locking and spending of the
initial reservation.
pull/651/head
sputn1ck 8 months ago
parent 112e612c7a
commit 89b5c00cfa
No known key found for this signature in database
GPG Key ID: 671103D881A5F0E4

@ -7,6 +7,7 @@ import (
"github.com/btcsuite/btcd/btcutil"
"github.com/lightninglabs/loop/fsm"
looprpc "github.com/lightninglabs/loop/swapserverrpc"
"github.com/lightningnetwork/lnd/chainntnfs"
)
// InitReservationContext contains the request parameters for a reservation.
@ -21,18 +22,18 @@ type InitReservationContext struct {
// InitAction is the action that is executed when the reservation state machine
// is initialized. It creates the reservation in the database and dispatches the
// payment to the server.
func (r *FSM) InitAction(eventCtx fsm.EventContext) fsm.EventType {
func (f *FSM) InitAction(eventCtx fsm.EventContext) fsm.EventType {
// Check if the context is of the correct type.
reservationRequest, ok := eventCtx.(*InitReservationContext)
if !ok {
return r.HandleError(fsm.ErrInvalidContextType)
return f.HandleError(fsm.ErrInvalidContextType)
}
keyRes, err := r.cfg.Wallet.DeriveNextKey(
r.ctx, KeyFamily,
keyRes, err := f.cfg.Wallet.DeriveNextKey(
f.ctx, KeyFamily,
)
if err != nil {
return r.HandleError(err)
return f.HandleError(err)
}
// Send the client reservation details to the server.
@ -44,9 +45,9 @@ func (r *FSM) InitAction(eventCtx fsm.EventContext) fsm.EventType {
ClientKey: keyRes.PubKey.SerializeCompressed(),
}
_, err = r.cfg.ReservationClient.OpenReservation(r.ctx, request)
_, err = f.cfg.ReservationClient.OpenReservation(f.ctx, request)
if err != nil {
return r.HandleError(err)
return f.HandleError(err)
}
reservation, err := NewReservation(
@ -59,15 +60,15 @@ func (r *FSM) InitAction(eventCtx fsm.EventContext) fsm.EventType {
keyRes.KeyLocator,
)
if err != nil {
return r.HandleError(err)
return f.HandleError(err)
}
r.reservation = reservation
f.reservation = reservation
// Create the reservation in the database.
err = r.cfg.Store.CreateReservation(r.ctx, reservation)
err = f.cfg.Store.CreateReservation(f.ctx, reservation)
if err != nil {
return r.HandleError(err)
return f.HandleError(err)
}
return OnBroadcast
@ -76,101 +77,163 @@ func (r *FSM) InitAction(eventCtx fsm.EventContext) fsm.EventType {
// SubscribeToConfirmationAction is the action that is executed when the
// reservation is waiting for confirmation. It subscribes to the confirmation
// of the reservation transaction.
func (r *FSM) SubscribeToConfirmationAction(_ fsm.EventContext) fsm.EventType {
pkscript, err := r.reservation.GetPkScript()
func (f *FSM) SubscribeToConfirmationAction(_ fsm.EventContext) fsm.EventType {
pkscript, err := f.reservation.GetPkScript()
if err != nil {
return r.HandleError(err)
return f.HandleError(err)
}
callCtx, cancel := context.WithCancel(r.ctx)
callCtx, cancel := context.WithCancel(f.ctx)
defer cancel()
// Subscribe to the confirmation of the reservation transaction.
log.Debugf("Subscribing to conf for reservation: %x pkscript: %x, "+
"initiation height: %v", r.reservation.ID, pkscript,
r.reservation.InitiationHeight)
"initiation height: %v", f.reservation.ID, pkscript,
f.reservation.InitiationHeight)
confChan, errConfChan, err := r.cfg.ChainNotifier.RegisterConfirmationsNtfn(
confChan, errConfChan, err := f.cfg.ChainNotifier.RegisterConfirmationsNtfn(
callCtx, nil, pkscript, DefaultConfTarget,
r.reservation.InitiationHeight,
f.reservation.InitiationHeight,
)
if err != nil {
r.Errorf("unable to subscribe to conf notification: %v", err)
return r.HandleError(err)
f.Errorf("unable to subscribe to conf notification: %v", err)
return f.HandleError(err)
}
blockChan, errBlockChan, err := r.cfg.ChainNotifier.RegisterBlockEpochNtfn(
blockChan, errBlockChan, err := f.cfg.ChainNotifier.RegisterBlockEpochNtfn(
callCtx,
)
if err != nil {
r.Errorf("unable to subscribe to block notifications: %v", err)
return r.HandleError(err)
f.Errorf("unable to subscribe to block notifications: %v", err)
return f.HandleError(err)
}
// We'll now wait for the confirmation of the reservation transaction.
for {
select {
case err := <-errConfChan:
r.Errorf("conf subscription error: %v", err)
return r.HandleError(err)
f.Errorf("conf subscription error: %v", err)
return f.HandleError(err)
case err := <-errBlockChan:
r.Errorf("block subscription error: %v", err)
return r.HandleError(err)
f.Errorf("block subscription error: %v", err)
return f.HandleError(err)
case confInfo := <-confChan:
r.Debugf("reservation confirmed: %v", confInfo)
outpoint, err := r.reservation.findReservationOutput(
f.Debugf("confirmed in block %v", confInfo.Block)
outpoint, err := f.reservation.findReservationOutput(
confInfo.Tx,
)
if err != nil {
return r.HandleError(err)
return f.HandleError(err)
}
r.reservation.ConfirmationHeight = confInfo.BlockHeight
r.reservation.Outpoint = outpoint
f.reservation.ConfirmationHeight = confInfo.BlockHeight
f.reservation.Outpoint = outpoint
return OnConfirmed
case block := <-blockChan:
r.Debugf("block received: %v expiry: %v", block,
r.reservation.Expiry)
f.Debugf("block received: %v expiry: %v", block,
f.reservation.Expiry)
if uint32(block) >= r.reservation.Expiry {
if uint32(block) >= f.reservation.Expiry {
return OnTimedOut
}
case <-r.ctx.Done():
case <-f.ctx.Done():
return fsm.NoOp
}
}
}
// ReservationConfirmedAction waits for the reservation to be either expired or
// waits for other actions to happen.
func (r *FSM) ReservationConfirmedAction(_ fsm.EventContext) fsm.EventType {
blockHeightChan, errEpochChan, err := r.cfg.ChainNotifier.
RegisterBlockEpochNtfn(r.ctx)
// AsyncWaitForExpiredOrSweptAction waits for the reservation to be either
// expired or swept. This is non-blocking and can be used to wait for the
// reservation to expire while expecting other events.
func (f *FSM) AsyncWaitForExpiredOrSweptAction(_ fsm.EventContext,
) fsm.EventType {
notifCtx, cancel := context.WithCancel(f.ctx)
blockHeightChan, errEpochChan, err := f.cfg.ChainNotifier.
RegisterBlockEpochNtfn(notifCtx)
if err != nil {
return r.HandleError(err)
cancel()
return f.HandleError(err)
}
pkScript, err := f.reservation.GetPkScript()
if err != nil {
cancel()
return f.HandleError(err)
}
spendChan, errSpendChan, err := f.cfg.ChainNotifier.RegisterSpendNtfn(
notifCtx, f.reservation.Outpoint, pkScript,
f.reservation.InitiationHeight,
)
if err != nil {
cancel()
return f.HandleError(err)
}
go func() {
defer cancel()
op, err := f.handleSubcriptions(
notifCtx, blockHeightChan, spendChan, errEpochChan,
errSpendChan,
)
if err != nil {
f.handleAsyncError(err)
return
}
if op == fsm.NoOp {
return
}
err = f.SendEvent(op, nil)
if err != nil {
f.Errorf("Error sending %s event: %v", op, err)
}
}()
return fsm.NoOp
}
func (f *FSM) handleSubcriptions(ctx context.Context,
blockHeightChan <-chan int32, spendChan <-chan *chainntnfs.SpendDetail,
errEpochChan <-chan error, errSpendChan <-chan error,
) (fsm.EventType, error) {
for {
select {
case err := <-errEpochChan:
return r.HandleError(err)
return fsm.OnError, err
case err := <-errSpendChan:
return fsm.OnError, err
case blockHeight := <-blockHeightChan:
expired := blockHeight >= int32(r.reservation.Expiry)
if expired {
r.Debugf("Reservation %v expired",
r.reservation.ID)
expired := blockHeight >= int32(f.reservation.Expiry)
return OnTimedOut
if expired {
f.Debugf("Reservation expired")
return OnTimedOut, nil
}
case <-r.ctx.Done():
return fsm.NoOp
case <-spendChan:
return OnSpent, nil
case <-ctx.Done():
return fsm.NoOp, nil
}
}
}
func (f *FSM) handleAsyncError(err error) {
f.LastActionError = err
f.Errorf("Error on async action: %v", err)
err2 := f.SendEvent(fsm.OnError, err)
if err2 != nil {
f.Errorf("Error sending event: %v", err2)
}
}

@ -129,6 +129,7 @@ func TestInitReservationAction(t *testing.T) {
}
for _, tc := range tests {
tc := tc
ctxb := context.Background()
mockLnd := test.NewMockLnd()
mockReservationClient := new(mockReservationClient)
@ -223,6 +224,7 @@ func TestSubscribeToConfirmationAction(t *testing.T) {
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
chainNotifier := new(MockChainNotifier)
@ -304,14 +306,83 @@ func TestSubscribeToConfirmationAction(t *testing.T) {
}
}
// TestReservationConfirmedAction tests the ReservationConfirmedAction of the
// AsyncWaitForExpiredOrSweptAction tests the AsyncWaitForExpiredOrSweptAction
// of the reservation state machine.
func TestAsyncWaitForExpiredOrSweptAction(t *testing.T) {
tests := []struct {
name string
blockErr error
spendErr error
expectedEvent fsm.EventType
}{
{
name: "noop",
expectedEvent: fsm.NoOp,
},
{
name: "block error",
blockErr: errors.New("block error"),
expectedEvent: fsm.OnError,
},
{
name: "spend error",
spendErr: errors.New("spend error"),
expectedEvent: fsm.OnError,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) { // Create a mock ChainNotifier and Reservation
chainNotifier := new(MockChainNotifier)
// Define your FSM
r := NewFSMFromReservation(
context.Background(), &Config{
ChainNotifier: chainNotifier,
},
&Reservation{
ServerPubkey: defaultPubkey,
ClientPubkey: defaultPubkey,
Expiry: defaultExpiry,
},
)
// Define the expected return values for your mocks
chainNotifier.On("RegisterBlockEpochNtfn", mock.Anything).Return(
make(chan int32), make(chan error), tc.blockErr,
)
chainNotifier.On(
"RegisterSpendNtfn", mock.Anything,
mock.Anything, mock.Anything,
).Return(
make(chan *chainntnfs.SpendDetail),
make(chan error), tc.spendErr,
)
eventType := r.AsyncWaitForExpiredOrSweptAction(nil)
// Assert that the return value is as expected
require.Equal(t, tc.expectedEvent, eventType)
})
}
}
// TesthandleSubcriptions tests the handleSubcriptions function of the
// reservation state machine.
func TestReservationConfirmedAction(t *testing.T) {
func TestHandleSubcriptions(t *testing.T) {
var (
blockErr = errors.New("block error")
spendErr = errors.New("spend error")
)
tests := []struct {
name string
blockHeight int32
blockErr error
spendDetail *chainntnfs.SpendDetail
spendErr error
expectedEvent fsm.EventType
expectedErr error
}{
{
name: "expired",
@ -320,13 +391,25 @@ func TestReservationConfirmedAction(t *testing.T) {
},
{
name: "block error",
blockHeight: 0,
blockErr: errors.New("block error"),
blockErr: blockErr,
expectedEvent: fsm.OnError,
expectedErr: blockErr,
},
{
name: "spent",
spendDetail: &chainntnfs.SpendDetail{},
expectedEvent: OnSpent,
},
{
name: "spend error",
spendErr: spendErr,
expectedEvent: fsm.OnError,
expectedErr: spendErr,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
chainNotifier := new(MockChainNotifier)
@ -336,36 +419,41 @@ func TestReservationConfirmedAction(t *testing.T) {
ChainNotifier: chainNotifier,
},
&Reservation{
Expiry: defaultExpiry,
ServerPubkey: defaultPubkey,
ClientPubkey: defaultPubkey,
Expiry: defaultExpiry,
},
)
blockChan := make(chan int32)
blockErrChan := make(chan error)
// Define our expected return values for the mocks.
chainNotifier.On("RegisterBlockEpochNtfn", mock.Anything).Return(
blockChan, blockErrChan, nil,
)
spendChan := make(chan *chainntnfs.SpendDetail)
spendErrChan := make(chan error)
go func() {
// Send the block notification.
if tc.blockHeight != 0 {
blockChan <- tc.blockHeight
}
}()
go func() {
// Send the block notification error.
if tc.blockErr != nil {
blockErrChan <- tc.blockErr
}
if tc.spendDetail != nil {
spendChan <- tc.spendDetail
}
if tc.spendErr != nil {
spendErrChan <- tc.spendErr
}
}()
eventType := r.ReservationConfirmedAction(nil)
eventType, err := r.handleSubcriptions(
context.Background(), blockChan, spendChan,
blockErrChan, spendErrChan,
)
require.Equal(t, tc.expectedErr, err)
require.Equal(t, tc.expectedEvent, eventType)
// Assert that the expected functions were called on the mocks
chainNotifier.AssertExpectations(t)
})
}
}

@ -123,6 +123,18 @@ var (
// OnRecover is the event that is triggered when the reservation FSM
// recovers from a restart.
OnRecover = fsm.EventType("OnRecover")
// OnSpent is the event that is triggered when the reservation has been
// spent.
OnSpent = fsm.EventType("OnSpent")
// OnLocked is the event that is triggered when the reservation has
// been locked.
OnLocked = fsm.EventType("OnLocked")
// OnUnlocked is the event that is triggered when the reservation has
// been unlocked.
OnUnlocked = fsm.EventType("OnUnlocked")
)
// GetReservationStates returns the statemap that defines the reservation
@ -153,14 +165,38 @@ func (f *FSM) GetReservationStates() fsm.States {
},
Confirmed: fsm.State{
Transitions: fsm.Transitions{
OnTimedOut: TimedOut,
OnRecover: Confirmed,
OnSpent: Spent,
OnTimedOut: TimedOut,
OnRecover: Confirmed,
OnLocked: Locked,
fsm.OnError: Confirmed,
},
Action: f.AsyncWaitForExpiredOrSweptAction,
},
Locked: fsm.State{
Transitions: fsm.Transitions{
OnUnlocked: Confirmed,
OnTimedOut: TimedOut,
OnRecover: Locked,
OnSpent: Spent,
fsm.OnError: Locked,
},
Action: f.ReservationConfirmedAction,
Action: f.AsyncWaitForExpiredOrSweptAction,
},
TimedOut: fsm.State{
Transitions: fsm.Transitions{
OnTimedOut: TimedOut,
},
Action: fsm.NoOpAction,
},
Spent: fsm.State{
Transitions: fsm.Transitions{
OnSpent: Spent,
},
Action: fsm.NoOpAction,
},
Failed: fsm.State{
Action: fsm.NoOpAction,
},

@ -3,6 +3,7 @@ package reservation
import (
"context"
"fmt"
"strings"
"sync"
"time"
@ -35,7 +36,6 @@ func NewManager(cfg *Config) *Manager {
// Run runs the reservation manager.
func (m *Manager) Run(ctx context.Context, height int32) error {
// todo(sputn1ck): recover swaps on startup
log.Debugf("Starting reservation manager")
runCtx, cancel := context.WithCancel(ctx)
@ -269,3 +269,54 @@ func (m *Manager) RecoverReservations(ctx context.Context) error {
func (m *Manager) GetReservations(ctx context.Context) ([]*Reservation, error) {
return m.cfg.Store.ListReservations(ctx)
}
// GetReservation returns the reservation for the given id.
func (m *Manager) GetReservation(ctx context.Context, id ID) (*Reservation,
error) {
return m.cfg.Store.GetReservation(ctx, id)
}
// LockReservation locks the reservation with the given ID.
func (m *Manager) LockReservation(ctx context.Context, id ID) error {
// Try getting the reservation from the active reservations map.
m.Lock()
reservation, ok := m.activeReservations[id]
m.Unlock()
if !ok {
return fmt.Errorf("reservation not found")
}
// Try to send the lock event to the reservation.
err := reservation.SendEvent(OnLocked, nil)
if err != nil {
return err
}
return nil
}
// UnlockReservation unlocks the reservation with the given ID.
func (m *Manager) UnlockReservation(ctx context.Context, id ID) error {
// Try getting the reservation from the active reservations map.
m.Lock()
reservation, ok := m.activeReservations[id]
m.Unlock()
if !ok {
return fmt.Errorf("reservation not found")
}
// Try to send the unlock event to the reservation.
err := reservation.SendEvent(OnUnlocked, nil)
if err != nil && strings.Contains(err.Error(), "config error") {
// If the error is a config error, we can ignore it, as the
// reservation is already unlocked.
return nil
} else if err != nil {
return err
}
return nil
}

@ -33,7 +33,7 @@ func TestManager(t *testing.T) {
}()
// Create a new reservation.
fsm, err := testContext.manager.newReservation(
reservationFSM, err := testContext.manager.newReservation(
ctxb, uint32(testContext.mockLnd.Height),
&swapserverrpc.ServerReservationNotification{
ReservationId: defaultReservationId[:],
@ -45,11 +45,11 @@ func TestManager(t *testing.T) {
require.NoError(t, err)
// We'll expect the spendConfirmation to be sent to the server.
pkScript, err := fsm.reservation.GetPkScript()
pkScript, err := reservationFSM.reservation.GetPkScript()
require.NoError(t, err)
conf := <-testContext.mockLnd.RegisterConfChannel
require.Equal(t, conf.PkScript, pkScript)
confReg := <-testContext.mockLnd.RegisterConfChannel
require.Equal(t, confReg.PkScript, pkScript)
confTx := &wire.MsgTx{
TxOut: []*wire.TxOut{
@ -59,23 +59,39 @@ func TestManager(t *testing.T) {
},
}
// We'll now confirm the spend.
conf.ConfChan <- &chainntnfs.TxConfirmation{
confReg.ConfChan <- &chainntnfs.TxConfirmation{
BlockHeight: uint32(testContext.mockLnd.Height),
Tx: confTx,
}
// We'll now expect the reservation to be confirmed.
err = fsm.DefaultObserver.WaitForState(ctxb, 5*time.Second, Confirmed)
err = reservationFSM.DefaultObserver.WaitForState(ctxb, 5*time.Second, Confirmed)
require.NoError(t, err)
// We'll now expire the reservation.
err = testContext.mockLnd.NotifyHeight(
testContext.mockLnd.Height + int32(defaultExpiry),
)
// We'll now expect a spend registration.
spendReg := <-testContext.mockLnd.RegisterSpendChannel
require.Equal(t, spendReg.PkScript, pkScript)
go func() {
// We'll expect a second spend registration.
spendReg = <-testContext.mockLnd.RegisterSpendChannel
require.Equal(t, spendReg.PkScript, pkScript)
}()
// We'll now try to lock the reservation.
err = testContext.manager.LockReservation(ctxb, defaultReservationId)
require.NoError(t, err)
// We'll try to lock the reservation again, which should fail.
err = testContext.manager.LockReservation(ctxb, defaultReservationId)
require.Error(t, err)
testContext.mockLnd.SpendChannel <- &chainntnfs.SpendDetail{
SpentOutPoint: spendReg.Outpoint,
}
// We'll now expect the reservation to be expired.
err = fsm.DefaultObserver.WaitForState(ctxb, 5*time.Second, TimedOut)
err = reservationFSM.DefaultObserver.WaitForState(ctxb, 5*time.Second, Spent)
require.NoError(t, err)
}

@ -2,6 +2,7 @@
stateDiagram-v2
[*] --> Init: OnServerRequest
Confirmed
Confirmed --> SpendBroadcasted: OnSpendBroadcasted
Confirmed --> TimedOut: OnTimedOut
Confirmed --> Confirmed: OnRecover
Failed
@ -9,6 +10,9 @@ Init
Init --> Failed: OnError
Init --> WaitForConfirmation: OnBroadcast
Init --> Failed: OnRecover
SpendBroadcasted
SpendBroadcasted --> SpendConfirmed: OnSpendConfirmed
SpendConfirmed
TimedOut
WaitForConfirmation
WaitForConfirmation --> WaitForConfirmation: OnRecover

Loading…
Cancel
Save