diff --git a/lndclient/chainnotifier_client.go b/lndclient/chainnotifier_client.go index 319788e..c48a2ed 100644 --- a/lndclient/chainnotifier_client.go +++ b/lndclient/chainnotifier_client.go @@ -28,13 +28,16 @@ type ChainNotifierClient interface { } type chainNotifierClient struct { - client chainrpc.ChainNotifierClient - wg sync.WaitGroup + client chainrpc.ChainNotifierClient + chainMac serializedMacaroon + + wg sync.WaitGroup } -func newChainNotifierClient(conn *grpc.ClientConn) *chainNotifierClient { +func newChainNotifierClient(conn *grpc.ClientConn, chainMac serializedMacaroon) *chainNotifierClient { return &chainNotifierClient{ - client: chainrpc.NewChainNotifierClient(conn), + client: chainrpc.NewChainNotifierClient(conn), + chainMac: chainMac, } } @@ -54,7 +57,8 @@ func (s *chainNotifierClient) RegisterSpendNtfn(ctx context.Context, } } - resp, err := s.client.RegisterSpendNtfn(ctx, &chainrpc.SpendRequest{ + macaroonAuth := s.chainMac.WithMacaroonAuth(ctx) + resp, err := s.client.RegisterSpendNtfn(macaroonAuth, &chainrpc.SpendRequest{ HeightHint: uint32(heightHint), Outpoint: rpcOutpoint, Script: pkScript, @@ -125,16 +129,15 @@ func (s *chainNotifierClient) RegisterConfirmationsNtfn(ctx context.Context, if txid != nil { txidSlice = txid[:] } - confStream, err := s.client. - RegisterConfirmationsNtfn( - ctx, - &chainrpc.ConfRequest{ - Script: pkScript, - NumConfs: uint32(numConfs), - HeightHint: uint32(heightHint), - Txid: txidSlice, - }, - ) + confStream, err := s.client.RegisterConfirmationsNtfn( + s.chainMac.WithMacaroonAuth(ctx), + &chainrpc.ConfRequest{ + Script: pkScript, + NumConfs: uint32(numConfs), + HeightHint: uint32(heightHint), + Txid: txidSlice, + }, + ) if err != nil { return nil, nil, err } @@ -203,8 +206,9 @@ func (s *chainNotifierClient) RegisterConfirmationsNtfn(ctx context.Context, func (s *chainNotifierClient) RegisterBlockEpochNtfn(ctx context.Context) ( chan int32, chan error, error) { - blockEpochClient, err := s.client. - RegisterBlockEpochNtfn(ctx, &chainrpc.BlockEpoch{}) + blockEpochClient, err := s.client.RegisterBlockEpochNtfn( + s.chainMac.WithMacaroonAuth(ctx), &chainrpc.BlockEpoch{}, + ) if err != nil { return nil, nil, err } diff --git a/lndclient/invoices_client.go b/lndclient/invoices_client.go index 5e92138..60663ba 100644 --- a/lndclient/invoices_client.go +++ b/lndclient/invoices_client.go @@ -33,13 +33,15 @@ type InvoiceUpdate struct { } type invoicesClient struct { - client invoicesrpc.InvoicesClient - wg sync.WaitGroup + client invoicesrpc.InvoicesClient + invoiceMac serializedMacaroon + wg sync.WaitGroup } -func newInvoicesClient(conn *grpc.ClientConn) *invoicesClient { +func newInvoicesClient(conn *grpc.ClientConn, invoiceMac serializedMacaroon) *invoicesClient { return &invoicesClient{ - client: invoicesrpc.NewInvoicesClient(conn), + client: invoicesrpc.NewInvoicesClient(conn), + invoiceMac: invoiceMac, } } @@ -53,6 +55,7 @@ func (s *invoicesClient) SettleInvoice(ctx context.Context, rpcCtx, cancel := context.WithTimeout(ctx, rpcTimeout) defer cancel() + rpcCtx = s.invoiceMac.WithMacaroonAuth(ctx) _, err := s.client.SettleInvoice(rpcCtx, &invoicesrpc.SettleInvoiceMsg{ Preimage: preimage[:], }) @@ -66,6 +69,7 @@ func (s *invoicesClient) CancelInvoice(ctx context.Context, rpcCtx, cancel := context.WithTimeout(ctx, rpcTimeout) defer cancel() + rpcCtx = s.invoiceMac.WithMacaroonAuth(rpcCtx) _, err := s.client.CancelInvoice(rpcCtx, &invoicesrpc.CancelInvoiceMsg{ PaymentHash: hash[:], }) @@ -77,11 +81,12 @@ func (s *invoicesClient) SubscribeSingleInvoice(ctx context.Context, hash lntypes.Hash) (<-chan InvoiceUpdate, <-chan error, error) { - invoiceStream, err := s.client. - SubscribeSingleInvoice(ctx, - &lnrpc.PaymentHash{ - RHash: hash[:], - }) + invoiceStream, err := s.client.SubscribeSingleInvoice( + s.invoiceMac.WithMacaroonAuth(ctx), + &lnrpc.PaymentHash{ + RHash: hash[:], + }, + ) if err != nil { return nil, nil, err } @@ -135,6 +140,7 @@ func (s *invoicesClient) AddHoldInvoice(ctx context.Context, Private: true, } + rpcCtx = s.invoiceMac.WithMacaroonAuth(rpcCtx) resp, err := s.client.AddHoldInvoice(rpcCtx, rpcIn) if err != nil { return "", err diff --git a/lndclient/lightning_client.go b/lndclient/lightning_client.go index d74c538..23d7456 100644 --- a/lndclient/lightning_client.go +++ b/lndclient/lightning_client.go @@ -78,17 +78,19 @@ var ( ) type lightningClient struct { - client lnrpc.LightningClient - wg sync.WaitGroup - params *chaincfg.Params + client lnrpc.LightningClient + wg sync.WaitGroup + params *chaincfg.Params + adminMac serializedMacaroon } func newLightningClient(conn *grpc.ClientConn, - params *chaincfg.Params) *lightningClient { + params *chaincfg.Params, adminMac serializedMacaroon) *lightningClient { return &lightningClient{ - client: lnrpc.NewLightningClient(conn), - params: params, + client: lnrpc.NewLightningClient(conn), + params: params, + adminMac: adminMac, } } @@ -110,6 +112,7 @@ func (s *lightningClient) ConfirmedWalletBalance(ctx context.Context) ( rpcCtx, cancel := context.WithTimeout(ctx, rpcTimeout) defer cancel() + rpcCtx = s.adminMac.WithMacaroonAuth(rpcCtx) resp, err := s.client.WalletBalance(rpcCtx, &lnrpc.WalletBalanceRequest{}) if err != nil { return 0, err @@ -122,6 +125,7 @@ func (s *lightningClient) GetInfo(ctx context.Context) (*Info, error) { rpcCtx, cancel := context.WithTimeout(ctx, rpcTimeout) defer cancel() + rpcCtx = s.adminMac.WithMacaroonAuth(rpcCtx) resp, err := s.client.GetInfo(rpcCtx, &lnrpc.GetInfoRequest{}) if err != nil { return nil, err @@ -159,6 +163,7 @@ func (s *lightningClient) EstimateFeeToP2WSH(ctx context.Context, return 0, err } + rpcCtx = s.adminMac.WithMacaroonAuth(rpcCtx) resp, err := s.client.EstimateFee( rpcCtx, &lnrpc.EstimateFeeRequest{ @@ -216,6 +221,7 @@ func (s *lightningClient) payInvoice(ctx context.Context, invoice string, hash := lntypes.Hash(*payReq.PaymentHash) + ctx = s.adminMac.WithMacaroonAuth(ctx) for { // Create no timeout context as this call can block for a long // time. @@ -329,6 +335,7 @@ func (s *lightningClient) AddInvoice(ctx context.Context, rpcIn.RHash = in.Hash[:] } + rpcCtx = s.adminMac.WithMacaroonAuth(rpcCtx) resp, err := s.client.AddInvoice(rpcCtx, rpcIn) if err != nil { return lntypes.Hash{}, "", err diff --git a/lndclient/lnd_services.go b/lndclient/lnd_services.go index 2249d67..8350d6f 100644 --- a/lndclient/lnd_services.go +++ b/lndclient/lnd_services.go @@ -68,12 +68,17 @@ func NewLndServices(lndAddress, application, network, macaroonDir, return nil, err } - lightningClient := newLightningClient(conn, chainParams) + lightningClient := newLightningClient( + conn, chainParams, macaroons.adminMac, + ) + // With our macaroons obtained, we'll ensure that the network for lnd + // matches our expected network. info, err := lightningClient.GetInfo(context.Background()) if err != nil { conn.Close() - return nil, err + return nil, fmt.Errorf("unable to get info for lnd "+ + "node: %v", err) } if network != info.Network { conn.Close() @@ -82,10 +87,12 @@ func NewLndServices(lndAddress, application, network, macaroonDir, ) } - notifierClient := newChainNotifierClient(conn) - signerClient := newSignerClient(conn) - walletKitClient := newWalletKitClient(conn) - invoicesClient := newInvoicesClient(conn) + // With the network check passed, we'll now initialize the rest of the + // sub-sever connections, giving each of them their specific macaroon. + notifierClient := newChainNotifierClient(conn, macaroons.chainMac) + signerClient := newSignerClient(conn, macaroons.signerMac) + walletKitClient := newWalletKitClient(conn, macaroons.walletKitMac) + invoicesClient := newInvoicesClient(conn, macaroons.invoiceMac) cleanup := func() { logger.Debugf("Closing lnd connection") diff --git a/lndclient/signer_client.go b/lndclient/signer_client.go index e9397ba..a5cad28 100644 --- a/lndclient/signer_client.go +++ b/lndclient/signer_client.go @@ -17,12 +17,16 @@ type SignerClient interface { } type signerClient struct { - client signrpc.SignerClient + client signrpc.SignerClient + signerMac serializedMacaroon } -func newSignerClient(conn *grpc.ClientConn) *signerClient { +func newSignerClient(conn *grpc.ClientConn, + signerMac serializedMacaroon) *signerClient { + return &signerClient{ - client: signrpc.NewSignerClient(conn), + client: signrpc.NewSignerClient(conn), + signerMac: signerMac, } } @@ -76,6 +80,7 @@ func (s *signerClient) SignOutputRaw(ctx context.Context, tx *wire.MsgTx, rpcCtx, cancel := context.WithTimeout(ctx, rpcTimeout) defer cancel() + rpcCtx = s.signerMac.WithMacaroonAuth(rpcCtx) resp, err := s.client.SignOutputRaw(rpcCtx, &signrpc.SignReq{ RawTxBytes: txRaw, diff --git a/lndclient/walletkit_client.go b/lndclient/walletkit_client.go index a477978..7b7156b 100644 --- a/lndclient/walletkit_client.go +++ b/lndclient/walletkit_client.go @@ -34,12 +34,16 @@ type WalletKitClient interface { } type walletKitClient struct { - client walletrpc.WalletKitClient + client walletrpc.WalletKitClient + walletKitMac serializedMacaroon } -func newWalletKitClient(conn *grpc.ClientConn) *walletKitClient { +func newWalletKitClient(conn *grpc.ClientConn, + walletKitMac serializedMacaroon) *walletKitClient { + return &walletKitClient{ - client: walletrpc.NewWalletKitClient(conn), + client: walletrpc.NewWalletKitClient(conn), + walletKitMac: walletKitMac, } } @@ -49,6 +53,7 @@ func (m *walletKitClient) DeriveNextKey(ctx context.Context, family int32) ( rpcCtx, cancel := context.WithTimeout(ctx, rpcTimeout) defer cancel() + rpcCtx = m.walletKitMac.WithMacaroonAuth(rpcCtx) resp, err := m.client.DeriveNextKey(rpcCtx, &walletrpc.KeyReq{ KeyFamily: family, }) @@ -76,6 +81,7 @@ func (m *walletKitClient) DeriveKey(ctx context.Context, in *keychain.KeyLocator rpcCtx, cancel := context.WithTimeout(ctx, rpcTimeout) defer cancel() + rpcCtx = m.walletKitMac.WithMacaroonAuth(rpcCtx) resp, err := m.client.DeriveKey(rpcCtx, &signrpc.KeyLocator{ KeyFamily: int32(in.Family), KeyIndex: int32(in.Index), @@ -101,6 +107,7 @@ func (m *walletKitClient) NextAddr(ctx context.Context) ( rpcCtx, cancel := context.WithTimeout(ctx, rpcTimeout) defer cancel() + rpcCtx = m.walletKitMac.WithMacaroonAuth(rpcCtx) resp, err := m.client.NextAddr(rpcCtx, &walletrpc.AddrRequest{}) if err != nil { return nil, err @@ -125,6 +132,7 @@ func (m *walletKitClient) PublishTransaction(ctx context.Context, rpcCtx, cancel := context.WithTimeout(ctx, rpcTimeout) defer cancel() + rpcCtx = m.walletKitMac.WithMacaroonAuth(rpcCtx) _, err = m.client.PublishTransaction(rpcCtx, &walletrpc.Transaction{ TxHex: txHex, }) @@ -147,6 +155,7 @@ func (m *walletKitClient) SendOutputs(ctx context.Context, rpcCtx, cancel := context.WithTimeout(ctx, rpcTimeout) defer cancel() + rpcCtx = m.walletKitMac.WithMacaroonAuth(rpcCtx) resp, err := m.client.SendOutputs(rpcCtx, &walletrpc.SendOutputsRequest{ Outputs: rpcOutputs, SatPerKw: int64(feeRate), @@ -169,6 +178,7 @@ func (m *walletKitClient) EstimateFee(ctx context.Context, confTarget int32) ( rpcCtx, cancel := context.WithTimeout(ctx, rpcTimeout) defer cancel() + rpcCtx = m.walletKitMac.WithMacaroonAuth(rpcCtx) resp, err := m.client.EstimateFee(rpcCtx, &walletrpc.EstimateFeeRequest{ ConfTarget: int32(confTarget), })