diff --git a/lndclient/lnd_services.go b/lndclient/lnd_services.go index c230e84..1921025 100644 --- a/lndclient/lnd_services.go +++ b/lndclient/lnd_services.go @@ -52,6 +52,8 @@ type LndServices struct { Versioner VersionerClient ChainParams *chaincfg.Params + NodeAlias string + NodePubkey [33]byte macaroons *macaroonPouch } @@ -133,7 +135,9 @@ func NewLndServices(cfg *LndServicesConfig) (*GrpcLndServices, error) { if err != nil { return nil, err } - err = checkLndCompatibility(conn, chainParams, readonlyMac, cfg.Network) + nodeAlias, nodeKey, err := checkLndCompatibility( + conn, chainParams, readonlyMac, cfg.Network, + ) if err != nil { return nil, err } @@ -189,6 +193,8 @@ func NewLndServices(cfg *LndServicesConfig) (*GrpcLndServices, error) { Router: routerClient, Versioner: versionerClient, ChainParams: chainParams, + NodeAlias: nodeAlias, + NodePubkey: nodeKey, macaroons: macaroons, }, cleanup: cleanup, @@ -210,7 +216,19 @@ func (s *GrpcLndServices) Close() { // checkLndCompatibility makes sure the connected lnd instance is running on the // correct network. func checkLndCompatibility(conn *grpc.ClientConn, chainParams *chaincfg.Params, - readonlyMac serializedMacaroon, network string) error { + readonlyMac serializedMacaroon, network string) (string, [33]byte, + error) { + + // onErr is a closure that simplifies returning multiple values in the + // error case. + onErr := func(err error) (string, [33]byte, error) { + closeErr := conn.Close() + if closeErr != nil { + log.Errorf("Error closing lnd connection: %v", closeErr) + } + + return "", [33]byte{}, err + } // We use our own client with a readonly macaroon here, because we know // that's all we need for the checks. @@ -220,17 +238,18 @@ func checkLndCompatibility(conn *grpc.ClientConn, chainParams *chaincfg.Params, // for lnd matches our expected network. info, err := lightningClient.GetInfo(context.Background()) if err != nil { - conn.Close() - return fmt.Errorf("unable to get info for lnd "+ - "node: %v", err) + err := fmt.Errorf("unable to get info for lnd node: %v", err) + return onErr(err) } if network != info.Network { - conn.Close() - return fmt.Errorf("network mismatch with connected lnd "+ - "node, got '%s', wanted '%s'", info.Network, network) - + err := fmt.Errorf("network mismatch with connected lnd node, "+ + "wanted '%s', got '%s'", network, info.Network) + return onErr(err) } - return nil + + // Return the static part of the info we just queried from the node so + // it can be cached for later use. + return info.Alias, info.IdentityPubkey, nil } var ( diff --git a/test/lnd_services_mock.go b/test/lnd_services_mock.go index 74a1d43..fb1b3e1 100644 --- a/test/lnd_services_mock.go +++ b/test/lnd_services_mock.go @@ -1,6 +1,7 @@ package test import ( + "context" "errors" "sync" "time" @@ -77,6 +78,11 @@ func NewMockLnd() *LndMockServices { router.lnd = &lnd signer.lnd = &lnd + // Also simulate the cached info that is loaded on startup. + info, _ := lightningClient.GetInfo(context.Background()) + lnd.LndServices.NodeAlias = info.Alias + lnd.LndServices.NodePubkey = info.IdentityPubkey + lnd.WaitForFinished = func() { chainNotifier.WaitForFinished() lightningClient.WaitForFinished()