diff --git a/instantout/reservation/actions.go b/instantout/reservation/actions.go index 3d0965b..6acbfc9 100644 --- a/instantout/reservation/actions.go +++ b/instantout/reservation/actions.go @@ -7,6 +7,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/lightninglabs/loop/fsm" looprpc "github.com/lightninglabs/loop/swapserverrpc" + "github.com/lightningnetwork/lnd/chainntnfs" ) // InitReservationContext contains the request parameters for a reservation. @@ -21,18 +22,18 @@ type InitReservationContext struct { // InitAction is the action that is executed when the reservation state machine // is initialized. It creates the reservation in the database and dispatches the // payment to the server. -func (r *FSM) InitAction(eventCtx fsm.EventContext) fsm.EventType { +func (f *FSM) InitAction(eventCtx fsm.EventContext) fsm.EventType { // Check if the context is of the correct type. reservationRequest, ok := eventCtx.(*InitReservationContext) if !ok { - return r.HandleError(fsm.ErrInvalidContextType) + return f.HandleError(fsm.ErrInvalidContextType) } - keyRes, err := r.cfg.Wallet.DeriveNextKey( - r.ctx, KeyFamily, + keyRes, err := f.cfg.Wallet.DeriveNextKey( + f.ctx, KeyFamily, ) if err != nil { - return r.HandleError(err) + return f.HandleError(err) } // Send the client reservation details to the server. @@ -44,9 +45,9 @@ func (r *FSM) InitAction(eventCtx fsm.EventContext) fsm.EventType { ClientKey: keyRes.PubKey.SerializeCompressed(), } - _, err = r.cfg.ReservationClient.OpenReservation(r.ctx, request) + _, err = f.cfg.ReservationClient.OpenReservation(f.ctx, request) if err != nil { - return r.HandleError(err) + return f.HandleError(err) } reservation, err := NewReservation( @@ -59,15 +60,15 @@ func (r *FSM) InitAction(eventCtx fsm.EventContext) fsm.EventType { keyRes.KeyLocator, ) if err != nil { - return r.HandleError(err) + return f.HandleError(err) } - r.reservation = reservation + f.reservation = reservation // Create the reservation in the database. - err = r.cfg.Store.CreateReservation(r.ctx, reservation) + err = f.cfg.Store.CreateReservation(f.ctx, reservation) if err != nil { - return r.HandleError(err) + return f.HandleError(err) } return OnBroadcast @@ -76,101 +77,163 @@ func (r *FSM) InitAction(eventCtx fsm.EventContext) fsm.EventType { // SubscribeToConfirmationAction is the action that is executed when the // reservation is waiting for confirmation. It subscribes to the confirmation // of the reservation transaction. -func (r *FSM) SubscribeToConfirmationAction(_ fsm.EventContext) fsm.EventType { - pkscript, err := r.reservation.GetPkScript() +func (f *FSM) SubscribeToConfirmationAction(_ fsm.EventContext) fsm.EventType { + pkscript, err := f.reservation.GetPkScript() if err != nil { - return r.HandleError(err) + return f.HandleError(err) } - callCtx, cancel := context.WithCancel(r.ctx) + callCtx, cancel := context.WithCancel(f.ctx) defer cancel() // Subscribe to the confirmation of the reservation transaction. log.Debugf("Subscribing to conf for reservation: %x pkscript: %x, "+ - "initiation height: %v", r.reservation.ID, pkscript, - r.reservation.InitiationHeight) + "initiation height: %v", f.reservation.ID, pkscript, + f.reservation.InitiationHeight) - confChan, errConfChan, err := r.cfg.ChainNotifier.RegisterConfirmationsNtfn( + confChan, errConfChan, err := f.cfg.ChainNotifier.RegisterConfirmationsNtfn( callCtx, nil, pkscript, DefaultConfTarget, - r.reservation.InitiationHeight, + f.reservation.InitiationHeight, ) if err != nil { - r.Errorf("unable to subscribe to conf notification: %v", err) - return r.HandleError(err) + f.Errorf("unable to subscribe to conf notification: %v", err) + return f.HandleError(err) } - blockChan, errBlockChan, err := r.cfg.ChainNotifier.RegisterBlockEpochNtfn( + blockChan, errBlockChan, err := f.cfg.ChainNotifier.RegisterBlockEpochNtfn( callCtx, ) if err != nil { - r.Errorf("unable to subscribe to block notifications: %v", err) - return r.HandleError(err) + f.Errorf("unable to subscribe to block notifications: %v", err) + return f.HandleError(err) } // We'll now wait for the confirmation of the reservation transaction. for { select { case err := <-errConfChan: - r.Errorf("conf subscription error: %v", err) - return r.HandleError(err) + f.Errorf("conf subscription error: %v", err) + return f.HandleError(err) case err := <-errBlockChan: - r.Errorf("block subscription error: %v", err) - return r.HandleError(err) + f.Errorf("block subscription error: %v", err) + return f.HandleError(err) case confInfo := <-confChan: - r.Debugf("reservation confirmed: %v", confInfo) - outpoint, err := r.reservation.findReservationOutput( + f.Debugf("confirmed in block %v", confInfo.Block) + outpoint, err := f.reservation.findReservationOutput( confInfo.Tx, ) if err != nil { - return r.HandleError(err) + return f.HandleError(err) } - r.reservation.ConfirmationHeight = confInfo.BlockHeight - r.reservation.Outpoint = outpoint + f.reservation.ConfirmationHeight = confInfo.BlockHeight + f.reservation.Outpoint = outpoint return OnConfirmed case block := <-blockChan: - r.Debugf("block received: %v expiry: %v", block, - r.reservation.Expiry) + f.Debugf("block received: %v expiry: %v", block, + f.reservation.Expiry) - if uint32(block) >= r.reservation.Expiry { + if uint32(block) >= f.reservation.Expiry { return OnTimedOut } - case <-r.ctx.Done(): + case <-f.ctx.Done(): return fsm.NoOp } } } -// ReservationConfirmedAction waits for the reservation to be either expired or -// waits for other actions to happen. -func (r *FSM) ReservationConfirmedAction(_ fsm.EventContext) fsm.EventType { - blockHeightChan, errEpochChan, err := r.cfg.ChainNotifier. - RegisterBlockEpochNtfn(r.ctx) +// AsyncWaitForExpiredOrSweptAction waits for the reservation to be either +// expired or swept. This is non-blocking and can be used to wait for the +// reservation to expire while expecting other events. +func (f *FSM) AsyncWaitForExpiredOrSweptAction(_ fsm.EventContext, +) fsm.EventType { + + notifCtx, cancel := context.WithCancel(f.ctx) + + blockHeightChan, errEpochChan, err := f.cfg.ChainNotifier. + RegisterBlockEpochNtfn(notifCtx) if err != nil { - return r.HandleError(err) + cancel() + return f.HandleError(err) } + pkScript, err := f.reservation.GetPkScript() + if err != nil { + cancel() + return f.HandleError(err) + } + + spendChan, errSpendChan, err := f.cfg.ChainNotifier.RegisterSpendNtfn( + notifCtx, f.reservation.Outpoint, pkScript, + f.reservation.InitiationHeight, + ) + if err != nil { + cancel() + return f.HandleError(err) + } + + go func() { + defer cancel() + op, err := f.handleSubcriptions( + notifCtx, blockHeightChan, spendChan, errEpochChan, + errSpendChan, + ) + if err != nil { + f.handleAsyncError(err) + return + } + if op == fsm.NoOp { + return + } + err = f.SendEvent(op, nil) + if err != nil { + f.Errorf("Error sending %s event: %v", op, err) + } + }() + + return fsm.NoOp +} + +func (f *FSM) handleSubcriptions(ctx context.Context, + blockHeightChan <-chan int32, spendChan <-chan *chainntnfs.SpendDetail, + errEpochChan <-chan error, errSpendChan <-chan error, +) (fsm.EventType, error) { + for { select { case err := <-errEpochChan: - return r.HandleError(err) + return fsm.OnError, err + + case err := <-errSpendChan: + return fsm.OnError, err case blockHeight := <-blockHeightChan: - expired := blockHeight >= int32(r.reservation.Expiry) - if expired { - r.Debugf("Reservation %v expired", - r.reservation.ID) + expired := blockHeight >= int32(f.reservation.Expiry) - return OnTimedOut + if expired { + f.Debugf("Reservation expired") + return OnTimedOut, nil } - case <-r.ctx.Done(): - return fsm.NoOp + case <-spendChan: + return OnSpent, nil + + case <-ctx.Done(): + return fsm.NoOp, nil } } } + +func (f *FSM) handleAsyncError(err error) { + f.LastActionError = err + f.Errorf("Error on async action: %v", err) + err2 := f.SendEvent(fsm.OnError, err) + if err2 != nil { + f.Errorf("Error sending event: %v", err2) + } +} diff --git a/instantout/reservation/actions_test.go b/instantout/reservation/actions_test.go index 13559f1..2f989d9 100644 --- a/instantout/reservation/actions_test.go +++ b/instantout/reservation/actions_test.go @@ -129,6 +129,7 @@ func TestInitReservationAction(t *testing.T) { } for _, tc := range tests { + tc := tc ctxb := context.Background() mockLnd := test.NewMockLnd() mockReservationClient := new(mockReservationClient) @@ -223,6 +224,7 @@ func TestSubscribeToConfirmationAction(t *testing.T) { } for _, tc := range tests { + tc := tc t.Run(tc.name, func(t *testing.T) { chainNotifier := new(MockChainNotifier) @@ -304,14 +306,83 @@ func TestSubscribeToConfirmationAction(t *testing.T) { } } -// TestReservationConfirmedAction tests the ReservationConfirmedAction of the +// AsyncWaitForExpiredOrSweptAction tests the AsyncWaitForExpiredOrSweptAction +// of the reservation state machine. +func TestAsyncWaitForExpiredOrSweptAction(t *testing.T) { + tests := []struct { + name string + blockErr error + spendErr error + expectedEvent fsm.EventType + }{ + { + name: "noop", + expectedEvent: fsm.NoOp, + }, + { + name: "block error", + blockErr: errors.New("block error"), + expectedEvent: fsm.OnError, + }, + { + name: "spend error", + spendErr: errors.New("spend error"), + expectedEvent: fsm.OnError, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { // Create a mock ChainNotifier and Reservation + chainNotifier := new(MockChainNotifier) + + // Define your FSM + r := NewFSMFromReservation( + context.Background(), &Config{ + ChainNotifier: chainNotifier, + }, + &Reservation{ + ServerPubkey: defaultPubkey, + ClientPubkey: defaultPubkey, + Expiry: defaultExpiry, + }, + ) + + // Define the expected return values for your mocks + chainNotifier.On("RegisterBlockEpochNtfn", mock.Anything).Return( + make(chan int32), make(chan error), tc.blockErr, + ) + + chainNotifier.On( + "RegisterSpendNtfn", mock.Anything, + mock.Anything, mock.Anything, + ).Return( + make(chan *chainntnfs.SpendDetail), + make(chan error), tc.spendErr, + ) + + eventType := r.AsyncWaitForExpiredOrSweptAction(nil) + // Assert that the return value is as expected + require.Equal(t, tc.expectedEvent, eventType) + }) + } +} + +// TesthandleSubcriptions tests the handleSubcriptions function of the // reservation state machine. -func TestReservationConfirmedAction(t *testing.T) { +func TestHandleSubcriptions(t *testing.T) { + var ( + blockErr = errors.New("block error") + spendErr = errors.New("spend error") + ) tests := []struct { name string blockHeight int32 blockErr error + spendDetail *chainntnfs.SpendDetail + spendErr error expectedEvent fsm.EventType + expectedErr error }{ { name: "expired", @@ -320,13 +391,25 @@ func TestReservationConfirmedAction(t *testing.T) { }, { name: "block error", - blockHeight: 0, - blockErr: errors.New("block error"), + blockErr: blockErr, + expectedEvent: fsm.OnError, + expectedErr: blockErr, + }, + { + name: "spent", + spendDetail: &chainntnfs.SpendDetail{}, + expectedEvent: OnSpent, + }, + { + name: "spend error", + spendErr: spendErr, expectedEvent: fsm.OnError, + expectedErr: spendErr, }, } for _, tc := range tests { + tc := tc t.Run(tc.name, func(t *testing.T) { chainNotifier := new(MockChainNotifier) @@ -336,36 +419,41 @@ func TestReservationConfirmedAction(t *testing.T) { ChainNotifier: chainNotifier, }, &Reservation{ - Expiry: defaultExpiry, + ServerPubkey: defaultPubkey, + ClientPubkey: defaultPubkey, + Expiry: defaultExpiry, }, ) blockChan := make(chan int32) blockErrChan := make(chan error) - // Define our expected return values for the mocks. - chainNotifier.On("RegisterBlockEpochNtfn", mock.Anything).Return( - blockChan, blockErrChan, nil, - ) + spendChan := make(chan *chainntnfs.SpendDetail) + spendErrChan := make(chan error) + go func() { - // Send the block notification. if tc.blockHeight != 0 { blockChan <- tc.blockHeight } - }() - go func() { - // Send the block notification error. if tc.blockErr != nil { blockErrChan <- tc.blockErr } + + if tc.spendDetail != nil { + spendChan <- tc.spendDetail + } + if tc.spendErr != nil { + spendErrChan <- tc.spendErr + } }() - eventType := r.ReservationConfirmedAction(nil) + eventType, err := r.handleSubcriptions( + context.Background(), blockChan, spendChan, + blockErrChan, spendErrChan, + ) + require.Equal(t, tc.expectedErr, err) require.Equal(t, tc.expectedEvent, eventType) - - // Assert that the expected functions were called on the mocks - chainNotifier.AssertExpectations(t) }) } } diff --git a/instantout/reservation/fsm.go b/instantout/reservation/fsm.go index 3bd0820..af2698a 100644 --- a/instantout/reservation/fsm.go +++ b/instantout/reservation/fsm.go @@ -123,6 +123,18 @@ var ( // OnRecover is the event that is triggered when the reservation FSM // recovers from a restart. OnRecover = fsm.EventType("OnRecover") + + // OnSpent is the event that is triggered when the reservation has been + // spent. + OnSpent = fsm.EventType("OnSpent") + + // OnLocked is the event that is triggered when the reservation has + // been locked. + OnLocked = fsm.EventType("OnLocked") + + // OnUnlocked is the event that is triggered when the reservation has + // been unlocked. + OnUnlocked = fsm.EventType("OnUnlocked") ) // GetReservationStates returns the statemap that defines the reservation @@ -153,14 +165,38 @@ func (f *FSM) GetReservationStates() fsm.States { }, Confirmed: fsm.State{ Transitions: fsm.Transitions{ - OnTimedOut: TimedOut, - OnRecover: Confirmed, + OnSpent: Spent, + OnTimedOut: TimedOut, + OnRecover: Confirmed, + OnLocked: Locked, + fsm.OnError: Confirmed, + }, + Action: f.AsyncWaitForExpiredOrSweptAction, + }, + Locked: fsm.State{ + Transitions: fsm.Transitions{ + OnUnlocked: Confirmed, + OnTimedOut: TimedOut, + OnRecover: Locked, + OnSpent: Spent, + fsm.OnError: Locked, }, - Action: f.ReservationConfirmedAction, + Action: f.AsyncWaitForExpiredOrSweptAction, }, TimedOut: fsm.State{ + Transitions: fsm.Transitions{ + OnTimedOut: TimedOut, + }, Action: fsm.NoOpAction, }, + + Spent: fsm.State{ + Transitions: fsm.Transitions{ + OnSpent: Spent, + }, + Action: fsm.NoOpAction, + }, + Failed: fsm.State{ Action: fsm.NoOpAction, }, diff --git a/instantout/reservation/manager.go b/instantout/reservation/manager.go index 58baeae..655b5f9 100644 --- a/instantout/reservation/manager.go +++ b/instantout/reservation/manager.go @@ -3,6 +3,7 @@ package reservation import ( "context" "fmt" + "strings" "sync" "time" @@ -35,7 +36,6 @@ func NewManager(cfg *Config) *Manager { // Run runs the reservation manager. func (m *Manager) Run(ctx context.Context, height int32) error { - // todo(sputn1ck): recover swaps on startup log.Debugf("Starting reservation manager") runCtx, cancel := context.WithCancel(ctx) @@ -269,3 +269,54 @@ func (m *Manager) RecoverReservations(ctx context.Context) error { func (m *Manager) GetReservations(ctx context.Context) ([]*Reservation, error) { return m.cfg.Store.ListReservations(ctx) } + +// GetReservation returns the reservation for the given id. +func (m *Manager) GetReservation(ctx context.Context, id ID) (*Reservation, + error) { + + return m.cfg.Store.GetReservation(ctx, id) +} + +// LockReservation locks the reservation with the given ID. +func (m *Manager) LockReservation(ctx context.Context, id ID) error { + // Try getting the reservation from the active reservations map. + m.Lock() + reservation, ok := m.activeReservations[id] + m.Unlock() + + if !ok { + return fmt.Errorf("reservation not found") + } + + // Try to send the lock event to the reservation. + err := reservation.SendEvent(OnLocked, nil) + if err != nil { + return err + } + + return nil +} + +// UnlockReservation unlocks the reservation with the given ID. +func (m *Manager) UnlockReservation(ctx context.Context, id ID) error { + // Try getting the reservation from the active reservations map. + m.Lock() + reservation, ok := m.activeReservations[id] + m.Unlock() + + if !ok { + return fmt.Errorf("reservation not found") + } + + // Try to send the unlock event to the reservation. + err := reservation.SendEvent(OnUnlocked, nil) + if err != nil && strings.Contains(err.Error(), "config error") { + // If the error is a config error, we can ignore it, as the + // reservation is already unlocked. + return nil + } else if err != nil { + return err + } + + return nil +} diff --git a/instantout/reservation/manager_test.go b/instantout/reservation/manager_test.go index 1a4fe05..6b9239c 100644 --- a/instantout/reservation/manager_test.go +++ b/instantout/reservation/manager_test.go @@ -33,7 +33,7 @@ func TestManager(t *testing.T) { }() // Create a new reservation. - fsm, err := testContext.manager.newReservation( + reservationFSM, err := testContext.manager.newReservation( ctxb, uint32(testContext.mockLnd.Height), &swapserverrpc.ServerReservationNotification{ ReservationId: defaultReservationId[:], @@ -45,11 +45,11 @@ func TestManager(t *testing.T) { require.NoError(t, err) // We'll expect the spendConfirmation to be sent to the server. - pkScript, err := fsm.reservation.GetPkScript() + pkScript, err := reservationFSM.reservation.GetPkScript() require.NoError(t, err) - conf := <-testContext.mockLnd.RegisterConfChannel - require.Equal(t, conf.PkScript, pkScript) + confReg := <-testContext.mockLnd.RegisterConfChannel + require.Equal(t, confReg.PkScript, pkScript) confTx := &wire.MsgTx{ TxOut: []*wire.TxOut{ @@ -59,23 +59,39 @@ func TestManager(t *testing.T) { }, } // We'll now confirm the spend. - conf.ConfChan <- &chainntnfs.TxConfirmation{ + confReg.ConfChan <- &chainntnfs.TxConfirmation{ BlockHeight: uint32(testContext.mockLnd.Height), Tx: confTx, } // We'll now expect the reservation to be confirmed. - err = fsm.DefaultObserver.WaitForState(ctxb, 5*time.Second, Confirmed) + err = reservationFSM.DefaultObserver.WaitForState(ctxb, 5*time.Second, Confirmed) require.NoError(t, err) - // We'll now expire the reservation. - err = testContext.mockLnd.NotifyHeight( - testContext.mockLnd.Height + int32(defaultExpiry), - ) + // We'll now expect a spend registration. + spendReg := <-testContext.mockLnd.RegisterSpendChannel + require.Equal(t, spendReg.PkScript, pkScript) + + go func() { + // We'll expect a second spend registration. + spendReg = <-testContext.mockLnd.RegisterSpendChannel + require.Equal(t, spendReg.PkScript, pkScript) + }() + + // We'll now try to lock the reservation. + err = testContext.manager.LockReservation(ctxb, defaultReservationId) require.NoError(t, err) + // We'll try to lock the reservation again, which should fail. + err = testContext.manager.LockReservation(ctxb, defaultReservationId) + require.Error(t, err) + + testContext.mockLnd.SpendChannel <- &chainntnfs.SpendDetail{ + SpentOutPoint: spendReg.Outpoint, + } + // We'll now expect the reservation to be expired. - err = fsm.DefaultObserver.WaitForState(ctxb, 5*time.Second, TimedOut) + err = reservationFSM.DefaultObserver.WaitForState(ctxb, 5*time.Second, Spent) require.NoError(t, err) } diff --git a/instantout/reservation/reservation_fsm.md b/instantout/reservation/reservation_fsm.md index 712a1ea..d15c3b0 100644 --- a/instantout/reservation/reservation_fsm.md +++ b/instantout/reservation/reservation_fsm.md @@ -2,6 +2,7 @@ stateDiagram-v2 [*] --> Init: OnServerRequest Confirmed +Confirmed --> SpendBroadcasted: OnSpendBroadcasted Confirmed --> TimedOut: OnTimedOut Confirmed --> Confirmed: OnRecover Failed @@ -9,6 +10,9 @@ Init Init --> Failed: OnError Init --> WaitForConfirmation: OnBroadcast Init --> Failed: OnRecover +SpendBroadcasted +SpendBroadcasted --> SpendConfirmed: OnSpendConfirmed +SpendConfirmed TimedOut WaitForConfirmation WaitForConfirmation --> WaitForConfirmation: OnRecover