From f3229d3e3ca2bf226fae706aac4443ad0926497c Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Tue, 19 Sep 2023 13:48:43 +0200 Subject: [PATCH 1/8] Propagate (original) request ID to webhook requests Technically the webhook request is a new request, so maybe the `X-Request-ID` should not be set to the value of the original request? But then the original request ID should be propageted in the webhook request body, or using a different header. The way the request ID is used in this functionality is actually more like a tracing ID, so that may be an option too. --- authority/provisioner/webhook.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/authority/provisioner/webhook.go b/authority/provisioner/webhook.go index 407b84d8..1097c003 100644 --- a/authority/provisioner/webhook.go +++ b/authority/provisioner/webhook.go @@ -15,6 +15,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/templates" "github.com/smallstep/certificates/webhook" "go.step.sm/linkedca" @@ -169,6 +170,11 @@ retry: return nil, err } + requestID, ok := logging.GetRequestID(ctx) + if ok { + req.Header.Set("X-Request-ID", requestID) + } + secret, err := base64.StdEncoding.DecodeString(w.Secret) if err != nil { return nil, err From b2301ea12731a35f3795505a8b5b2f3ec736e83f Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Tue, 19 Sep 2023 15:39:54 +0200 Subject: [PATCH 2/8] Remove the webhook `Do` method --- authority/provisioner/webhook.go | 20 +++++++++++--------- authority/provisioner/webhook_test.go | 19 ++++++++++++++++--- 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/authority/provisioner/webhook.go b/authority/provisioner/webhook.go index 1097c003..14d357f1 100644 --- a/authority/provisioner/webhook.go +++ b/authority/provisioner/webhook.go @@ -56,7 +56,11 @@ func (wc *WebhookController) Enrich(req *webhook.RequestBody) error { if !wc.isCertTypeOK(wh) { continue } - resp, err := wh.Do(wc.client, req, wc.TemplateData) + // TODO(hs): propagate context from above + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + resp, err := wh.DoWithContext(ctx, wc.client, req, wc.TemplateData) if err != nil { return err } @@ -88,7 +92,12 @@ func (wc *WebhookController) Authorize(req *webhook.RequestBody) error { if !wc.isCertTypeOK(wh) { continue } - resp, err := wh.Do(wc.client, req, wc.TemplateData) + + // TODO(hs): propagate context from above + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + resp, err := wh.DoWithContext(ctx, wc.client, req, wc.TemplateData) if err != nil { return err } @@ -124,13 +133,6 @@ type Webhook struct { } `json:"-"` } -func (w *Webhook) Do(client *http.Client, reqBody *webhook.RequestBody, data any) (*webhook.ResponseBody, error) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - - return w.DoWithContext(ctx, client, reqBody, data) -} - func (w *Webhook) DoWithContext(ctx context.Context, client *http.Client, reqBody *webhook.RequestBody, data any) (*webhook.ResponseBody, error) { tmpl, err := template.New("url").Funcs(templates.StepFuncMap()).Parse(w.URL) if err != nil { diff --git a/authority/provisioner/webhook_test.go b/authority/provisioner/webhook_test.go index 656d75d8..a61da39c 100644 --- a/authority/provisioner/webhook_test.go +++ b/authority/provisioner/webhook_test.go @@ -1,6 +1,7 @@ package provisioner import ( + "context" "crypto/hmac" "crypto/sha256" "crypto/tls" @@ -13,6 +14,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/pkg/errors" "github.com/smallstep/assert" @@ -522,7 +524,11 @@ func TestWebhook_Do(t *testing.T) { reqBody, err := webhook.NewRequestBody(webhook.WithX509CertificateRequest(csr)) assert.FatalError(t, err) - got, err := tc.webhook.Do(http.DefaultClient, reqBody, tc.dataArg) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + got, err := tc.webhook.DoWithContext(ctx, http.DefaultClient, reqBody, tc.dataArg) if tc.expectErr != nil { assert.Equals(t, tc.expectErr.Error(), err.Error()) return @@ -553,11 +559,18 @@ func TestWebhook_Do(t *testing.T) { } reqBody, err := webhook.NewRequestBody(webhook.WithX509CertificateRequest(csr)) assert.FatalError(t, err) - _, err = wh.Do(client, reqBody, nil) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + _, err = wh.DoWithContext(ctx, client, reqBody, nil) assert.FatalError(t, err) + ctx, cancel = context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + wh.DisableTLSClientAuth = true - _, err = wh.Do(client, reqBody, nil) + _, err = wh.DoWithContext(ctx, client, reqBody, nil) assert.Error(t, err) }) } From 4e06bdbc514826ee65983e7f7d5f201f18023130 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Tue, 19 Sep 2023 16:17:36 +0200 Subject: [PATCH 3/8] Add `SignWithContext` method to authority and mocks --- acme/api/revoke_test.go | 4 ++++ acme/common.go | 1 + acme/order_test.go | 10 ++++++++++ api/api.go | 1 + api/api_test.go | 8 ++++++++ authority/provisioner/webhook.go | 19 +++++++++---------- authority/provisioner/webhook_test.go | 4 ++-- authority/ssh.go | 14 +++++++------- authority/tls.go | 23 ++++++++++++++++------- authority/webhook.go | 10 +++++++--- authority/webhook_test.go | 6 ++++-- 11 files changed, 69 insertions(+), 31 deletions(-) diff --git a/acme/api/revoke_test.go b/acme/api/revoke_test.go index a225aa19..e8edcc41 100644 --- a/acme/api/revoke_test.go +++ b/acme/api/revoke_test.go @@ -285,6 +285,10 @@ func (m *mockCA) Sign(*x509.CertificateRequest, provisioner.SignOptions, ...prov return nil, nil } +func (m *mockCA) SignWithContext(context.Context, *x509.CertificateRequest, provisioner.SignOptions, ...provisioner.SignOption) ([]*x509.Certificate, error) { + return nil, nil +} + func (m *mockCA) AreSANsAllowed(ctx context.Context, sans []string) error { if m.MockAreSANsallowed != nil { return m.MockAreSANsallowed(ctx, sans) diff --git a/acme/common.go b/acme/common.go index 7d58305f..afab13b2 100644 --- a/acme/common.go +++ b/acme/common.go @@ -22,6 +22,7 @@ var clock Clock // CertificateAuthority is the interface implemented by a CA authority. type CertificateAuthority interface { Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) + SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) AreSANsAllowed(ctx context.Context, sans []string) error IsRevoked(sn string) (bool, error) Revoke(context.Context, *authority.RevokeOptions) error diff --git a/acme/order_test.go b/acme/order_test.go index 2851bb19..3fa99b9b 100644 --- a/acme/order_test.go +++ b/acme/order_test.go @@ -272,6 +272,7 @@ func TestOrder_UpdateStatus(t *testing.T) { type mockSignAuth struct { sign func(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) + signWithContext func(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) areSANsAllowed func(ctx context.Context, sans []string) error loadProvisionerByName func(string) (provisioner.Interface, error) ret1, ret2 interface{} @@ -287,6 +288,15 @@ func (m *mockSignAuth) Sign(csr *x509.CertificateRequest, signOpts provisioner.S return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err } +func (m *mockSignAuth) SignWithContext(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + if m.signWithContext != nil { + return m.signWithContext(ctx, csr, signOpts, extraOpts...) + } else if m.err != nil { + return nil, m.err + } + return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err +} + func (m *mockSignAuth) AreSANsAllowed(ctx context.Context, sans []string) error { if m.areSANsAllowed != nil { return m.areSANsAllowed(ctx, sans) diff --git a/api/api.go b/api/api.go index c9820351..2d6c0bf7 100644 --- a/api/api.go +++ b/api/api.go @@ -42,6 +42,7 @@ type Authority interface { GetTLSOptions() *config.TLSOptions Root(shasum string) (*x509.Certificate, error) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) + SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) Renew(peer *x509.Certificate) ([]*x509.Certificate, error) RenewContext(ctx context.Context, peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) Rekey(peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) diff --git a/api/api_test.go b/api/api_test.go index d96015f9..90acf759 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -193,6 +193,7 @@ type mockAuthority struct { getTLSOptions func() *authority.TLSOptions root func(shasum string) (*x509.Certificate, error) sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) + signWithContext func(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) renew func(cert *x509.Certificate) ([]*x509.Certificate, error) rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) renewContext func(ctx context.Context, oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) @@ -261,6 +262,13 @@ func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.SignO return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err } +func (m *mockAuthority) SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + if m.signWithContext != nil { + return m.signWithContext(ctx, cr, opts, signOpts...) + } + return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err +} + func (m *mockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, error) { if m.renew != nil { return m.renew(cert) diff --git a/authority/provisioner/webhook.go b/authority/provisioner/webhook.go index 14d357f1..1cc2047c 100644 --- a/authority/provisioner/webhook.go +++ b/authority/provisioner/webhook.go @@ -37,7 +37,7 @@ type WebhookController struct { // Enrich fetches data from remote servers and adds returned data to the // templateData -func (wc *WebhookController) Enrich(req *webhook.RequestBody) error { +func (wc *WebhookController) Enrich(ctx context.Context, req *webhook.RequestBody) error { if wc == nil { return nil } @@ -56,11 +56,11 @@ func (wc *WebhookController) Enrich(req *webhook.RequestBody) error { if !wc.isCertTypeOK(wh) { continue } - // TODO(hs): propagate context from above - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - resp, err := wh.DoWithContext(ctx, wc.client, req, wc.TemplateData) + whCtx, cancel := context.WithTimeout(ctx, time.Second*10) + defer cancel() //nolint:gocritic // every request canceled with its own timeout + + resp, err := wh.DoWithContext(whCtx, wc.client, req, wc.TemplateData) if err != nil { return err } @@ -73,7 +73,7 @@ func (wc *WebhookController) Enrich(req *webhook.RequestBody) error { } // Authorize checks that all remote servers allow the request -func (wc *WebhookController) Authorize(req *webhook.RequestBody) error { +func (wc *WebhookController) Authorize(ctx context.Context, req *webhook.RequestBody) error { if wc == nil { return nil } @@ -93,11 +93,10 @@ func (wc *WebhookController) Authorize(req *webhook.RequestBody) error { continue } - // TODO(hs): propagate context from above - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() + whCtx, cancel := context.WithTimeout(ctx, time.Second*10) + defer cancel() //nolint:gocritic // every request canceled with its own timeout - resp, err := wh.DoWithContext(ctx, wc.client, req, wc.TemplateData) + resp, err := wh.DoWithContext(whCtx, wc.client, req, wc.TemplateData) if err != nil { return err } diff --git a/authority/provisioner/webhook_test.go b/authority/provisioner/webhook_test.go index a61da39c..cc79a09b 100644 --- a/authority/provisioner/webhook_test.go +++ b/authority/provisioner/webhook_test.go @@ -242,7 +242,7 @@ func TestWebhookController_Enrich(t *testing.T) { wh.URL = ts.URL } - err := test.ctl.Enrich(test.req) + err := test.ctl.Enrich(context.Background(), test.req) if (err != nil) != test.expectErr { t.Fatalf("Got err %v, want %v", err, test.expectErr) } @@ -352,7 +352,7 @@ func TestWebhookController_Authorize(t *testing.T) { wh.URL = ts.URL } - err := test.ctl.Authorize(test.req) + err := test.ctl.Authorize(context.Background(), test.req) if (err != nil) != test.expectErr { t.Fatalf("Got err %v, want %v", err, test.expectErr) } diff --git a/authority/ssh.go b/authority/ssh.go index f9371d60..688bfd76 100644 --- a/authority/ssh.go +++ b/authority/ssh.go @@ -146,7 +146,7 @@ func (a *Authority) GetSSHBastion(ctx context.Context, user, hostname string) (* } // SignSSH creates a signed SSH certificate with the given public key and options. -func (a *Authority) SignSSH(_ context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { +func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { var ( certOptions []sshutil.Option mods []provisioner.SSHCertModifier @@ -205,7 +205,7 @@ func (a *Authority) SignSSH(_ context.Context, key ssh.PublicKey, opts provision } // Call enriching webhooks - if err := callEnrichingWebhooksSSH(webhookCtl, cr); err != nil { + if err := callEnrichingWebhooksSSH(ctx, webhookCtl, cr); err != nil { return nil, errs.ApplyOptions( errs.ForbiddenErr(err, err.Error()), errs.WithKeyVal("signOptions", signOpts), @@ -277,7 +277,7 @@ func (a *Authority) SignSSH(_ context.Context, key ssh.PublicKey, opts provision } // Send certificate to webhooks for authorization - if err := callAuthorizingWebhooksSSH(webhookCtl, certificate, certTpl); err != nil { + if err := callAuthorizingWebhooksSSH(ctx, webhookCtl, certificate, certTpl); err != nil { return nil, errs.ApplyOptions( errs.ForbiddenErr(err, "authority.SignSSH: error signing certificate"), ) @@ -653,7 +653,7 @@ func (a *Authority) getAddUserCommand(principal string) string { return strings.ReplaceAll(cmd, "", principal) } -func callEnrichingWebhooksSSH(webhookCtl webhookController, cr sshutil.CertificateRequest) error { +func callEnrichingWebhooksSSH(ctx context.Context, webhookCtl webhookController, cr sshutil.CertificateRequest) error { if webhookCtl == nil { return nil } @@ -663,10 +663,10 @@ func callEnrichingWebhooksSSH(webhookCtl webhookController, cr sshutil.Certifica if err != nil { return err } - return webhookCtl.Enrich(whEnrichReq) + return webhookCtl.Enrich(ctx, whEnrichReq) } -func callAuthorizingWebhooksSSH(webhookCtl webhookController, cert *sshutil.Certificate, certTpl *ssh.Certificate) error { +func callAuthorizingWebhooksSSH(ctx context.Context, webhookCtl webhookController, cert *sshutil.Certificate, certTpl *ssh.Certificate) error { if webhookCtl == nil { return nil } @@ -676,5 +676,5 @@ func callAuthorizingWebhooksSSH(webhookCtl webhookController, cert *sshutil.Cert if err != nil { return err } - return webhookCtl.Authorize(whAuthBody) + return webhookCtl.Authorize(ctx, whAuthBody) } diff --git a/authority/tls.go b/authority/tls.go index 6e967920..900b1ff8 100644 --- a/authority/tls.go +++ b/authority/tls.go @@ -91,8 +91,17 @@ func withDefaultASN1DN(def *config.ASN1DN) provisioner.CertificateModifierFunc { } } -// Sign creates a signed certificate from a certificate signing request. +// Sign creates a signed certificate from a certificate signing request. It +// creates a new context.Context, and calls into SignWithContext. +// +// Deprecated: Use authority.SignWithContext with an actual context.Context. func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + return a.SignWithContext(context.Background(), csr, signOpts, extraOpts...) +} + +// SignWithContext creates a signed certificate from a certificate signing request, +// taking the provided context.Context. +func (a *Authority) SignWithContext(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { var ( certOptions []x509util.Option certValidators []provisioner.CertificateValidator @@ -163,7 +172,7 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign } } - if err := callEnrichingWebhooksX509(webhookCtl, attData, csr); err != nil { + if err := callEnrichingWebhooksX509(ctx, webhookCtl, attData, csr); err != nil { return nil, errs.ApplyOptions( errs.ForbiddenErr(err, err.Error()), errs.WithKeyVal("csr", csr), @@ -256,7 +265,7 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign } // Send certificate to webhooks for authorization - if err := callAuthorizingWebhooksX509(webhookCtl, cert, leaf, attData); err != nil { + if err := callAuthorizingWebhooksX509(ctx, webhookCtl, cert, leaf, attData); err != nil { return nil, errs.ApplyOptions( errs.ForbiddenErr(err, "error creating certificate"), opts..., @@ -952,7 +961,7 @@ func templatingError(err error) error { return errors.Wrap(cause, "error applying certificate template") } -func callEnrichingWebhooksX509(webhookCtl webhookController, attData *provisioner.AttestationData, csr *x509.CertificateRequest) error { +func callEnrichingWebhooksX509(ctx context.Context, webhookCtl webhookController, attData *provisioner.AttestationData, csr *x509.CertificateRequest) error { if webhookCtl == nil { return nil } @@ -969,10 +978,10 @@ func callEnrichingWebhooksX509(webhookCtl webhookController, attData *provisione if err != nil { return err } - return webhookCtl.Enrich(whEnrichReq) + return webhookCtl.Enrich(ctx, whEnrichReq) } -func callAuthorizingWebhooksX509(webhookCtl webhookController, cert *x509util.Certificate, leaf *x509.Certificate, attData *provisioner.AttestationData) error { +func callAuthorizingWebhooksX509(ctx context.Context, webhookCtl webhookController, cert *x509util.Certificate, leaf *x509.Certificate, attData *provisioner.AttestationData) error { if webhookCtl == nil { return nil } @@ -989,5 +998,5 @@ func callAuthorizingWebhooksX509(webhookCtl webhookController, cert *x509util.Ce if err != nil { return err } - return webhookCtl.Authorize(whAuthBody) + return webhookCtl.Authorize(ctx, whAuthBody) } diff --git a/authority/webhook.go b/authority/webhook.go index d887e077..29e3e6c3 100644 --- a/authority/webhook.go +++ b/authority/webhook.go @@ -1,8 +1,12 @@ package authority -import "github.com/smallstep/certificates/webhook" +import ( + "context" + + "github.com/smallstep/certificates/webhook" +) type webhookController interface { - Enrich(*webhook.RequestBody) error - Authorize(*webhook.RequestBody) error + Enrich(context.Context, *webhook.RequestBody) error + Authorize(context.Context, *webhook.RequestBody) error } diff --git a/authority/webhook_test.go b/authority/webhook_test.go index 0e713af7..75b59f63 100644 --- a/authority/webhook_test.go +++ b/authority/webhook_test.go @@ -1,6 +1,8 @@ package authority import ( + "context" + "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/webhook" ) @@ -14,7 +16,7 @@ type mockWebhookController struct { var _ webhookController = &mockWebhookController{} -func (wc *mockWebhookController) Enrich(*webhook.RequestBody) error { +func (wc *mockWebhookController) Enrich(context.Context, *webhook.RequestBody) error { for key, data := range wc.respData { wc.templateData.SetWebhook(key, data) } @@ -22,6 +24,6 @@ func (wc *mockWebhookController) Enrich(*webhook.RequestBody) error { return wc.enrichErr } -func (wc *mockWebhookController) Authorize(*webhook.RequestBody) error { +func (wc *mockWebhookController) Authorize(context.Context, *webhook.RequestBody) error { return wc.authorizeErr } From 9e3807eaa3096d633e6f28437387fef4256d36d7 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Tue, 19 Sep 2023 16:34:29 +0200 Subject: [PATCH 4/8] Use `SignWithContext` in the critical paths --- acme/order.go | 2 +- api/sign.go | 2 +- api/ssh.go | 2 +- scep/authority.go | 3 ++- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/acme/order.go b/acme/order.go index 8dfcf97a..5a86c2c8 100644 --- a/acme/order.go +++ b/acme/order.go @@ -263,7 +263,7 @@ func (o *Order) Finalize(ctx context.Context, db DB, csr *x509.CertificateReques signOps = append(signOps, extraOptions...) // Sign a new certificate. - certChain, err := auth.Sign(csr, provisioner.SignOptions{ + certChain, err := auth.SignWithContext(ctx, csr, provisioner.SignOptions{ NotBefore: provisioner.NewTimeDuration(o.NotBefore), NotAfter: provisioner.NewTimeDuration(o.NotAfter), }, signOps...) diff --git a/api/sign.go b/api/sign.go index c0c83ce2..26b3c396 100644 --- a/api/sign.go +++ b/api/sign.go @@ -78,7 +78,7 @@ func Sign(w http.ResponseWriter, r *http.Request) { return } - certChain, err := a.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...) + certChain, err := a.SignWithContext(ctx, body.CsrPEM.CertificateRequest, opts, signOpts...) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error signing certificate")) return diff --git a/api/ssh.go b/api/ssh.go index fbaa8c5a..a07dab29 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -330,7 +330,7 @@ func SSHSign(w http.ResponseWriter, r *http.Request) { NotAfter: time.Unix(int64(cert.ValidBefore), 0), }) - certChain, err := a.Sign(cr, provisioner.SignOptions{}, signOpts...) + certChain, err := a.SignWithContext(ctx, cr, provisioner.SignOptions{}, signOpts...) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate")) return diff --git a/scep/authority.go b/scep/authority.go index 23c28813..a7333aa7 100644 --- a/scep/authority.go +++ b/scep/authority.go @@ -65,6 +65,7 @@ type AuthorityOptions struct { // SignAuthority is the interface for a signing authority type SignAuthority interface { Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) + SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) LoadProvisionerByName(string) (provisioner.Interface, error) } @@ -296,7 +297,7 @@ func (a *Authority) SignCSR(ctx context.Context, csr *x509.CertificateRequest, m } signOps = append(signOps, templateOptions) - certChain, err := a.signAuth.Sign(csr, opts, signOps...) + certChain, err := a.signAuth.SignWithContext(ctx, csr, opts, signOps...) if err != nil { return nil, fmt.Errorf("error generating certificate for order: %w", err) } From 4ef093dc4b478c89b17ce761e0260a99265fa39c Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Tue, 19 Sep 2023 16:55:59 +0200 Subject: [PATCH 5/8] Fix broken tests relying on `Sign` in mocks --- acme/order_test.go | 18 +++++++++--------- api/ssh_test.go | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/acme/order_test.go b/acme/order_test.go index 3fa99b9b..17060f11 100644 --- a/acme/order_test.go +++ b/acme/order_test.go @@ -588,7 +588,7 @@ func TestOrder_Finalize(t *testing.T) { }, }, ca: &mockSignAuth{ - sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return nil, errors.New("force") }, @@ -638,7 +638,7 @@ func TestOrder_Finalize(t *testing.T) { }, }, ca: &mockSignAuth{ - sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{foo, bar, baz}, nil }, @@ -695,7 +695,7 @@ func TestOrder_Finalize(t *testing.T) { }, }, ca: &mockSignAuth{ - sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{foo, bar, baz}, nil }, @@ -780,7 +780,7 @@ func TestOrder_Finalize(t *testing.T) { }, }, ca: &mockSignAuth{ - sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{leaf, inter, root}, nil }, @@ -873,7 +873,7 @@ func TestOrder_Finalize(t *testing.T) { }, }, ca: &mockSignAuth{ - sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{leaf, inter, root}, nil }, @@ -983,7 +983,7 @@ func TestOrder_Finalize(t *testing.T) { // using the mocking functions as a wrapper for actual test helpers generated per test case or per // function that's tested. ca: &mockSignAuth{ - sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{leaf, inter, root}, nil }, @@ -1054,7 +1054,7 @@ func TestOrder_Finalize(t *testing.T) { }, }, ca: &mockSignAuth{ - sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{foo, bar, baz}, nil }, @@ -1118,7 +1118,7 @@ func TestOrder_Finalize(t *testing.T) { }, }, ca: &mockSignAuth{ - sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{foo, bar, baz}, nil }, @@ -1185,7 +1185,7 @@ func TestOrder_Finalize(t *testing.T) { }, }, ca: &mockSignAuth{ - sign: func(_csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + signWithContext: func(_ context.Context, _csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, _csr, csr) return []*x509.Certificate{foo, bar, baz}, nil }, diff --git a/api/ssh_test.go b/api/ssh_test.go index 57dd6775..2b90dc12 100644 --- a/api/ssh_test.go +++ b/api/ssh_test.go @@ -325,7 +325,7 @@ func Test_SSHSign(t *testing.T) { signSSHAddUser: func(ctx context.Context, key ssh.PublicKey, cert *ssh.Certificate) (*ssh.Certificate, error) { return tt.addUserCert, tt.addUserErr }, - sign: func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + signWithContext: func(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { return tt.tlsSignCerts, tt.tlsSignErr }, }) From 96895087098cdcc95299e2a6126c23f1c5260282 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Tue, 27 Feb 2024 13:39:21 +0100 Subject: [PATCH 6/8] Add tests for webhook request IDs --- authority/provisioner/webhook_test.go | 126 ++++++++++++++++++-------- 1 file changed, 86 insertions(+), 40 deletions(-) diff --git a/authority/provisioner/webhook_test.go b/authority/provisioner/webhook_test.go index 0ce3f36d..ced713d1 100644 --- a/authority/provisioner/webhook_test.go +++ b/authority/provisioner/webhook_test.go @@ -17,8 +17,11 @@ import ( "time" "github.com/pkg/errors" - "github.com/smallstep/assert" + sassert "github.com/smallstep/assert" + "github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/webhook" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.step.sm/crypto/pemutil" "go.step.sm/crypto/x509util" "go.step.sm/linkedca" @@ -94,19 +97,24 @@ func TestWebhookController_isCertTypeOK(t *testing.T) { } for name, test := range tests { t.Run(name, func(t *testing.T) { - assert.Equals(t, test.want, test.wc.isCertTypeOK(test.wh)) + sassert.Equals(t, test.want, test.wc.isCertTypeOK(test.wh)) }) } } +// withRequestID is a helper that calls into [logging.WithRequestID] and returns +// a new context with the requestID added to the provided context. +func withRequestID(ctx context.Context, requestID string) context.Context { + return logging.WithRequestID(ctx, requestID) +} + func TestWebhookController_Enrich(t *testing.T) { cert, err := pemutil.ReadCertificate("testdata/certs/x5c-leaf.crt", pemutil.WithFirstBlock()) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) type test struct { ctl *WebhookController + ctx context.Context req *webhook.RequestBody responses []*webhook.ResponseBody expectErr bool @@ -131,6 +139,7 @@ func TestWebhookController_Enrich(t *testing.T) { webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}}, TemplateData: x509util.TemplateData{}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: true, Data: map[string]any{"role": "bar"}}}, expectErr: false, @@ -145,6 +154,7 @@ func TestWebhookController_Enrich(t *testing.T) { }, TemplateData: x509util.TemplateData{}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{ {Allow: true, Data: map[string]any{"role": "bar"}}, @@ -168,6 +178,7 @@ func TestWebhookController_Enrich(t *testing.T) { TemplateData: x509util.TemplateData{}, certType: linkedca.Webhook_X509, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{ {Allow: true, Data: map[string]any{"role": "bar"}}, @@ -187,14 +198,15 @@ func TestWebhookController_Enrich(t *testing.T) { TemplateData: x509util.TemplateData{}, options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: true, Data: map[string]any{"role": "bar"}}}, expectErr: false, expectTemplateData: x509util.TemplateData{"Webhooks": map[string]any{"people": map[string]any{"role": "bar"}}}, assertRequest: func(t *testing.T, req *webhook.RequestBody) { key, err := x509.MarshalPKIXPublicKey(cert.PublicKey) - assert.FatalError(t, err) - assert.Equals(t, &webhook.X5CCertificate{ + sassert.FatalError(t, err) + sassert.Equals(t, &webhook.X5CCertificate{ Raw: cert.Raw, PublicKey: key, PublicKeyAlgorithm: cert.PublicKeyAlgorithm.String(), @@ -209,6 +221,7 @@ func TestWebhookController_Enrich(t *testing.T) { webhooks: []*Webhook{{Name: "people", Kind: "ENRICHING"}}, TemplateData: x509util.TemplateData{}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: true, @@ -223,6 +236,7 @@ func TestWebhookController_Enrich(t *testing.T) { PublicKey: []byte("bad"), })}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: true, @@ -234,19 +248,21 @@ func TestWebhookController_Enrich(t *testing.T) { for i, wh := range test.ctl.webhooks { var j = i ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "reqID", r.Header.Get("X-Request-ID")) + err := json.NewEncoder(w).Encode(test.responses[j]) - assert.FatalError(t, err) + require.NoError(t, err) })) // nolint: gocritic // defer in loop isn't a memory leak defer ts.Close() wh.URL = ts.URL } - err := test.ctl.Enrich(context.Background(), test.req) + err := test.ctl.Enrich(test.ctx, test.req) if (err != nil) != test.expectErr { t.Fatalf("Got err %v, want %v", err, test.expectErr) } - assert.Equals(t, test.expectTemplateData, test.ctl.TemplateData) + sassert.Equals(t, test.expectTemplateData, test.ctl.TemplateData) if test.assertRequest != nil { test.assertRequest(t, test.req) } @@ -256,12 +272,11 @@ func TestWebhookController_Enrich(t *testing.T) { func TestWebhookController_Authorize(t *testing.T) { cert, err := pemutil.ReadCertificate("testdata/certs/x5c-leaf.crt", pemutil.WithFirstBlock()) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) type test struct { ctl *WebhookController + ctx context.Context req *webhook.RequestBody responses []*webhook.ResponseBody expectErr bool @@ -282,6 +297,7 @@ func TestWebhookController_Authorize(t *testing.T) { client: http.DefaultClient, webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: true}}, expectErr: false, @@ -292,6 +308,7 @@ func TestWebhookController_Authorize(t *testing.T) { webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING", CertType: linkedca.Webhook_X509.String()}}, certType: linkedca.Webhook_SSH, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: false, @@ -302,13 +319,14 @@ func TestWebhookController_Authorize(t *testing.T) { webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, options: []webhook.RequestBodyOption{webhook.WithX5CCertificate(cert)}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: true}}, expectErr: false, assertRequest: func(t *testing.T, req *webhook.RequestBody) { key, err := x509.MarshalPKIXPublicKey(cert.PublicKey) - assert.FatalError(t, err) - assert.Equals(t, &webhook.X5CCertificate{ + require.NoError(t, err) + sassert.Equals(t, &webhook.X5CCertificate{ Raw: cert.Raw, PublicKey: key, PublicKeyAlgorithm: cert.PublicKeyAlgorithm.String(), @@ -322,6 +340,7 @@ func TestWebhookController_Authorize(t *testing.T) { client: http.DefaultClient, webhooks: []*Webhook{{Name: "people", Kind: "AUTHORIZING"}}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: true, @@ -334,6 +353,7 @@ func TestWebhookController_Authorize(t *testing.T) { PublicKey: []byte("bad"), })}, }, + ctx: withRequestID(context.Background(), "reqID"), req: &webhook.RequestBody{}, responses: []*webhook.ResponseBody{{Allow: false}}, expectErr: true, @@ -344,15 +364,17 @@ func TestWebhookController_Authorize(t *testing.T) { for i, wh := range test.ctl.webhooks { var j = i ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "reqID", r.Header.Get("X-Request-ID")) + err := json.NewEncoder(w).Encode(test.responses[j]) - assert.FatalError(t, err) + require.NoError(t, err) })) // nolint: gocritic // defer in loop isn't a memory leak defer ts.Close() wh.URL = ts.URL } - err := test.ctl.Authorize(context.Background(), test.req) + err := test.ctl.Authorize(test.ctx, test.req) if (err != nil) != test.expectErr { t.Fatalf("Got err %v, want %v", err, test.expectErr) } @@ -368,6 +390,7 @@ func TestWebhook_Do(t *testing.T) { type test struct { webhook Webhook dataArg any + requestID string webhookResponse webhook.ResponseBody expectPath string errStatusCode int @@ -377,6 +400,16 @@ func TestWebhook_Do(t *testing.T) { } tests := map[string]test{ "ok": { + webhook: Webhook{ + ID: "abc123", + Secret: "c2VjcmV0Cg==", + }, + requestID: "reqID", + webhookResponse: webhook.ResponseBody{ + Data: map[string]interface{}{"role": "dba"}, + }, + }, + "ok/no-request-id": { webhook: Webhook{ ID: "abc123", Secret: "c2VjcmV0Cg==", @@ -391,6 +424,7 @@ func TestWebhook_Do(t *testing.T) { Secret: "c2VjcmV0Cg==", BearerToken: "mytoken", }, + requestID: "reqID", webhookResponse: webhook.ResponseBody{ Data: map[string]interface{}{"role": "dba"}, }, @@ -407,6 +441,7 @@ func TestWebhook_Do(t *testing.T) { Password: "mypass", }, }, + requestID: "reqID", webhookResponse: webhook.ResponseBody{ Data: map[string]interface{}{"role": "dba"}, }, @@ -418,7 +453,8 @@ func TestWebhook_Do(t *testing.T) { URL: "/users/{{ .username }}?region={{ .region }}", Secret: "c2VjcmV0Cg==", }, - dataArg: map[string]interface{}{"username": "areed", "region": "central"}, + requestID: "reqID", + dataArg: map[string]interface{}{"username": "areed", "region": "central"}, webhookResponse: webhook.ResponseBody{ Data: map[string]interface{}{"role": "dba"}, }, @@ -453,6 +489,7 @@ func TestWebhook_Do(t *testing.T) { ID: "abc123", Secret: "c2VjcmV0Cg==", }, + requestID: "reqID", webhookResponse: webhook.ResponseBody{ Allow: true, }, @@ -465,6 +502,7 @@ func TestWebhook_Do(t *testing.T) { webhookResponse: webhook.ResponseBody{ Data: map[string]interface{}{"role": "dba"}, }, + requestID: "reqID", errStatusCode: 404, serverErrMsg: "item not found", expectErr: errors.New("Webhook server responded with 404"), @@ -473,38 +511,42 @@ func TestWebhook_Do(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if tc.requestID != "" { + assert.Equal(t, tc.requestID, r.Header.Get("X-Request-ID")) + } + id := r.Header.Get("X-Smallstep-Webhook-ID") - assert.Equals(t, tc.webhook.ID, id) + sassert.Equals(t, tc.webhook.ID, id) sig, err := hex.DecodeString(r.Header.Get("X-Smallstep-Signature")) - assert.FatalError(t, err) + assert.NoError(t, err) body, err := io.ReadAll(r.Body) - assert.FatalError(t, err) + assert.NoError(t, err) secret, err := base64.StdEncoding.DecodeString(tc.webhook.Secret) - assert.FatalError(t, err) + assert.NoError(t, err) h := hmac.New(sha256.New, secret) h.Write(body) mac := h.Sum(nil) - assert.True(t, hmac.Equal(sig, mac)) + sassert.True(t, hmac.Equal(sig, mac)) switch { case tc.webhook.BearerToken != "": ah := fmt.Sprintf("Bearer %s", tc.webhook.BearerToken) - assert.Equals(t, ah, r.Header.Get("Authorization")) + sassert.Equals(t, ah, r.Header.Get("Authorization")) case tc.webhook.BasicAuth.Username != "" || tc.webhook.BasicAuth.Password != "": whReq, err := http.NewRequest("", "", http.NoBody) - assert.FatalError(t, err) + assert.NoError(t, err) whReq.SetBasicAuth(tc.webhook.BasicAuth.Username, tc.webhook.BasicAuth.Password) ah := whReq.Header.Get("Authorization") - assert.Equals(t, ah, whReq.Header.Get("Authorization")) + sassert.Equals(t, ah, whReq.Header.Get("Authorization")) default: - assert.Equals(t, "", r.Header.Get("Authorization")) + sassert.Equals(t, "", r.Header.Get("Authorization")) } if tc.expectPath != "" { - assert.Equals(t, tc.expectPath, r.URL.Path+"?"+r.URL.RawQuery) + sassert.Equals(t, tc.expectPath, r.URL.Path+"?"+r.URL.RawQuery) } if tc.errStatusCode != 0 { @@ -514,30 +556,34 @@ func TestWebhook_Do(t *testing.T) { reqBody := new(webhook.RequestBody) err = json.Unmarshal(body, reqBody) - assert.FatalError(t, err) - // assert.Equals(t, tc.expectToken, reqBody.Token) + require.NoError(t, err) + // sassert.Equals(t, tc.expectToken, reqBody.Token) err = json.NewEncoder(w).Encode(tc.webhookResponse) - assert.FatalError(t, err) + require.NoError(t, err) })) defer ts.Close() tc.webhook.URL = ts.URL + tc.webhook.URL reqBody, err := webhook.NewRequestBody(webhook.WithX509CertificateRequest(csr)) - assert.FatalError(t, err) + require.NoError(t, err) - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + ctx := context.Background() + if tc.requestID != "" { + ctx = withRequestID(context.Background(), tc.requestID) + } + ctx, cancel := context.WithTimeout(ctx, time.Second*10) defer cancel() got, err := tc.webhook.DoWithContext(ctx, http.DefaultClient, reqBody, tc.dataArg) if tc.expectErr != nil { - assert.Equals(t, tc.expectErr.Error(), err.Error()) + sassert.Equals(t, tc.expectErr.Error(), err.Error()) return } - assert.FatalError(t, err) + assert.NoError(t, err) - assert.Equals(t, got, &tc.webhookResponse) + sassert.Equals(t, got, &tc.webhookResponse) }) } @@ -550,7 +596,7 @@ func TestWebhook_Do(t *testing.T) { URL: ts.URL, } cert, err := tls.LoadX509KeyPair("testdata/certs/foo.crt", "testdata/secrets/foo.key") - assert.FatalError(t, err) + require.NoError(t, err) transport := http.DefaultTransport.(*http.Transport).Clone() transport.TLSClientConfig = &tls.Config{ InsecureSkipVerify: true, @@ -560,19 +606,19 @@ func TestWebhook_Do(t *testing.T) { Transport: transport, } reqBody, err := webhook.NewRequestBody(webhook.WithX509CertificateRequest(csr)) - assert.FatalError(t, err) + require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() _, err = wh.DoWithContext(ctx, client, reqBody, nil) - assert.FatalError(t, err) + require.NoError(t, err) ctx, cancel = context.WithTimeout(context.Background(), time.Second*10) defer cancel() wh.DisableTLSClientAuth = true _, err = wh.DoWithContext(ctx, client, reqBody, nil) - assert.Error(t, err) + require.Error(t, err) }) } From c16a0b70ee31ef59fdb652dadb8705f6f0a49012 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Tue, 27 Feb 2024 13:44:44 +0100 Subject: [PATCH 7/8] Remove `smallstep/assert` and `pkg/errors` from webhook tests --- authority/provisioner/webhook_test.go | 33 ++++++++++++--------------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/authority/provisioner/webhook_test.go b/authority/provisioner/webhook_test.go index ced713d1..60dcdbc7 100644 --- a/authority/provisioner/webhook_test.go +++ b/authority/provisioner/webhook_test.go @@ -9,6 +9,7 @@ import ( "encoding/base64" "encoding/hex" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -16,8 +17,6 @@ import ( "testing" "time" - "github.com/pkg/errors" - sassert "github.com/smallstep/assert" "github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/webhook" "github.com/stretchr/testify/assert" @@ -97,7 +96,7 @@ func TestWebhookController_isCertTypeOK(t *testing.T) { } for name, test := range tests { t.Run(name, func(t *testing.T) { - sassert.Equals(t, test.want, test.wc.isCertTypeOK(test.wh)) + assert.Equal(t, test.want, test.wc.isCertTypeOK(test.wh)) }) } } @@ -205,8 +204,8 @@ func TestWebhookController_Enrich(t *testing.T) { expectTemplateData: x509util.TemplateData{"Webhooks": map[string]any{"people": map[string]any{"role": "bar"}}}, assertRequest: func(t *testing.T, req *webhook.RequestBody) { key, err := x509.MarshalPKIXPublicKey(cert.PublicKey) - sassert.FatalError(t, err) - sassert.Equals(t, &webhook.X5CCertificate{ + require.NoError(t, err) + assert.Equal(t, &webhook.X5CCertificate{ Raw: cert.Raw, PublicKey: key, PublicKeyAlgorithm: cert.PublicKeyAlgorithm.String(), @@ -262,7 +261,7 @@ func TestWebhookController_Enrich(t *testing.T) { if (err != nil) != test.expectErr { t.Fatalf("Got err %v, want %v", err, test.expectErr) } - sassert.Equals(t, test.expectTemplateData, test.ctl.TemplateData) + assert.Equal(t, test.expectTemplateData, test.ctl.TemplateData) if test.assertRequest != nil { test.assertRequest(t, test.req) } @@ -326,7 +325,7 @@ func TestWebhookController_Authorize(t *testing.T) { assertRequest: func(t *testing.T, req *webhook.RequestBody) { key, err := x509.MarshalPKIXPublicKey(cert.PublicKey) require.NoError(t, err) - sassert.Equals(t, &webhook.X5CCertificate{ + assert.Equal(t, &webhook.X5CCertificate{ Raw: cert.Raw, PublicKey: key, PublicKeyAlgorithm: cert.PublicKeyAlgorithm.String(), @@ -515,8 +514,7 @@ func TestWebhook_Do(t *testing.T) { assert.Equal(t, tc.requestID, r.Header.Get("X-Request-ID")) } - id := r.Header.Get("X-Smallstep-Webhook-ID") - sassert.Equals(t, tc.webhook.ID, id) + assert.Equal(t, tc.webhook.ID, r.Header.Get("X-Smallstep-Webhook-ID")) sig, err := hex.DecodeString(r.Header.Get("X-Smallstep-Signature")) assert.NoError(t, err) @@ -529,24 +527,24 @@ func TestWebhook_Do(t *testing.T) { h := hmac.New(sha256.New, secret) h.Write(body) mac := h.Sum(nil) - sassert.True(t, hmac.Equal(sig, mac)) + assert.True(t, hmac.Equal(sig, mac)) switch { case tc.webhook.BearerToken != "": ah := fmt.Sprintf("Bearer %s", tc.webhook.BearerToken) - sassert.Equals(t, ah, r.Header.Get("Authorization")) + assert.Equal(t, ah, r.Header.Get("Authorization")) case tc.webhook.BasicAuth.Username != "" || tc.webhook.BasicAuth.Password != "": whReq, err := http.NewRequest("", "", http.NoBody) - assert.NoError(t, err) + require.NoError(t, err) whReq.SetBasicAuth(tc.webhook.BasicAuth.Username, tc.webhook.BasicAuth.Password) ah := whReq.Header.Get("Authorization") - sassert.Equals(t, ah, whReq.Header.Get("Authorization")) + assert.Equal(t, ah, whReq.Header.Get("Authorization")) default: - sassert.Equals(t, "", r.Header.Get("Authorization")) + assert.Equal(t, "", r.Header.Get("Authorization")) } if tc.expectPath != "" { - sassert.Equals(t, tc.expectPath, r.URL.Path+"?"+r.URL.RawQuery) + assert.Equal(t, tc.expectPath, r.URL.Path+"?"+r.URL.RawQuery) } if tc.errStatusCode != 0 { @@ -557,7 +555,6 @@ func TestWebhook_Do(t *testing.T) { reqBody := new(webhook.RequestBody) err = json.Unmarshal(body, reqBody) require.NoError(t, err) - // sassert.Equals(t, tc.expectToken, reqBody.Token) err = json.NewEncoder(w).Encode(tc.webhookResponse) require.NoError(t, err) @@ -578,12 +575,12 @@ func TestWebhook_Do(t *testing.T) { got, err := tc.webhook.DoWithContext(ctx, http.DefaultClient, reqBody, tc.dataArg) if tc.expectErr != nil { - sassert.Equals(t, tc.expectErr.Error(), err.Error()) + assert.Equal(t, tc.expectErr.Error(), err.Error()) return } assert.NoError(t, err) - sassert.Equals(t, got, &tc.webhookResponse) + assert.Equal(t, &tc.webhookResponse, got) }) } From 041b486c556017aac05a3dc12c1b5681190ac55d Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Tue, 27 Feb 2024 14:00:09 +0100 Subject: [PATCH 8/8] Remove usages of `Sign` without context --- acme/api/revoke_test.go | 4 ---- acme/common.go | 1 - acme/order_test.go | 10 ---------- api/api.go | 1 - api/api_test.go | 8 -------- authority/authority_test.go | 3 ++- authority/authorize_test.go | 2 +- authority/provisioners_test.go | 2 +- authority/tls_test.go | 8 ++++---- scep/authority.go | 1 - 10 files changed, 8 insertions(+), 32 deletions(-) diff --git a/acme/api/revoke_test.go b/acme/api/revoke_test.go index 5d274faf..85b9a032 100644 --- a/acme/api/revoke_test.go +++ b/acme/api/revoke_test.go @@ -281,10 +281,6 @@ type mockCA struct { MockAreSANsallowed func(ctx context.Context, sans []string) error } -func (m *mockCA) Sign(*x509.CertificateRequest, provisioner.SignOptions, ...provisioner.SignOption) ([]*x509.Certificate, error) { - return nil, nil -} - func (m *mockCA) SignWithContext(context.Context, *x509.CertificateRequest, provisioner.SignOptions, ...provisioner.SignOption) ([]*x509.Certificate, error) { return nil, nil } diff --git a/acme/common.go b/acme/common.go index 46e86ae6..e86b23e9 100644 --- a/acme/common.go +++ b/acme/common.go @@ -21,7 +21,6 @@ var clock Clock // CertificateAuthority is the interface implemented by a CA authority. type CertificateAuthority interface { - Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) AreSANsAllowed(ctx context.Context, sans []string) error IsRevoked(sn string) (bool, error) diff --git a/acme/order_test.go b/acme/order_test.go index 17060f11..07372af0 100644 --- a/acme/order_test.go +++ b/acme/order_test.go @@ -271,7 +271,6 @@ func TestOrder_UpdateStatus(t *testing.T) { } type mockSignAuth struct { - sign func(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) signWithContext func(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) areSANsAllowed func(ctx context.Context, sans []string) error loadProvisionerByName func(string) (provisioner.Interface, error) @@ -279,15 +278,6 @@ type mockSignAuth struct { err error } -func (m *mockSignAuth) Sign(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { - if m.sign != nil { - return m.sign(csr, signOpts, extraOpts...) - } else if m.err != nil { - return nil, m.err - } - return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err -} - func (m *mockSignAuth) SignWithContext(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { if m.signWithContext != nil { return m.signWithContext(ctx, csr, signOpts, extraOpts...) diff --git a/api/api.go b/api/api.go index 1d367f7d..a12e7e19 100644 --- a/api/api.go +++ b/api/api.go @@ -42,7 +42,6 @@ type Authority interface { AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) GetTLSOptions() *config.TLSOptions Root(shasum string) (*x509.Certificate, error) - Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) Renew(peer *x509.Certificate) ([]*x509.Certificate, error) RenewContext(ctx context.Context, peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) diff --git a/api/api_test.go b/api/api_test.go index 4266dff3..cf988593 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -189,7 +189,6 @@ type mockAuthority struct { authorizeRenewToken func(ctx context.Context, ott string) (*x509.Certificate, error) getTLSOptions func() *authority.TLSOptions root func(shasum string) (*x509.Certificate, error) - sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) signWithContext func(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) renew func(cert *x509.Certificate) ([]*x509.Certificate, error) rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) @@ -252,13 +251,6 @@ func (m *mockAuthority) Root(shasum string) (*x509.Certificate, error) { return m.ret1.(*x509.Certificate), m.err } -func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { - if m.sign != nil { - return m.sign(cr, opts, signOpts...) - } - return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err -} - func (m *mockAuthority) SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { if m.signWithContext != nil { return m.signWithContext(ctx, cr, opts, signOpts...) diff --git a/authority/authority_test.go b/authority/authority_test.go index 45c7cd86..3787dab7 100644 --- a/authority/authority_test.go +++ b/authority/authority_test.go @@ -1,6 +1,7 @@ package authority import ( + "context" "crypto" "crypto/rand" "crypto/sha256" @@ -414,7 +415,7 @@ func TestNewEmbedded_Sign(t *testing.T) { csr, err := x509.ParseCertificateRequest(cr) assert.FatalError(t, err) - cert, err := a.Sign(csr, provisioner.SignOptions{}) + cert, err := a.SignWithContext(context.Background(), csr, provisioner.SignOptions{}) assert.FatalError(t, err) assert.Equals(t, []string{"foo.bar.zar"}, cert[0].DNSNames) assert.Equals(t, crt, cert[1]) diff --git a/authority/authorize_test.go b/authority/authorize_test.go index 3d748f69..8f3c1ae2 100644 --- a/authority/authorize_test.go +++ b/authority/authorize_test.go @@ -1375,7 +1375,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) { } generateX5cToken := func(a *Authority, key crypto.Signer, claims jose.Claims, opts ...provisioner.SignOption) (string, *x509.Certificate) { - chain, err := a.Sign(csr, provisioner.SignOptions{}, opts...) + chain, err := a.SignWithContext(ctx, csr, provisioner.SignOptions{}, opts...) if err != nil { t.Fatal(err) } diff --git a/authority/provisioners_test.go b/authority/provisioners_test.go index f6af6f54..f62f8127 100644 --- a/authority/provisioners_test.go +++ b/authority/provisioners_test.go @@ -149,7 +149,7 @@ func TestAuthority_LoadProvisionerByCertificate(t *testing.T) { opts, err := a.Authorize(ctx, token) require.NoError(t, err) opts = append(opts, extraOpts...) - certs, err := a.Sign(csr, provisioner.SignOptions{}, opts...) + certs, err := a.SignWithContext(ctx, csr, provisioner.SignOptions{}, opts...) require.NoError(t, err) return certs[0] } diff --git a/authority/tls_test.go b/authority/tls_test.go index 1fb8411a..b481ca68 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -239,7 +239,7 @@ func (e *testEnforcer) Enforce(cert *x509.Certificate) error { return nil } -func TestAuthority_Sign(t *testing.T) { +func TestAuthority_SignWithContext(t *testing.T) { pub, priv, err := keyutil.GenerateDefaultKeyPair() require.NoError(t, err) @@ -848,7 +848,7 @@ ZYtQ9Ot36qc= t.Run(name, func(t *testing.T) { tc := genTestCase(t) - certChain, err := tc.auth.Sign(tc.csr, tc.signOpts, tc.extraOpts...) + certChain, err := tc.auth.SignWithContext(context.Background(), tc.csr, tc.signOpts, tc.extraOpts...) if err != nil { if assert.NotNil(t, tc.err, fmt.Sprintf("unexpected error: %s", err)) { assert.Nil(t, certChain) @@ -1797,9 +1797,9 @@ func TestAuthority_constraints(t *testing.T) { t.Fatal(err) } - _, err = auth.Sign(csr, provisioner.SignOptions{}, templateOption) + _, err = auth.SignWithContext(context.Background(), csr, provisioner.SignOptions{}, templateOption) if (err != nil) != tt.wantErr { - t.Errorf("Authority.Sign() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("Authority.SignWithContext() error = %v, wantErr %v", err, tt.wantErr) } _, err = auth.Renew(cert) diff --git a/scep/authority.go b/scep/authority.go index e2aa759e..8ed065fb 100644 --- a/scep/authority.go +++ b/scep/authority.go @@ -60,7 +60,6 @@ func MustFromContext(ctx context.Context) *Authority { // SignAuthority is the interface for a signing authority type SignAuthority interface { - Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) LoadProvisionerByName(string) (provisioner.Interface, error) }