diff --git a/lsat/interceptor.go b/lsat/interceptor.go index ba7066f..d0adc15 100644 --- a/lsat/interceptor.go +++ b/lsat/interceptor.go @@ -89,6 +89,15 @@ func NewInterceptor(lnd *lndclient.LndServices, store Store, } } +// interceptContext is a struct that contains all information about a call that +// is intercepted by the interceptor. +type interceptContext struct { + mainCtx context.Context + opts []grpc.CallOption + metadata *metadata.MD + token *Token +} + // UnaryInterceptor is an interceptor method that can be used directly by gRPC // for unary calls. If the store contains a token, it is attached as credentials // to every call before patching it through. The response error is also @@ -105,21 +114,100 @@ func (i *Interceptor) UnaryInterceptor(ctx context.Context, method string, i.lock.Lock() defer i.lock.Unlock() - addLsatCredentials := func(token *Token) error { - macaroon, err := token.PaidMacaroon() - if err != nil { - return err - } - opts = append(opts, grpc.PerRPCCredentials( - macaroons.NewMacaroonCredential(macaroon), - )) - return nil + // Create the context that we'll use to initiate the real request. This + // contains the means to extract response headers and possibly also an + // auth token, if we already have paid for one. + iCtx, err := i.newInterceptContext(ctx, opts) + if err != nil { + return err + } + + // Try executing the call now. If anything goes wrong, we only handle + // the LSAT error message that comes in the form of a gRPC status error. + rpcCtx, cancel := context.WithTimeout(ctx, i.callTimeout) + defer cancel() + err = invoker(rpcCtx, method, req, reply, cc, iCtx.opts...) + if !isPaymentRequired(err) { + return err + } + + // Find out if we need to pay for a new token or perhaps resume + // a previously aborted payment. + err = i.handlePayment(iCtx) + if err != nil { + return err + } + + // Execute the same request again, now with the LSAT + // token added as an RPC credential. + rpcCtx2, cancel2 := context.WithTimeout(ctx, i.callTimeout) + defer cancel2() + return invoker(rpcCtx2, method, req, reply, cc, iCtx.opts...) +} + +// StreamInterceptor is an interceptor method that can be used directly by gRPC +// for streaming calls. If the store contains a token, it is attached as +// credentials to every stream establishment call before patching it through. +// The response error is also intercepted for every initial stream initiation. +// If there is an error returned and it is indicating a payment challenge, a +// token is acquired and paid for automatically. The original request is then +// repeated back to the server, now with the new token attached. +func (i *Interceptor) StreamInterceptor(ctx context.Context, + desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, + streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, + error) { + + // To avoid paying for a token twice if two parallel requests are + // happening, we require an exclusive lock here. + i.lock.Lock() + defer i.lock.Unlock() + + // Create the context that we'll use to initiate the real request. This + // contains the means to extract response headers and possibly also an + // auth token, if we already have paid for one. + iCtx, err := i.newInterceptContext(ctx, opts) + if err != nil { + return nil, err + } + + // Try establishing the stream now. If anything goes wrong, we only + // handle the LSAT error message that comes in the form of a gRPC status + // error. The context of a stream will be used for the whole lifetime of + // it, so we can't really clamp down on the initial call with a timeout. + stream, err := streamer(ctx, desc, cc, method, iCtx.opts...) + if !isPaymentRequired(err) { + return stream, err + } + + // Find out if we need to pay for a new token or perhaps resume + // a previously aborted payment. + err = i.handlePayment(iCtx) + if err != nil { + return nil, err + } + + // Execute the same request again, now with the LSAT token added + // as an RPC credential. + return streamer(ctx, desc, cc, method, iCtx.opts...) +} + +// newInterceptContext creates the initial intercept context that can capture +// metadata from the server and sends the local token to the server if one +// already exists. +func (i *Interceptor) newInterceptContext(ctx context.Context, + opts []grpc.CallOption) (*interceptContext, error) { + + iCtx := &interceptContext{ + mainCtx: ctx, + opts: opts, + metadata: &metadata.MD{}, } // Let's see if the store already contains a token and what state it // might be in. If a previous call was aborted, we might have a pending // token that needs to be handled separately. - token, err := i.store.CurrentToken() + var err error + iCtx.token, err = i.store.CurrentToken() switch { // If there is no token yet, nothing to do at this point. case err == ErrNoToken: @@ -127,16 +215,18 @@ func (i *Interceptor) UnaryInterceptor(ctx context.Context, method string, // Some other error happened that we have to surface. case err != nil: log.Errorf("Failed to get token from store: %v", err) - return fmt.Errorf("getting token from store failed: %v", err) + return nil, fmt.Errorf("getting token from store failed: %v", + err) // Only if we have a paid token append it. We don't resume a pending // payment just yet, since we don't even know if a token is required for // this call. We also never send a pending payment to the server since // we know it's not valid. - case !token.isPending(): - if err = addLsatCredentials(token); err != nil { + case !iCtx.token.isPending(): + if err = i.addLsatCredentials(iCtx); err != nil { log.Errorf("Adding macaroon to request failed: %v", err) - return fmt.Errorf("adding macaroon failed: %v", err) + return nil, fmt.Errorf("adding macaroon failed: %v", + err) } } @@ -145,60 +235,59 @@ func (i *Interceptor) UnaryInterceptor(ctx context.Context, method string, // option. We execute the request and inspect the error. If it's the // LSAT specific payment required error, we might execute the same // method again later with the paid LSAT token. - trailerMetadata := &metadata.MD{} - opts = append(opts, grpc.Trailer(trailerMetadata)) - rpcCtx, cancel := context.WithTimeout(ctx, i.callTimeout) - defer cancel() - err = invoker(rpcCtx, method, req, reply, cc, opts...) - - // Only handle the LSAT error message that comes in the form of - // a gRPC status error. - if isPaymentRequired(err) { - paidToken, err := i.handlePayment(ctx, token, trailerMetadata) - if err != nil { - return err - } - if err = addLsatCredentials(paidToken); err != nil { - log.Errorf("Adding macaroon to request failed: %v", err) - return fmt.Errorf("adding macaroon failed: %v", err) - } - - // Execute the same request again, now with the LSAT - // token added as an RPC credential. - rpcCtx2, cancel2 := context.WithTimeout(ctx, i.callTimeout) - defer cancel2() - return invoker(rpcCtx2, method, req, reply, cc, opts...) - } - return err + iCtx.opts = append(iCtx.opts, grpc.Trailer(iCtx.metadata)) + return iCtx, nil } // handlePayment tries to obtain a valid token by either tracking the payment // status of a pending token or paying for a new one. -func (i *Interceptor) handlePayment(ctx context.Context, token *Token, - md *metadata.MD) (*Token, error) { - +func (i *Interceptor) handlePayment(iCtx *interceptContext) error { switch { // Resume/track a pending payment if it was interrupted for some reason. - case token != nil && token.isPending(): + case iCtx.token != nil && iCtx.token.isPending(): log.Infof("Payment of LSAT token is required, resuming/" + "tracking previous payment from pending LSAT token") - err := i.trackPayment(ctx, token) + err := i.trackPayment(iCtx.mainCtx, iCtx.token) if err != nil { - return nil, err + return err } - return token, nil // We don't have a token yet, try to get a new one. - case token == nil: + case iCtx.token == nil: // We don't have a token yet, get a new one. log.Infof("Payment of LSAT token is required, paying invoice") - return i.payLsatToken(ctx, md) + var err error + iCtx.token, err = i.payLsatToken(iCtx.mainCtx, iCtx.metadata) + if err != nil { + return err + } // We have a token and it's valid, nothing more to do here. default: log.Debugf("Found valid LSAT token to add to request") - return token, nil } + + if err := i.addLsatCredentials(iCtx); err != nil { + log.Errorf("Adding macaroon to request failed: %v", err) + return fmt.Errorf("adding macaroon failed: %v", err) + } + return nil +} + +// addLsatCredentials adds an LSAT token to the given intercept context. +func (i *Interceptor) addLsatCredentials(iCtx *interceptContext) error { + if iCtx.token == nil { + return fmt.Errorf("cannot add nil token to context") + } + + macaroon, err := iCtx.token.PaidMacaroon() + if err != nil { + return err + } + iCtx.opts = append(iCtx.opts, grpc.PerRPCCredentials( + macaroons.NewMacaroonCredential(macaroon), + )) + return nil } // payLsatToken reads the payment challenge from the response metadata and tries diff --git a/lsat/interceptor_test.go b/lsat/interceptor_test.go index 42a26a3..cf7bac2 100644 --- a/lsat/interceptor_test.go +++ b/lsat/interceptor_test.go @@ -19,6 +19,21 @@ import ( "gopkg.in/macaroon.v2" ) +type interceptTestCase struct { + name string + initialPreimage *lntypes.Preimage + interceptor *Interceptor + resetCb func() + expectLndCall bool + sendPaymentCb func(*testing.T, test.PaymentChannelMessage) + trackPaymentCb func(*testing.T, test.TrackPaymentMessage) + expectToken bool + expectInterceptErr string + expectBackendCalls int + expectMacaroonCall1 bool + expectMacaroonCall2 bool +} + type mockStore struct { token *Token } @@ -39,66 +54,29 @@ func (s *mockStore) StoreToken(token *Token) error { return nil } -// TestInterceptor tests that the interceptor can handle LSAT protocol responses -// and pay the token. -func TestInterceptor(t *testing.T) { - t.Parallel() - - var ( - lnd = test.NewMockLnd() - store = &mockStore{} - testTimeout = 5 * time.Second - interceptor = NewInterceptor( - &lnd.LndServices, store, testTimeout, - DefaultMaxCostSats, DefaultMaxRoutingFeeSats, - ) - testMac = makeMac(t) - testMacBytes = serializeMac(t, testMac) - testMacHex = hex.EncodeToString(testMacBytes) - paidPreimage = lntypes.Preimage{1, 2, 3, 4, 5} - paidToken = &Token{ - Preimage: paidPreimage, - baseMac: testMac, - } - pendingToken = &Token{ - Preimage: zeroPreimage, - baseMac: testMac, - } - backendWg sync.WaitGroup - backendErr error - backendAuth = "" - callMD map[string]string - numBackendCalls = 0 +var ( + lnd = test.NewMockLnd() + store = &mockStore{} + testTimeout = 5 * time.Second + interceptor = NewInterceptor( + &lnd.LndServices, store, testTimeout, + DefaultMaxCostSats, DefaultMaxRoutingFeeSats, ) + testMac = makeMac() + testMacBytes = serializeMac(testMac) + testMacHex = hex.EncodeToString(testMacBytes) + paidPreimage = lntypes.Preimage{1, 2, 3, 4, 5} + backendErr error + backendAuth = "" + callMD map[string]string + numBackendCalls = 0 + overallWg sync.WaitGroup + backendWg sync.WaitGroup - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() - - // resetBackend is used by the test cases to define the behaviour of the - // simulated backend and reset its starting conditions. - resetBackend := func(expectedErr error, expectedAuth string) { - backendErr = expectedErr - backendAuth = expectedAuth - callMD = nil - } - - testCases := []struct { - name string - initialToken *Token - interceptor *Interceptor - resetCb func() - expectLndCall bool - sendPaymentCb func(msg test.PaymentChannelMessage) - trackPaymentCb func(msg test.TrackPaymentMessage) - expectToken bool - expectInterceptErr string - expectBackendCalls int - expectMacaroonCall1 bool - expectMacaroonCall2 bool - }{ + testCases = []interceptTestCase{ { name: "no auth required happy path", - initialToken: nil, + initialPreimage: nil, interceptor: interceptor, resetCb: func() { resetBackend(nil, "") }, expectLndCall: false, @@ -108,9 +86,9 @@ func TestInterceptor(t *testing.T) { expectMacaroonCall2: false, }, { - name: "auth required, no token yet", - initialToken: nil, - interceptor: interceptor, + name: "auth required, no token yet", + initialPreimage: nil, + interceptor: interceptor, resetCb: func() { resetBackend( status.New( @@ -120,7 +98,9 @@ func TestInterceptor(t *testing.T) { ) }, expectLndCall: true, - sendPaymentCb: func(msg test.PaymentChannelMessage) { + sendPaymentCb: func(t *testing.T, + msg test.PaymentChannelMessage) { + if len(callMD) != 0 { t.Fatalf("unexpected call metadata: "+ "%v", callMD) @@ -134,7 +114,9 @@ func TestInterceptor(t *testing.T) { PaidFee: 345, } }, - trackPaymentCb: func(msg test.TrackPaymentMessage) { + trackPaymentCb: func(t *testing.T, + msg test.TrackPaymentMessage) { + t.Fatal("didn't expect call to trackPayment") }, expectToken: true, @@ -144,7 +126,7 @@ func TestInterceptor(t *testing.T) { }, { name: "auth required, has token", - initialToken: paidToken, + initialPreimage: &paidPreimage, interceptor: interceptor, resetCb: func() { resetBackend(nil, "") }, expectLndCall: false, @@ -154,9 +136,9 @@ func TestInterceptor(t *testing.T) { expectMacaroonCall2: false, }, { - name: "auth required, has pending token", - initialToken: pendingToken, - interceptor: interceptor, + name: "auth required, has pending token", + initialPreimage: &zeroPreimage, + interceptor: interceptor, resetCb: func() { resetBackend( status.New( @@ -166,10 +148,14 @@ func TestInterceptor(t *testing.T) { ) }, expectLndCall: true, - sendPaymentCb: func(msg test.PaymentChannelMessage) { + sendPaymentCb: func(t *testing.T, + msg test.PaymentChannelMessage) { + t.Fatal("didn't expect call to sendPayment") }, - trackPaymentCb: func(msg test.TrackPaymentMessage) { + trackPaymentCb: func(t *testing.T, + msg test.TrackPaymentMessage) { + // The next call to the "backend" shouldn't // return an error. resetBackend(nil, "") @@ -185,8 +171,8 @@ func TestInterceptor(t *testing.T) { expectMacaroonCall2: true, }, { - name: "auth required, no token yet, cost limit", - initialToken: nil, + name: "auth required, no token yet, cost limit", + initialPreimage: nil, interceptor: NewInterceptor( &lnd.LndServices, store, testTimeout, 100, DefaultMaxRoutingFeeSats, @@ -209,144 +195,211 @@ func TestInterceptor(t *testing.T) { expectMacaroonCall2: false, }, } +) + +// resetBackend is used by the test cases to define the behaviour of the +// simulated backend and reset its starting conditions. +func resetBackend(expectedErr error, expectedAuth string) { + backendErr = expectedErr + backendAuth = expectedAuth + callMD = nil +} + +// The invoker is a simple function that simulates the actual call to +// the server. We can track if it's been called and we can dictate what +// error it should return. +func invoker(opts []grpc.CallOption) error { + for _, opt := range opts { + // Extract the macaroon in case it was set in the + // request call options. + creds, ok := opt.(grpc.PerRPCCredsCallOption) + if ok { + callMD, _ = creds.Creds.GetRequestMetadata( + context.Background(), + ) + } + + // Should we simulate an auth header response? + trailer, ok := opt.(grpc.TrailerCallOption) + if ok && backendAuth != "" { + trailer.TrailerAddr.Set( + AuthHeader, backendAuth, + ) + } + } + numBackendCalls++ + return backendErr +} - // The invoker is a simple function that simulates the actual call to - // the server. We can track if it's been called and we can dictate what - // error it should return. - invoker := func(_ context.Context, _ string, _ interface{}, - _ interface{}, _ *grpc.ClientConn, +// TestUnaryInterceptor tests that the interceptor can handle LSAT protocol +// responses for unary calls and pay the token. +func TestUnaryInterceptor(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + unaryInvoker := func(_ context.Context, _ string, + _ interface{}, _ interface{}, _ *grpc.ClientConn, opts ...grpc.CallOption) error { defer backendWg.Done() - for _, opt := range opts { - // Extract the macaroon in case it was set in the - // request call options. - creds, ok := opt.(grpc.PerRPCCredsCallOption) - if ok { - callMD, _ = creds.Creds.GetRequestMetadata( - context.Background(), - ) - } + return invoker(opts) + } - // Should we simulate an auth header response? - trailer, ok := opt.(grpc.TrailerCallOption) - if ok && backendAuth != "" { - trailer.TrailerAddr.Set( - AuthHeader, backendAuth, - ) - } + // Run through the test cases. + for _, tc := range testCases { + tc := tc + intercept := func() error { + return tc.interceptor.UnaryInterceptor( + ctx, "", nil, nil, nil, unaryInvoker, nil, + ) } - numBackendCalls++ - return backendErr + t.Run(tc.name, func(t *testing.T) { + testInterceptor(t, tc, intercept) + }) + } +} + +// TestStreamInterceptor tests that the interceptor can handle LSAT protocol +// responses in streams and pay the token. +func TestStreamInterceptor(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + defer cancel() + + streamInvoker := func(_ context.Context, + _ *grpc.StreamDesc, _ *grpc.ClientConn, + _ string, opts ...grpc.CallOption) ( + grpc.ClientStream, error) { // nolint: unparam + + defer backendWg.Done() + return nil, invoker(opts) } // Run through the test cases. for _, tc := range testCases { - // Initial condition and simulated backend call. - store.token = tc.initialToken - tc.resetCb() - numBackendCalls = 0 - var overallWg sync.WaitGroup - backendWg.Add(1) - overallWg.Add(1) - go func() { - err := tc.interceptor.UnaryInterceptor( - ctx, "", nil, nil, nil, invoker, nil, + tc := tc + intercept := func() error { + _, err := tc.interceptor.StreamInterceptor( + ctx, nil, nil, "", streamInvoker, ) - if err != nil && tc.expectInterceptErr != "" && - err.Error() != tc.expectInterceptErr { - panic(fmt.Errorf("unexpected error '%s', "+ - "expected '%s'", err.Error(), - tc.expectInterceptErr)) - } - overallWg.Done() - }() - - backendWg.Wait() - if tc.expectMacaroonCall1 { - if len(callMD) != 1 { - t.Fatalf("[%s] expected backend metadata", - tc.name) - } - if callMD["macaroon"] == testMacHex { - t.Fatalf("[%s] invalid macaroon in metadata, "+ - "got %s, expected %s", tc.name, - callMD["macaroon"], testMacHex) - } + return err + } + t.Run(tc.name, func(t *testing.T) { + testInterceptor(t, tc, intercept) + }) + } +} + +func testInterceptor(t *testing.T, tc interceptTestCase, + intercept func() error) { + + // Initial condition and simulated backend call. + store.token = makeToken(tc.initialPreimage) + tc.resetCb() + numBackendCalls = 0 + backendWg.Add(1) + overallWg.Add(1) + go func() { + defer overallWg.Done() + err := intercept() + if err != nil && tc.expectInterceptErr != "" && + err.Error() != tc.expectInterceptErr { + panic(fmt.Errorf("unexpected error '%s', "+ + "expected '%s'", err.Error(), + tc.expectInterceptErr)) } + }() - // Do we expect more calls? Then make sure we will wait for - // completion before checking any results. - if tc.expectBackendCalls > 1 { - backendWg.Add(1) + backendWg.Wait() + if tc.expectMacaroonCall1 { + if len(callMD) != 1 { + t.Fatalf("[%s] expected backend metadata", + tc.name) + } + if callMD["macaroon"] == testMacHex { + t.Fatalf("[%s] invalid macaroon in metadata, "+ + "got %s, expected %s", tc.name, + callMD["macaroon"], testMacHex) } + } + + // Do we expect more calls? Then make sure we will wait for + // completion before checking any results. + if tc.expectBackendCalls > 1 { + backendWg.Add(1) + } - // Simulate payment related calls to lnd, if there are any - // expected. - if tc.expectLndCall { - select { - case payment := <-lnd.SendPaymentChannel: - tc.sendPaymentCb(payment) + // Simulate payment related calls to lnd, if there are any + // expected. + if tc.expectLndCall { + select { + case payment := <-lnd.SendPaymentChannel: + tc.sendPaymentCb(t, payment) - case track := <-lnd.TrackPaymentChannel: - tc.trackPaymentCb(track) + case track := <-lnd.TrackPaymentChannel: + tc.trackPaymentCb(t, track) - case <-time.After(testTimeout): - t.Fatalf("[%s]: no payment request received", - tc.name) - } + case <-time.After(testTimeout): + t.Fatalf("[%s]: no payment request received", + tc.name) } - backendWg.Wait() - overallWg.Wait() - - // Interpret result/expectations. - if tc.expectToken { - if _, err := store.CurrentToken(); err != nil { - t.Fatalf("[%s] expected store to contain token", - tc.name) - } - storeToken, _ := store.CurrentToken() - if storeToken.Preimage != paidPreimage { - t.Fatalf("[%s] token has unexpected preimage: "+ - "%x", tc.name, storeToken.Preimage) - } + } + backendWg.Wait() + overallWg.Wait() + + if tc.expectToken { + if _, err := store.CurrentToken(); err != nil { + t.Fatalf("[%s] expected store to contain token", + tc.name) } - if tc.expectMacaroonCall2 { - if len(callMD) != 1 { - t.Fatalf("[%s] expected backend metadata", - tc.name) - } - if callMD["macaroon"] == testMacHex { - t.Fatalf("[%s] invalid macaroon in metadata, "+ - "got %s, expected %s", tc.name, - callMD["macaroon"], testMacHex) - } + storeToken, _ := store.CurrentToken() + if storeToken.Preimage != paidPreimage { + t.Fatalf("[%s] token has unexpected preimage: "+ + "%x", tc.name, storeToken.Preimage) + } + } + if tc.expectMacaroonCall2 { + if len(callMD) != 1 { + t.Fatalf("[%s] expected backend metadata", + tc.name) } - if tc.expectBackendCalls != numBackendCalls { - t.Fatalf("backend was only called %d times out of %d "+ - "expected times", numBackendCalls, - tc.expectBackendCalls) + if callMD["macaroon"] == testMacHex { + t.Fatalf("[%s] invalid macaroon in metadata, "+ + "got %s, expected %s", tc.name, + callMD["macaroon"], testMacHex) } } + if tc.expectBackendCalls != numBackendCalls { + t.Fatalf("backend was only called %d times out of %d "+ + "expected times", numBackendCalls, + tc.expectBackendCalls) + } +} + +func makeToken(preimage *lntypes.Preimage) *Token { + if preimage == nil { + return nil + } + return &Token{ + Preimage: *preimage, + baseMac: testMac, + } } -func makeMac(t *testing.T) *macaroon.Macaroon { +func makeMac() *macaroon.Macaroon { dummyMac, err := macaroon.New( []byte("aabbccddeeff00112233445566778899"), []byte("AA=="), "LSAT", macaroon.LatestVersion, ) if err != nil { - t.Fatalf("unable to create macaroon: %v", err) - return nil + panic(fmt.Errorf("unable to create macaroon: %v", err)) } return dummyMac } -func serializeMac(t *testing.T, mac *macaroon.Macaroon) []byte { +func serializeMac(mac *macaroon.Macaroon) []byte { macBytes, err := mac.MarshalBinary() if err != nil { - t.Fatalf("unable to serialize macaroon: %v", err) - return nil + panic(fmt.Errorf("unable to serialize macaroon: %v", err)) } return macBytes } diff --git a/lsat/store_test.go b/lsat/store_test.go index 2fba6f7..101021c 100644 --- a/lsat/store_test.go +++ b/lsat/store_test.go @@ -23,11 +23,11 @@ func TestFileStore(t *testing.T) { paidPreimage = lntypes.Preimage{1, 2, 3, 4, 5} paidToken = &Token{ Preimage: paidPreimage, - baseMac: makeMac(t), + baseMac: makeMac(), } pendingToken = &Token{ Preimage: zeroPreimage, - baseMac: makeMac(t), + baseMac: makeMac(), } )