Add `SignWithContext` method to authority and mocks

pull/1542/head
Herman Slatman 7 months ago
parent b2301ea127
commit 4e06bdbc51
No known key found for this signature in database
GPG Key ID: F4D8A44EA0A75A4F

@ -285,6 +285,10 @@ func (m *mockCA) Sign(*x509.CertificateRequest, provisioner.SignOptions, ...prov
return nil, nil
}
func (m *mockCA) SignWithContext(context.Context, *x509.CertificateRequest, provisioner.SignOptions, ...provisioner.SignOption) ([]*x509.Certificate, error) {
return nil, nil
}
func (m *mockCA) AreSANsAllowed(ctx context.Context, sans []string) error {
if m.MockAreSANsallowed != nil {
return m.MockAreSANsallowed(ctx, sans)

@ -22,6 +22,7 @@ var clock Clock
// CertificateAuthority is the interface implemented by a CA authority.
type CertificateAuthority interface {
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
AreSANsAllowed(ctx context.Context, sans []string) error
IsRevoked(sn string) (bool, error)
Revoke(context.Context, *authority.RevokeOptions) error

@ -272,6 +272,7 @@ func TestOrder_UpdateStatus(t *testing.T) {
type mockSignAuth struct {
sign func(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
signWithContext func(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
areSANsAllowed func(ctx context.Context, sans []string) error
loadProvisionerByName func(string) (provisioner.Interface, error)
ret1, ret2 interface{}
@ -287,6 +288,15 @@ func (m *mockSignAuth) Sign(csr *x509.CertificateRequest, signOpts provisioner.S
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
}
func (m *mockSignAuth) SignWithContext(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
if m.signWithContext != nil {
return m.signWithContext(ctx, csr, signOpts, extraOpts...)
} else if m.err != nil {
return nil, m.err
}
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
}
func (m *mockSignAuth) AreSANsAllowed(ctx context.Context, sans []string) error {
if m.areSANsAllowed != nil {
return m.areSANsAllowed(ctx, sans)

@ -42,6 +42,7 @@ type Authority interface {
GetTLSOptions() *config.TLSOptions
Root(shasum string) (*x509.Certificate, error)
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
Renew(peer *x509.Certificate) ([]*x509.Certificate, error)
RenewContext(ctx context.Context, peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
Rekey(peer *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)

@ -193,6 +193,7 @@ type mockAuthority struct {
getTLSOptions func() *authority.TLSOptions
root func(shasum string) (*x509.Certificate, error)
sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
signWithContext func(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
renew func(cert *x509.Certificate) ([]*x509.Certificate, error)
rekey func(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
renewContext func(ctx context.Context, oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x509.Certificate, error)
@ -261,6 +262,13 @@ func (m *mockAuthority) Sign(cr *x509.CertificateRequest, opts provisioner.SignO
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
}
func (m *mockAuthority) SignWithContext(ctx context.Context, cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
if m.signWithContext != nil {
return m.signWithContext(ctx, cr, opts, signOpts...)
}
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
}
func (m *mockAuthority) Renew(cert *x509.Certificate) ([]*x509.Certificate, error) {
if m.renew != nil {
return m.renew(cert)

@ -37,7 +37,7 @@ type WebhookController struct {
// Enrich fetches data from remote servers and adds returned data to the
// templateData
func (wc *WebhookController) Enrich(req *webhook.RequestBody) error {
func (wc *WebhookController) Enrich(ctx context.Context, req *webhook.RequestBody) error {
if wc == nil {
return nil
}
@ -56,11 +56,11 @@ func (wc *WebhookController) Enrich(req *webhook.RequestBody) error {
if !wc.isCertTypeOK(wh) {
continue
}
// TODO(hs): propagate context from above
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
resp, err := wh.DoWithContext(ctx, wc.client, req, wc.TemplateData)
whCtx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel() //nolint:gocritic // every request canceled with its own timeout
resp, err := wh.DoWithContext(whCtx, wc.client, req, wc.TemplateData)
if err != nil {
return err
}
@ -73,7 +73,7 @@ func (wc *WebhookController) Enrich(req *webhook.RequestBody) error {
}
// Authorize checks that all remote servers allow the request
func (wc *WebhookController) Authorize(req *webhook.RequestBody) error {
func (wc *WebhookController) Authorize(ctx context.Context, req *webhook.RequestBody) error {
if wc == nil {
return nil
}
@ -93,11 +93,10 @@ func (wc *WebhookController) Authorize(req *webhook.RequestBody) error {
continue
}
// TODO(hs): propagate context from above
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
whCtx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel() //nolint:gocritic // every request canceled with its own timeout
resp, err := wh.DoWithContext(ctx, wc.client, req, wc.TemplateData)
resp, err := wh.DoWithContext(whCtx, wc.client, req, wc.TemplateData)
if err != nil {
return err
}

@ -242,7 +242,7 @@ func TestWebhookController_Enrich(t *testing.T) {
wh.URL = ts.URL
}
err := test.ctl.Enrich(test.req)
err := test.ctl.Enrich(context.Background(), test.req)
if (err != nil) != test.expectErr {
t.Fatalf("Got err %v, want %v", err, test.expectErr)
}
@ -352,7 +352,7 @@ func TestWebhookController_Authorize(t *testing.T) {
wh.URL = ts.URL
}
err := test.ctl.Authorize(test.req)
err := test.ctl.Authorize(context.Background(), test.req)
if (err != nil) != test.expectErr {
t.Fatalf("Got err %v, want %v", err, test.expectErr)
}

@ -146,7 +146,7 @@ func (a *Authority) GetSSHBastion(ctx context.Context, user, hostname string) (*
}
// SignSSH creates a signed SSH certificate with the given public key and options.
func (a *Authority) SignSSH(_ context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
var (
certOptions []sshutil.Option
mods []provisioner.SSHCertModifier
@ -205,7 +205,7 @@ func (a *Authority) SignSSH(_ context.Context, key ssh.PublicKey, opts provision
}
// Call enriching webhooks
if err := callEnrichingWebhooksSSH(webhookCtl, cr); err != nil {
if err := callEnrichingWebhooksSSH(ctx, webhookCtl, cr); err != nil {
return nil, errs.ApplyOptions(
errs.ForbiddenErr(err, err.Error()),
errs.WithKeyVal("signOptions", signOpts),
@ -277,7 +277,7 @@ func (a *Authority) SignSSH(_ context.Context, key ssh.PublicKey, opts provision
}
// Send certificate to webhooks for authorization
if err := callAuthorizingWebhooksSSH(webhookCtl, certificate, certTpl); err != nil {
if err := callAuthorizingWebhooksSSH(ctx, webhookCtl, certificate, certTpl); err != nil {
return nil, errs.ApplyOptions(
errs.ForbiddenErr(err, "authority.SignSSH: error signing certificate"),
)
@ -653,7 +653,7 @@ func (a *Authority) getAddUserCommand(principal string) string {
return strings.ReplaceAll(cmd, "<principal>", principal)
}
func callEnrichingWebhooksSSH(webhookCtl webhookController, cr sshutil.CertificateRequest) error {
func callEnrichingWebhooksSSH(ctx context.Context, webhookCtl webhookController, cr sshutil.CertificateRequest) error {
if webhookCtl == nil {
return nil
}
@ -663,10 +663,10 @@ func callEnrichingWebhooksSSH(webhookCtl webhookController, cr sshutil.Certifica
if err != nil {
return err
}
return webhookCtl.Enrich(whEnrichReq)
return webhookCtl.Enrich(ctx, whEnrichReq)
}
func callAuthorizingWebhooksSSH(webhookCtl webhookController, cert *sshutil.Certificate, certTpl *ssh.Certificate) error {
func callAuthorizingWebhooksSSH(ctx context.Context, webhookCtl webhookController, cert *sshutil.Certificate, certTpl *ssh.Certificate) error {
if webhookCtl == nil {
return nil
}
@ -676,5 +676,5 @@ func callAuthorizingWebhooksSSH(webhookCtl webhookController, cert *sshutil.Cert
if err != nil {
return err
}
return webhookCtl.Authorize(whAuthBody)
return webhookCtl.Authorize(ctx, whAuthBody)
}

@ -91,8 +91,17 @@ func withDefaultASN1DN(def *config.ASN1DN) provisioner.CertificateModifierFunc {
}
}
// Sign creates a signed certificate from a certificate signing request.
// Sign creates a signed certificate from a certificate signing request. It
// creates a new context.Context, and calls into SignWithContext.
//
// Deprecated: Use authority.SignWithContext with an actual context.Context.
func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
return a.SignWithContext(context.Background(), csr, signOpts, extraOpts...)
}
// SignWithContext creates a signed certificate from a certificate signing request,
// taking the provided context.Context.
func (a *Authority) SignWithContext(ctx context.Context, csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
var (
certOptions []x509util.Option
certValidators []provisioner.CertificateValidator
@ -163,7 +172,7 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
}
}
if err := callEnrichingWebhooksX509(webhookCtl, attData, csr); err != nil {
if err := callEnrichingWebhooksX509(ctx, webhookCtl, attData, csr); err != nil {
return nil, errs.ApplyOptions(
errs.ForbiddenErr(err, err.Error()),
errs.WithKeyVal("csr", csr),
@ -256,7 +265,7 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
}
// Send certificate to webhooks for authorization
if err := callAuthorizingWebhooksX509(webhookCtl, cert, leaf, attData); err != nil {
if err := callAuthorizingWebhooksX509(ctx, webhookCtl, cert, leaf, attData); err != nil {
return nil, errs.ApplyOptions(
errs.ForbiddenErr(err, "error creating certificate"),
opts...,
@ -952,7 +961,7 @@ func templatingError(err error) error {
return errors.Wrap(cause, "error applying certificate template")
}
func callEnrichingWebhooksX509(webhookCtl webhookController, attData *provisioner.AttestationData, csr *x509.CertificateRequest) error {
func callEnrichingWebhooksX509(ctx context.Context, webhookCtl webhookController, attData *provisioner.AttestationData, csr *x509.CertificateRequest) error {
if webhookCtl == nil {
return nil
}
@ -969,10 +978,10 @@ func callEnrichingWebhooksX509(webhookCtl webhookController, attData *provisione
if err != nil {
return err
}
return webhookCtl.Enrich(whEnrichReq)
return webhookCtl.Enrich(ctx, whEnrichReq)
}
func callAuthorizingWebhooksX509(webhookCtl webhookController, cert *x509util.Certificate, leaf *x509.Certificate, attData *provisioner.AttestationData) error {
func callAuthorizingWebhooksX509(ctx context.Context, webhookCtl webhookController, cert *x509util.Certificate, leaf *x509.Certificate, attData *provisioner.AttestationData) error {
if webhookCtl == nil {
return nil
}
@ -989,5 +998,5 @@ func callAuthorizingWebhooksX509(webhookCtl webhookController, cert *x509util.Ce
if err != nil {
return err
}
return webhookCtl.Authorize(whAuthBody)
return webhookCtl.Authorize(ctx, whAuthBody)
}

@ -1,8 +1,12 @@
package authority
import "github.com/smallstep/certificates/webhook"
import (
"context"
"github.com/smallstep/certificates/webhook"
)
type webhookController interface {
Enrich(*webhook.RequestBody) error
Authorize(*webhook.RequestBody) error
Enrich(context.Context, *webhook.RequestBody) error
Authorize(context.Context, *webhook.RequestBody) error
}

@ -1,6 +1,8 @@
package authority
import (
"context"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/webhook"
)
@ -14,7 +16,7 @@ type mockWebhookController struct {
var _ webhookController = &mockWebhookController{}
func (wc *mockWebhookController) Enrich(*webhook.RequestBody) error {
func (wc *mockWebhookController) Enrich(context.Context, *webhook.RequestBody) error {
for key, data := range wc.respData {
wc.templateData.SetWebhook(key, data)
}
@ -22,6 +24,6 @@ func (wc *mockWebhookController) Enrich(*webhook.RequestBody) error {
return wc.enrichErr
}
func (wc *mockWebhookController) Authorize(*webhook.RequestBody) error {
func (wc *mockWebhookController) Authorize(context.Context, *webhook.RequestBody) error {
return wc.authorizeErr
}

Loading…
Cancel
Save