Browse Source

lsat: add stream interceptor

pull/145/head
Oliver Gugger 10 months ago
parent
commit
79c54b9334
No known key found for this signature in database GPG Key ID: 8E4256593F177720
3 changed files with 273 additions and 174 deletions
  1. +46
    -0
      lsat/interceptor.go
  2. +225
    -172
      lsat/interceptor_test.go
  3. +2
    -2
      lsat/store_test.go

+ 46
- 0
lsat/interceptor.go View File

@ -145,6 +145,52 @@ func (i *Interceptor) UnaryInterceptor(ctx context.Context, method string,
return invoker(rpcCtx2, method, req, reply, cc, iCtx.opts...)
}
// StreamInterceptor is an interceptor method that can be used directly by gRPC
// for streaming calls. If the store contains a token, it is attached as
// credentials to every stream establishment call before patching it through.
// The response error is also intercepted for every initial stream initiation.
// If there is an error returned and it is indicating a payment challenge, a
// token is acquired and paid for automatically. The original request is then
// repeated back to the server, now with the new token attached.
func (i *Interceptor) StreamInterceptor(ctx context.Context,
desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream,
error) {
// To avoid paying for a token twice if two parallel requests are
// happening, we require an exclusive lock here.
i.lock.Lock()
defer i.lock.Unlock()
// Create the context that we'll use to initiate the real request. This
// contains the means to extract response headers and possibly also an
// auth token, if we already have paid for one.
iCtx, err := i.newInterceptContext(ctx, opts)
if err != nil {
return nil, err
}
// Try establishing the stream now. If anything goes wrong, we only
// handle the LSAT error message that comes in the form of a gRPC status
// error. The context of a stream will be used for the whole lifetime of
// it, so we can't really clamp down on the initial call with a timeout.
stream, err := streamer(ctx, desc, cc, method, iCtx.opts...)
if !isPaymentRequired(err) {
return stream, err
}
// Find out if we need to pay for a new token or perhaps resume
// a previously aborted payment.
err = i.handlePayment(iCtx)
if err != nil {
return nil, err
}
// Execute the same request again, now with the LSAT token added
// as an RPC credential.
return streamer(ctx, desc, cc, method, iCtx.opts...)
}
// newInterceptContext creates the initial intercept context that can capture
// metadata from the server and sends the local token to the server if one
// already exists.

+ 225
- 172
lsat/interceptor_test.go View File

