From b2bf2c330bf83ce3d2060a6d14ebfef08759f8d6 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Thu, 1 Jun 2023 16:22:00 +0200 Subject: [PATCH] Simplify SCEP provisioner context handling --- authority/authority.go | 4 ++-- scep/api/api.go | 2 +- scep/authority.go | 25 +++++-------------------- scep/provisioner.go | 32 +++++++++++++------------------- 4 files changed, 21 insertions(+), 42 deletions(-) diff --git a/authority/authority.go b/authority/authority.go index 29cbf846..8be23ed3 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -858,8 +858,8 @@ func (a *Authority) IsRevoked(sn string) (bool, error) { return a.db.IsRevoked(sn) } -// requiresSCEPService iterates over the configured provisioners -// and determines if one of them is a SCEP provisioner. +// requiresSCEP iterates over the configured provisioners +// and determines if at least one of them is a SCEP provisioner. func (a *Authority) requiresSCEP() bool { for _, p := range a.config.AuthorityConfig.Provisioners { if p.GetType() == provisioner.TypeSCEP { diff --git a/scep/api/api.go b/scep/api/api.go index 98da818b..1615313f 100644 --- a/scep/api/api.go +++ b/scep/api/api.go @@ -221,7 +221,7 @@ func lookupProvisioner(next http.HandlerFunc) http.HandlerFunc { return } - ctx = context.WithValue(ctx, scep.ProvisionerContextKey, scep.Provisioner(prov)) + ctx = scep.NewProvisionerContext(ctx, scep.Provisioner(prov)) next(w, r.WithContext(ctx)) } } diff --git a/scep/authority.go b/scep/authority.go index 55fd2086..5e02468d 100644 --- a/scep/authority.go +++ b/scep/authority.go @@ -136,10 +136,7 @@ func (a *Authority) LoadProvisionerByName(name string) (provisioner.Interface, e // Using an RA does not seem to exist in https://tools.ietf.org/html/rfc8894, but is mentioned in // https://tools.ietf.org/id/draft-nourse-scep-21.html. func (a *Authority) GetCACertificates(ctx context.Context) (certs []*x509.Certificate, err error) { - p, err := provisionerFromContext(ctx) - if err != nil { - return - } + p := provisionerFromContext(ctx) // if a provisioner specific RSA decrypter is available, it is returned as // the first certificate. @@ -214,10 +211,7 @@ func (a *Authority) DecryptPKIEnvelope(ctx context.Context, msg *PKIMessage) err } func (a *Authority) selectDecrypter(ctx context.Context) (cert *x509.Certificate, pkey crypto.PrivateKey, err error) { - p, err := provisionerFromContext(ctx) - if err != nil { - return nil, nil, err - } + p := provisionerFromContext(ctx) // return provisioner specific decrypter, if available if cert, pkey = p.GetDecrypter(); cert != nil && pkey != nil { @@ -239,10 +233,7 @@ func (a *Authority) SignCSR(ctx context.Context, csr *x509.CertificateRequest, m // poll for the status. It seems to be similar as what can happen in ACME, so might want to model // the implementation after the one in the ACME authority. Requires storage, etc. - p, err := provisionerFromContext(ctx) - if err != nil { - return nil, err - } + p := provisionerFromContext(ctx) // check if CSRReqMessage has already been decrypted if msg.CSRReqMessage.CSR == nil { @@ -463,10 +454,7 @@ func (a *Authority) CreateFailureResponse(_ context.Context, _ *x509.Certificate // GetCACaps returns the CA capabilities func (a *Authority) GetCACaps(ctx context.Context) []string { - p, err := provisionerFromContext(ctx) - if err != nil { - return defaultCapabilities - } + p := provisionerFromContext(ctx) caps := p.GetCapabilities() if len(caps) == 0 { @@ -483,9 +471,6 @@ func (a *Authority) GetCACaps(ctx context.Context) []string { } func (a *Authority) ValidateChallenge(ctx context.Context, challenge, transactionID string) error { - p, err := provisionerFromContext(ctx) - if err != nil { - return err - } + p := provisionerFromContext(ctx) return p.ValidateChallenge(ctx, challenge, transactionID) } diff --git a/scep/provisioner.go b/scep/provisioner.go index 79852e22..a1796b5b 100644 --- a/scep/provisioner.go +++ b/scep/provisioner.go @@ -4,7 +4,6 @@ import ( "context" "crypto" "crypto/x509" - "errors" "time" "github.com/smallstep/certificates/authority/provisioner" @@ -24,25 +23,20 @@ type Provisioner interface { ValidateChallenge(ctx context.Context, challenge, transactionID string) error } -// ContextKey is the key type for storing and searching for SCEP request -// essentials in the context of a request. -type ContextKey string - -const ( - // ProvisionerContextKey provisioner key - ProvisionerContextKey = ContextKey("provisioner") -) +// provisionerKey is the key type for storing and searching a +// SCEP provisioner in the context. +type provisionerKey struct{} // provisionerFromContext searches the context for a SCEP provisioner. -// Returns the provisioner or an error. -func provisionerFromContext(ctx context.Context) (Provisioner, error) { - val := ctx.Value(ProvisionerContextKey) - if val == nil { - return nil, errors.New("provisioner expected in request context") - } - p, ok := val.(Provisioner) - if !ok || p == nil { - return nil, errors.New("provisioner in context is not a SCEP provisioner") +// Returns the provisioner or panics if no SCEP provisioner is found. +func provisionerFromContext(ctx context.Context) Provisioner { + p, ok := ctx.Value(provisionerKey{}).(Provisioner) + if !ok { + panic("SCEP provisioner expected in request context") } - return p, nil + return p +} + +func NewProvisionerContext(ctx context.Context, p Provisioner) context.Context { + return context.WithValue(ctx, provisionerKey{}, p) }