diff --git a/liquidity/liquidity_test.go b/liquidity/liquidity_test.go index a38dd1b..d341d5c 100644 --- a/liquidity/liquidity_test.go +++ b/liquidity/liquidity_test.go @@ -351,13 +351,16 @@ func TestRestrictedSuggestions(t *testing.T) { return testCase.loopIn, nil } - rules := map[lnwire.ShortChannelID]*ThresholdRule{ + lnd.Channels = testCase.channels + + params := defaultParameters + params.ChannelRules = map[lnwire.ShortChannelID]*ThresholdRule{ chanID1: chanRule, chanID2: chanRule, } testSuggestSwaps( - t, cfg, lnd, testCase.channels, rules, + t, newSuggestSwapsSetup(cfg, lnd, params), testCase.expected, ) }) @@ -397,16 +400,18 @@ func TestSweepFeeLimit(t *testing.T) { loop.DefaultSweepConfTarget, testCase.feeRate, ) - channels := []lndclient.ChannelInfo{ + lnd.Channels = []lndclient.ChannelInfo{ channel1, } - rules := map[lnwire.ShortChannelID]*ThresholdRule{ + params := defaultParameters + params.ChannelRules = map[lnwire.ShortChannelID]*ThresholdRule{ chanID1: chanRule, } testSuggestSwaps( - t, cfg, lnd, channels, rules, testCase.swaps, + t, newSuggestSwapsSetup(cfg, lnd, params), + testCase.swaps, ) }) } @@ -448,12 +453,15 @@ func TestSuggestSwaps(t *testing.T) { t.Run(testCase.name, func(t *testing.T) { cfg, lnd := newTestConfig() - channels := []lndclient.ChannelInfo{ + lnd.Channels = []lndclient.ChannelInfo{ channel1, } + params := defaultParameters + params.ChannelRules = testCase.rules + testSuggestSwaps( - t, cfg, lnd, channels, testCase.rules, + t, newSuggestSwapsSetup(cfg, lnd, params), testCase.swaps, ) }) @@ -513,40 +521,78 @@ func TestFeeLimits(t *testing.T) { return testCase.quote, nil } - channels := []lndclient.ChannelInfo{ + lnd.Channels = []lndclient.ChannelInfo{ channel1, } - rules := map[lnwire.ShortChannelID]*ThresholdRule{ + + params := defaultParameters + params.ChannelRules = map[lnwire.ShortChannelID]*ThresholdRule{ chanID1: chanRule, } testSuggestSwaps( - t, cfg, lnd, channels, rules, testCase.expected, + t, newSuggestSwapsSetup(cfg, lnd, params), + testCase.expected, ) }) } } -// testSuggestSwaps tests getting swap suggestions. -func testSuggestSwaps(t *testing.T, cfg *Config, lnd *test.LndMockServices, - channels []lndclient.ChannelInfo, - rules map[lnwire.ShortChannelID]*ThresholdRule, +// testSuggestSwapsSetup contains the elements that are used to create a +// suggest swaps test. +type testSuggestSwapsSetup struct { + cfg *Config + lnd *test.LndMockServices + params Parameters +} + +// newSuggestSwapsSetup creates a suggest swaps setup struct. +func newSuggestSwapsSetup(cfg *Config, lnd *test.LndMockServices, + params Parameters) *testSuggestSwapsSetup { + + return &testSuggestSwapsSetup{ + cfg: cfg, + lnd: lnd, + params: params, + } +} + +// testSuggestSwaps tests getting swap suggestions. It takes a setup struct +// which contains custom setup for the test. If this struct is nil, it will +// use the default parameters and setup two channels (channel1 + channel2) with +// chanRule set for each. +func testSuggestSwaps(t *testing.T, setup *testSuggestSwapsSetup, expected []loop.OutRequest) { t.Parallel() - // Create a mock lnd with the set of channels set in our test case and - // update our test case lnd to use these channels. - lnd.Channels = channels + // If our setup struct is nil, we replace it with our default test + // values. + if setup == nil { + cfg, lnd := newTestConfig() + + lnd.Channels = []lndclient.ChannelInfo{ + channel1, channel2, + } + + params := defaultParameters + params.ChannelRules = map[lnwire.ShortChannelID]*ThresholdRule{ + chanID1: chanRule, + chanID2: chanRule, + } + + setup = &testSuggestSwapsSetup{ + cfg: cfg, + lnd: lnd, + params: params, + } + } // Create a new manager, get our current set of parameters and update // them to use the rules set by the test. - manager := NewManager(cfg) - - currentParams := manager.GetParameters() - currentParams.ChannelRules = rules + manager := NewManager(setup.cfg) - err := manager.SetParameters(currentParams) + err := manager.SetParameters(setup.params) require.NoError(t, err) actual, err := manager.SuggestSwaps(context.Background())