From 4213a190d5204176132e2f27e7df235639d4adbf Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Tue, 27 Feb 2024 16:17:09 +0100 Subject: [PATCH] Use `X-Request-Id` as canonical request identifier (if available) If `X-Request-Id` is available in an HTTP request made against the CA server, it'll be used as the identifier for the request. This slightly changes the existing behavior, which relied on the custom `X-Smallstep-Id` header, but usage of that header is currently not very widespread, and `X-Request-Id` is more generally known for the use case `X-Smallstep-Id` is used for. `X-Smallstep-Id` is currently still considered, but it'll only be used if `X-Request-Id` is not set. --- logging/context.go | 23 +++++++--- logging/context_test.go | 94 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 5 deletions(-) create mode 100644 logging/context_test.go diff --git a/logging/context.go b/logging/context.go index b24b3638..ab8464d0 100644 --- a/logging/context.go +++ b/logging/context.go @@ -21,14 +21,27 @@ func NewRequestID() string { return xid.New().String() } -// RequestID returns a new middleware that gets the given header and sets it -// in the context so it can be written in the logger. If the header does not -// exists or it's the empty string, it uses github.com/rs/xid to create a new -// one. +// defaultRequestIDHeader is the header name used for propagating +// request IDs. If available in an HTTP request, it'll be used instead +// of the X-Smallstep-Id header. +const defaultRequestIDHeader = "X-Request-Id" + +// RequestID returns a new middleware that obtains the current request ID +// and sets it in the context. It first tries to read the request ID from +// the "X-Request-Id" header. If that's not set, it tries to read it from +// the provided header name. If the header does not exist or its value is +// the empty string, it uses github.com/rs/xid to create a new one. func RequestID(headerName string) func(next http.Handler) http.Handler { + if headerName == "" { + headerName = defaultTraceIDHeader + } return func(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, req *http.Request) { - requestID := req.Header.Get(headerName) + requestID := req.Header.Get(defaultRequestIDHeader) + if requestID == "" { + requestID = req.Header.Get(headerName) + } + if requestID == "" { requestID = NewRequestID() req.Header.Set(headerName, requestID) diff --git a/logging/context_test.go b/logging/context_test.go new file mode 100644 index 00000000..c519539d --- /dev/null +++ b/logging/context_test.go @@ -0,0 +1,94 @@ +package logging + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newRequest(t *testing.T) *http.Request { + r, err := http.NewRequest(http.MethodGet, "https://example.com", http.NoBody) + require.NoError(t, err) + return r +} + +func TestRequestID(t *testing.T) { + requestWithID := newRequest(t) + requestWithID.Header.Set("X-Request-Id", "reqID") + requestWithoutID := newRequest(t) + requestWithEmptyHeader := newRequest(t) + requestWithEmptyHeader.Header.Set("X-Request-Id", "") + requestWithSmallstepID := newRequest(t) + requestWithSmallstepID.Header.Set("X-Smallstep-Id", "smallstepID") + + tests := []struct { + name string + headerName string + handler http.HandlerFunc + req *http.Request + }{ + { + name: "default-request-id", + headerName: defaultTraceIDHeader, + handler: func(_ http.ResponseWriter, r *http.Request) { + assert.Empty(t, r.Header.Get("X-Smallstep-Id")) + assert.Equal(t, "reqID", r.Header.Get("X-Request-Id")) + reqID, ok := GetRequestID(r.Context()) + if assert.True(t, ok) { + assert.Equal(t, "reqID", reqID) + } + }, + req: requestWithID, + }, + { + name: "no-request-id", + headerName: "X-Request-Id", + handler: func(_ http.ResponseWriter, r *http.Request) { + assert.Empty(t, r.Header.Get("X-Smallstep-Id")) + value := r.Header.Get("X-Request-Id") + assert.NotEmpty(t, value) + reqID, ok := GetRequestID(r.Context()) + if assert.True(t, ok) { + assert.Equal(t, value, reqID) + } + }, + req: requestWithoutID, + }, + { + name: "empty-header-name", + headerName: "", + handler: func(_ http.ResponseWriter, r *http.Request) { + assert.Empty(t, r.Header.Get("X-Request-Id")) + value := r.Header.Get("X-Smallstep-Id") + assert.NotEmpty(t, value) + reqID, ok := GetRequestID(r.Context()) + if assert.True(t, ok) { + assert.Equal(t, value, reqID) + } + }, + req: requestWithEmptyHeader, + }, + { + name: "fallback-header-name", + headerName: defaultTraceIDHeader, + handler: func(_ http.ResponseWriter, r *http.Request) { + assert.Empty(t, r.Header.Get("X-Request-Id")) + assert.Equal(t, "smallstepID", r.Header.Get("X-Smallstep-Id")) + reqID, ok := GetRequestID(r.Context()) + if assert.True(t, ok) { + assert.Equal(t, "smallstepID", reqID) + } + }, + req: requestWithSmallstepID, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := RequestID(tt.headerName) + h(tt.handler).ServeHTTP(httptest.NewRecorder(), tt.req) + }) + } +}