diff --git a/liquidity/liquidity.go b/liquidity/liquidity.go index 511db6a..acd56e5 100644 --- a/liquidity/liquidity.go +++ b/liquidity/liquidity.go @@ -34,8 +34,11 @@ package liquidity import ( "context" + "encoding/gob" "errors" "fmt" + "io" + "os" "sort" "strings" "sync" @@ -438,6 +441,11 @@ func (m *Manager) Run(ctx context.Context) error { m.cfg.AutoloopTicker.Resume() defer m.cfg.AutoloopTicker.Stop() + // Before we start, load the parameters. + if err := m.LoadParameters(); err != nil { + log.Errorf("load parameters failed: %v", err) + } + for { select { case <-m.cfg.AutoloopTicker.Ticks(): @@ -453,6 +461,11 @@ func (m *Manager) Run(ctx context.Context) error { } case <-ctx.Done(): + // Before we quit, save the parameters. + if err := m.SaveParameters(); err != nil { + log.Errorf("save parameters failed: %v", err) + } + return ctx.Err() } } @@ -499,6 +512,45 @@ func (m *Manager) SetParameters(ctx context.Context, params Parameters) error { return nil } +// LoadParameters loads the parameters saved in the file specified by the +// config. +func (m *Manager) LoadParameters() error { + m.paramsLock.Lock() + defer m.paramsLock.Unlock() + + params, err := decodeParams(m.cfg.LiquidityParamsPath) + + // We will get an EOF if it's first time loading the params from the + // file. In this case, we will do nothing and let the manager use the + // default params. + if errors.Is(err, io.EOF) { + return nil + } + + // Otherwise return the error as it's unexpected. + if err != nil { + return fmt.Errorf("failed to load params: %w", err) + } + + // Attach the saved params. + m.params = cloneParameters(params) + return nil +} + +// SaveParameters saves the manager's parameters to the file specified by the +// config. +func (m *Manager) SaveParameters() error { + m.paramsLock.Lock() + defer m.paramsLock.Unlock() + + err := encodeParams(m.cfg.LiquidityParamsPath, m.params) + if err != nil { + return fmt.Errorf("failed to save params: %w", err) + } + + return nil +} + // cloneParameters creates a deep clone of a parameters struct so that callers // cannot mutate our parameters. Although our parameters struct itself is not // a reference, we still need to clone the contents of maps. @@ -527,6 +579,42 @@ func cloneParameters(params Parameters) Parameters { return paramCopy } +// encodeParams encodes the given parameters and saves it to a file. +func encodeParams(filepath string, p Parameters) error { + // Create the file if not exists, otherwise open it. + file, err := os.OpenFile(filepath, os.O_RDWR|os.O_CREATE, 0666) + if err != nil { + return fmt.Errorf("cannot create params file: %w", err) + } + defer file.Close() + + encoder := gob.NewEncoder(file) + if err := encoder.Encode(p); err != nil { + return fmt.Errorf("encoding params error: %w", err) + } + + return nil +} + +// decodeParams reads the file and decodes a Parameters struct. +func decodeParams(filepath string) (Parameters, error) { + var p Parameters + + // Create the file or truncate the old file. + file, err := os.Create(filepath) + if err != nil { + return p, fmt.Errorf("cannot create params file: %w", err) + } + defer file.Close() + + decoder := gob.NewDecoder(file) + if err := decoder.Decode(&p); err != nil { + return p, fmt.Errorf("decoding params error: %w", err) + } + + return p, nil +} + // autoloop gets a set of suggested swaps and dispatches them automatically if // we have automated looping enabled. func (m *Manager) autoloop(ctx context.Context) error { @@ -1212,3 +1300,13 @@ func ppmToSat(amount btcutil.Amount, ppm uint64) btcutil.Amount { func mSatToSatoshis(amount lnwire.MilliSatoshi) btcutil.Amount { return btcutil.Amount(amount / 1000) } + +func init() { + // Init custom structs for gob. + // + // NOTE: we need to use pointers here to make sure the interface check + // passes. + gob.Register(&SwapRule{}) + gob.Register(&FeePortion{}) + gob.Register(&FeeCategoryLimit{}) +} diff --git a/liquidity/liquidity_test.go b/liquidity/liquidity_test.go index 4104ba0..7df4ab7 100644 --- a/liquidity/liquidity_test.go +++ b/liquidity/liquidity_test.go @@ -2,6 +2,7 @@ package liquidity import ( "context" + "os" "testing" "time" @@ -246,6 +247,63 @@ func TestParameters(t *testing.T) { require.Equal(t, ErrZeroChannelID, err) } +// TestPersistParameters tests loading and saving from a file. +func TestPersistParameters(t *testing.T) { + // Overwrite the filepath. + paramsFile := "liquidity_params.gob.test" + + // Remove the test file when test finishes. + defer func() { + require.NoError(t, os.Remove(paramsFile)) + }() + + cfg, _ := newTestConfig() + cfg.LiquidityParamsPath = paramsFile + manager := NewManager(cfg) + + // Load params for the first time, which should end up not touching the + // manager's params. + err := manager.LoadParameters() + require.NoError(t, err) + + // Check the params are not changed. + startParams := manager.GetParameters() + require.Equal(t, defaultParameters, startParams) + + // Save params should give us no error. + err = manager.SaveParameters() + require.NoError(t, err) + + // We now update the manager's parameters. + chanID := lnwire.NewShortChanIDFromInt(1) + originalRule := &SwapRule{ + ThresholdRule: NewThresholdRule(10, 10), + Type: swap.TypeOut, + } + + // Create a new parameters struct. + expected := defaultParameters + expected.ChannelRules = map[lnwire.ShortChannelID]*SwapRule{ + chanID: originalRule, + } + + // Set the params. + err = manager.SetParameters(context.Background(), expected) + require.NoError(t, err) + + // Now save the updated params. + err = manager.SaveParameters() + require.NoError(t, err) + + // Load the params again which updates the manager's params. + err = manager.LoadParameters() + require.NoError(t, err) + + // Validate that the manager has the updated params. + params := manager.GetParameters() + require.Equal(t, expected, params) +} + // TestValidateRestrictions tests validating client restrictions against a set // of server restrictions. func TestValidateRestrictions(t *testing.T) { diff --git a/version.go b/version.go index c79f9ec..2768ed4 100644 --- a/version.go +++ b/version.go @@ -26,7 +26,7 @@ const semanticAlphabet = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqr const ( // Note: please update release_notes.md when you change these values. appMajor uint = 0 - appMinor uint = 18 + appMinor uint = 19 appPatch uint = 0 // appPreRelease MUST only contain characters from semanticAlphabet per