You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
loop/fsm/fsm.go

342 lines
9.0 KiB
Go

package fsm
import (
"errors"
"fmt"
"sync"
)
// ErrEventRejected is the error returned when the state machine cannot process
// an event in the state that it is in.
var (
ErrEventRejected = errors.New("event rejected")
ErrWaitForStateTimedOut = errors.New(
"timed out while waiting for event",
)
ErrInvalidContextType = errors.New("invalid context")
ErrWaitingForStateEarlyAbortError = errors.New(
"waiting for state early abort",
)
)
const (
// EmptyState represents the default state of the system.
EmptyState StateType = ""
// NoOp represents a no-op event.
NoOp EventType = "NoOp"
// OnError can be used when an action returns a generic error.
OnError EventType = "OnError"
// ContextValidationFailed can be when the passed context if
// not of the expected type.
ContextValidationFailed EventType = "ContextValidationFailed"
)
// StateType represents an extensible state type in the state machine.
type StateType string
// EventType represents an extensible event type in the state machine.
type EventType string
// EventContext represents the context to be passed to the action
// implementation.
type EventContext interface{}
// Action represents the action to be executed in a given state.
type Action func(eventCtx EventContext) EventType
// Transitions represents a mapping of events and states.
type Transitions map[EventType]StateType
// State binds a state with an action and a set of events it can handle.
type State struct {
// EntryFunc is a function that is called when the state is entered.
EntryFunc func()
// ExitFunc is a function that is called when the state is exited.
ExitFunc func()
// Action is the action to be executed in the state.
Action Action
// Transitions is a mapping of events and states.
Transitions Transitions
}
// States represents a mapping of states and their implementations.
type States map[StateType]State
// Notification represents a notification sent to the state machine's
// notification channel.
type Notification struct {
// PreviousState is the state the state machine was in before the event
// was processed.
PreviousState StateType
// NextState is the state the state machine is in after the event was
// processed.
NextState StateType
// Event is the event that was processed.
Event EventType
// LastActionError is the error returned by the last action executed.
LastActionError error
}
// Observer is an interface that can be implemented by types that want to
// observe the state machine.
type Observer interface {
Notify(Notification)
}
// StateMachine represents the state machine.
type StateMachine struct {
// Context represents the state machine context.
States States
// ActionEntryFunc is a function that is called before an action is
// executed.
ActionEntryFunc func(Notification)
// ActionExitFunc is a function that is called after an action is
// executed, it is called with the EventType returned by the action.
ActionExitFunc func(NextEvent EventType)
// LastActionError is an error set by the last action executed.
LastActionError error
// DefaultObserver is the default observer that is notified when the
// state machine transitions between states.
DefaultObserver *CachedObserver
// previous represents the previous state.
previous StateType
// current represents the current state.
current StateType
// observers is a slice of observers that are notified when the state
// machine transitions between states.
observers []Observer
// observerMutex ensures that observers are only added or removed
// safely.
observerMutex sync.Mutex
// mutex ensures that only 1 event is processed by the state machine at
// any given time.
mutex sync.Mutex
}
// NewStateMachine creates a new state machine.
func NewStateMachine(states States, observerSize int) *StateMachine {
return NewStateMachineWithState(states, EmptyState, observerSize)
}
// NewStateMachineWithState creates a new state machine and sets the initial
// state.
func NewStateMachineWithState(states States, current StateType,
observerSize int) *StateMachine {
observers := []Observer{}
var defaultObserver *CachedObserver
if observerSize > 0 {
defaultObserver = NewCachedObserver(observerSize)
observers = append(observers, defaultObserver)
}
return &StateMachine{
States: states,
current: current,
DefaultObserver: defaultObserver,
observers: observers,
}
}
// getNextState returns the next state for the event given the machine's current
// state, or an error if the event can't be handled in the given state.
func (s *StateMachine) getNextState(event EventType) (State, error) {
var (
state State
ok bool
)
stateMap := s.States
if state, ok = stateMap[s.current]; !ok {
return State{}, NewErrConfigError("current state not found")
}
if state.Transitions == nil {
return State{}, NewErrConfigError(
"current state has no transitions",
)
}
var next StateType
if next, ok = state.Transitions[event]; !ok {
return State{}, NewErrConfigError(
"event not found in current transitions",
)
}
// Identify the state definition for the next state.
state, ok = stateMap[next]
if !ok {
return State{}, NewErrConfigError("next state not found")
}
if state.Action == nil {
return State{}, NewErrConfigError("next state has no action")
}
// Transition over to the next state.
s.previous = s.current
s.current = next
return state, nil
}
// SendEvent sends an event to the state machine. It returns an error if the
// event cannot be processed in the current state. Otherwise, it only returns
// nil if the event for the last action is a no-op.
func (s *StateMachine) SendEvent(event EventType, eventCtx EventContext) error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.States == nil {
return NewErrConfigError("state machine config is nil")
}
for {
// Determine the next state for the event given the machine's
// current state.
state, err := s.getNextState(event)
if err != nil {
log.Errorf("unable to get next state: %v from event: "+
"%v, current state: %v", err, event, s.current)
return ErrEventRejected
}
// Notify the state machine's observers.
s.observerMutex.Lock()
notification := Notification{
PreviousState: s.previous,
NextState: s.current,
Event: event,
LastActionError: s.LastActionError,
}
for _, observer := range s.observers {
observer.Notify(notification)
}
s.observerMutex.Unlock()
// Execute the state machines ActionEntryFunc.
if s.ActionEntryFunc != nil {
s.ActionEntryFunc(notification)
}
// Execute the current state's entry function
if state.EntryFunc != nil {
state.EntryFunc()
}
// Execute the next state's action and loop over again if the
// event returned is not a no-op.
nextEvent := state.Action(eventCtx)
// Execute the current state's exit function
if state.ExitFunc != nil {
state.ExitFunc()
}
// Execute the state machines ActionExitFunc.
if s.ActionExitFunc != nil {
s.ActionExitFunc(nextEvent)
}
// If the next event is a no-op, we're done.
if nextEvent == NoOp {
return nil
}
event = nextEvent
}
}
// RegisterObserver registers an observer with the state machine.
func (s *StateMachine) RegisterObserver(observer Observer) {
s.observerMutex.Lock()
defer s.observerMutex.Unlock()
if observer != nil {
s.observers = append(s.observers, observer)
}
}
// RemoveObserver removes an observer from the state machine. It returns true
// if the observer was removed, false otherwise.
func (s *StateMachine) RemoveObserver(observer Observer) bool {
s.observerMutex.Lock()
defer s.observerMutex.Unlock()
for i, o := range s.observers {
if o == observer {
s.observers = append(
s.observers[:i], s.observers[i+1:]...,
)
return true
}
}
return false
}
// HandleError is a helper function that can be used by actions to handle
// errors.
func (s *StateMachine) HandleError(err error) EventType {
log.Errorf("StateMachine error: %s", err)
s.LastActionError = err
return OnError
}
// NoOpAction is a no-op action that can be used by states that don't need to
// execute any action.
func NoOpAction(_ EventContext) EventType {
return NoOp
}
// ErrConfigError is an error returned when the state machine is misconfigured.
type ErrConfigError struct {
msg string
}
// Error returns the error message.
func (e ErrConfigError) Error() string {
return fmt.Sprintf("config error: %s", e.msg)
}
// NewErrConfigError creates a new ErrConfigError.
func NewErrConfigError(msg string) ErrConfigError {
return ErrConfigError{
msg: msg,
}
}
// ErrWaitingForStateTimeout is an error returned when the state machine times
// out while waiting for a state.
type ErrWaitingForStateTimeout struct {
expected StateType
}
// Error returns the error message.
func (e ErrWaitingForStateTimeout) Error() string {
return fmt.Sprintf("waiting for state timed out: %s", e.expected)
}
// NewErrWaitingForStateTimeout creates a new ErrWaitingForStateTimeout.
func NewErrWaitingForStateTimeout(expected StateType) ErrWaitingForStateTimeout {
return ErrWaitingForStateTimeout{
expected: expected,
}
}