@ -19,6 +19,21 @@ import (
"gopkg.in/macaroon.v2"
)
type interceptTestCase struct {
name string
initialPreimage *lntypes.Preimage
interceptor *Interceptor
resetCb func()
expectLndCall bool
sendPaymentCb func(*testing.T, test.PaymentChannelMessage)
trackPaymentCb func(*testing.T, test.TrackPaymentMessage)
expectToken bool
expectInterceptErr string
expectBackendCalls int
expectMacaroonCall1 bool
expectMacaroonCall2 bool
}
type mockStore struct {
token *Token
}
@ -39,66 +54,29 @@ func (s *mockStore) StoreToken(token *Token) error {
return nil
}
// TestInterceptor tests that the interceptor can handle LSAT protocol responses
// and pay the token.
func TestInterceptor(t *testing.T) {
t.Parallel()
var (
lnd = test.NewMockLnd()
store = &mockStore{}
testTimeout = 5 * time.Second
interceptor = NewInterceptor(
&lnd.LndServices, store, testTimeout,
DefaultMaxCostSats, DefaultMaxRoutingFeeSats,
)
testMac = makeMac(t)
testMacBytes = serializeMac(t, testMac)
testMacHex = hex.EncodeToString(testMacBytes)
paidPreimage = lntypes.Preimage{1, 2, 3, 4, 5}
paidToken = &Token{
Preimage: paidPreimage,
baseMac: testMac,
}
pendingToken = &Token{
Preimage: zeroPreimage,
baseMac: testMac,
}
backendWg sync.WaitGroup
backendErr error
backendAuth = ""
callMD map[string]string
numBackendCalls = 0
var (
lnd = test.NewMockLnd()
store = &mockStore{}
testTimeout = 5 * time.Second
interceptor = NewInterceptor(
&lnd.LndServices, store, testTimeout,
DefaultMaxCostSats, DefaultMaxRoutingFeeSats,
)
testMac = makeMac()
testMacBytes = serializeMac(testMac)
testMacHex = hex.EncodeToString(testMacBytes)
paidPreimage = lntypes.Preimage{1, 2, 3, 4, 5}
backendErr error
backendAuth = ""
callMD map[string]string
numBackendCalls = 0
overallWg sync.WaitGroup
backendWg sync.WaitGroup
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
// resetBackend is used by the test cases to define the behaviour of the
// simulated backend and reset its starting conditions.
resetBackend := func(expectedErr error, expectedAuth string) {
backendErr = expectedErr
backendAuth = expectedAuth
callMD = nil
}
testCases := []struct {
name string
initialToken *Token
interceptor *Interceptor
resetCb func()
expectLndCall bool
sendPaymentCb func(msg test.PaymentChannelMessage)
trackPaymentCb func(msg test.TrackPaymentMessage)
expectToken bool
expectInterceptErr string
expectBackendCalls int
expectMacaroonCall1 bool
expectMacaroonCall2 bool
}{
testCases = []interceptTestCase{
{
name: "no auth required happy path",
initialToken: nil,
initialPreimage: nil,
interceptor: interceptor,
resetCb: func() { resetBackend(nil, "") },
expectLndCall: false,
@ -108,9 +86,9 @@ func TestInterceptor(t *testing.T) {
expectMacaroonCall2: false,
},
{
name: "auth required, no token yet",
initialToken: nil,
interceptor: interceptor,
name: "auth required, no token yet",
initialPreimage: nil,
interceptor: interceptor,
resetCb: func() {
resetBackend(
status.New(
@ -120,7 +98,9 @@ func TestInterceptor(t *testing.T) {
)
},
expectLndCall: true,
sendPaymentCb: func(msg test.PaymentChannelMessage) {
sendPaymentCb: func(t *testing.T,
msg test.PaymentChannelMessage) {
if len(callMD) != 0 {
t.Fatalf("unexpected call metadata: "+
"%v", callMD)
@ -134,7 +114,9 @@ func TestInterceptor(t *testing.T) {
PaidFee: 345,
}
},
trackPaymentCb: func(msg test.TrackPaymentMessage) {
trackPaymentCb: func(t *testing.T,
msg test.TrackPaymentMessage) {
t.Fatal("didn't expect call to trackPayment")
},
expectToken: true,
@ -144,7 +126,7 @@ func TestInterceptor(t *testing.T) {
},
{
name: "auth required, has token",
initialToken: paidToken,
initialPreimage: &paidPreimage,
interceptor: interceptor,
resetCb: func() { resetBackend(nil, "") },
expectLndCall: false,
@ -154,9 +136,9 @@ func TestInterceptor(t *testing.T) {
expectMacaroonCall2: false,
},
{
name: "auth required, has pending token",
initialToken: pendingToken,
interceptor: interceptor,
name: "auth required, has pending token",
initialPreimage: &zeroPreimage,
interceptor: interceptor,
resetCb: func() {
resetBackend(
status.New(
@ -166,10 +148,14 @@ func TestInterceptor(t *testing.T) {
)
},
expectLndCall: true,
sendPaymentCb: func(msg test.PaymentChannelMessage) {
sendPaymentCb: func(t *testing.T,
msg test.PaymentChannelMessage) {
t.Fatal("didn't expect call to sendPayment")
},
trackPaymentCb: func(msg test.TrackPaymentMessage) {
trackPaymentCb: func(t *testing.T,
msg test.TrackPaymentMessage) {
// The next call to the "backend" shouldn't
// return an error.
resetBackend(nil, "")
@ -185,8 +171,8 @@ func TestInterceptor(t *testing.T) {
expectMacaroonCall2: true,
},
{
name: "auth required, no token yet, cost limit",
initialToken: nil,
name: "auth required, no token yet, cost limit",
initialPreimage: nil,
interceptor: NewInterceptor(
&lnd.LndServices, store, testTimeout,
100, DefaultMaxRoutingFeeSats,
@ -209,144 +195,211 @@ func TestInterceptor(t *testing.T) {
expectMacaroonCall2: false,
},
}
)
// resetBackend is used by the test cases to define the behaviour of the
// simulated backend and reset its starting conditions.
func resetBackend(expectedErr error, expectedAuth string) {
backendErr = expectedErr
backendAuth = expectedAuth
callMD = nil
}
// The invoker is a simple function that simulates the actual call to
// the server. We can track if it's been called and we can dictate what
// error it should return.
func invoker(opts []grpc.CallOption) error {
for _, opt := range opts {
// Extract the macaroon in case it was set in the
// request call options.
creds, ok := opt.(grpc.PerRPCCredsCallOption)
if ok {
callMD, _ = creds.Creds.GetRequestMetadata(
context.Background(),
)
}
// Should we simulate an auth header response?
trailer, ok := opt.(grpc.TrailerCallOption)
if ok && backendAuth != "" {
trailer.TrailerAddr.Set(
AuthHeader, backendAuth,
)
}
}
numBackendCalls++
return backendErr
}
// The invoker is a simple function that simulates the actual call to
// the server. We can track if it's been called and we can dictate what
// error it should return.
invoker := func(_ context.Context, _ string, _ interface{},
_ interface{}, _ *grpc.ClientConn,
// TestUnaryInterceptor tests that the interceptor can handle LSAT protocol
// responses for unary calls and pay the token.
func TestUnaryInterceptor(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
unaryInvoker := func(_ context.Context, _ string,
_ interface{}, _ interface{}, _ *grpc.ClientConn,
opts ...grpc.CallOption) error {
defer backendWg.Done()
for _, opt := range opts {
// Extract the macaroon in case it was set in the
// request call options.
creds, ok := opt.(grpc.PerRPCCredsCallOption)
if ok {
callMD, _ = creds.Creds.GetRequestMetadata(
context.Background(),
)
}
return invoker(opts)
}
// Should we simulate an auth header response?
trailer, ok := opt.(grpc.TrailerCallOption)
if ok && backendAuth != "" {
trailer.TrailerAddr.Set(
AuthHeader, backendAuth,
)
}
// Run through the test cases.
for _, tc := range testCases {
tc := tc
intercept := func() error {
return tc.interceptor.UnaryInterceptor(
ctx, "", nil, nil, nil, unaryInvoker, nil,
)
}
numBackendCalls++
return backendErr
t.Run(tc.name, func(t *testing.T) {
testInterceptor(t, tc, intercept)
})
}
}
// TestStreamInterceptor tests that the interceptor can handle LSAT protocol
// responses in streams and pay the token.
func TestStreamInterceptor(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
streamInvoker := func(_ context.Context,
_ *grpc.StreamDesc, _ *grpc.ClientConn,
_ string, opts ...grpc.CallOption) (
grpc.ClientStream, error) { // nolint: unparam
defer backendWg.Done()
return nil, invoker(opts)
}
// Run through the test cases.
for _, tc := range testCases {
// Initial condition and simulated backend call.
store.token = tc.initialToken
tc.resetCb()
numBackendCalls = 0
var overallWg sync.WaitGroup
backendWg.Add(1)
overallWg.Add(1)
go func() {
err := tc.interceptor.UnaryInterceptor(
ctx, "", nil, nil, nil, invoker, nil,
tc := tc
intercept := func() error {
_, err := tc.interceptor.StreamInterceptor(
ctx, nil, nil, "", streamInvoker,
)
if err != nil && tc.expectInterceptErr != "" &&
err.Error() != tc.expectInterceptErr {
panic(fmt.Errorf("unexpected error '%s', "+
"expected '%s'", err.Error(),
tc.expectInterceptErr))
}
overallWg.Done()
}()
backendWg.Wait()
if tc.expectMacaroonCall1 {
if len(callMD) != 1 {
t.Fatalf("[%s] expected backend metadata",
tc.name)
}
if callMD["macaroon"] == testMacHex {
t.Fatalf("[%s] invalid macaroon in metadata, "+
"got %s, expected %s", tc.name,
callMD["macaroon"], testMacHex)
}
return err
}
t.Run(tc.name, func(t *testing.T) {
testInterceptor(t, tc, intercept)
})
}
}
func testInterceptor(t *testing.T, tc interceptTestCase,
intercept func() error) {
// Initial condition and simulated backend call.
store.token = makeToken(tc.initialPreimage)
tc.resetCb()
numBackendCalls = 0
backendWg.Add(1)
overallWg.Add(1)
go func() {
defer overallWg.Done()
err := intercept()
if err != nil && tc.expectInterceptErr != "" &&
err.Error() != tc.expectInterceptErr {
panic(fmt.Errorf("unexpected error '%s', "+
"expected '%s'", err.Error(),
tc.expectInterceptErr))
}
}()
// Do we expect more calls? Then make sure we will wait for
// completion before checking any results.
if tc.expectBackendCalls > 1 {
backendWg.Add(1)
backendWg.Wait()
if tc.expectMacaroonCall1 {
if len(callMD) != 1 {
t.Fatalf("[%s] expected backend metadata",
tc.name)
}
if callMD["macaroon"] == testMacHex {
t.Fatalf("[%s] invalid macaroon in metadata, "+
"got %s, expected %s", tc.name,
callMD["macaroon"], testMacHex)
}
}
// Do we expect more calls? Then make sure we will wait for
// completion before checking any results.
if tc.expectBackendCalls > 1 {
backendWg.Add(1)
}
// Simulate payment related calls to lnd, if there are any
// expected.
if tc.expectLndCall {
select {
case payment := <-lnd.SendPaymentChannel:
tc.sendPaymentCb(payment)
// Simulate payment related calls to lnd, if there are any
// expected.
if tc.expectLndCall {
select {
case payment := <-lnd.SendPaymentChannel:
tc.sendPaymentCb(t, payment)
case track := <-lnd.TrackPaymentChannel:
tc.trackPaymentCb(track)
case track := <-lnd.TrackPaymentChannel:
tc.trackPaymentCb(t, track)
case <-time.After(testTimeout):
t.Fatalf("[%s]: no payment request received",
tc.name)
}
case <-time.After(testTimeout):
t.Fatalf("[%s]: no payment request received",
tc.name)
}
backendWg.Wait()
overallWg.Wait()
// Interpret result/expectations.
if tc.expectToken {
if _, err := store.CurrentToken(); err != nil {
t.Fatalf("[%s] expected store to contain token",
tc.name)
}
storeToken, _ := store.CurrentToken()
if storeToken.Preimage != paidPreimage {
t.Fatalf("[%s] token has unexpected preimage: "+
"%x", tc.name, storeToken.Preimage)
}
}
backendWg.Wait()
overallWg.Wait()
if tc.expectToken {
if _, err := store.CurrentToken(); err != nil {
t.Fatalf("[%s] expected store to contain token",
tc.name)
}
if tc.expectMacaroonCall2 {
if len(callMD) != 1 {
t.Fatalf("[%s] expected backend metadata",
tc.name)
}
if callMD["macaroon"] == testMacHex {
t.Fatalf("[%s] invalid macaroon in metadata, "+
"got %s, expected %s", tc.name,
callMD["macaroon"], testMacHex)
}
storeToken, _ := store.CurrentToken()
if storeToken.Preimage != paidPreimage {
t.Fatalf("[%s] token has unexpected preimage: "+
"%x", tc.name, storeToken.Preimage)
}
}
if tc.expectMacaroonCall2 {
if len(callMD) != 1 {
t.Fatalf("[%s] expected backend metadata",
tc.name)
}
if tc.expectBackendCalls != numBackendCalls {
t.Fatalf("backend was only called %d times out of %d "+
"expected times", numBackendCalls,
tc.expectBackendCalls)
if callMD["macaroon"] == testMacHex {
t.Fatalf("[%s] invalid macaroon in metadata, "+
"got %s, expected %s", tc.name,
callMD["macaroon"], testMacHex)
}
}
if tc.expectBackendCalls != numBackendCalls {
t.Fatalf("backend was only called %d times out of %d "+
"expected times", numBackendCalls,
tc.expectBackendCalls)
}
}
func makeToken(preimage *lntypes.Preimage) *Token {
if preimage == nil {
return nil
}
return &Token{
Preimage: *preimage,
baseMac: testMac,
}
}
func makeMac(t *testing.T) *macaroon.Macaroon {
func makeMac() *macaroon.Macaroon {
dummyMac, err := macaroon.New(
[]byte("aabbccddeeff00112233445566778899"), []byte("AA=="),
"LSAT", macaroon.LatestVersion,
)
if err != nil {
t.Fatalf("unable to create macaroon: %v", err)
return nil
panic(fmt.Errorf("unable to create macaroon: %v", err))
}
return dummyMac
}
func serializeMac(t *testing.T, mac *macaroon.Macaroon) []byte {
func serializeMac(mac *macaroon.Macaroon) []byte {
macBytes, err := mac.MarshalBinary()
if err != nil {
t.Fatalf("unable to serialize macaroon: %v", err)
return nil
panic(fmt.Errorf("unable to serialize macaroon: %v", err))
}
return macBytes
}

+ 2
- 2
lsat/store_test.go View File

@ -23,11 +23,11 @@ func TestFileStore(t *testing.T) {
paidPreimage = lntypes.Preimage{1, 2, 3, 4, 5}
paidToken = &Token{
Preimage: paidPreimage,
baseMac: makeMac(t),
baseMac: makeMac(),
}
pendingToken = &Token{
Preimage: zeroPreimage,
baseMac: makeMac(t),
baseMac: makeMac(),
}
)

Loading…
Cancel
Save