diff --git a/api/api_test.go b/api/api_test.go index cf988593..8090c6d4 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -884,16 +884,12 @@ func Test_Sign(t *testing.T) { CsrPEM: CertificateRequest{csr}, OTT: "foobarzar", }) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) invalid, err := json.Marshal(SignRequest{ CsrPEM: CertificateRequest{csr}, OTT: "", }) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) expected1 := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`) expected2 := []byte(`{"crt":"` + strings.ReplaceAll(stepCertPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(stepCertPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`) diff --git a/authority/provisioner/webhook.go b/authority/provisioner/webhook.go index c33dfa23..1e08b8b7 100644 --- a/authority/provisioner/webhook.go +++ b/authority/provisioner/webhook.go @@ -15,7 +15,7 @@ import ( "time" "github.com/pkg/errors" - "github.com/smallstep/certificates/logging" + "github.com/smallstep/certificates/internal/requestid" "github.com/smallstep/certificates/templates" "github.com/smallstep/certificates/webhook" "go.step.sm/linkedca" @@ -171,9 +171,8 @@ retry: return nil, err } - requestID, ok := logging.GetRequestID(ctx) - if ok { - req.Header.Set("X-Request-ID", requestID) + if requestID, ok := requestid.FromContext(ctx); ok { + req.Header.Set("X-Request-Id", requestID) } secret, err := base64.StdEncoding.DecodeString(w.Secret) diff --git a/authority/provisioner/webhook_test.go b/authority/provisioner/webhook_test.go index 60dcdbc7..90583418 100644 --- a/authority/provisioner/webhook_test.go +++ b/authority/provisioner/webhook_test.go @@ -17,13 +17,15 @@ import ( "testing" "time" - "github.com/smallstep/certificates/logging" - "github.com/smallstep/certificates/webhook" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.step.sm/crypto/pemutil" "go.step.sm/crypto/x509util" "go.step.sm/linkedca" + + "github.com/smallstep/certificates/internal/requestid" + "github.com/smallstep/certificates/webhook" ) func TestWebhookController_isCertTypeOK(t *testing.T) { @@ -101,10 +103,11 @@ func TestWebhookController_isCertTypeOK(t *testing.T) { } } -// withRequestID is a helper that calls into [logging.WithRequestID] and returns -// a new context with the requestID added to the provided context. -func withRequestID(ctx context.Context, requestID string) context.Context { - return logging.WithRequestID(ctx, requestID) +// withRequestID is a helper that calls into [requestid.NewContext] and returns +// a new context with the requestID added. +func withRequestID(t *testing.T, ctx context.Context, requestID string) context.Context { + t.Helper() + return requestid.NewContext(ctx, requestID) } func TestWebhookController_Enrich(t *testing.T) { @@ -138,7 +141,7 @@ func TestWebhookController_Enrich(t *testing.T) { webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}}, TemplateData: x509util.TemplateData{}, }, - ctx: withRequestID(context.Background(), "reqID"), + ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: true, Data: map[string]any{"role": "bar"}}}, expectErr: false, @@ -153,7 +156,7 @@ func TestWebhookController_Enrich(t *testing.T) { }, TemplateData: x509util.TemplateData{}, }, - ctx: withRequestID(context.Background(), "reqID"), + ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{ {Allow: true, Data: map[string]any{"role": "bar"}}, @@ -177,7 +180,7 @@ func TestWebhookController_Enrich(t *testing.T) { TemplateData: x509util.TemplateData{}, certType: linkedca.Webhook_X509, }, - ctx: withRequestID(context.Background(), "reqID"), + ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{ {Allow: true, Data: map[string]any{"role": "bar"}}, @@ -197,7 +200,7 @@ func TestWebhookController_Enrich(t *testing.T) { TemplateData: x509util.TemplateData{}, options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)}, }, - ctx: withRequestID(context.Background(), "reqID"), + ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: true, Data: map[string]any{"role": "bar"}}}, expectErr: false, @@ -220,7 +223,7 @@ func TestWebhookController_Enrich(t *testing.T) { webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}}, TemplateData: x509util.TemplateData{}, }, - ctx: withRequestID(context.Background(), "reqID"), + ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: true, @@ -235,7 +238,7 @@ func TestWebhookController_Enrich(t *testing.T) { PublicKey: []byte("bad"), })}, }, - ctx: withRequestID(context.Background(), "reqID"), + ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: true, @@ -296,7 +299,7 @@ func TestWebhookController_Authorize(t *testing.T) { client: http.DefaultClient, webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, }, - ctx: withRequestID(context.Background(), "reqID"), + ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: true}}, expectErr: false, @@ -307,7 +310,7 @@ func TestWebhookController_Authorize(t *testing.T) { webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING", CertType: linkedca.Webhook_X509.String()}}, certType: linkedca.Webhook_SSH, }, - ctx: withRequestID(context.Background(), "reqID"), + ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: false, @@ -318,7 +321,7 @@ func TestWebhookController_Authorize(t *testing.T) { webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)}, }, - ctx: withRequestID(context.Background(), "reqID"), + ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: true}}, expectErr: false, @@ -339,7 +342,7 @@ func TestWebhookController_Authorize(t *testing.T) { client: http.DefaultClient, webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, }, - ctx: withRequestID(context.Background(), "reqID"), + ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: true, @@ -352,7 +355,7 @@ func TestWebhookController_Authorize(t *testing.T) { PublicKey: []byte("bad"), })}, }, - ctx: withRequestID(context.Background(), "reqID"), + ctx: withRequestID(t, context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: true, @@ -568,7 +571,7 @@ func TestWebhook_Do(t *testing.T) { ctx := context.Background() if tc.requestID != "" { - ctx = withRequestID(context.Background(), tc.requestID) + ctx = withRequestID(t, ctx, tc.requestID) } ctx, cancel := context.WithTimeout(ctx, time.Second*10) defer cancel() diff --git a/ca/acmeClient.go b/ca/acmeClient.go index bb3b1d84..3ef2f191 100644 --- a/ca/acmeClient.go +++ b/ca/acmeClient.go @@ -48,6 +48,7 @@ func NewACMEClient(endpoint string, contact []string, opts ...ClientOption) (*AC return nil, errors.Wrapf(err, "creating GET request %s failed", endpoint) } req.Header.Set("User-Agent", UserAgent) + enforceRequestID(req) resp, err := ac.client.Do(req) if err != nil { return nil, errors.Wrapf(err, "client GET %s failed", endpoint) @@ -109,6 +110,7 @@ func (c *ACMEClient) GetNonce() (string, error) { return "", errors.Wrapf(err, "creating GET request %s failed", c.dir.NewNonce) } req.Header.Set("User-Agent", UserAgent) + enforceRequestID(req) resp, err := c.client.Do(req) if err != nil { return "", errors.Wrapf(err, "client GET %s failed", c.dir.NewNonce) @@ -188,6 +190,7 @@ func (c *ACMEClient) post(payload []byte, url string, headerOps ...withHeaderOpt } req.Header.Set("Content-Type", "application/jose+json") req.Header.Set("User-Agent", UserAgent) + enforceRequestID(req) resp, err := c.client.Do(req) if err != nil { return nil, errors.Wrapf(err, "client POST %s failed", c.dir.NewOrder) diff --git a/ca/ca.go b/ca/ca.go index 4146466d..ab4a1a9b 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -29,6 +29,7 @@ import ( "github.com/smallstep/certificates/cas/apiv1" "github.com/smallstep/certificates/db" "github.com/smallstep/certificates/internal/metrix" + "github.com/smallstep/certificates/internal/requestid" "github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/monitoring" "github.com/smallstep/certificates/scep" @@ -329,15 +330,21 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) { } // Add logger if configured + var legacyTraceHeader string if len(cfg.Logger) > 0 { logger, err := logging.New("ca", cfg.Logger) if err != nil { return nil, err } + legacyTraceHeader = logger.GetTraceHeader() handler = logger.Middleware(handler) insecureHandler = logger.Middleware(insecureHandler) } + // always use request ID middleware; traceHeader is provided for backwards compatibility (for now) + handler = requestid.New(legacyTraceHeader).Middleware(handler) + insecureHandler = requestid.New(legacyTraceHeader).Middleware(insecureHandler) + // Create context with all the necessary values. baseContext := buildContext(auth, scepAuthority, acmeDB, acmeLinker) diff --git a/ca/ca_test.go b/ca/ca_test.go index 7ad25cc6..a8c173c4 100644 --- a/ca/ca_test.go +++ b/ca/ca_test.go @@ -289,6 +289,9 @@ ZEp7knvU2psWRw== if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} + resp := &http.Response{ + Body: body, + } if rr.Code < http.StatusBadRequest { var sign api.SignResponse assert.FatalError(t, readJSON(body, &sign)) @@ -325,7 +328,7 @@ ZEp7knvU2psWRw== assert.FatalError(t, err) assert.Equals(t, intermediate, realIntermediate) } else { - err := readError(body) + err := readError(resp) if tc.errMsg == "" { assert.FatalError(t, errors.New("must validate response error")) } @@ -369,6 +372,9 @@ func TestCAProvisioners(t *testing.T) { if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} + resp := &http.Response{ + Body: body, + } if rr.Code < http.StatusBadRequest { var resp api.ProvisionersResponse @@ -379,7 +385,7 @@ func TestCAProvisioners(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, a, b) } else { - err := readError(body) + err := readError(resp) if tc.errMsg == "" { assert.FatalError(t, errors.New("must validate response error")) } @@ -436,12 +442,15 @@ func TestCAProvisionerEncryptedKey(t *testing.T) { if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} + resp := &http.Response{ + Body: body, + } if rr.Code < http.StatusBadRequest { var ek api.ProvisionerKeyResponse assert.FatalError(t, readJSON(body, &ek)) assert.Equals(t, ek.Key, tc.expectedKey) } else { - err := readError(body) + err := readError(resp) if tc.errMsg == "" { assert.FatalError(t, errors.New("must validate response error")) } @@ -498,12 +507,15 @@ func TestCARoot(t *testing.T) { if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} + resp := &http.Response{ + Body: body, + } if rr.Code < http.StatusBadRequest { var root api.RootResponse assert.FatalError(t, readJSON(body, &root)) assert.Equals(t, root.RootPEM.Certificate, rootCrt) } else { - err := readError(body) + err := readError(resp) if tc.errMsg == "" { assert.FatalError(t, errors.New("must validate response error")) } @@ -641,6 +653,9 @@ func TestCARenew(t *testing.T) { if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} + resp := &http.Response{ + Body: body, + } if rr.Code < http.StatusBadRequest { var sign api.SignResponse assert.FatalError(t, readJSON(body, &sign)) @@ -673,7 +688,7 @@ func TestCARenew(t *testing.T) { assert.Equals(t, *sign.TLSOptions, authority.DefaultTLSOptions) } else { - err := readError(body) + err := readError(resp) if tc.errMsg == "" { assert.FatalError(t, errors.New("must validate response error")) } diff --git a/ca/client.go b/ca/client.go index ac13e1fe..b18efbaf 100644 --- a/ca/client.go +++ b/ca/client.go @@ -27,12 +27,14 @@ import ( "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/ca/client" "github.com/smallstep/certificates/ca/identity" "github.com/smallstep/certificates/errs" "go.step.sm/cli-utils/step" "go.step.sm/crypto/jose" "go.step.sm/crypto/keyutil" "go.step.sm/crypto/pemutil" + "go.step.sm/crypto/randutil" "go.step.sm/crypto/x509util" "golang.org/x/net/http2" "google.golang.org/protobuf/encoding/protojson" @@ -83,8 +85,7 @@ func (c *uaClient) GetWithContext(ctx context.Context, u string) (*http.Response if err != nil { return nil, errors.Wrapf(err, "create GET %s request failed", u) } - req.Header.Set("User-Agent", UserAgent) - return c.Client.Do(req) + return c.Do(req) } func (c *uaClient) Post(u, contentType string, body io.Reader) (*http.Response, error) { @@ -97,12 +98,43 @@ func (c *uaClient) PostWithContext(ctx context.Context, u, contentType string, b return nil, errors.Wrapf(err, "create POST %s request failed", u) } req.Header.Set("Content-Type", contentType) - req.Header.Set("User-Agent", UserAgent) - return c.Client.Do(req) + return c.Do(req) +} + +// requestIDHeader is the header name used for propagating request IDs from +// the CA client to the CA and back again. +const requestIDHeader = "X-Request-Id" + +// newRequestID generates a new random UUIDv4 request ID. If it fails, +// the request ID will be the empty string. +func newRequestID() string { + requestID, err := randutil.UUIDv4() + if err != nil { + return "" + } + + return requestID +} + +// enforceRequestID checks if the X-Request-Id HTTP header is filled. If it's +// empty, the context is searched for a request ID. If that's also empty, a new +// request ID is generated. +func enforceRequestID(r *http.Request) { + if requestID := r.Header.Get(requestIDHeader); requestID == "" { + if reqID, ok := client.RequestIDFromContext(r.Context()); ok { + // TODO(hs): ensure the request ID from the context is fresh, and thus hasn't been + // used before by the client (unless it's a retry for the same request)? + requestID = reqID + } else { + requestID = newRequestID() + } + r.Header.Set(requestIDHeader, requestID) + } } func (c *uaClient) Do(req *http.Request) (*http.Response, error) { req.Header.Set("User-Agent", UserAgent) + enforceRequestID(req) return c.Client.Do(req) } @@ -375,8 +407,8 @@ func getTransportFromSHA256(endpoint, sum string) (http.RoundTripper, error) { if err != nil { return nil, err } - client := &Client{endpoint: u} - root, err := client.Root(sum) + caClient := &Client{endpoint: u} + root, err := caClient.Root(sum) if err != nil { return nil, err } @@ -610,7 +642,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var version api.VersionResponse if err := readJSON(resp.Body, &version); err != nil { @@ -640,7 +672,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var health api.HealthResponse if err := readJSON(resp.Body, &health); err != nil { @@ -675,7 +707,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var root api.RootResponse if err := readJSON(resp.Body, &root); err != nil { @@ -714,7 +746,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var sign api.SignResponse if err := readJSON(resp.Body, &sign); err != nil { @@ -737,14 +769,14 @@ func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) { func (c *Client) RenewWithContext(ctx context.Context, tr http.RoundTripper) (*api.SignResponse, error) { var retried bool u := c.endpoint.ResolveReference(&url.URL{Path: "/renew"}) - client := &http.Client{Transport: tr} + httpClient := &http.Client{Transport: tr} retry: req, err := http.NewRequestWithContext(ctx, "POST", u.String(), http.NoBody) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/json") - resp, err := client.Do(req) + resp, err := httpClient.Do(req) if err != nil { return nil, clientError(err) } @@ -753,7 +785,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var sign api.SignResponse if err := readJSON(resp.Body, &sign); err != nil { @@ -790,7 +822,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var sign api.SignResponse if err := readJSON(resp.Body, &sign); err != nil { @@ -814,14 +846,14 @@ func (c *Client) RekeyWithContext(ctx context.Context, req *api.RekeyRequest, tr return nil, errors.Wrap(err, "error marshaling request") } u := c.endpoint.ResolveReference(&url.URL{Path: "/rekey"}) - client := &http.Client{Transport: tr} + httpClient := &http.Client{Transport: tr} retry: httpReq, err := http.NewRequestWithContext(ctx, "POST", u.String(), bytes.NewReader(body)) if err != nil { return nil, err } httpReq.Header.Set("Content-Type", "application/json") - resp, err := client.Do(httpReq) + resp, err := httpClient.Do(httpReq) if err != nil { return nil, clientError(err) } @@ -830,7 +862,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var sign api.SignResponse if err := readJSON(resp.Body, &sign); err != nil { @@ -853,16 +885,16 @@ func (c *Client) RevokeWithContext(ctx context.Context, req *api.RevokeRequest, if err != nil { return nil, errors.Wrap(err, "error marshaling request") } - var client *uaClient + var uaClient *uaClient retry: if tr != nil { - client = newClient(tr) + uaClient = newClient(tr) } else { - client = c.client + uaClient = c.client } u := c.endpoint.ResolveReference(&url.URL{Path: "/revoke"}) - resp, err := client.PostWithContext(ctx, u.String(), "application/json", bytes.NewReader(body)) + resp, err := uaClient.PostWithContext(ctx, u.String(), "application/json", bytes.NewReader(body)) if err != nil { return nil, clientError(err) } @@ -871,7 +903,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var revoke api.RevokeResponse if err := readJSON(resp.Body, &revoke); err != nil { @@ -914,7 +946,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var provisioners api.ProvisionersResponse if err := readJSON(resp.Body, &provisioners); err != nil { @@ -946,7 +978,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var key api.ProvisionerKeyResponse if err := readJSON(resp.Body, &key); err != nil { @@ -976,7 +1008,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var roots api.RootsResponse if err := readJSON(resp.Body, &roots); err != nil { @@ -1006,7 +1038,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var federation api.FederationResponse if err := readJSON(resp.Body, &federation); err != nil { @@ -1040,7 +1072,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var sign api.SSHSignResponse if err := readJSON(resp.Body, &sign); err != nil { @@ -1074,7 +1106,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var renew api.SSHRenewResponse if err := readJSON(resp.Body, &renew); err != nil { @@ -1108,7 +1140,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var rekey api.SSHRekeyResponse if err := readJSON(resp.Body, &rekey); err != nil { @@ -1142,7 +1174,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var revoke api.SSHRevokeResponse if err := readJSON(resp.Body, &revoke); err != nil { @@ -1172,7 +1204,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var keys api.SSHRootsResponse if err := readJSON(resp.Body, &keys); err != nil { @@ -1202,7 +1234,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var keys api.SSHRootsResponse if err := readJSON(resp.Body, &keys); err != nil { @@ -1236,7 +1268,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var cfg api.SSHConfigResponse if err := readJSON(resp.Body, &cfg); err != nil { @@ -1275,7 +1307,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var check api.SSHCheckPrincipalResponse if err := readJSON(resp.Body, &check); err != nil { @@ -1304,7 +1336,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var hosts api.SSHGetHostsResponse if err := readJSON(resp.Body, &hosts); err != nil { @@ -1336,7 +1368,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var bastion api.SSHBastionResponse if err := readJSON(resp.Body, &bastion); err != nil { @@ -1504,12 +1536,13 @@ func readProtoJSON(r io.ReadCloser, m proto.Message) error { return protojson.Unmarshal(data, m) } -func readError(r io.ReadCloser) error { - defer r.Close() +func readError(r *http.Response) error { + defer r.Body.Close() apiErr := new(errs.Error) - if err := json.NewDecoder(r).Decode(apiErr); err != nil { - return err + if err := json.NewDecoder(r.Body).Decode(apiErr); err != nil { + return fmt.Errorf("failed decoding CA error response: %w", err) } + apiErr.RequestID = r.Header.Get("X-Request-Id") return apiErr } diff --git a/ca/client/requestid.go b/ca/client/requestid.go new file mode 100644 index 00000000..1fb785eb --- /dev/null +++ b/ca/client/requestid.go @@ -0,0 +1,18 @@ +package client + +import "context" + +type contextKey struct{} + +// NewRequestIDContext returns a new context with the given request ID added to the +// context. +func NewRequestIDContext(ctx context.Context, requestID string) context.Context { + return context.WithValue(ctx, contextKey{}, requestID) +} + +// RequestIDFromContext returns the request ID from the context if it exists. +// and is not empty. +func RequestIDFromContext(ctx context.Context) (string, bool) { + v, ok := ctx.Value(contextKey{}).(string) + return v, ok && v != "" +} diff --git a/ca/client_test.go b/ca/client_test.go index 6292e3ea..44d24c6e 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -9,24 +9,26 @@ import ( "encoding/json" "encoding/pem" "errors" - "fmt" "net/http" "net/http/httptest" "net/url" "reflect" + "strings" "testing" "time" - "go.step.sm/crypto/x509util" - "golang.org/x/crypto/ssh" - - "github.com/smallstep/assert" + "github.com/google/uuid" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api/read" "github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/ca/client" "github.com/smallstep/certificates/errs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.step.sm/crypto/x509util" + "golang.org/x/crypto/ssh" ) const ( @@ -106,52 +108,49 @@ DCbKzWTW8lqVdp9Kyf7XEhhc2R8C5w== -----END CERTIFICATE REQUEST-----` ) -func mustKey() *ecdsa.PrivateKey { +func mustKey(t *testing.T) *ecdsa.PrivateKey { + t.Helper() priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - panic(err) - } + require.NoError(t, err) return priv } -func parseCertificate(data string) *x509.Certificate { +func parseCertificate(t *testing.T, data string) *x509.Certificate { + t.Helper() block, _ := pem.Decode([]byte(data)) if block == nil { - panic("failed to parse certificate PEM") + require.Fail(t, "failed to parse certificate PEM") + return nil } cert, err := x509.ParseCertificate(block.Bytes) - if err != nil { - panic("failed to parse certificate: " + err.Error()) - } + require.NoError(t, err, "failed to parse certificate") return cert } -func parseCertificateRequest(string) *x509.CertificateRequest { +func parseCertificateRequest(t *testing.T, csrPEM string) *x509.CertificateRequest { + t.Helper() block, _ := pem.Decode([]byte(csrPEM)) if block == nil { - panic("failed to parse certificate request PEM") + require.Fail(t, "failed to parse certificate request PEM") + return nil } csr, err := x509.ParseCertificateRequest(block.Bytes) - if err != nil { - panic("failed to parse certificate request: " + err.Error()) - } + require.NoError(t, err, "failed to parse certificate request") return csr } func equalJSON(t *testing.T, a, b interface{}) bool { + t.Helper() if reflect.DeepEqual(a, b) { return true } + ab, err := json.Marshal(a) - if err != nil { - t.Error(err) - return false - } + require.NoError(t, err) + bb, err := json.Marshal(b) - if err != nil { - t.Error(err) - return false - } + require.NoError(t, err) + return bytes.Equal(ab, bb) } @@ -176,32 +175,23 @@ func TestClient_Version(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Version() - if (err != nil) != tt.wantErr { - t.Errorf("Client.Version() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + assert.EqualError(t, err, tt.expectedErr.Error()) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.Version() = %v, want nil", got) - } - assert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.Version() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } @@ -226,40 +216,30 @@ func TestClient_Health(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Health() - if (err != nil) != tt.wantErr { - fmt.Printf("%+v", err) - t.Errorf("Client.Health() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + assert.EqualError(t, err, tt.expectedErr.Error()) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.Health() = %v, want nil", got) - } - assert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.Health() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } func TestClient_Root(t *testing.T) { ok := &api.RootResponse{ - RootPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, + RootPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)}, } tests := []struct { @@ -280,10 +260,7 @@ func TestClient_Root(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { expected := "/root/" + tt.shasum @@ -294,37 +271,31 @@ func TestClient_Root(t *testing.T) { }) got, err := c.Root(tt.shasum) - if (err != nil) != tt.wantErr { - t.Errorf("Client.Root() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + assert.EqualError(t, err, tt.expectedErr.Error()) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.Root() = %v, want nil", got) - } - assert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.Root() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } func TestClient_Sign(t *testing.T) { ok := &api.SignResponse{ - ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, - CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, + ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)}, + CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)}, CertChainPEM: []api.Certificate{ - {Certificate: parseCertificate(certPEM)}, - {Certificate: parseCertificate(rootPEM)}, + {Certificate: parseCertificate(t, certPEM)}, + {Certificate: parseCertificate(t, rootPEM)}, }, } request := &api.SignRequest{ - CsrPEM: api.CertificateRequest{CertificateRequest: parseCertificateRequest(csrPEM)}, + CsrPEM: api.CertificateRequest{CertificateRequest: parseCertificateRequest(t, csrPEM)}, OTT: "the-ott", NotBefore: api.NewTimeDuration(time.Now()), NotAfter: api.NewTimeDuration(time.Now().AddDate(0, 1, 0)), @@ -350,16 +321,13 @@ func TestClient_Sign(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { body := new(api.SignRequest) if err := read.JSON(req.Body, body); err != nil { e, ok := tt.response.(error) - assert.Fatal(t, ok, "response expected to be error type") + require.True(t, ok, "response expected to be error type") render.Error(w, e) return } else if !equalJSON(t, body, tt.request) { @@ -375,23 +343,16 @@ func TestClient_Sign(t *testing.T) { }) got, err := c.Sign(tt.request) - if (err != nil) != tt.wantErr { - fmt.Printf("%+v", err) - t.Errorf("Client.Sign() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + assert.EqualError(t, err, tt.expectedErr.Error()) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.Sign() = %v, want nil", got) - } - assert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.Sign() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } @@ -422,16 +383,13 @@ func TestClient_Revoke(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { body := new(api.RevokeRequest) if err := read.JSON(req.Body, body); err != nil { e, ok := tt.response.(error) - assert.Fatal(t, ok, "response expected to be error type") + require.True(t, ok, "response expected to be error type") render.Error(w, e) return } else if !equalJSON(t, body, tt.request) { @@ -447,34 +405,27 @@ func TestClient_Revoke(t *testing.T) { }) got, err := c.Revoke(tt.request, nil) - if (err != nil) != tt.wantErr { - fmt.Printf("%+v", err) - t.Errorf("Client.Revoke() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + assert.True(t, strings.HasPrefix(err.Error(), tt.expectedErr.Error())) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.Revoke() = %v, want nil", got) - } - assert.HasPrefix(t, err.Error(), tt.expectedErr.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.Revoke() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } func TestClient_Renew(t *testing.T) { ok := &api.SignResponse{ - ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, - CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, + ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)}, + CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)}, CertChainPEM: []api.Certificate{ - {Certificate: parseCertificate(certPEM)}, - {Certificate: parseCertificate(rootPEM)}, + {Certificate: parseCertificate(t, certPEM)}, + {Certificate: parseCertificate(t, rootPEM)}, }, } @@ -497,49 +448,38 @@ func TestClient_Renew(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Renew(nil) - if (err != nil) != tt.wantErr { - fmt.Printf("%+v", err) - t.Errorf("Client.Renew() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + var sc render.StatusCodedError + if assert.ErrorAs(t, err, &sc) { + assert.Equal(t, tt.responseCode, sc.StatusCode()) + } + assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error())) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.Renew() = %v, want nil", got) - } - - var sc render.StatusCodedError - if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - assert.Equals(t, sc.StatusCode(), tt.responseCode) - } - assert.HasPrefix(t, err.Error(), tt.err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.Renew() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } func TestClient_RenewWithToken(t *testing.T) { ok := &api.SignResponse{ - ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, - CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, + ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)}, + CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)}, CertChainPEM: []api.Certificate{ - {Certificate: parseCertificate(certPEM)}, - {Certificate: parseCertificate(rootPEM)}, + {Certificate: parseCertificate(t, certPEM)}, + {Certificate: parseCertificate(t, rootPEM)}, }, } @@ -562,10 +502,7 @@ func TestClient_RenewWithToken(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { if req.Header.Get("Authorization") != "Bearer token" { @@ -576,44 +513,36 @@ func TestClient_RenewWithToken(t *testing.T) { }) got, err := c.RenewWithToken("token") - if (err != nil) != tt.wantErr { - fmt.Printf("%+v", err) - t.Errorf("Client.RenewWithToken() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + var sc render.StatusCodedError + if assert.ErrorAs(t, err, &sc) { + assert.Equal(t, tt.responseCode, sc.StatusCode()) + } + assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error())) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.RenewWithToken() = %v, want nil", got) - } - - var sc render.StatusCodedError - if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - assert.Equals(t, sc.StatusCode(), tt.responseCode) - } - assert.HasPrefix(t, err.Error(), tt.err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.RenewWithToken() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } func TestClient_Rekey(t *testing.T) { ok := &api.SignResponse{ - ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, - CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, + ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)}, + CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)}, CertChainPEM: []api.Certificate{ - {Certificate: parseCertificate(certPEM)}, - {Certificate: parseCertificate(rootPEM)}, + {Certificate: parseCertificate(t, certPEM)}, + {Certificate: parseCertificate(t, rootPEM)}, }, } request := &api.RekeyRequest{ - CsrPEM: api.CertificateRequest{CertificateRequest: parseCertificateRequest(csrPEM)}, + CsrPEM: api.CertificateRequest{CertificateRequest: parseCertificateRequest(t, csrPEM)}, } tests := []struct { @@ -636,38 +565,27 @@ func TestClient_Rekey(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Rekey(tt.request, nil) - if (err != nil) != tt.wantErr { - fmt.Printf("%+v", err) - t.Errorf("Client.Renew() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + var sc render.StatusCodedError + if assert.ErrorAs(t, err, &sc) { + assert.Equal(t, tt.responseCode, sc.StatusCode()) + } + assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error())) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.Renew() = %v, want nil", got) - } - - var sc render.StatusCodedError - if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - assert.Equals(t, sc.StatusCode(), tt.responseCode) - } - assert.HasPrefix(t, err.Error(), tt.err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.Renew() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } @@ -699,10 +617,7 @@ func TestClient_Provisioners(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { if req.RequestURI != tt.expectedURI { @@ -712,22 +627,16 @@ func TestClient_Provisioners(t *testing.T) { }) got, err := c.Provisioners(tt.args...) - if (err != nil) != tt.wantErr { - t.Errorf("Client.Provisioners() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + assert.True(t, strings.HasPrefix(err.Error(), errs.InternalServerErrorDefaultMsg)) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.Provisioners() = %v, want nil", got) - } - assert.HasPrefix(t, errs.InternalServerErrorDefaultMsg, err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.Provisioners() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } @@ -755,10 +664,7 @@ func TestClient_ProvisionerKey(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { expected := "/provisioners/" + tt.kid + "/encrypted-key" @@ -769,27 +675,20 @@ func TestClient_ProvisionerKey(t *testing.T) { }) got, err := c.ProvisionerKey(tt.kid) - if (err != nil) != tt.wantErr { - t.Errorf("Client.ProvisionerKey() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + var sc render.StatusCodedError + if assert.ErrorAs(t, err, &sc) { + assert.Equal(t, tt.responseCode, sc.StatusCode()) + } + assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error())) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.ProvisionerKey() = %v, want nil", got) - } - - var sc render.StatusCodedError - if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - assert.Equals(t, sc.StatusCode(), tt.responseCode) - } - assert.HasPrefix(t, tt.err.Error(), err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.ProvisionerKey() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } @@ -797,7 +696,7 @@ func TestClient_ProvisionerKey(t *testing.T) { func TestClient_Roots(t *testing.T) { ok := &api.RootsResponse{ Certificates: []api.Certificate{ - {Certificate: parseCertificate(rootPEM)}, + {Certificate: parseCertificate(t, rootPEM)}, }, } @@ -819,37 +718,27 @@ func TestClient_Roots(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Roots() - if (err != nil) != tt.wantErr { - fmt.Printf("%+v", err) - t.Errorf("Client.Roots() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + var sc render.StatusCodedError + if assert.ErrorAs(t, err, &sc) { + assert.Equal(t, tt.responseCode, sc.StatusCode()) + } + assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error())) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.Roots() = %v, want nil", got) - } - var sc render.StatusCodedError - if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - assert.Equals(t, sc.StatusCode(), tt.responseCode) - } - assert.HasPrefix(t, err.Error(), tt.err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.Roots() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } @@ -857,7 +746,7 @@ func TestClient_Roots(t *testing.T) { func TestClient_Federation(t *testing.T) { ok := &api.FederationResponse{ Certificates: []api.Certificate{ - {Certificate: parseCertificate(rootPEM)}, + {Certificate: parseCertificate(t, rootPEM)}, }, } @@ -878,46 +767,34 @@ func TestClient_Federation(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.Federation() - if (err != nil) != tt.wantErr { - fmt.Printf("%+v", err) - t.Errorf("Client.Federation() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + var sc render.StatusCodedError + if assert.ErrorAs(t, err, &sc) { + assert.Equal(t, tt.responseCode, sc.StatusCode()) + } + assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error())) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.Federation() = %v, want nil", got) - } - var sc render.StatusCodedError - if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - assert.Equals(t, sc.StatusCode(), tt.responseCode) - } - assert.HasPrefix(t, tt.err.Error(), err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.Federation() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } func TestClient_SSHRoots(t *testing.T) { - key, err := ssh.NewPublicKey(mustKey().Public()) - if err != nil { - t.Fatal(err) - } + key, err := ssh.NewPublicKey(mustKey(t).Public()) + require.NoError(t, err) ok := &api.SSHRootsResponse{ HostKeys: []api.SSHPublicKey{{PublicKey: key}}, @@ -941,37 +818,27 @@ func TestClient_SSHRoots(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.SSHRoots() - if (err != nil) != tt.wantErr { - fmt.Printf("%+v", err) - t.Errorf("Client.SSHKeys() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + if assert.Error(t, err) { + var sc render.StatusCodedError + if assert.ErrorAs(t, err, &sc) { + assert.Equal(t, tt.responseCode, sc.StatusCode()) + } + assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error())) + } + assert.Nil(t, got) return } - switch { - case err != nil: - if got != nil { - t.Errorf("Client.SSHKeys() = %v, want nil", got) - } - var sc render.StatusCodedError - if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - assert.Equals(t, sc.StatusCode(), tt.responseCode) - } - assert.HasPrefix(t, tt.err.Error(), err.Error()) - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.SSHKeys() = %v, want %v", got, tt.response) - } - } + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } @@ -1003,13 +870,14 @@ func Test_parseEndpoint(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := parseEndpoint(tt.args.endpoint) - if (err != nil) != tt.wantErr { - t.Errorf("parseEndpoint() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, got) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("parseEndpoint() = %v, want %v", got, tt.want) - } + + assert.NoError(t, err) + assert.Equal(t, tt.want, got) }) } } @@ -1042,24 +910,21 @@ func TestClient_RootFingerprint(t *testing.T) { t.Run(tt.name, func(t *testing.T) { tr := tt.server.Client().Transport c, err := NewClient(tt.server.URL, WithTransport(tr)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.RootFingerprint() - if (err != nil) != tt.wantErr { - fmt.Printf("%+v", err) - t.Errorf("Client.RootFingerprint() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + assert.Error(t, err) + assert.Empty(t, got) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("Client.RootFingerprint() = %v, want %v", got, tt.want) - } + + assert.NoError(t, err) + assert.Equal(t, tt.want, got) }) } } @@ -1068,12 +933,12 @@ func TestClient_RootFingerprintWithServer(t *testing.T) { srv := startCABootstrapServer() defer srv.Close() - client, err := NewClient(srv.URL+"/sign", WithRootFile("testdata/secrets/root_ca.crt")) - assert.FatalError(t, err) + caClient, err := NewClient(srv.URL+"/sign", WithRootFile("testdata/secrets/root_ca.crt")) + require.NoError(t, err) - fp, err := client.RootFingerprint() - assert.FatalError(t, err) - assert.Equals(t, "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7", fp) + fp, err := caClient.RootFingerprint() + assert.NoError(t, err) + assert.Equal(t, "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7", fp) } func TestClient_SSHBastion(t *testing.T) { @@ -1103,39 +968,29 @@ func TestClient_SSHBastion(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } + require.NoError(t, err) srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { render.JSONStatus(w, tt.response, tt.responseCode) }) got, err := c.SSHBastion(tt.request) - if (err != nil) != tt.wantErr { - fmt.Printf("%+v", err) - t.Errorf("Client.SSHBastion() error = %v, wantErr %v", err, tt.wantErr) - return - } - - switch { - case err != nil: - if got != nil { - t.Errorf("Client.SSHBastion() = %v, want nil", got) - } - if tt.responseCode != 200 { - var sc render.StatusCodedError - if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { - assert.Equals(t, sc.StatusCode(), tt.responseCode) + if tt.wantErr { + if assert.Error(t, err) { + if tt.responseCode != 200 { + var sc render.StatusCodedError + if assert.ErrorAs(t, err, &sc) { + assert.Equal(t, tt.responseCode, sc.StatusCode()) + } + assert.True(t, strings.HasPrefix(err.Error(), tt.err.Error())) } - assert.HasPrefix(t, err.Error(), tt.err.Error()) - } - default: - if !reflect.DeepEqual(got, tt.response) { - t.Errorf("Client.SSHBastion() = %v, want %v", got, tt.response) } + assert.Nil(t, got) + return } + + assert.NoError(t, err) + assert.Equal(t, tt.response, got) }) } } @@ -1154,13 +1009,60 @@ func TestClient_GetCaURL(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c, err := NewClient(tt.caURL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Errorf("NewClient() error = %v", err) - return - } - if got := c.GetCaURL(); got != tt.want { - t.Errorf("Client.GetCaURL() = %v, want %v", got, tt.want) + require.NoError(t, err) + + got := c.GetCaURL() + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_enforceRequestID(t *testing.T) { + set := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody) + set.Header.Set("X-Request-Id", "already-set") + inContext := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody) + inContext = inContext.WithContext(client.NewRequestIDContext(inContext.Context(), "from-context")) + newRequestID := httptest.NewRequest(http.MethodGet, "https://example.com", http.NoBody) + + tests := []struct { + name string + r *http.Request + want string + }{ + { + name: "set", + r: set, + want: "already-set", + }, + { + name: "context", + r: inContext, + want: "from-context", + }, + { + name: "new", + r: newRequestID, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + enforceRequestID(tt.r) + + v := tt.r.Header.Get("X-Request-Id") + if assert.NotEmpty(t, v) { + if tt.want != "" { + assert.Equal(t, tt.want, v) + } } }) } } + +func Test_newRequestID(t *testing.T) { + requestID := newRequestID() + u, err := uuid.Parse(requestID) + assert.NoError(t, err) + assert.Equal(t, uuid.Version(0x4), u.Version()) + assert.Equal(t, uuid.RFC4122, u.Variant()) + assert.Equal(t, requestID, u.String()) +} diff --git a/ca/provisioner_test.go b/ca/provisioner_test.go index 39193f3f..5a754f08 100644 --- a/ca/provisioner_test.go +++ b/ca/provisioner_test.go @@ -7,6 +7,8 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + "go.step.sm/crypto/jose" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/x509util" @@ -41,14 +43,12 @@ func getTestProvisioner(t *testing.T, caURL string) *Provisioner { } func TestNewProvisioner(t *testing.T) { - ca := startCATestServer() + ca := startCATestServer(t) defer ca.Close() want := getTestProvisioner(t, ca.URL) caBundle, err := os.ReadFile("testdata/secrets/root_ca.crt") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) type args struct { name string diff --git a/ca/tls_options_test.go b/ca/tls_options_test.go index 7dea3dc8..4ac6ff85 100644 --- a/ca/tls_options_test.go +++ b/ca/tls_options_test.go @@ -10,6 +10,8 @@ import ( "sort" "testing" + "github.com/stretchr/testify/require" + "github.com/smallstep/certificates/api" ) @@ -130,7 +132,7 @@ func TestVerifyClientCertIfGiven(t *testing.T) { //nolint:gosec // test tls config func TestAddRootCA(t *testing.T) { - cert := parseCertificate(rootPEM) + cert := parseCertificate(t, rootPEM) pool := x509.NewCertPool() pool.AddCert(cert) @@ -163,7 +165,7 @@ func TestAddRootCA(t *testing.T) { //nolint:gosec // test tls config func TestAddClientCA(t *testing.T) { - cert := parseCertificate(rootPEM) + cert := parseCertificate(t, rootPEM) pool := x509.NewCertPool() pool.AddCert(cert) @@ -196,25 +198,19 @@ func TestAddClientCA(t *testing.T) { //nolint:gosec // test tls config func TestAddRootsToRootCAs(t *testing.T) { - ca := startCATestServer() + ca := startCATestServer(t) defer ca.Close() client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) root, err := os.ReadFile("testdata/secrets/root_ca.crt") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) - cert := parseCertificate(string(root)) + cert := parseCertificate(t, string(root)) pool := x509.NewCertPool() pool.AddCert(cert) @@ -251,25 +247,19 @@ func TestAddRootsToRootCAs(t *testing.T) { //nolint:gosec // test tls config func TestAddRootsToClientCAs(t *testing.T) { - ca := startCATestServer() + ca := startCATestServer(t) defer ca.Close() client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) root, err := os.ReadFile("testdata/secrets/root_ca.crt") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) - cert := parseCertificate(string(root)) + cert := parseCertificate(t, string(root)) pool := x509.NewCertPool() pool.AddCert(cert) @@ -306,31 +296,23 @@ func TestAddRootsToClientCAs(t *testing.T) { //nolint:gosec // test tls config func TestAddFederationToRootCAs(t *testing.T) { - ca := startCATestServer() + ca := startCATestServer(t) defer ca.Close() client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) root, err := os.ReadFile("testdata/secrets/root_ca.crt") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) federated, err := os.ReadFile("testdata/secrets/federated_ca.crt") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) - crt1 := parseCertificate(string(root)) - crt2 := parseCertificate(string(federated)) + crt1 := parseCertificate(t, string(root)) + crt2 := parseCertificate(t, string(federated)) pool := x509.NewCertPool() pool.AddCert(crt1) pool.AddCert(crt2) @@ -371,31 +353,23 @@ func TestAddFederationToRootCAs(t *testing.T) { //nolint:gosec // test tls config func TestAddFederationToClientCAs(t *testing.T) { - ca := startCATestServer() + ca := startCATestServer(t) defer ca.Close() client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) root, err := os.ReadFile("testdata/secrets/root_ca.crt") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) federated, err := os.ReadFile("testdata/secrets/federated_ca.crt") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) - crt1 := parseCertificate(string(root)) - crt2 := parseCertificate(string(federated)) + crt1 := parseCertificate(t, string(root)) + crt2 := parseCertificate(t, string(federated)) pool := x509.NewCertPool() pool.AddCert(crt1) pool.AddCert(crt2) @@ -436,25 +410,19 @@ func TestAddFederationToClientCAs(t *testing.T) { //nolint:gosec // test tls config func TestAddRootsToCAs(t *testing.T) { - ca := startCATestServer() + ca := startCATestServer(t) defer ca.Close() client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) root, err := os.ReadFile("testdata/secrets/root_ca.crt") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) - cert := parseCertificate(string(root)) + cert := parseCertificate(t, string(root)) pool := x509.NewCertPool() pool.AddCert(cert) @@ -491,31 +459,23 @@ func TestAddRootsToCAs(t *testing.T) { //nolint:gosec // test tls config func TestAddFederationToCAs(t *testing.T) { - ca := startCATestServer() + ca := startCATestServer(t) defer ca.Close() client, err := NewClient(ca.URL, WithRootFile("testdata/secrets/root_ca.crt")) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) clientFail, err := NewClient(ca.URL, WithTransport(http.DefaultTransport)) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) root, err := os.ReadFile("testdata/secrets/root_ca.crt") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) federated, err := os.ReadFile("testdata/secrets/federated_ca.crt") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) - crt1 := parseCertificate(string(root)) - crt2 := parseCertificate(string(federated)) + crt1 := parseCertificate(t, string(root)) + crt2 := parseCertificate(t, string(federated)) pool := x509.NewCertPool() pool.AddCert(crt1) pool.AddCert(crt2) diff --git a/ca/tls_test.go b/ca/tls_test.go index dbcc6023..d1ce11ea 100644 --- a/ca/tls_test.go +++ b/ca/tls_test.go @@ -17,27 +17,28 @@ import ( "testing" "time" - "github.com/smallstep/certificates/api" - "github.com/smallstep/certificates/authority" + "github.com/stretchr/testify/require" + "go.step.sm/crypto/jose" "go.step.sm/crypto/randutil" + + "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/authority" ) -func generateOTT(subject string) string { +func generateOTT(t *testing.T, subject string) string { + t.Helper() now := time.Now() jwk, err := jose.ReadKey("testdata/secrets/ott_mariano_priv.jwk", jose.WithPassword([]byte("password"))) - if err != nil { - panic(err) - } + require.NoError(t, err) + opts := new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID) sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, opts) - if err != nil { - panic(err) - } + require.NoError(t, err) + id, err := randutil.ASCII(64) - if err != nil { - panic(err) - } + require.NoError(t, err) + cl := struct { jose.Claims SANS []string `json:"sans"` @@ -53,9 +54,8 @@ func generateOTT(subject string) string { SANS: []string{subject}, } raw, err := jose.Signed(sig).Claims(cl).CompactSerialize() - if err != nil { - panic(err) - } + require.NoError(t, err) + return raw } @@ -72,32 +72,28 @@ func startTestServer(baseContext context.Context, tlsConfig *tls.Config, handler return srv } -func startCATestServer() *httptest.Server { +func startCATestServer(t *testing.T) *httptest.Server { config, err := authority.LoadConfiguration("testdata/ca.json") - if err != nil { - panic(err) - } + require.NoError(t, err) ca, err := New(config) - if err != nil { - panic(err) - } + require.NoError(t, err) // Use a httptest.Server instead baseContext := buildContext(ca.auth, nil, nil, nil) srv := startTestServer(baseContext, ca.srv.TLSConfig, ca.srv.Handler) return srv } -func sign(domain string) (*Client, *api.SignResponse, crypto.PrivateKey) { - srv := startCATestServer() +func sign(t *testing.T, domain string) (*Client, *api.SignResponse, crypto.PrivateKey) { + t.Helper() + srv := startCATestServer(t) defer srv.Close() - return signDuration(srv, domain, 0) + return signDuration(t, srv, domain, 0) } -func signDuration(srv *httptest.Server, domain string, duration time.Duration) (*Client, *api.SignResponse, crypto.PrivateKey) { - req, pk, err := CreateSignRequest(generateOTT(domain)) - if err != nil { - panic(err) - } +func signDuration(t *testing.T, srv *httptest.Server, domain string, duration time.Duration) (*Client, *api.SignResponse, crypto.PrivateKey) { + t.Helper() + req, pk, err := CreateSignRequest(generateOTT(t, domain)) + require.NoError(t, err) if duration > 0 { req.NotBefore = api.NewTimeDuration(time.Now()) @@ -105,13 +101,11 @@ func signDuration(srv *httptest.Server, domain string, duration time.Duration) ( } client, err := NewClient(srv.URL, WithRootFile("testdata/secrets/root_ca.crt")) - if err != nil { - panic(err) - } + require.NoError(t, err) + sr, err := client.Sign(req) - if err != nil { - panic(err) - } + require.NoError(t, err) + return client, sr, pk } @@ -145,7 +139,7 @@ func serverHandler(t *testing.T, clientDomain string) http.Handler { func TestClient_GetServerTLSConfig_http(t *testing.T) { clientDomain := "test.domain" - client, sr, pk := sign("127.0.0.1") + client, sr, pk := sign(t, "127.0.0.1") // Create mTLS server ctx, cancel := context.WithCancel(context.Background()) @@ -212,7 +206,7 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - client, sr, pk := sign(clientDomain) + client, sr, pk := sign(t, clientDomain) cli := tt.getClient(t, client, sr, pk) if cli == nil { return @@ -246,19 +240,18 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) { defer reset() // Start CA - ca := startCATestServer() + ca := startCATestServer(t) defer ca.Close() clientDomain := "test.domain" - client, sr, pk := signDuration(ca, "127.0.0.1", 5*time.Second) + client, sr, pk := signDuration(t, ca, "127.0.0.1", 5*time.Second) // Start mTLS server ctx, cancel := context.WithCancel(context.Background()) defer cancel() tlsConfig, err := client.GetServerTLSConfig(ctx, sr, pk) - if err != nil { - t.Fatalf("Client.GetServerTLSConfig() error = %v", err) - } + require.NoError(t, err) + srvMTLS := startTestServer(context.Background(), tlsConfig, serverHandler(t, clientDomain)) defer srvMTLS.Close() @@ -266,30 +259,26 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) { ctx, cancel = context.WithCancel(context.Background()) defer cancel() tlsConfig, err = client.GetServerTLSConfig(ctx, sr, pk, VerifyClientCertIfGiven()) - if err != nil { - t.Fatalf("Client.GetServerTLSConfig() error = %v", err) - } + require.NoError(t, err) + srvTLS := startTestServer(context.Background(), tlsConfig, serverHandler(t, clientDomain)) defer srvTLS.Close() // Transport - client, sr, pk = signDuration(ca, clientDomain, 5*time.Second) + client, sr, pk = signDuration(t, ca, clientDomain, 5*time.Second) tr1, err := client.Transport(context.Background(), sr, pk) - if err != nil { - t.Fatalf("Client.Transport() error = %v", err) - } + require.NoError(t, err) + // Transport with tlsConfig - client, sr, pk = signDuration(ca, clientDomain, 5*time.Second) + client, sr, pk = signDuration(t, ca, clientDomain, 5*time.Second) tlsConfig, err = client.GetClientTLSConfig(context.Background(), sr, pk) - if err != nil { - t.Fatalf("Client.GetClientTLSConfig() error = %v", err) - } + require.NoError(t, err) + tr2 := getDefaultTransport(tlsConfig) // No client cert root, err := RootCertificate(sr) - if err != nil { - t.Fatalf("RootCertificate() error = %v", err) - } + require.NoError(t, err) + tlsConfig = getDefaultTLSConfig(sr) tlsConfig.RootCAs = x509.NewCertPool() tlsConfig.RootCAs.AddCert(root) @@ -401,13 +390,13 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) { } func TestCertificate(t *testing.T) { - cert := parseCertificate(certPEM) + cert := parseCertificate(t, certPEM) ok := &api.SignResponse{ ServerPEM: api.Certificate{Certificate: cert}, - CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, + CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)}, CertChainPEM: []api.Certificate{ {Certificate: cert}, - {Certificate: parseCertificate(rootPEM)}, + {Certificate: parseCertificate(t, rootPEM)}, }, } tests := []struct { @@ -434,12 +423,12 @@ func TestCertificate(t *testing.T) { } func TestIntermediateCertificate(t *testing.T) { - intermediate := parseCertificate(rootPEM) + intermediate := parseCertificate(t, rootPEM) ok := &api.SignResponse{ - ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, + ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)}, CaPEM: api.Certificate{Certificate: intermediate}, CertChainPEM: []api.Certificate{ - {Certificate: parseCertificate(certPEM)}, + {Certificate: parseCertificate(t, certPEM)}, {Certificate: intermediate}, }, } @@ -467,24 +456,24 @@ func TestIntermediateCertificate(t *testing.T) { } func TestRootCertificateCertificate(t *testing.T) { - root := parseCertificate(rootPEM) + root := parseCertificate(t, rootPEM) ok := &api.SignResponse{ - ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, - CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, + ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)}, + CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)}, CertChainPEM: []api.Certificate{ - {Certificate: parseCertificate(certPEM)}, - {Certificate: parseCertificate(rootPEM)}, + {Certificate: parseCertificate(t, certPEM)}, + {Certificate: parseCertificate(t, rootPEM)}, }, TLS: &tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{ {root, root}, }}, } noTLS := &api.SignResponse{ - ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, - CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, + ServerPEM: api.Certificate{Certificate: parseCertificate(t, certPEM)}, + CaPEM: api.Certificate{Certificate: parseCertificate(t, rootPEM)}, CertChainPEM: []api.Certificate{ - {Certificate: parseCertificate(certPEM)}, - {Certificate: parseCertificate(rootPEM)}, + {Certificate: parseCertificate(t, certPEM)}, + {Certificate: parseCertificate(t, rootPEM)}, }, } tests := []struct { diff --git a/errs/error.go b/errs/error.go index ba066925..c9ad92a6 100644 --- a/errs/error.go +++ b/errs/error.go @@ -49,10 +49,11 @@ func WithKeyVal(key string, val interface{}) Option { // Error represents the CA API errors. type Error struct { - Status int - Err error - Msg string - Details map[string]interface{} + Status int + Err error + Msg string + Details map[string]interface{} + RequestID string `json:"-"` } // ErrorResponse represents an error in JSON format. diff --git a/errs/errors_test.go b/errs/errors_test.go index 7b83c8d9..11590d7d 100644 --- a/errs/errors_test.go +++ b/errs/errors_test.go @@ -2,8 +2,9 @@ package errs import ( "fmt" - "reflect" "testing" + + "github.com/stretchr/testify/assert" ) func TestError_MarshalJSON(t *testing.T) { @@ -27,13 +28,14 @@ func TestError_MarshalJSON(t *testing.T) { Err: tt.fields.Err, } got, err := e.MarshalJSON() - if (err != nil) != tt.wantErr { - t.Errorf("Error.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + assert.Error(t, err) + assert.Empty(t, got) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("Error.MarshalJSON() = %s, want %s", got, tt.want) - } + + assert.NoError(t, err) + assert.Equal(t, tt.want, got) }) } } @@ -54,13 +56,14 @@ func TestError_UnmarshalJSON(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { e := new(Error) - if err := e.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr { - t.Errorf("Error.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) - } - //nolint:govet // best option - if !reflect.DeepEqual(tt.expected, e) { - t.Errorf("Error.UnmarshalJSON() wants = %+v, got %+v", tt.expected, e) + err := e.UnmarshalJSON(tt.args.data) + if tt.wantErr { + assert.Error(t, err) + return } + + assert.NoError(t, err) + assert.Equal(t, tt.expected, e) }) } } diff --git a/internal/requestid/requestid.go b/internal/requestid/requestid.go new file mode 100644 index 00000000..ace08f16 --- /dev/null +++ b/internal/requestid/requestid.go @@ -0,0 +1,91 @@ +package requestid + +import ( + "context" + "net/http" + + "github.com/rs/xid" + + "go.step.sm/crypto/randutil" +) + +const ( + // requestIDHeader is the header name used for propagating request IDs. If + // available in an HTTP request, it'll be used instead of the X-Smallstep-Id + // header. It'll always be used in response and set to the request ID. + requestIDHeader = "X-Request-Id" + + // defaultTraceHeader is the default Smallstep tracing header that's currently + // in use. It is used as a fallback to retrieve a request ID from, if the + // "X-Request-Id" request header is not set. + defaultTraceHeader = "X-Smallstep-Id" +) + +type Handler struct { + legacyTraceHeader string +} + +// New creates a new request ID [handler]. It takes a trace header, +// which is used keep the legacy behavior intact, which relies on the +// X-Smallstep-Id header instead of X-Request-Id. +func New(legacyTraceHeader string) *Handler { + if legacyTraceHeader == "" { + legacyTraceHeader = defaultTraceHeader + } + + return &Handler{legacyTraceHeader: legacyTraceHeader} +} + +// Middleware wraps an [http.Handler] with request ID extraction +// from the X-Reqeust-Id header by default, or from the X-Smallstep-Id +// header if not set. If both are not set, a new request ID is generated. +// In all cases, the request ID is added to the request context, and +// set to be reflected in the response. +func (h *Handler) Middleware(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, req *http.Request) { + requestID := req.Header.Get(requestIDHeader) + if requestID == "" { + requestID = req.Header.Get(h.legacyTraceHeader) + } + + if requestID == "" { + requestID = newRequestID() + req.Header.Set(h.legacyTraceHeader, requestID) // legacy behavior + } + + // immediately set the request ID to be reflected in the response + w.Header().Set(requestIDHeader, requestID) + + // continue down the handler chain + ctx := NewContext(req.Context(), requestID) + next.ServeHTTP(w, req.WithContext(ctx)) + } + return http.HandlerFunc(fn) +} + +// newRequestID generates a new random UUIDv4 request ID. If UUIDv4 +// generation fails, it'll fallback to generating a random ID using +// github.com/rs/xid. +func newRequestID() string { + requestID, err := randutil.UUIDv4() + if err != nil { + requestID = xid.New().String() + } + + return requestID +} + +type contextKey struct{} + +// NewContext returns a new context with the given request ID added to the +// context. +func NewContext(ctx context.Context, requestID string) context.Context { + return context.WithValue(ctx, contextKey{}, requestID) +} + +// FromContext returns the request ID from the context if it exists and +// is not the empty value. +func FromContext(ctx context.Context) (string, bool) { + v, ok := ctx.Value(contextKey{}).(string) + return v, ok && v != "" +} diff --git a/internal/requestid/requestid_test.go b/internal/requestid/requestid_test.go new file mode 100644 index 00000000..84a9021f --- /dev/null +++ b/internal/requestid/requestid_test.go @@ -0,0 +1,105 @@ +package requestid + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newRequest(t *testing.T) *http.Request { + t.Helper() + r, err := http.NewRequest(http.MethodGet, "https://example.com", http.NoBody) + require.NoError(t, err) + return r +} + +func Test_Middleware(t *testing.T) { + requestWithID := newRequest(t) + requestWithID.Header.Set("X-Request-Id", "reqID") + + requestWithoutID := newRequest(t) + + requestWithEmptyHeader := newRequest(t) + requestWithEmptyHeader.Header.Set("X-Request-Id", "") + + requestWithSmallstepID := newRequest(t) + requestWithSmallstepID.Header.Set("X-Smallstep-Id", "smallstepID") + + tests := []struct { + name string + traceHeader string + next http.HandlerFunc + req *http.Request + }{ + { + name: "default-request-id", + traceHeader: defaultTraceHeader, + next: func(w http.ResponseWriter, r *http.Request) { + assert.Empty(t, r.Header.Get("X-Smallstep-Id")) + assert.Equal(t, "reqID", r.Header.Get("X-Request-Id")) + reqID, ok := FromContext(r.Context()) + if assert.True(t, ok) { + assert.Equal(t, "reqID", reqID) + } + assert.Equal(t, "reqID", w.Header().Get("X-Request-Id")) + }, + req: requestWithID, + }, + { + name: "no-request-id", + traceHeader: "X-Request-Id", + next: func(w http.ResponseWriter, r *http.Request) { + assert.Empty(t, r.Header.Get("X-Smallstep-Id")) + value := r.Header.Get("X-Request-Id") + assert.NotEmpty(t, value) + reqID, ok := FromContext(r.Context()) + if assert.True(t, ok) { + assert.Equal(t, value, reqID) + } + assert.Equal(t, value, w.Header().Get("X-Request-Id")) + }, + req: requestWithoutID, + }, + { + name: "empty-header", + traceHeader: "", + next: func(w http.ResponseWriter, r *http.Request) { + assert.Empty(t, r.Header.Get("X-Request-Id")) + value := r.Header.Get("X-Smallstep-Id") + assert.NotEmpty(t, value) + reqID, ok := FromContext(r.Context()) + if assert.True(t, ok) { + assert.Equal(t, value, reqID) + } + assert.Equal(t, value, w.Header().Get("X-Request-Id")) + }, + req: requestWithEmptyHeader, + }, + { + name: "fallback-header-name", + traceHeader: defaultTraceHeader, + next: func(w http.ResponseWriter, r *http.Request) { + assert.Empty(t, r.Header.Get("X-Request-Id")) + assert.Equal(t, "smallstepID", r.Header.Get("X-Smallstep-Id")) + reqID, ok := FromContext(r.Context()) + if assert.True(t, ok) { + assert.Equal(t, "smallstepID", reqID) + } + assert.Equal(t, "smallstepID", w.Header().Get("X-Request-Id")) + }, + req: requestWithSmallstepID, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + handler := New(tt.traceHeader).Middleware(tt.next) + + w := httptest.NewRecorder() + handler.ServeHTTP(w, tt.req) + assert.NotEmpty(t, w.Header().Get("X-Request-Id")) + }) + } +} diff --git a/internal/userid/userid.go b/internal/userid/userid.go new file mode 100644 index 00000000..48087da8 --- /dev/null +++ b/internal/userid/userid.go @@ -0,0 +1,20 @@ +package userid + +import "context" + +type contextKey struct{} + +// NewContext returns a new context with the given user ID added to the +// context. +// TODO(hs): this doesn't seem to be used / set currently; implement +// when/where it makes sense. +func NewContext(ctx context.Context, userID string) context.Context { + return context.WithValue(ctx, contextKey{}, userID) +} + +// FromContext returns the user ID from the context if it exists +// and is not empty. +func FromContext(ctx context.Context) (string, bool) { + v, ok := ctx.Value(contextKey{}).(string) + return v, ok && v != "" +} diff --git a/logging/context.go b/logging/context.go deleted file mode 100644 index b24b3638..00000000 --- a/logging/context.go +++ /dev/null @@ -1,66 +0,0 @@ -package logging - -import ( - "context" - "net/http" - - "github.com/rs/xid" -) - -type key int - -const ( - // RequestIDKey is the context key that should store the request identifier. - RequestIDKey key = iota - // UserIDKey is the context key that should store the user identifier. - UserIDKey -) - -// NewRequestID creates a new request id using github.com/rs/xid. -func NewRequestID() string { - return xid.New().String() -} - -// RequestID returns a new middleware that gets the given header and sets it -// in the context so it can be written in the logger. If the header does not -// exists or it's the empty string, it uses github.com/rs/xid to create a new -// one. -func RequestID(headerName string) func(next http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - fn := func(w http.ResponseWriter, req *http.Request) { - requestID := req.Header.Get(headerName) - if requestID == "" { - requestID = NewRequestID() - req.Header.Set(headerName, requestID) - } - - ctx := WithRequestID(req.Context(), requestID) - next.ServeHTTP(w, req.WithContext(ctx)) - } - return http.HandlerFunc(fn) - } -} - -// WithRequestID returns a new context with the given requestID added to the -// context. -func WithRequestID(ctx context.Context, requestID string) context.Context { - return context.WithValue(ctx, RequestIDKey, requestID) -} - -// GetRequestID returns the request id from the context if it exists. -func GetRequestID(ctx context.Context) (string, bool) { - v, ok := ctx.Value(RequestIDKey).(string) - return v, ok -} - -// WithUserID decodes the token, extracts the user from the payload and stores -// it in the context. -func WithUserID(ctx context.Context, userID string) context.Context { - return context.WithValue(ctx, UserIDKey, userID) -} - -// GetUserID returns the request id from the context if it exists. -func GetUserID(ctx context.Context) (string, bool) { - v, ok := ctx.Value(UserIDKey).(string) - return v, ok -} diff --git a/logging/handler.go b/logging/handler.go index a8b77d60..06fc56d3 100644 --- a/logging/handler.go +++ b/logging/handler.go @@ -9,6 +9,9 @@ import ( "time" "github.com/sirupsen/logrus" + + "github.com/smallstep/certificates/internal/requestid" + "github.com/smallstep/certificates/internal/userid" ) // LoggerHandler creates a logger handler @@ -29,16 +32,15 @@ type options struct { // NewLoggerHandler returns the given http.Handler with the logger integrated. func NewLoggerHandler(name string, logger *Logger, next http.Handler) http.Handler { - h := RequestID(logger.GetTraceHeader()) onlyTraceHealthEndpoint, _ := strconv.ParseBool(os.Getenv("STEP_LOGGER_ONLY_TRACE_HEALTH_ENDPOINT")) - return h(&LoggerHandler{ + return &LoggerHandler{ name: name, logger: logger.GetImpl(), options: options{ onlyTraceHealthEndpoint: onlyTraceHealthEndpoint, }, next: next, - }) + } } // ServeHTTP implements the http.Handler and call to the handler to log with a @@ -54,14 +56,14 @@ func (l *LoggerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // writeEntry writes to the Logger writer the request information in the logger. func (l *LoggerHandler) writeEntry(w ResponseLogger, r *http.Request, t time.Time, d time.Duration) { - var reqID, user string + var requestID, userID string ctx := r.Context() - if v, ok := ctx.Value(RequestIDKey).(string); ok && v != "" { - reqID = v + if v, ok := requestid.FromContext(ctx); ok { + requestID = v } - if v, ok := ctx.Value(UserIDKey).(string); ok && v != "" { - user = v + if v, ok := userid.FromContext(ctx); ok { + userID = v } // Remote hostname @@ -85,10 +87,10 @@ func (l *LoggerHandler) writeEntry(w ResponseLogger, r *http.Request, t time.Tim status := w.StatusCode() fields := logrus.Fields{ - "request-id": reqID, + "request-id": requestID, "remote-address": addr, "name": l.name, - "user-id": user, + "user-id": userID, "time": t.Format(time.RFC3339), "duration-ns": d.Nanoseconds(), "duration": d.String(), diff --git a/monitoring/monitoring.go b/monitoring/monitoring.go index a0d0886b..2ca2ef54 100644 --- a/monitoring/monitoring.go +++ b/monitoring/monitoring.go @@ -9,6 +9,8 @@ import ( "github.com/newrelic/go-agent/v3/newrelic" "github.com/pkg/errors" + + "github.com/smallstep/certificates/internal/requestid" "github.com/smallstep/certificates/logging" ) @@ -82,7 +84,7 @@ func newRelicMiddleware(app *newrelic.Application) Middleware { txn.AddAttribute("httpResponseCode", strconv.Itoa(status)) // Add custom attributes - if v, ok := logging.GetRequestID(r.Context()); ok { + if v, ok := requestid.FromContext(r.Context()); ok { txn.AddAttribute("request.id", v) } diff --git a/test/integration/requestid_test.go b/test/integration/requestid_test.go new file mode 100644 index 00000000..54fd2eb0 --- /dev/null +++ b/test/integration/requestid_test.go @@ -0,0 +1,289 @@ +package integration + +import ( + "context" + "crypto/tls" + "crypto/x509" + "encoding/json" + "fmt" + "net" + "net/http" + "net/http/httptest" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.step.sm/crypto/jose" + "go.step.sm/crypto/keyutil" + "go.step.sm/crypto/minica" + "go.step.sm/crypto/pemutil" + "go.step.sm/crypto/randutil" + "go.step.sm/crypto/x509util" + + "github.com/smallstep/certificates/api" + "github.com/smallstep/certificates/authority/config" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/ca" + "github.com/smallstep/certificates/ca/client" + "github.com/smallstep/certificates/errs" +) + +// reservePort "reserves" a TCP port by opening a listener on a random +// port and immediately closing it. The port can then be assumed to be +// available for running a server on. +func reservePort(t *testing.T) (host, port string) { + t.Helper() + l, err := net.Listen("tcp", ":0") + require.NoError(t, err) + + address := l.Addr().String() + err = l.Close() + require.NoError(t, err) + + host, port, err = net.SplitHostPort(address) + require.NoError(t, err) + + return +} + +func Test_reflectRequestID(t *testing.T) { + dir := t.TempDir() + m, err := minica.New(minica.WithName("Step E2E")) + require.NoError(t, err) + + rootFilepath := filepath.Join(dir, "root.crt") + _, err = pemutil.Serialize(m.Root, pemutil.WithFilename(rootFilepath)) + require.NoError(t, err) + + intermediateCertFilepath := filepath.Join(dir, "intermediate.crt") + _, err = pemutil.Serialize(m.Intermediate, pemutil.WithFilename(intermediateCertFilepath)) + require.NoError(t, err) + + intermediateKeyFilepath := filepath.Join(dir, "intermediate.key") + _, err = pemutil.Serialize(m.Signer, pemutil.WithFilename(intermediateKeyFilepath)) + require.NoError(t, err) + + // get a random address to listen on and connect to; currently no nicer way to get one before starting the server + // TODO(hs): find/implement a nicer way to expose the CA URL, similar to how e.g. httptest.Server exposes it? + host, port := reservePort(t) + + authorizingSrv := newAuthorizingServer(t, m) + defer authorizingSrv.Close() + authorizingSrv.StartTLS() + + password := []byte("1234") + jwk, jwe, err := jose.GenerateDefaultKeyPair(password) + require.NoError(t, err) + encryptedKey, err := jwe.CompactSerialize() + require.NoError(t, err) + prov := &provisioner.JWK{ + ID: "jwk", + Name: "jwk", + Type: "JWK", + Key: jwk, + EncryptedKey: encryptedKey, + Claims: &config.GlobalProvisionerClaims, + Options: &provisioner.Options{ + Webhooks: []*provisioner.Webhook{ + { + ID: "webhook", + Name: "webhook-test", + URL: fmt.Sprintf("%s/authorize", authorizingSrv.URL), + Kind: "AUTHORIZING", + CertType: "X509", + }, + }, + }, + } + err = prov.Init(provisioner.Config{}) + require.NoError(t, err) + + cfg := &config.Config{ + Root: []string{rootFilepath}, + IntermediateCert: intermediateCertFilepath, + IntermediateKey: intermediateKeyFilepath, + Address: net.JoinHostPort(host, port), // reuse the address that was just "reserved" + DNSNames: []string{"127.0.0.1", "[::1]", "localhost"}, + AuthorityConfig: &config.AuthConfig{ + AuthorityID: "stepca-test", + DeploymentType: "standalone-test", + Provisioners: provisioner.List{prov}, + }, + Logger: json.RawMessage(`{"format": "text"}`), + } + c, err := ca.New(cfg) + require.NoError(t, err) + + // instantiate a client for the CA running at the random address + caClient, err := ca.NewClient( + fmt.Sprintf("https://localhost:%s", port), + ca.WithRootFile(rootFilepath), + ) + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + err = c.Run() + require.ErrorIs(t, err, http.ErrServerClosed) + }() + + // require OK health response as the baseline + ctx := context.Background() + healthResponse, err := caClient.HealthWithContext(ctx) + require.NoError(t, err) + if assert.NotNil(t, healthResponse) { + require.Equal(t, "ok", healthResponse.Status) + } + + // expect an error when retrieving an invalid root + rootResponse, err := caClient.RootWithContext(ctx, "invalid") + var firstErr *errs.Error + if assert.ErrorAs(t, err, &firstErr) { + assert.Equal(t, 404, firstErr.StatusCode()) + assert.Equal(t, "The requested resource could not be found. Please see the certificate authority logs for more info.", firstErr.Err.Error()) + assert.NotEmpty(t, firstErr.RequestID) + + // TODO: include the below error in the JSON? It's currently only output to the CA logs. Also see https://github.com/smallstep/certificates/pull/759 + //assert.Equal(t, "/root/invalid was not found: certificate with fingerprint invalid was not found", apiErr.Msg) + } + assert.Nil(t, rootResponse) + + // expect an error when retrieving an invalid root and provided request ID + rootResponse, err = caClient.RootWithContext(client.NewRequestIDContext(ctx, "reqID"), "invalid") + var secondErr *errs.Error + if assert.ErrorAs(t, err, &secondErr) { + assert.Equal(t, 404, secondErr.StatusCode()) + assert.Equal(t, "The requested resource could not be found. Please see the certificate authority logs for more info.", secondErr.Err.Error()) + assert.Equal(t, "reqID", secondErr.RequestID) + } + assert.Nil(t, rootResponse) + + // prepare a Sign request + subject := "test" + decryptedJWK := decryptPrivateKey(t, jwe, password) + ott := generateOTT(t, decryptedJWK, subject) + + signer, err := keyutil.GenerateDefaultSigner() + require.NoError(t, err) + + csr, err := x509util.CreateCertificateRequest(subject, []string{subject}, signer) + require.NoError(t, err) + + // perform the Sign request using the OTT and CSR + signResponse, err := caClient.SignWithContext(client.NewRequestIDContext(ctx, "signRequestID"), &api.SignRequest{ + CsrPEM: api.CertificateRequest{CertificateRequest: csr}, + OTT: ott, + NotAfter: api.NewTimeDuration(time.Now().Add(1 * time.Hour)), + NotBefore: api.NewTimeDuration(time.Now().Add(-1 * time.Hour)), + }) + assert.NoError(t, err) + + // assert a certificate was returned for the subject "test" + if assert.NotNil(t, signResponse) { + assert.Len(t, signResponse.CertChainPEM, 2) + cert, err := x509.ParseCertificate(signResponse.CertChainPEM[0].Raw) + assert.NoError(t, err) + if assert.NotNil(t, cert) { + assert.Equal(t, "test", cert.Subject.CommonName) + assert.Contains(t, cert.DNSNames, "test") + } + } + + // done testing; stop and wait for the server to quit + err = c.Stop() + require.NoError(t, err) + + wg.Wait() +} + +func decryptPrivateKey(t *testing.T, jwe *jose.JSONWebEncryption, pass []byte) *jose.JSONWebKey { + t.Helper() + d, err := jwe.Decrypt(pass) + require.NoError(t, err) + + jwk := &jose.JSONWebKey{} + err = json.Unmarshal(d, jwk) + require.NoError(t, err) + + return jwk +} + +func generateOTT(t *testing.T, jwk *jose.JSONWebKey, subject string) string { + t.Helper() + now := time.Now() + + keyID, err := jose.Thumbprint(jwk) + require.NoError(t, err) + + opts := new(jose.SignerOptions).WithType("JWT").WithHeader("kid", keyID) + signer, err := jose.NewSigner(jose.SigningKey{Key: jwk.Key}, opts) + require.NoError(t, err) + + id, err := randutil.ASCII(64) + require.NoError(t, err) + + cl := struct { + jose.Claims + SANS []string `json:"sans"` + }{ + Claims: jose.Claims{ + ID: id, + Subject: subject, + Issuer: "jwk", + NotBefore: jose.NewNumericDate(now), + Expiry: jose.NewNumericDate(now.Add(time.Minute)), + Audience: []string{"https://127.0.0.1/1.0/sign"}, + }, + SANS: []string{subject}, + } + raw, err := jose.Signed(signer).Claims(cl).CompactSerialize() + require.NoError(t, err) + + return raw +} + +func newAuthorizingServer(t *testing.T, mca *minica.CA) *httptest.Server { + t.Helper() + + key, err := keyutil.GenerateDefaultSigner() + require.NoError(t, err) + + csr, err := x509util.CreateCertificateRequest("127.0.0.1", []string{"127.0.0.1"}, key) + require.NoError(t, err) + + crt, err := mca.SignCSR(csr) + require.NoError(t, err) + + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if assert.Equal(t, "signRequestID", r.Header.Get("X-Request-Id")) { + json.NewEncoder(w).Encode(struct{ Allow bool }{Allow: true}) + w.WriteHeader(http.StatusOK) + return + } + + w.WriteHeader(http.StatusBadRequest) + })) + trustedRoots := x509.NewCertPool() + trustedRoots.AddCert(mca.Root) + + srv.TLS = &tls.Config{ + Certificates: []tls.Certificate{ + { + Certificate: [][]byte{crt.Raw, mca.Intermediate.Raw}, + PrivateKey: key, + Leaf: crt, + }, + }, + ClientCAs: trustedRoots, + ClientAuth: tls.RequireAndVerifyClientCert, + ServerName: "localhost", + } + + return srv +}