diff --git a/test/chainnotifier_mock.go b/test/chainnotifier_mock.go index 836e50f..f155d11 100644 --- a/test/chainnotifier_mock.go +++ b/test/chainnotifier_mock.go @@ -73,31 +73,40 @@ func (c *mockChainNotifier) RegisterBlockEpochNtfn(ctx context.Context) ( chan int32, chan error, error) { blockErrorChan := make(chan error, 1) - blockEpochChan := make(chan int32) + blockEpochChan := make(chan int32, 1) + + c.lnd.lock.Lock() + c.lnd.blockHeightListeners = append( + c.lnd.blockHeightListeners, blockEpochChan, + ) + c.lnd.lock.Unlock() c.wg.Add(1) go func() { defer c.wg.Done() + defer func() { + c.lnd.lock.Lock() + defer c.lnd.lock.Unlock() + for i := 0; i < len(c.lnd.blockHeightListeners); i++ { + if c.lnd.blockHeightListeners[i] == blockEpochChan { + c.lnd.blockHeightListeners = append( + c.lnd.blockHeightListeners[:i], + c.lnd.blockHeightListeners[i+1:]..., + ) + break + } + } + }() // Send initial block height + c.lnd.lock.Lock() select { case blockEpochChan <- c.lnd.Height: case <-ctx.Done(): - return } + c.lnd.lock.Unlock() - for { - select { - case m := <-c.lnd.epochChannel: - select { - case blockEpochChan <- m: - case <-ctx.Done(): - return - } - case <-ctx.Done(): - return - } - } + <-ctx.Done() }() return blockEpochChan, blockErrorChan, nil diff --git a/test/context.go b/test/context.go index 638bcc7..5fae593 100644 --- a/test/context.go +++ b/test/context.go @@ -259,3 +259,9 @@ func (ctx *Context) GetOutputIndex(tx *wire.MsgTx, func (ctx *Context) NotifyServerHeight(height int32) { require.NoError(ctx.T, ctx.Lnd.NotifyHeight(height)) } + +func (ctx *Context) AssertEpochListeners(numListeners int32) { + require.Eventually(ctx.T, func() bool { + return ctx.Lnd.EpochSubscribers() == numListeners + }, Timeout, time.Millisecond*250) +} diff --git a/test/lnd_services_mock.go b/test/lnd_services_mock.go index 1dad7cc..db44474 100644 --- a/test/lnd_services_mock.go +++ b/test/lnd_services_mock.go @@ -4,7 +4,6 @@ import ( "context" "errors" "sync" - "time" "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/wire" @@ -63,13 +62,13 @@ func NewMockLnd() *LndMockServices { SignOutputRawChannel: make(chan SignOutputRawRequest), - FailInvoiceChannel: make(chan lntypes.Hash, 2), - epochChannel: make(chan int32), - Height: testStartingHeight, - NodePubkey: testNodePubkey, - Signature: testSignature, - SignatureMsg: testSignatureMsg, - Invoices: make(map[lntypes.Hash]*lndclient.Invoice), + FailInvoiceChannel: make(chan lntypes.Hash, 2), + blockHeightListeners: make([]chan int32, 0), + Height: testStartingHeight, + NodePubkey: testNodePubkey, + Signature: testSignature, + SignatureMsg: testSignatureMsg, + Invoices: make(map[lntypes.Hash]*lndclient.Invoice), } lightningClient.lnd = &lnd @@ -139,7 +138,7 @@ type LndMockServices struct { SendOutputsChannel chan wire.MsgTx SettleInvoiceChannel chan lntypes.Preimage FailInvoiceChannel chan lntypes.Hash - epochChannel chan int32 + blockHeightListeners []chan int32 ConfChannel chan *chainntnfs.TxConfirmation RegisterConfChannel chan *ConfRegistration @@ -177,15 +176,28 @@ type LndMockServices struct { lock sync.Mutex } +// EpochSubscribers returns the number of subscribers to block epoch +// notifications. +func (s *LndMockServices) EpochSubscribers() int32 { + s.lock.Lock() + defer s.lock.Unlock() + + return int32(len(s.blockHeightListeners)) +} + // NotifyHeight notifies a new block height. func (s *LndMockServices) NotifyHeight(height int32) error { + s.lock.Lock() + defer s.lock.Unlock() s.Height = height - select { - case s.epochChannel <- height: - case <-time.After(Timeout): - return ErrTimeout + for _, listener := range s.blockHeightListeners { + lis := listener + go func() { + lis <- height + }() } + return nil } diff --git a/testcontext_test.go b/testcontext_test.go index 71f0ade..47b1488 100644 --- a/testcontext_test.go +++ b/testcontext_test.go @@ -250,3 +250,11 @@ func (ctx *testContext) assertPreimagePush(preimage lntypes.Preimage) { ctx.Context.T.Fatalf("preimage not pushed") } } + +func (ctx *testContext) AssertEpochListeners(numListeners int32) { + ctx.Context.T.Helper() + + require.Eventually(ctx.Context.T, func() bool { + return ctx.Lnd.EpochSubscribers() == numListeners + }, test.Timeout, time.Millisecond*250) +}