diff --git a/acme/api/revoke_test.go b/acme/api/revoke_test.go index a225aa19..e8edcc41 100644 --- a/acme/api/revoke_test.go +++ b/acme/api/revoke_test.go @@ -285,6 +285,10 @@ func (m *mockCA) Sign(*x509.CertificateRequest, provisioner.SignOptions, ...prov return nil, nil } +func (m *mockCA) SignWithContext(context.Context, *x509.CertificateRequest, provisioner.SignOptions, ...provisioner.SignOption) ([]*x509.Certificate, error) { + return nil, nil +} + func (m *mockCA) AreSANsAllowed(ctx context.Context, sans []string) error { if m.MockAreSANsallowed != nil { return m.MockAreSANsallowed(ctx, sans) diff --git a/acme/common.go b/acme/common.go index 7d58305f..afab13b2 100644 --- a/acme/common.go +++ b/acme/common.go @@ -22,6 +22,7 @@ var clock Clock // CertificateAuthority is the interface implemented by a CA authority. type CertificateAuthority interface { Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) + SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) AreSANsAllowed(ctx context.Context, sans []string) error IsRevoked(sn string) (bool, error) Revoke(context.Context, *authority.RevokeOptions) error diff --git a/acme/order_test.go b/acme/order_test.go index 2851bb19..3fa99b9b 100644 --- a/acme/order_test.go +++ b/acme/order_test.go @@ -272,6 +272,7 @@ func TestOrder_UpdateStatus(t *testing.T) { type mockSignAuth struct { sign func(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) + signWithContext func(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) areSANsAllowed func(ctx context.Context, sans []string) error loadProvisionerByName func(string) (provisioner.Interface, error) ret1, ret2 interface{} @@ -287,6 +288,15 @@ func (m *mockSignAuth) Sign(csr *x509.CertificateRequest, signOpts provisioner.S return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err } +func (m *mockSignAuth) SignWithContext(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + if m.signWithContext != nil { + return m.signWithContext(ctx, csr, signOpts, extraOpts...) + } else if m.err != nil { + return nil, m.err + } + return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err +} + func (m *mockSignAuth) AreSANsAllowed(ctx context.Context, sans []string) error { if m.areSANsAllowed != nil { return m.areSANsAllowed(ctx, sans) diff --git a/api/api.go b/api/api.go index c9820351..2d6c0bf7 100644 --- a/api/api.go +++ b/api/api.go @@ -42,6 +42,7 @@ type Authority interface { GetTLSOptions() *config.TLSOptions Root(shasum string) (*x509.Certificate, error) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) + SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) Renew(peer *x509.Certificate) ([]*x509.Certificate, error) RenewContext(ctx context.Context, peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) Rekey(peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) diff --git a/api/api_test.go b/api/api_test.go index d96015f9..90acf759 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -193,6 +193,7 @@ type mockAuthority struct { getTLSOptions func() *authority.TLSOptions root func(shasum string) (*x509.Certificate, error) sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) + signWithContext func(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) renew func(cert *x509.Certificate) ([]*x509.Certificate, error) rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) renewContext func(ctx context.Context, oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error) @@ -261,6 +262,13 @@ func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.SignO return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err } +func (m *mockAuthority) SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + if m.signWithContext != nil { + return m.signWithContext(ctx, cr, opts, signOpts...) + } + return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err +} + func (m *mockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, error) { if m.renew != nil { return m.renew(cert) diff --git a/authority/provisioner/webhook.go b/authority/provisioner/webhook.go index 14d357f1..1cc2047c 100644 --- a/authority/provisioner/webhook.go +++ b/authority/provisioner/webhook.go @@ -37,7 +37,7 @@ type WebhookController struct { // Enrich fetches data from remote servers and adds returned data to the // templateData -func (wc *WebhookController) Enrich(req *webhook.RequestBody) error { +func (wc *WebhookController) Enrich(ctx context.Context, req *webhook.RequestBody) error { if wc == nil { return nil } @@ -56,11 +56,11 @@ func (wc *WebhookController) Enrich(req *webhook.RequestBody) error { if !wc.isCertTypeOK(wh) { continue } - // TODO(hs): propagate context from above - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - resp, err := wh.DoWithContext(ctx, wc.client, req, wc.TemplateData) + whCtx, cancel := context.WithTimeout(ctx, time.Second*10) + defer cancel() //nolint:gocritic // every request canceled with its own timeout + + resp, err := wh.DoWithContext(whCtx, wc.client, req, wc.TemplateData) if err != nil { return err } @@ -73,7 +73,7 @@ func (wc *WebhookController) Enrich(req *webhook.RequestBody) error { } // Authorize checks that all remote servers allow the request -func (wc *WebhookController) Authorize(req *webhook.RequestBody) error { +func (wc *WebhookController) Authorize(ctx context.Context, req *webhook.RequestBody) error { if wc == nil { return nil } @@ -93,11 +93,10 @@ func (wc *WebhookController) Authorize(req *webhook.RequestBody) error { continue } - // TODO(hs): propagate context from above - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() + whCtx, cancel := context.WithTimeout(ctx, time.Second*10) + defer cancel() //nolint:gocritic // every request canceled with its own timeout - resp, err := wh.DoWithContext(ctx, wc.client, req, wc.TemplateData) + resp, err := wh.DoWithContext(whCtx, wc.client, req, wc.TemplateData) if err != nil { return err } diff --git a/authority/provisioner/webhook_test.go b/authority/provisioner/webhook_test.go index a61da39c..cc79a09b 100644 --- a/authority/provisioner/webhook_test.go +++ b/authority/provisioner/webhook_test.go @@ -242,7 +242,7 @@ func TestWebhookController_Enrich(t *testing.T) { wh.URL = ts.URL } - err := test.ctl.Enrich(test.req) + err := test.ctl.Enrich(context.Background(), test.req) if (err != nil) != test.expectErr { t.Fatalf("Got err %v, want %v", err, test.expectErr) } @@ -352,7 +352,7 @@ func TestWebhookController_Authorize(t *testing.T) { wh.URL = ts.URL } - err := test.ctl.Authorize(test.req) + err := test.ctl.Authorize(context.Background(), test.req) if (err != nil) != test.expectErr { t.Fatalf("Got err %v, want %v", err, test.expectErr) } diff --git a/authority/ssh.go b/authority/ssh.go index f9371d60..688bfd76 100644 --- a/authority/ssh.go +++ b/authority/ssh.go @@ -146,7 +146,7 @@ func (a *Authority) GetSSHBastion(ctx context.Context, user, hostname string) (* } // SignSSH creates a signed SSH certificate with the given public key and options. -func (a *Authority) SignSSH(_ context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { +func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { var ( certOptions []sshutil.Option mods []provisioner.SSHCertModifier @@ -205,7 +205,7 @@ func (a *Authority) SignSSH(_ context.Context, key ssh.PublicKey, opts provision } // Call enriching webhooks - if err := callEnrichingWebhooksSSH(webhookCtl, cr); err != nil { + if err := callEnrichingWebhooksSSH(ctx, webhookCtl, cr); err != nil { return nil, errs.ApplyOptions( errs.ForbiddenErr(err, err.Error()), errs.WithKeyVal("signOptions", signOpts), @@ -277,7 +277,7 @@ func (a *Authority) SignSSH(_ context.Context, key ssh.PublicKey, opts provision } // Send certificate to webhooks for authorization - if err := callAuthorizingWebhooksSSH(webhookCtl, certificate, certTpl); err != nil { + if err := callAuthorizingWebhooksSSH(ctx, webhookCtl, certificate, certTpl); err != nil { return nil, errs.ApplyOptions( errs.ForbiddenErr(err, "authority.SignSSH: error signing certificate"), ) @@ -653,7 +653,7 @@ func (a *Authority) getAddUserCommand(principal string) string { return strings.ReplaceAll(cmd, "", principal) } -func callEnrichingWebhooksSSH(webhookCtl webhookController, cr sshutil.CertificateRequest) error { +func callEnrichingWebhooksSSH(ctx context.Context, webhookCtl webhookController, cr sshutil.CertificateRequest) error { if webhookCtl == nil { return nil } @@ -663,10 +663,10 @@ func callEnrichingWebhooksSSH(webhookCtl webhookController, cr sshutil.Certifica if err != nil { return err } - return webhookCtl.Enrich(whEnrichReq) + return webhookCtl.Enrich(ctx, whEnrichReq) } -func callAuthorizingWebhooksSSH(webhookCtl webhookController, cert *sshutil.Certificate, certTpl *ssh.Certificate) error { +func callAuthorizingWebhooksSSH(ctx context.Context, webhookCtl webhookController, cert *sshutil.Certificate, certTpl *ssh.Certificate) error { if webhookCtl == nil { return nil } @@ -676,5 +676,5 @@ func callAuthorizingWebhooksSSH(webhookCtl webhookController, cert *sshutil.Cert if err != nil { return err } - return webhookCtl.Authorize(whAuthBody) + return webhookCtl.Authorize(ctx, whAuthBody) } diff --git a/authority/tls.go b/authority/tls.go index 6e967920..900b1ff8 100644 --- a/authority/tls.go +++ b/authority/tls.go @@ -91,8 +91,17 @@ func withDefaultASN1DN(def *config.ASN1DN) provisioner.CertificateModifierFunc { } } -// Sign creates a signed certificate from a certificate signing request. +// Sign creates a signed certificate from a certificate signing request. It +// creates a new context.Context, and calls into SignWithContext. +// +// Deprecated: Use authority.SignWithContext with an actual context.Context. func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { + return a.SignWithContext(context.Background(), csr, signOpts, extraOpts...) +} + +// SignWithContext creates a signed certificate from a certificate signing request, +// taking the provided context.Context. +func (a *Authority) SignWithContext(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { var ( certOptions []x509util.Option certValidators []provisioner.CertificateValidator @@ -163,7 +172,7 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign } } - if err := callEnrichingWebhooksX509(webhookCtl, attData, csr); err != nil { + if err := callEnrichingWebhooksX509(ctx, webhookCtl, attData, csr); err != nil { return nil, errs.ApplyOptions( errs.ForbiddenErr(err, err.Error()), errs.WithKeyVal("csr", csr), @@ -256,7 +265,7 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign } // Send certificate to webhooks for authorization - if err := callAuthorizingWebhooksX509(webhookCtl, cert, leaf, attData); err != nil { + if err := callAuthorizingWebhooksX509(ctx, webhookCtl, cert, leaf, attData); err != nil { return nil, errs.ApplyOptions( errs.ForbiddenErr(err, "error creating certificate"), opts..., @@ -952,7 +961,7 @@ func templatingError(err error) error { return errors.Wrap(cause, "error applying certificate template") } -func callEnrichingWebhooksX509(webhookCtl webhookController, attData *provisioner.AttestationData, csr *x509.CertificateRequest) error { +func callEnrichingWebhooksX509(ctx context.Context, webhookCtl webhookController, attData *provisioner.AttestationData, csr *x509.CertificateRequest) error { if webhookCtl == nil { return nil } @@ -969,10 +978,10 @@ func callEnrichingWebhooksX509(webhookCtl webhookController, attData *provisione if err != nil { return err } - return webhookCtl.Enrich(whEnrichReq) + return webhookCtl.Enrich(ctx, whEnrichReq) } -func callAuthorizingWebhooksX509(webhookCtl webhookController, cert *x509util.Certificate, leaf *x509.Certificate, attData *provisioner.AttestationData) error { +func callAuthorizingWebhooksX509(ctx context.Context, webhookCtl webhookController, cert *x509util.Certificate, leaf *x509.Certificate, attData *provisioner.AttestationData) error { if webhookCtl == nil { return nil } @@ -989,5 +998,5 @@ func callAuthorizingWebhooksX509(webhookCtl webhookController, cert *x509util.Ce if err != nil { return err } - return webhookCtl.Authorize(whAuthBody) + return webhookCtl.Authorize(ctx, whAuthBody) } diff --git a/authority/webhook.go b/authority/webhook.go index d887e077..29e3e6c3 100644 --- a/authority/webhook.go +++ b/authority/webhook.go @@ -1,8 +1,12 @@ package authority -import "github.com/smallstep/certificates/webhook" +import ( + "context" + + "github.com/smallstep/certificates/webhook" +) type webhookController interface { - Enrich(*webhook.RequestBody) error - Authorize(*webhook.RequestBody) error + Enrich(context.Context, *webhook.RequestBody) error + Authorize(context.Context, *webhook.RequestBody) error } diff --git a/authority/webhook_test.go b/authority/webhook_test.go index 0e713af7..75b59f63 100644 --- a/authority/webhook_test.go +++ b/authority/webhook_test.go @@ -1,6 +1,8 @@ package authority import ( + "context" + "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/webhook" ) @@ -14,7 +16,7 @@ type mockWebhookController struct { var _ webhookController = &mockWebhookController{} -func (wc *mockWebhookController) Enrich(*webhook.RequestBody) error { +func (wc *mockWebhookController) Enrich(context.Context, *webhook.RequestBody) error { for key, data := range wc.respData { wc.templateData.SetWebhook(key, data) } @@ -22,6 +24,6 @@ func (wc *mockWebhookController) Enrich(*webhook.RequestBody) error { return wc.enrichErr } -func (wc *mockWebhookController) Authorize(*webhook.RequestBody) error { +func (wc *mockWebhookController) Authorize(context.Context, *webhook.RequestBody) error { return wc.authorizeErr }